Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/proc-macro-api/src/bidirectional_protocol.rs')
-rw-r--r--crates/proc-macro-api/src/bidirectional_protocol.rs296
1 files changed, 296 insertions, 0 deletions
diff --git a/crates/proc-macro-api/src/bidirectional_protocol.rs b/crates/proc-macro-api/src/bidirectional_protocol.rs
new file mode 100644
index 0000000000..246f70a101
--- /dev/null
+++ b/crates/proc-macro-api/src/bidirectional_protocol.rs
@@ -0,0 +1,296 @@
+//! Bidirectional protocol methods
+
+use std::{
+ io::{self, BufRead, Write},
+ sync::Arc,
+};
+
+use base_db::SourceDatabase;
+use paths::AbsPath;
+use span::{FileId, Span};
+
+use crate::{
+ Codec, ProcMacro, ProcMacroKind, ServerError,
+ bidirectional_protocol::msg::{
+ Envelope, ExpandMacro, ExpandMacroData, ExpnGlobals, Kind, Payload, Request, RequestId,
+ Response, SubRequest, SubResponse,
+ },
+ legacy_protocol::{
+ SpanMode,
+ msg::{
+ FlatTree, ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
+ serialize_span_data_index_map,
+ },
+ },
+ process::ProcMacroServerProcess,
+ transport::codec::{json::JsonProtocol, postcard::PostcardProtocol},
+ version,
+};
+
+pub mod msg;
+
+pub trait ClientCallbacks {
+ fn handle_sub_request(&mut self, id: u64, req: SubRequest) -> Result<SubResponse, ServerError>;
+}
+
+pub fn run_conversation<C: Codec>(
+ writer: &mut dyn Write,
+ reader: &mut dyn BufRead,
+ buf: &mut C::Buf,
+ id: RequestId,
+ initial: Payload,
+ callbacks: &mut dyn ClientCallbacks,
+) -> Result<Payload, ServerError> {
+ let msg = Envelope { id, kind: Kind::Request, payload: initial };
+ let encoded = C::encode(&msg).map_err(wrap_encode)?;
+ C::write(writer, &encoded).map_err(wrap_io("failed to write initial request"))?;
+
+ loop {
+ let maybe_buf = C::read(reader, buf).map_err(wrap_io("failed to read message"))?;
+ let Some(b) = maybe_buf else {
+ return Err(ServerError {
+ message: "proc-macro server closed the stream".into(),
+ io: Some(Arc::new(io::Error::new(io::ErrorKind::UnexpectedEof, "closed"))),
+ });
+ };
+
+ let msg: Envelope = C::decode(b).map_err(wrap_decode)?;
+
+ if msg.id != id {
+ return Err(ServerError {
+ message: format!("unexpected message id {}, expected {}", msg.id, id),
+ io: None,
+ });
+ }
+
+ match (msg.kind, msg.payload) {
+ (Kind::SubRequest, Payload::SubRequest(sr)) => {
+ let resp = callbacks.handle_sub_request(id, sr)?;
+ let reply =
+ Envelope { id, kind: Kind::SubResponse, payload: Payload::SubResponse(resp) };
+ let encoded = C::encode(&reply).map_err(wrap_encode)?;
+ C::write(writer, &encoded).map_err(wrap_io("failed to write sub-response"))?;
+ }
+ (Kind::Response, payload) => {
+ return Ok(payload);
+ }
+ (kind, payload) => {
+ return Err(ServerError {
+ message: format!(
+ "unexpected message kind {:?} with payload {:?}",
+ kind, payload
+ ),
+ io: None,
+ });
+ }
+ }
+ }
+}
+
+fn wrap_io(msg: &'static str) -> impl Fn(io::Error) -> ServerError {
+ move |err| ServerError { message: msg.into(), io: Some(Arc::new(err)) }
+}
+
+fn wrap_encode(err: io::Error) -> ServerError {
+ ServerError { message: "failed to encode message".into(), io: Some(Arc::new(err)) }
+}
+
+fn wrap_decode(err: io::Error) -> ServerError {
+ ServerError { message: "failed to decode message".into(), io: Some(Arc::new(err)) }
+}
+
+pub(crate) fn version_check(srv: &ProcMacroServerProcess) -> Result<u32, ServerError> {
+ let request = Payload::Request(Request::ApiVersionCheck {});
+
+ struct NoCallbacks;
+ impl ClientCallbacks for NoCallbacks {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ _req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ Err(ServerError { message: "sub-request not supported here".into(), io: None })
+ }
+ }
+
+ let mut callbacks = NoCallbacks;
+
+ let response_payload =
+ run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::ApiVersionCheck(version)) => Ok(version),
+ other => {
+ Err(ServerError { message: format!("unexpected response: {:?}", other), io: None })
+ }
+ }
+}
+
+/// Enable support for rust-analyzer span mode if the server supports it.
+pub(crate) fn enable_rust_analyzer_spans(
+ srv: &ProcMacroServerProcess,
+) -> Result<SpanMode, ServerError> {
+ let request =
+ Payload::Request(Request::SetConfig(ServerConfig { span_mode: SpanMode::RustAnalyzer }));
+
+ struct NoCallbacks;
+ impl ClientCallbacks for NoCallbacks {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ _req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ Err(ServerError { message: "sub-request not supported here".into(), io: None })
+ }
+ }
+
+ let mut callbacks = NoCallbacks;
+
+ let response_payload =
+ run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::SetConfig(ServerConfig { span_mode })) => Ok(span_mode),
+ _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
+ }
+}
+
+/// Finds proc-macros in a given dynamic library.
+pub(crate) fn find_proc_macros(
+ srv: &ProcMacroServerProcess,
+ dylib_path: &AbsPath,
+) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
+ let request =
+ Payload::Request(Request::ListMacros { dylib_path: dylib_path.to_path_buf().into() });
+
+ struct NoCallbacks;
+ impl ClientCallbacks for NoCallbacks {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ _req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ Err(ServerError { message: "sub-request not supported here".into(), io: None })
+ }
+ }
+
+ let mut callbacks = NoCallbacks;
+
+ let response_payload =
+ run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::ListMacros(it)) => Ok(it),
+ _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
+ }
+}
+
+pub(crate) fn expand(
+ proc_macro: &ProcMacro,
+ db: &dyn SourceDatabase,
+ subtree: tt::SubtreeView<'_, Span>,
+ attr: Option<tt::SubtreeView<'_, Span>>,
+ env: Vec<(String, String)>,
+ def_site: Span,
+ call_site: Span,
+ mixed_site: Span,
+ current_dir: String,
+) -> Result<Result<tt::TopSubtree<span::SpanData<span::SyntaxContext>>, String>, crate::ServerError>
+{
+ let version = proc_macro.process.version();
+ let mut span_data_table = SpanDataIndexMap::default();
+ let def_site = span_data_table.insert_full(def_site).0;
+ let call_site = span_data_table.insert_full(call_site).0;
+ let mixed_site = span_data_table.insert_full(mixed_site).0;
+ let task = Payload::Request(Request::ExpandMacro(Box::new(ExpandMacro {
+ data: ExpandMacroData {
+ macro_body: FlatTree::from_subtree(subtree, version, &mut span_data_table),
+ macro_name: proc_macro.name.to_string(),
+ attributes: attr
+ .map(|subtree| FlatTree::from_subtree(subtree, version, &mut span_data_table)),
+ has_global_spans: ExpnGlobals {
+ serialize: version >= version::HAS_GLOBAL_SPANS,
+ def_site,
+ call_site,
+ mixed_site,
+ },
+ span_data_table: if proc_macro.process.rust_analyzer_spans() {
+ serialize_span_data_index_map(&span_data_table)
+ } else {
+ Vec::new()
+ },
+ },
+ lib: proc_macro.dylib_path.to_path_buf().into(),
+ env,
+ current_dir: Some(current_dir),
+ })));
+
+ struct Callbacks<'de> {
+ db: &'de dyn SourceDatabase,
+ }
+ impl<'db> ClientCallbacks for Callbacks<'db> {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ match req {
+ SubRequest::SourceText { file_id, start, end } => {
+ let file = FileId::from_raw(file_id);
+ let text = self.db.file_text(file).text(self.db);
+
+ let slice = text.get(start as usize..end as usize).map(|s| s.to_owned());
+
+ Ok(SubResponse::SourceTextResult { text: slice })
+ }
+ }
+ }
+ }
+
+ let mut callbacks = Callbacks { db };
+
+ let response_payload =
+ run_bidirectional(&proc_macro.process, (0, Kind::Request, task).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::ExpandMacro(it)) => Ok(it
+ .map(|tree| {
+ let mut expanded = FlatTree::to_subtree_resolved(tree, version, &span_data_table);
+ if proc_macro.needs_fixup_change() {
+ proc_macro.change_fixup_to_match_old_server(&mut expanded);
+ }
+ expanded
+ })
+ .map_err(|msg| msg.0)),
+ Payload::Response(Response::ExpandMacroExtended(it)) => Ok(it
+ .map(|resp| {
+ let mut expanded = FlatTree::to_subtree_resolved(
+ resp.tree,
+ version,
+ &deserialize_span_data_index_map(&resp.span_data_table),
+ );
+ if proc_macro.needs_fixup_change() {
+ proc_macro.change_fixup_to_match_old_server(&mut expanded);
+ }
+ expanded
+ })
+ .map_err(|msg| msg.0)),
+ _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
+ }
+}
+
+fn run_bidirectional(
+ srv: &ProcMacroServerProcess,
+ msg: Envelope,
+ callbacks: &mut dyn ClientCallbacks,
+) -> Result<Payload, ServerError> {
+ if let Some(server_error) = srv.exited() {
+ return Err(server_error.clone());
+ }
+
+ if srv.use_postcard() {
+ srv.run_bidirectional::<PostcardProtocol>(msg.id, msg.payload, callbacks)
+ } else {
+ srv.run_bidirectional::<JsonProtocol>(msg.id, msg.payload, callbacks)
+ }
+}