Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--Cargo.lock3
-rw-r--r--crates/hir-def/src/macro_expansion_tests/mod.rs3
-rw-r--r--crates/hir-expand/src/proc_macro.rs4
-rw-r--r--crates/load-cargo/src/lib.rs6
-rw-r--r--crates/proc-macro-api/Cargo.toml1
-rw-r--r--crates/proc-macro-api/src/bidirectional_protocol.rs296
-rw-r--r--crates/proc-macro-api/src/bidirectional_protocol/msg.rs114
-rw-r--r--crates/proc-macro-api/src/legacy_protocol.rs23
-rw-r--r--crates/proc-macro-api/src/legacy_protocol/msg.rs2
-rw-r--r--crates/proc-macro-api/src/lib.rs11
-rw-r--r--crates/proc-macro-api/src/process.rs173
-rw-r--r--crates/proc-macro-api/src/transport.rs3
-rw-r--r--crates/proc-macro-api/src/transport/codec.rs (renamed from crates/proc-macro-api/src/codec.rs)5
-rw-r--r--crates/proc-macro-api/src/transport/codec/json.rs (renamed from crates/proc-macro-api/src/legacy_protocol/json.rs)6
-rw-r--r--crates/proc-macro-api/src/transport/codec/postcard.rs (renamed from crates/proc-macro-api/src/legacy_protocol/postcard.rs)6
-rw-r--r--crates/proc-macro-api/src/transport/framing.rs (renamed from crates/proc-macro-api/src/framing.rs)4
-rw-r--r--crates/proc-macro-srv-cli/Cargo.toml1
-rw-r--r--crates/proc-macro-srv-cli/src/main.rs6
-rw-r--r--crates/proc-macro-srv-cli/src/main_loop.rs309
-rw-r--r--crates/proc-macro-srv/Cargo.toml1
-rw-r--r--crates/proc-macro-srv/src/dylib.rs26
-rw-r--r--crates/proc-macro-srv/src/dylib/proc_macros.rs77
-rw-r--r--crates/proc-macro-srv/src/lib.rs86
-rw-r--r--crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs24
-rw-r--r--crates/proc-macro-srv/src/server_impl/token_id.rs4
-rw-r--r--crates/test-fixture/src/lib.rs10
26 files changed, 1112 insertions, 92 deletions
diff --git a/Cargo.lock b/Cargo.lock
index efe56cb7f6..060a62b112 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1836,6 +1836,7 @@ dependencies = [
name = "proc-macro-api"
version = "0.0.0"
dependencies = [
+ "base-db",
"indexmap",
"intern",
"paths",
@@ -1856,6 +1857,7 @@ dependencies = [
name = "proc-macro-srv"
version = "0.0.0"
dependencies = [
+ "crossbeam-channel",
"expect-test",
"intern",
"libc",
@@ -1874,6 +1876,7 @@ name = "proc-macro-srv-cli"
version = "0.0.0"
dependencies = [
"clap",
+ "crossbeam-channel",
"postcard",
"proc-macro-api",
"proc-macro-srv",
diff --git a/crates/hir-def/src/macro_expansion_tests/mod.rs b/crates/hir-def/src/macro_expansion_tests/mod.rs
index 78af976e1b..07cad9695b 100644
--- a/crates/hir-def/src/macro_expansion_tests/mod.rs
+++ b/crates/hir-def/src/macro_expansion_tests/mod.rs
@@ -16,7 +16,7 @@ mod proc_macros;
use std::{any::TypeId, iter, ops::Range, sync};
-use base_db::RootQueryDb;
+use base_db::{RootQueryDb, SourceDatabase};
use expect_test::Expect;
use hir_expand::{
AstId, InFile, MacroCallId, MacroCallKind, MacroKind,
@@ -374,6 +374,7 @@ struct IdentityWhenValidProcMacroExpander;
impl ProcMacroExpander for IdentityWhenValidProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
subtree: &TopSubtree,
_: Option<&TopSubtree>,
_: &base_db::Env,
diff --git a/crates/hir-expand/src/proc_macro.rs b/crates/hir-expand/src/proc_macro.rs
index f97d721dfa..d2614aa5f1 100644
--- a/crates/hir-expand/src/proc_macro.rs
+++ b/crates/hir-expand/src/proc_macro.rs
@@ -4,7 +4,7 @@ use core::fmt;
use std::any::Any;
use std::{panic::RefUnwindSafe, sync};
-use base_db::{Crate, CrateBuilderId, CratesIdMap, Env, ProcMacroLoadingError};
+use base_db::{Crate, CrateBuilderId, CratesIdMap, Env, ProcMacroLoadingError, SourceDatabase};
use intern::Symbol;
use rustc_hash::FxHashMap;
use span::Span;
@@ -25,6 +25,7 @@ pub trait ProcMacroExpander: fmt::Debug + Send + Sync + RefUnwindSafe + Any {
/// [`ProcMacroKind::Attr`]), environment variables, and span information.
fn expand(
&self,
+ db: &dyn SourceDatabase,
subtree: &tt::TopSubtree,
attrs: Option<&tt::TopSubtree>,
env: &Env,
@@ -309,6 +310,7 @@ impl CustomProcMacroExpander {
let current_dir = calling_crate.data(db).proc_macro_cwd.to_string();
match proc_macro.expander.expand(
+ db,
tt,
attr_arg,
env,
diff --git a/crates/load-cargo/src/lib.rs b/crates/load-cargo/src/lib.rs
index 28fbfecfde..e043e4ac76 100644
--- a/crates/load-cargo/src/lib.rs
+++ b/crates/load-cargo/src/lib.rs
@@ -17,7 +17,9 @@ use hir_expand::proc_macro::{
};
use ide_db::{
ChangeWithProcMacros, FxHashMap, RootDatabase,
- base_db::{CrateGraphBuilder, Env, ProcMacroLoadingError, SourceRoot, SourceRootId},
+ base_db::{
+ CrateGraphBuilder, Env, ProcMacroLoadingError, SourceDatabase, SourceRoot, SourceRootId,
+ },
prime_caches,
};
use itertools::Itertools;
@@ -522,6 +524,7 @@ struct Expander(proc_macro_api::ProcMacro);
impl ProcMacroExpander for Expander {
fn expand(
&self,
+ db: &dyn SourceDatabase,
subtree: &tt::TopSubtree<Span>,
attrs: Option<&tt::TopSubtree<Span>>,
env: &Env,
@@ -531,6 +534,7 @@ impl ProcMacroExpander for Expander {
current_dir: String,
) -> Result<tt::TopSubtree<Span>, ProcMacroExpansionError> {
match self.0.expand(
+ db,
subtree.view(),
attrs.map(|attrs| attrs.view()),
env.clone().into(),
diff --git a/crates/proc-macro-api/Cargo.toml b/crates/proc-macro-api/Cargo.toml
index 4de1a3e5dd..7e56d68964 100644
--- a/crates/proc-macro-api/Cargo.toml
+++ b/crates/proc-macro-api/Cargo.toml
@@ -19,6 +19,7 @@ serde_json = { workspace = true, features = ["unbounded_depth"] }
tracing.workspace = true
rustc-hash.workspace = true
indexmap.workspace = true
+base-db.workspace = true
# local deps
paths = { workspace = true, features = ["serde1"] }
diff --git a/crates/proc-macro-api/src/bidirectional_protocol.rs b/crates/proc-macro-api/src/bidirectional_protocol.rs
new file mode 100644
index 0000000000..246f70a101
--- /dev/null
+++ b/crates/proc-macro-api/src/bidirectional_protocol.rs
@@ -0,0 +1,296 @@
+//! Bidirectional protocol methods
+
+use std::{
+ io::{self, BufRead, Write},
+ sync::Arc,
+};
+
+use base_db::SourceDatabase;
+use paths::AbsPath;
+use span::{FileId, Span};
+
+use crate::{
+ Codec, ProcMacro, ProcMacroKind, ServerError,
+ bidirectional_protocol::msg::{
+ Envelope, ExpandMacro, ExpandMacroData, ExpnGlobals, Kind, Payload, Request, RequestId,
+ Response, SubRequest, SubResponse,
+ },
+ legacy_protocol::{
+ SpanMode,
+ msg::{
+ FlatTree, ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
+ serialize_span_data_index_map,
+ },
+ },
+ process::ProcMacroServerProcess,
+ transport::codec::{json::JsonProtocol, postcard::PostcardProtocol},
+ version,
+};
+
+pub mod msg;
+
+pub trait ClientCallbacks {
+ fn handle_sub_request(&mut self, id: u64, req: SubRequest) -> Result<SubResponse, ServerError>;
+}
+
+pub fn run_conversation<C: Codec>(
+ writer: &mut dyn Write,
+ reader: &mut dyn BufRead,
+ buf: &mut C::Buf,
+ id: RequestId,
+ initial: Payload,
+ callbacks: &mut dyn ClientCallbacks,
+) -> Result<Payload, ServerError> {
+ let msg = Envelope { id, kind: Kind::Request, payload: initial };
+ let encoded = C::encode(&msg).map_err(wrap_encode)?;
+ C::write(writer, &encoded).map_err(wrap_io("failed to write initial request"))?;
+
+ loop {
+ let maybe_buf = C::read(reader, buf).map_err(wrap_io("failed to read message"))?;
+ let Some(b) = maybe_buf else {
+ return Err(ServerError {
+ message: "proc-macro server closed the stream".into(),
+ io: Some(Arc::new(io::Error::new(io::ErrorKind::UnexpectedEof, "closed"))),
+ });
+ };
+
+ let msg: Envelope = C::decode(b).map_err(wrap_decode)?;
+
+ if msg.id != id {
+ return Err(ServerError {
+ message: format!("unexpected message id {}, expected {}", msg.id, id),
+ io: None,
+ });
+ }
+
+ match (msg.kind, msg.payload) {
+ (Kind::SubRequest, Payload::SubRequest(sr)) => {
+ let resp = callbacks.handle_sub_request(id, sr)?;
+ let reply =
+ Envelope { id, kind: Kind::SubResponse, payload: Payload::SubResponse(resp) };
+ let encoded = C::encode(&reply).map_err(wrap_encode)?;
+ C::write(writer, &encoded).map_err(wrap_io("failed to write sub-response"))?;
+ }
+ (Kind::Response, payload) => {
+ return Ok(payload);
+ }
+ (kind, payload) => {
+ return Err(ServerError {
+ message: format!(
+ "unexpected message kind {:?} with payload {:?}",
+ kind, payload
+ ),
+ io: None,
+ });
+ }
+ }
+ }
+}
+
+fn wrap_io(msg: &'static str) -> impl Fn(io::Error) -> ServerError {
+ move |err| ServerError { message: msg.into(), io: Some(Arc::new(err)) }
+}
+
+fn wrap_encode(err: io::Error) -> ServerError {
+ ServerError { message: "failed to encode message".into(), io: Some(Arc::new(err)) }
+}
+
+fn wrap_decode(err: io::Error) -> ServerError {
+ ServerError { message: "failed to decode message".into(), io: Some(Arc::new(err)) }
+}
+
+pub(crate) fn version_check(srv: &ProcMacroServerProcess) -> Result<u32, ServerError> {
+ let request = Payload::Request(Request::ApiVersionCheck {});
+
+ struct NoCallbacks;
+ impl ClientCallbacks for NoCallbacks {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ _req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ Err(ServerError { message: "sub-request not supported here".into(), io: None })
+ }
+ }
+
+ let mut callbacks = NoCallbacks;
+
+ let response_payload =
+ run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::ApiVersionCheck(version)) => Ok(version),
+ other => {
+ Err(ServerError { message: format!("unexpected response: {:?}", other), io: None })
+ }
+ }
+}
+
+/// Enable support for rust-analyzer span mode if the server supports it.
+pub(crate) fn enable_rust_analyzer_spans(
+ srv: &ProcMacroServerProcess,
+) -> Result<SpanMode, ServerError> {
+ let request =
+ Payload::Request(Request::SetConfig(ServerConfig { span_mode: SpanMode::RustAnalyzer }));
+
+ struct NoCallbacks;
+ impl ClientCallbacks for NoCallbacks {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ _req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ Err(ServerError { message: "sub-request not supported here".into(), io: None })
+ }
+ }
+
+ let mut callbacks = NoCallbacks;
+
+ let response_payload =
+ run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::SetConfig(ServerConfig { span_mode })) => Ok(span_mode),
+ _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
+ }
+}
+
+/// Finds proc-macros in a given dynamic library.
+pub(crate) fn find_proc_macros(
+ srv: &ProcMacroServerProcess,
+ dylib_path: &AbsPath,
+) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
+ let request =
+ Payload::Request(Request::ListMacros { dylib_path: dylib_path.to_path_buf().into() });
+
+ struct NoCallbacks;
+ impl ClientCallbacks for NoCallbacks {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ _req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ Err(ServerError { message: "sub-request not supported here".into(), io: None })
+ }
+ }
+
+ let mut callbacks = NoCallbacks;
+
+ let response_payload =
+ run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::ListMacros(it)) => Ok(it),
+ _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
+ }
+}
+
+pub(crate) fn expand(
+ proc_macro: &ProcMacro,
+ db: &dyn SourceDatabase,
+ subtree: tt::SubtreeView<'_, Span>,
+ attr: Option<tt::SubtreeView<'_, Span>>,
+ env: Vec<(String, String)>,
+ def_site: Span,
+ call_site: Span,
+ mixed_site: Span,
+ current_dir: String,
+) -> Result<Result<tt::TopSubtree<span::SpanData<span::SyntaxContext>>, String>, crate::ServerError>
+{
+ let version = proc_macro.process.version();
+ let mut span_data_table = SpanDataIndexMap::default();
+ let def_site = span_data_table.insert_full(def_site).0;
+ let call_site = span_data_table.insert_full(call_site).0;
+ let mixed_site = span_data_table.insert_full(mixed_site).0;
+ let task = Payload::Request(Request::ExpandMacro(Box::new(ExpandMacro {
+ data: ExpandMacroData {
+ macro_body: FlatTree::from_subtree(subtree, version, &mut span_data_table),
+ macro_name: proc_macro.name.to_string(),
+ attributes: attr
+ .map(|subtree| FlatTree::from_subtree(subtree, version, &mut span_data_table)),
+ has_global_spans: ExpnGlobals {
+ serialize: version >= version::HAS_GLOBAL_SPANS,
+ def_site,
+ call_site,
+ mixed_site,
+ },
+ span_data_table: if proc_macro.process.rust_analyzer_spans() {
+ serialize_span_data_index_map(&span_data_table)
+ } else {
+ Vec::new()
+ },
+ },
+ lib: proc_macro.dylib_path.to_path_buf().into(),
+ env,
+ current_dir: Some(current_dir),
+ })));
+
+ struct Callbacks<'de> {
+ db: &'de dyn SourceDatabase,
+ }
+ impl<'db> ClientCallbacks for Callbacks<'db> {
+ fn handle_sub_request(
+ &mut self,
+ _id: u64,
+ req: SubRequest,
+ ) -> Result<SubResponse, ServerError> {
+ match req {
+ SubRequest::SourceText { file_id, start, end } => {
+ let file = FileId::from_raw(file_id);
+ let text = self.db.file_text(file).text(self.db);
+
+ let slice = text.get(start as usize..end as usize).map(|s| s.to_owned());
+
+ Ok(SubResponse::SourceTextResult { text: slice })
+ }
+ }
+ }
+ }
+
+ let mut callbacks = Callbacks { db };
+
+ let response_payload =
+ run_bidirectional(&proc_macro.process, (0, Kind::Request, task).into(), &mut callbacks)?;
+
+ match response_payload {
+ Payload::Response(Response::ExpandMacro(it)) => Ok(it
+ .map(|tree| {
+ let mut expanded = FlatTree::to_subtree_resolved(tree, version, &span_data_table);
+ if proc_macro.needs_fixup_change() {
+ proc_macro.change_fixup_to_match_old_server(&mut expanded);
+ }
+ expanded
+ })
+ .map_err(|msg| msg.0)),
+ Payload::Response(Response::ExpandMacroExtended(it)) => Ok(it
+ .map(|resp| {
+ let mut expanded = FlatTree::to_subtree_resolved(
+ resp.tree,
+ version,
+ &deserialize_span_data_index_map(&resp.span_data_table),
+ );
+ if proc_macro.needs_fixup_change() {
+ proc_macro.change_fixup_to_match_old_server(&mut expanded);
+ }
+ expanded
+ })
+ .map_err(|msg| msg.0)),
+ _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
+ }
+}
+
+fn run_bidirectional(
+ srv: &ProcMacroServerProcess,
+ msg: Envelope,
+ callbacks: &mut dyn ClientCallbacks,
+) -> Result<Payload, ServerError> {
+ if let Some(server_error) = srv.exited() {
+ return Err(server_error.clone());
+ }
+
+ if srv.use_postcard() {
+ srv.run_bidirectional::<PostcardProtocol>(msg.id, msg.payload, callbacks)
+ } else {
+ srv.run_bidirectional::<JsonProtocol>(msg.id, msg.payload, callbacks)
+ }
+}
diff --git a/crates/proc-macro-api/src/bidirectional_protocol/msg.rs b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs
new file mode 100644
index 0000000000..7aed3ae1e6
--- /dev/null
+++ b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs
@@ -0,0 +1,114 @@
+//! Bidirectional protocol messages
+
+use paths::Utf8PathBuf;
+use serde::{Deserialize, Serialize};
+
+use crate::{
+ ProcMacroKind,
+ legacy_protocol::msg::{FlatTree, Message, PanicMessage, ServerConfig},
+};
+
+pub type RequestId = u64;
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Envelope {
+ pub id: RequestId,
+ pub kind: Kind,
+ pub payload: Payload,
+}
+
+impl From<(RequestId, Kind, Payload)> for Envelope {
+ fn from(value: (RequestId, Kind, Payload)) -> Self {
+ Envelope { id: value.0, kind: value.1, payload: value.2 }
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
+pub enum Kind {
+ Request,
+ Response,
+ SubRequest,
+ SubResponse,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum SubRequest {
+ SourceText { file_id: u32, start: u32, end: u32 },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum SubResponse {
+ SourceTextResult { text: Option<String> },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Payload {
+ Request(Request),
+ Response(Response),
+ SubRequest(SubRequest),
+ SubResponse(SubResponse),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Request {
+ ListMacros { dylib_path: Utf8PathBuf },
+ ExpandMacro(Box<ExpandMacro>),
+ ApiVersionCheck {},
+ SetConfig(ServerConfig),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Response {
+ ListMacros(Result<Vec<(String, ProcMacroKind)>, String>),
+ ExpandMacro(Result<FlatTree, PanicMessage>),
+ ApiVersionCheck(u32),
+ SetConfig(ServerConfig),
+ ExpandMacroExtended(Result<ExpandMacroExtended, PanicMessage>),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ExpandMacro {
+ pub lib: Utf8PathBuf,
+ pub env: Vec<(String, String)>,
+ pub current_dir: Option<String>,
+ #[serde(flatten)]
+ pub data: ExpandMacroData,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ExpandMacroExtended {
+ pub tree: FlatTree,
+ pub span_data_table: Vec<u32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ExpandMacroData {
+ pub macro_body: FlatTree,
+ pub macro_name: String,
+ pub attributes: Option<FlatTree>,
+ #[serde(skip_serializing_if = "ExpnGlobals::skip_serializing_if")]
+ #[serde(default)]
+ pub has_global_spans: ExpnGlobals,
+
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ #[serde(default)]
+ pub span_data_table: Vec<u32>,
+}
+
+#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize)]
+pub struct ExpnGlobals {
+ #[serde(skip_serializing)]
+ #[serde(default)]
+ pub serialize: bool,
+ pub def_site: usize,
+ pub call_site: usize,
+ pub mixed_site: usize,
+}
+
+impl ExpnGlobals {
+ fn skip_serializing_if(&self) -> bool {
+ !self.serialize
+ }
+}
+
+impl Message for Envelope {}
diff --git a/crates/proc-macro-api/src/legacy_protocol.rs b/crates/proc-macro-api/src/legacy_protocol.rs
index c2b132ddcc..81a9f39181 100644
--- a/crates/proc-macro-api/src/legacy_protocol.rs
+++ b/crates/proc-macro-api/src/legacy_protocol.rs
@@ -1,30 +1,26 @@
//! The initial proc-macro-srv protocol, soon to be deprecated.
-pub mod json;
pub mod msg;
-pub mod postcard;
use std::{
io::{BufRead, Write},
sync::Arc,
};
+use base_db::SourceDatabase;
use paths::AbsPath;
use span::Span;
use crate::{
ProcMacro, ProcMacroKind, ServerError,
- codec::Codec,
- legacy_protocol::{
- json::JsonProtocol,
- msg::{
- ExpandMacro, ExpandMacroData, ExpnGlobals, FlatTree, Message, Request, Response,
- ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
- flat::serialize_span_data_index_map,
- },
- postcard::PostcardProtocol,
+ legacy_protocol::msg::{
+ ExpandMacro, ExpandMacroData, ExpnGlobals, FlatTree, Message, Request, Response,
+ ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
+ flat::serialize_span_data_index_map,
},
process::ProcMacroServerProcess,
+ transport::codec::Codec,
+ transport::codec::{json::JsonProtocol, postcard::PostcardProtocol},
version,
};
@@ -82,6 +78,7 @@ pub(crate) fn find_proc_macros(
pub(crate) fn expand(
proc_macro: &ProcMacro,
+ _db: &dyn SourceDatabase,
subtree: tt::SubtreeView<'_, Span>,
attr: Option<tt::SubtreeView<'_, Span>>,
env: Vec<(String, String)>,
@@ -155,9 +152,9 @@ fn send_task(srv: &ProcMacroServerProcess, req: Request) -> Result<Response, Ser
}
if srv.use_postcard() {
- srv.send_task(send_request::<PostcardProtocol>, req)
+ srv.send_task::<_, _, PostcardProtocol>(send_request::<PostcardProtocol>, req)
} else {
- srv.send_task(send_request::<JsonProtocol>, req)
+ srv.send_task::<_, _, JsonProtocol>(send_request::<JsonProtocol>, req)
}
}
diff --git a/crates/proc-macro-api/src/legacy_protocol/msg.rs b/crates/proc-macro-api/src/legacy_protocol/msg.rs
index a6e228d977..0ebb0e9f93 100644
--- a/crates/proc-macro-api/src/legacy_protocol/msg.rs
+++ b/crates/proc-macro-api/src/legacy_protocol/msg.rs
@@ -8,7 +8,7 @@ use paths::Utf8PathBuf;
use serde::de::DeserializeOwned;
use serde_derive::{Deserialize, Serialize};
-use crate::{ProcMacroKind, codec::Codec};
+use crate::{Codec, ProcMacroKind};
/// Represents requests sent from the client to the proc-macro-srv.
#[derive(Debug, Serialize, Deserialize)]
diff --git a/crates/proc-macro-api/src/lib.rs b/crates/proc-macro-api/src/lib.rs
index 85b250eddf..7b9b5b39ab 100644
--- a/crates/proc-macro-api/src/lib.rs
+++ b/crates/proc-macro-api/src/lib.rs
@@ -16,18 +16,19 @@
#[cfg(feature = "in-rust-tree")]
extern crate rustc_driver as _;
-mod codec;
-mod framing;
+pub mod bidirectional_protocol;
pub mod legacy_protocol;
mod process;
+pub mod transport;
+use base_db::SourceDatabase;
use paths::{AbsPath, AbsPathBuf};
use semver::Version;
use span::{ErasedFileAstId, FIXUP_ERASED_FILE_AST_ID_MARKER, Span};
use std::{fmt, io, sync::Arc, time::SystemTime};
-pub use crate::codec::Codec;
use crate::process::ProcMacroServerProcess;
+pub use crate::transport::codec::Codec;
/// The versions of the server protocol
pub mod version {
@@ -218,6 +219,7 @@ impl ProcMacro {
/// This includes span information and environmental context.
pub fn expand(
&self,
+ db: &dyn SourceDatabase,
subtree: tt::SubtreeView<'_, Span>,
attr: Option<tt::SubtreeView<'_, Span>>,
env: Vec<(String, String)>,
@@ -240,7 +242,8 @@ impl ProcMacro {
}
}
- legacy_protocol::expand(
+ self.process.expand(
+ db,
self,
subtree,
attr,
diff --git a/crates/proc-macro-api/src/process.rs b/crates/proc-macro-api/src/process.rs
index d6a8d27bfc..39d9548551 100644
--- a/crates/proc-macro-api/src/process.rs
+++ b/crates/proc-macro-api/src/process.rs
@@ -7,12 +7,18 @@ use std::{
sync::{Arc, Mutex, OnceLock},
};
+use base_db::SourceDatabase;
use paths::AbsPath;
use semver::Version;
+use span::Span;
use stdx::JodChild;
use crate::{
- ProcMacroKind, ServerError,
+ Codec, ProcMacro, ProcMacroKind, ServerError,
+ bidirectional_protocol::{
+ self, ClientCallbacks,
+ msg::{Payload, RequestId},
+ },
legacy_protocol::{self, SpanMode},
version,
};
@@ -33,6 +39,8 @@ pub(crate) struct ProcMacroServerProcess {
pub(crate) enum Protocol {
LegacyJson { mode: SpanMode },
LegacyPostcard { mode: SpanMode },
+ NewPostcard { mode: SpanMode },
+ NewJson { mode: SpanMode },
}
/// Maintains the state of the proc-macro server process.
@@ -62,6 +70,8 @@ impl ProcMacroServerProcess {
&& has_working_format_flag
{
&[
+ (Some("postcard-new"), Protocol::NewPostcard { mode: SpanMode::Id }),
+ (Some("json-new"), Protocol::NewJson { mode: SpanMode::Id }),
(Some("postcard-legacy"), Protocol::LegacyPostcard { mode: SpanMode::Id }),
(Some("json-legacy"), Protocol::LegacyJson { mode: SpanMode::Id }),
]
@@ -105,9 +115,10 @@ impl ProcMacroServerProcess {
&& let Ok(new_mode) = srv.enable_rust_analyzer_spans()
{
match &mut srv.protocol {
- Protocol::LegacyJson { mode } | Protocol::LegacyPostcard { mode } => {
- *mode = new_mode
- }
+ Protocol::LegacyJson { mode }
+ | Protocol::LegacyPostcard { mode }
+ | Protocol::NewJson { mode }
+ | Protocol::NewPostcard { mode } => *mode = new_mode,
}
}
tracing::info!("Proc-macro server protocol: {:?}", srv.protocol);
@@ -143,22 +154,32 @@ impl ProcMacroServerProcess {
match self.protocol {
Protocol::LegacyJson { mode } => mode == SpanMode::RustAnalyzer,
Protocol::LegacyPostcard { mode } => mode == SpanMode::RustAnalyzer,
+ Protocol::NewJson { mode } => mode == SpanMode::RustAnalyzer,
+ Protocol::NewPostcard { mode } => mode == SpanMode::RustAnalyzer,
}
}
/// Checks the API version of the running proc-macro server.
fn version_check(&self) -> Result<u32, ServerError> {
match self.protocol {
- Protocol::LegacyJson { .. } => legacy_protocol::version_check(self),
- Protocol::LegacyPostcard { .. } => legacy_protocol::version_check(self),
+ Protocol::LegacyJson { .. } | Protocol::LegacyPostcard { .. } => {
+ legacy_protocol::version_check(self)
+ }
+ Protocol::NewJson { .. } | Protocol::NewPostcard { .. } => {
+ bidirectional_protocol::version_check(self)
+ }
}
}
/// Enable support for rust-analyzer span mode if the server supports it.
fn enable_rust_analyzer_spans(&self) -> Result<SpanMode, ServerError> {
match self.protocol {
- Protocol::LegacyJson { .. } => legacy_protocol::enable_rust_analyzer_spans(self),
- Protocol::LegacyPostcard { .. } => legacy_protocol::enable_rust_analyzer_spans(self),
+ Protocol::LegacyJson { .. } | Protocol::LegacyPostcard { .. } => {
+ legacy_protocol::enable_rust_analyzer_spans(self)
+ }
+ Protocol::NewJson { .. } | Protocol::NewPostcard { .. } => {
+ bidirectional_protocol::enable_rust_analyzer_spans(self)
+ }
}
}
@@ -168,28 +189,69 @@ impl ProcMacroServerProcess {
dylib_path: &AbsPath,
) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
match self.protocol {
- Protocol::LegacyJson { .. } => legacy_protocol::find_proc_macros(self, dylib_path),
- Protocol::LegacyPostcard { .. } => legacy_protocol::find_proc_macros(self, dylib_path),
+ Protocol::LegacyJson { .. } | Protocol::LegacyPostcard { .. } => {
+ legacy_protocol::find_proc_macros(self, dylib_path)
+ }
+ Protocol::NewJson { .. } | Protocol::NewPostcard { .. } => {
+ bidirectional_protocol::find_proc_macros(self, dylib_path)
+ }
+ }
+ }
+
+ pub(crate) fn expand(
+ &self,
+ db: &dyn SourceDatabase,
+ proc_macro: &ProcMacro,
+ subtree: tt::SubtreeView<'_, Span>,
+ attr: Option<tt::SubtreeView<'_, Span>>,
+ env: Vec<(String, String)>,
+ def_site: Span,
+ call_site: Span,
+ mixed_site: Span,
+ current_dir: String,
+ ) -> Result<Result<tt::TopSubtree<Span>, String>, ServerError> {
+ match self.protocol {
+ Protocol::LegacyJson { .. } | Protocol::LegacyPostcard { .. } => {
+ legacy_protocol::expand(
+ proc_macro,
+ db,
+ subtree,
+ attr,
+ env,
+ def_site,
+ call_site,
+ mixed_site,
+ current_dir,
+ )
+ }
+ Protocol::NewJson { .. } | Protocol::NewPostcard { .. } => {
+ bidirectional_protocol::expand(
+ proc_macro,
+ db,
+ subtree,
+ attr,
+ env,
+ def_site,
+ call_site,
+ mixed_site,
+ current_dir,
+ )
+ }
}
}
- pub(crate) fn send_task<Request, Response, Buf>(
+ pub(crate) fn send_task<Request, Response, C: Codec>(
&self,
- serialize_req: impl FnOnce(
+ send: impl FnOnce(
&mut dyn Write,
&mut dyn BufRead,
Request,
- &mut Buf,
+ &mut C::Buf,
) -> Result<Option<Response>, ServerError>,
req: Request,
- ) -> Result<Response, ServerError>
- where
- Buf: Default,
- {
- let state = &mut *self.state.lock().unwrap();
- let mut buf = Buf::default();
- serialize_req(&mut state.stdin, &mut state.stdout, req, &mut buf)
- .and_then(|res| {
+ ) -> Result<Response, ServerError> {
+ self.with_locked_io::<C, _>(|writer, reader, buf| {
+ send(writer, reader, req, buf).and_then(|res| {
res.ok_or_else(|| {
let message = "proc-macro server did not respond with data".to_owned();
ServerError {
@@ -201,33 +263,54 @@ impl ProcMacroServerProcess {
}
})
})
- .map_err(|e| {
- if e.io.as_ref().map(|it| it.kind()) == Some(io::ErrorKind::BrokenPipe) {
- match state.process.child.try_wait() {
- Ok(None) | Err(_) => e,
- Ok(Some(status)) => {
- let mut msg = String::new();
- if !status.success()
- && let Some(stderr) = state.process.child.stderr.as_mut()
- {
- _ = stderr.read_to_string(&mut msg);
- }
- let server_error = ServerError {
- message: format!(
- "proc-macro server exited with {status}{}{msg}",
- if msg.is_empty() { "" } else { ": " }
- ),
- io: None,
- };
- // `AssertUnwindSafe` is fine here, we already correct initialized
- // server_error at this point.
- self.exited.get_or_init(|| AssertUnwindSafe(server_error)).0.clone()
+ })
+ }
+
+ pub(crate) fn with_locked_io<C: Codec, R>(
+ &self,
+ f: impl FnOnce(&mut dyn Write, &mut dyn BufRead, &mut C::Buf) -> Result<R, ServerError>,
+ ) -> Result<R, ServerError> {
+ let state = &mut *self.state.lock().unwrap();
+ let mut buf = C::Buf::default();
+
+ f(&mut state.stdin, &mut state.stdout, &mut buf).map_err(|e| {
+ if e.io.as_ref().map(|it| it.kind()) == Some(io::ErrorKind::BrokenPipe) {
+ match state.process.child.try_wait() {
+ Ok(None) | Err(_) => e,
+ Ok(Some(status)) => {
+ let mut msg = String::new();
+ if !status.success()
+ && let Some(stderr) = state.process.child.stderr.as_mut()
+ {
+ _ = stderr.read_to_string(&mut msg);
}
+ let server_error = ServerError {
+ message: format!(
+ "proc-macro server exited with {status}{}{msg}",
+ if msg.is_empty() { "" } else { ": " }
+ ),
+ io: None,
+ };
+ self.exited.get_or_init(|| AssertUnwindSafe(server_error)).0.clone()
}
- } else {
- e
}
- })
+ } else {
+ e
+ }
+ })
+ }
+
+ pub(crate) fn run_bidirectional<C: Codec>(
+ &self,
+ id: RequestId,
+ initial: Payload,
+ callbacks: &mut dyn ClientCallbacks,
+ ) -> Result<Payload, ServerError> {
+ self.with_locked_io::<C, _>(|writer, reader, buf| {
+ bidirectional_protocol::run_conversation::<C>(
+ writer, reader, buf, id, initial, callbacks,
+ )
+ })
}
}
diff --git a/crates/proc-macro-api/src/transport.rs b/crates/proc-macro-api/src/transport.rs
new file mode 100644
index 0000000000..b7a1d8f732
--- /dev/null
+++ b/crates/proc-macro-api/src/transport.rs
@@ -0,0 +1,3 @@
+//! Contains construct for transport of messages.
+pub mod codec;
+pub mod framing;
diff --git a/crates/proc-macro-api/src/codec.rs b/crates/proc-macro-api/src/transport/codec.rs
index baccaa6be4..c9afad260a 100644
--- a/crates/proc-macro-api/src/codec.rs
+++ b/crates/proc-macro-api/src/transport/codec.rs
@@ -4,7 +4,10 @@ use std::io;
use serde::de::DeserializeOwned;
-use crate::framing::Framing;
+use crate::transport::framing::Framing;
+
+pub mod json;
+pub mod postcard;
pub trait Codec: Framing {
fn encode<T: serde::Serialize>(msg: &T) -> io::Result<Self::Buf>;
diff --git a/crates/proc-macro-api/src/legacy_protocol/json.rs b/crates/proc-macro-api/src/transport/codec/json.rs
index 1359c05684..96db802e0b 100644
--- a/crates/proc-macro-api/src/legacy_protocol/json.rs
+++ b/crates/proc-macro-api/src/transport/codec/json.rs
@@ -3,14 +3,14 @@ use std::io::{self, BufRead, Write};
use serde::{Serialize, de::DeserializeOwned};
-use crate::{codec::Codec, framing::Framing};
+use crate::{Codec, transport::framing::Framing};
pub struct JsonProtocol;
impl Framing for JsonProtocol {
type Buf = String;
- fn read<'a, R: BufRead>(
+ fn read<'a, R: BufRead + ?Sized>(
inp: &mut R,
buf: &'a mut String,
) -> io::Result<Option<&'a mut String>> {
@@ -35,7 +35,7 @@ impl Framing for JsonProtocol {
}
}
- fn write<W: Write>(out: &mut W, buf: &String) -> io::Result<()> {
+ fn write<W: Write + ?Sized>(out: &mut W, buf: &String) -> io::Result<()> {
tracing::debug!("> {}", buf);
out.write_all(buf.as_bytes())?;
out.write_all(b"\n")?;
diff --git a/crates/proc-macro-api/src/legacy_protocol/postcard.rs b/crates/proc-macro-api/src/transport/codec/postcard.rs
index c28a9bfe3a..6f5319e75b 100644
--- a/crates/proc-macro-api/src/legacy_protocol/postcard.rs
+++ b/crates/proc-macro-api/src/transport/codec/postcard.rs
@@ -4,14 +4,14 @@ use std::io::{self, BufRead, Write};
use serde::{Serialize, de::DeserializeOwned};
-use crate::{codec::Codec, framing::Framing};
+use crate::{Codec, transport::framing::Framing};
pub struct PostcardProtocol;
impl Framing for PostcardProtocol {
type Buf = Vec<u8>;
- fn read<'a, R: BufRead>(
+ fn read<'a, R: BufRead + ?Sized>(
inp: &mut R,
buf: &'a mut Vec<u8>,
) -> io::Result<Option<&'a mut Vec<u8>>> {
@@ -23,7 +23,7 @@ impl Framing for PostcardProtocol {
Ok(Some(buf))
}
- fn write<W: Write>(out: &mut W, buf: &Vec<u8>) -> io::Result<()> {
+ fn write<W: Write + ?Sized>(out: &mut W, buf: &Vec<u8>) -> io::Result<()> {
out.write_all(buf)?;
out.flush()
}
diff --git a/crates/proc-macro-api/src/framing.rs b/crates/proc-macro-api/src/transport/framing.rs
index a1e6fc05ca..2a11eb19c3 100644
--- a/crates/proc-macro-api/src/framing.rs
+++ b/crates/proc-macro-api/src/transport/framing.rs
@@ -5,10 +5,10 @@ use std::io::{self, BufRead, Write};
pub trait Framing {
type Buf: Default;
- fn read<'a, R: BufRead>(
+ fn read<'a, R: BufRead + ?Sized>(
inp: &mut R,
buf: &'a mut Self::Buf,
) -> io::Result<Option<&'a mut Self::Buf>>;
- fn write<W: Write>(out: &mut W, buf: &Self::Buf) -> io::Result<()>;
+ fn write<W: Write + ?Sized>(out: &mut W, buf: &Self::Buf) -> io::Result<()>;
}
diff --git a/crates/proc-macro-srv-cli/Cargo.toml b/crates/proc-macro-srv-cli/Cargo.toml
index aa153897fa..298592ee47 100644
--- a/crates/proc-macro-srv-cli/Cargo.toml
+++ b/crates/proc-macro-srv-cli/Cargo.toml
@@ -15,6 +15,7 @@ proc-macro-srv.workspace = true
proc-macro-api.workspace = true
tt.workspace = true
postcard.workspace = true
+crossbeam-channel.workspace = true
clap = {version = "4.5.42", default-features = false, features = ["std"]}
[features]
diff --git a/crates/proc-macro-srv-cli/src/main.rs b/crates/proc-macro-srv-cli/src/main.rs
index 813ac339a9..d3dae0494f 100644
--- a/crates/proc-macro-srv-cli/src/main.rs
+++ b/crates/proc-macro-srv-cli/src/main.rs
@@ -52,6 +52,8 @@ fn main() -> std::io::Result<()> {
enum ProtocolFormat {
JsonLegacy,
PostcardLegacy,
+ JsonNew,
+ PostcardNew,
}
impl ValueEnum for ProtocolFormat {
@@ -65,12 +67,16 @@ impl ValueEnum for ProtocolFormat {
ProtocolFormat::PostcardLegacy => {
Some(clap::builder::PossibleValue::new("postcard-legacy"))
}
+ ProtocolFormat::JsonNew => Some(clap::builder::PossibleValue::new("json-new")),
+ ProtocolFormat::PostcardNew => Some(clap::builder::PossibleValue::new("postcard-new")),
}
}
fn from_str(input: &str, _ignore_case: bool) -> Result<Self, String> {
match input {
"json-legacy" => Ok(ProtocolFormat::JsonLegacy),
"postcard-legacy" => Ok(ProtocolFormat::PostcardLegacy),
+ "postcard-new" => Ok(ProtocolFormat::PostcardNew),
+ "json-new" => Ok(ProtocolFormat::JsonNew),
_ => Err(format!("unknown protocol format: {input}")),
}
}
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs
index df54f38cbc..e543260964 100644
--- a/crates/proc-macro-srv-cli/src/main_loop.rs
+++ b/crates/proc-macro-srv-cli/src/main_loop.rs
@@ -1,16 +1,16 @@
//! The main loop of the proc-macro server.
use std::io;
+use crossbeam_channel::unbounded;
+use proc_macro_api::bidirectional_protocol::msg::Request;
use proc_macro_api::{
Codec,
- legacy_protocol::{
- json::JsonProtocol,
- msg::{
- self, ExpandMacroData, ExpnGlobals, Message, SpanMode, SpanTransformer,
- deserialize_span_data_index_map, serialize_span_data_index_map,
- },
- postcard::PostcardProtocol,
+ bidirectional_protocol::msg::{Envelope, Kind, Payload},
+ legacy_protocol::msg::{
+ self, ExpandMacroData, ExpnGlobals, Message, SpanMode, SpanTransformer,
+ deserialize_span_data_index_map, serialize_span_data_index_map,
},
+ transport::codec::{json::JsonProtocol, postcard::PostcardProtocol},
version::CURRENT_API_VERSION,
};
use proc_macro_srv::{EnvSnapshot, SpanId};
@@ -39,9 +39,280 @@ pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> {
match format {
ProtocolFormat::JsonLegacy => run_::<JsonProtocol>(),
ProtocolFormat::PostcardLegacy => run_::<PostcardProtocol>(),
+ ProtocolFormat::JsonNew => run_new::<JsonProtocol>(),
+ ProtocolFormat::PostcardNew => run_new::<PostcardProtocol>(),
}
}
+fn run_new<C: Codec>() -> io::Result<()> {
+ fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
+ match kind {
+ proc_macro_srv::ProcMacroKind::CustomDerive => {
+ proc_macro_api::ProcMacroKind::CustomDerive
+ }
+ proc_macro_srv::ProcMacroKind::Bang => proc_macro_api::ProcMacroKind::Bang,
+ proc_macro_srv::ProcMacroKind::Attr => proc_macro_api::ProcMacroKind::Attr,
+ }
+ }
+
+ let mut buf = C::Buf::default();
+ let mut stdin = io::stdin().lock();
+ let mut stdout = io::stdout().lock();
+
+ let env_snapshot = EnvSnapshot::default();
+ let srv = proc_macro_srv::ProcMacroSrv::new(&env_snapshot);
+
+ let mut span_mode = SpanMode::Id;
+
+ 'outer: loop {
+ let req_opt = Envelope::read::<_, C>(&mut stdin, &mut buf)?;
+ let Some(req) = req_opt else {
+ break 'outer;
+ };
+
+ match (req.kind, req.payload) {
+ (Kind::Request, Payload::Request(request)) => match request {
+ Request::ListMacros { dylib_path } => {
+ let res = srv.list_macros(&dylib_path).map(|macros| {
+ macros
+ .into_iter()
+ .map(|(name, kind)| (name, macro_kind_to_api(kind)))
+ .collect()
+ });
+
+ let resp_env = Envelope {
+ id: req.id,
+ kind: Kind::Response,
+ payload: Payload::Response(
+ proc_macro_api::bidirectional_protocol::msg::Response::ListMacros(res),
+ ),
+ };
+
+ resp_env.write::<_, C>(&mut stdout)?;
+ }
+
+ Request::ApiVersionCheck {} => {
+ let resp_env = Envelope {
+ id: req.id,
+ kind: Kind::Response,
+ payload: Payload::Response(
+ proc_macro_api::bidirectional_protocol::msg::Response::ApiVersionCheck(
+ CURRENT_API_VERSION,
+ ),
+ ),
+ };
+ resp_env.write::<_, C>(&mut stdout)?;
+ }
+
+ Request::SetConfig(config) => {
+ span_mode = config.span_mode;
+ let resp_env = Envelope {
+ id: req.id,
+ kind: Kind::Response,
+ payload: Payload::Response(
+ proc_macro_api::bidirectional_protocol::msg::Response::SetConfig(
+ config,
+ ),
+ ),
+ };
+ resp_env.write::<_, C>(&mut stdout)?;
+ }
+
+ Request::ExpandMacro(task) => {
+ let proc_macro_api::bidirectional_protocol::msg::ExpandMacro {
+ lib,
+ env,
+ current_dir,
+ data:
+ proc_macro_api::bidirectional_protocol::msg::ExpandMacroData {
+ macro_body,
+ macro_name,
+ attributes,
+ has_global_spans:
+ proc_macro_api::bidirectional_protocol::msg::ExpnGlobals {
+ serialize: _,
+ def_site,
+ call_site,
+ mixed_site,
+ },
+ span_data_table,
+ },
+ } = *task;
+
+ match span_mode {
+ SpanMode::Id => {
+ let def_site = SpanId(def_site as u32);
+ let call_site = SpanId(call_site as u32);
+ let mixed_site = SpanId(mixed_site as u32);
+
+ let macro_body = macro_body.to_tokenstream_unresolved::<SpanTrans>(
+ CURRENT_API_VERSION,
+ |_, b| b,
+ );
+ let attributes = attributes.map(|it| {
+ it.to_tokenstream_unresolved::<SpanTrans>(
+ CURRENT_API_VERSION,
+ |_, b| b,
+ )
+ });
+
+ let res = srv
+ .expand(
+ lib,
+ &env,
+ current_dir,
+ &macro_name,
+ macro_body,
+ attributes,
+ def_site,
+ call_site,
+ mixed_site,
+ )
+ .map(|it| {
+ msg::FlatTree::from_tokenstream_raw::<SpanTrans>(
+ it,
+ call_site,
+ CURRENT_API_VERSION,
+ )
+ })
+ .map_err(|e| e.into_string().unwrap_or_default())
+ .map_err(msg::PanicMessage);
+
+ let resp_env = Envelope {
+ id: req.id,
+ kind: Kind::Response,
+ payload: Payload::Response(proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacro(res)),
+ };
+
+ resp_env.write::<_, C>(&mut stdout)?;
+ }
+
+ SpanMode::RustAnalyzer => {
+ let mut span_data_table =
+ deserialize_span_data_index_map(&span_data_table);
+
+ let def_site_span = span_data_table[def_site];
+ let call_site_span = span_data_table[call_site];
+ let mixed_site_span = span_data_table[mixed_site];
+
+ let macro_body_ts = macro_body.to_tokenstream_resolved(
+ CURRENT_API_VERSION,
+ &span_data_table,
+ |a, b| srv.join_spans(a, b).unwrap_or(b),
+ );
+ let attributes_ts = attributes.map(|it| {
+ it.to_tokenstream_resolved(
+ CURRENT_API_VERSION,
+ &span_data_table,
+ |a, b| srv.join_spans(a, b).unwrap_or(b),
+ )
+ });
+
+ let (subreq_tx, subreq_rx) = unbounded::<proc_macro_srv::SubRequest>();
+ let (subresp_tx, subresp_rx) =
+ unbounded::<proc_macro_srv::SubResponse>();
+ let (result_tx, result_rx) = crossbeam_channel::bounded(1);
+
+ std::thread::scope(|scope| {
+ let srv_ref = &srv;
+
+ scope.spawn({
+ let lib = lib.clone();
+ let env = env.clone();
+ let current_dir = current_dir.clone();
+ let macro_name = macro_name.clone();
+ move || {
+ let res = srv_ref
+ .expand_with_channels(
+ lib,
+ &env,
+ current_dir,
+ &macro_name,
+ macro_body_ts,
+ attributes_ts,
+ def_site_span,
+ call_site_span,
+ mixed_site_span,
+ subresp_rx,
+ subreq_tx,
+ )
+ .map(|it| {
+ (
+ msg::FlatTree::from_tokenstream(
+ it,
+ CURRENT_API_VERSION,
+ call_site_span,
+ &mut span_data_table,
+ ),
+ serialize_span_data_index_map(&span_data_table),
+ )
+ })
+ .map(|(tree, span_data_table)| {
+ proc_macro_api::bidirectional_protocol::msg::ExpandMacroExtended { tree, span_data_table }
+ })
+ .map_err(|e| e.into_string().unwrap_or_default())
+ .map_err(msg::PanicMessage);
+ let _ = result_tx.send(res);
+ }
+ });
+
+ loop {
+ if let Ok(res) = result_rx.try_recv() {
+ let resp_env = Envelope {
+ id: req.id,
+ kind: Kind::Response,
+ payload: Payload::Response(
+ proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacroExtended(res),
+ ),
+ };
+ resp_env.write::<_, C>(&mut stdout).unwrap();
+ break;
+ }
+
+ let subreq = match subreq_rx.recv() {
+ Ok(r) => r,
+ Err(_) => {
+ break;
+ }
+ };
+
+ let sub_env = Envelope {
+ id: req.id,
+ kind: Kind::SubRequest,
+ payload: Payload::SubRequest(from_srv_req(subreq)),
+ };
+ sub_env.write::<_, C>(&mut stdout).unwrap();
+
+ let resp_opt =
+ Envelope::read::<_, C>(&mut stdin, &mut buf).unwrap();
+ let resp = match resp_opt {
+ Some(env) => env,
+ None => {
+ break;
+ }
+ };
+
+ match (resp.kind, resp.payload) {
+ (Kind::SubResponse, Payload::SubResponse(subresp)) => {
+ let _ = subresp_tx.send(from_client_res(subresp));
+ }
+ _ => {
+ break;
+ }
+ }
+ }
+ });
+ }
+ }
+ }
+ },
+ _ => {}
+ }
+ }
+
+ Ok(())
+}
+
fn run_<C: Codec>() -> io::Result<()> {
fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
match kind {
@@ -178,3 +449,27 @@ fn run_<C: Codec>() -> io::Result<()> {
Ok(())
}
+
+fn from_srv_req(
+ value: proc_macro_srv::SubRequest,
+) -> proc_macro_api::bidirectional_protocol::msg::SubRequest {
+ match value {
+ proc_macro_srv::SubRequest::SourceText { file_id, start, end } => {
+ proc_macro_api::bidirectional_protocol::msg::SubRequest::SourceText {
+ file_id: file_id.file_id().index(),
+ start,
+ end,
+ }
+ }
+ }
+}
+
+fn from_client_res(
+ value: proc_macro_api::bidirectional_protocol::msg::SubResponse,
+) -> proc_macro_srv::SubResponse {
+ match value {
+ proc_macro_api::bidirectional_protocol::msg::SubResponse::SourceTextResult { text } => {
+ proc_macro_srv::SubResponse::SourceTextResult { text }
+ }
+ }
+}
diff --git a/crates/proc-macro-srv/Cargo.toml b/crates/proc-macro-srv/Cargo.toml
index 3610171784..b2abda0bfd 100644
--- a/crates/proc-macro-srv/Cargo.toml
+++ b/crates/proc-macro-srv/Cargo.toml
@@ -22,6 +22,7 @@ paths.workspace = true
# span = {workspace = true, default-features = false} does not work
span = { path = "../span", version = "0.0.0", default-features = false}
intern.workspace = true
+crossbeam-channel.workspace = true
ra-ap-rustc_lexer.workspace = true
diff --git a/crates/proc-macro-srv/src/dylib.rs b/crates/proc-macro-srv/src/dylib.rs
index 03433197b7..ba089c9549 100644
--- a/crates/proc-macro-srv/src/dylib.rs
+++ b/crates/proc-macro-srv/src/dylib.rs
@@ -54,6 +54,32 @@ impl Expander {
.expand(macro_name, macro_body, attribute, def_site, call_site, mixed_site)
}
+ pub(crate) fn expand_with_channels<S: ProcMacroSrvSpan>(
+ &self,
+ macro_name: &str,
+ macro_body: crate::token_stream::TokenStream<S>,
+ attribute: Option<crate::token_stream::TokenStream<S>>,
+ def_site: S,
+ call_site: S,
+ mixed_site: S,
+ cli_to_server: crossbeam_channel::Receiver<crate::SubResponse>,
+ server_to_cli: crossbeam_channel::Sender<crate::SubRequest>,
+ ) -> Result<crate::token_stream::TokenStream<S>, crate::PanicMessage>
+ where
+ <S::Server as proc_macro::bridge::server::Types>::TokenStream: Default,
+ {
+ self.inner.proc_macros.expand_with_channels(
+ macro_name,
+ macro_body,
+ attribute,
+ def_site,
+ call_site,
+ mixed_site,
+ cli_to_server,
+ server_to_cli,
+ )
+ }
+
pub(crate) fn list_macros(&self) -> impl Iterator<Item = (&str, ProcMacroKind)> {
self.inner.proc_macros.list_macros()
}
diff --git a/crates/proc-macro-srv/src/dylib/proc_macros.rs b/crates/proc-macro-srv/src/dylib/proc_macros.rs
index c879c7609d..5b6f1cf2f3 100644
--- a/crates/proc-macro-srv/src/dylib/proc_macros.rs
+++ b/crates/proc-macro-srv/src/dylib/proc_macros.rs
@@ -1,5 +1,4 @@
//! Proc macro ABI
-
use proc_macro::bridge;
use crate::{ProcMacroKind, ProcMacroSrvSpan, token_stream::TokenStream};
@@ -32,7 +31,65 @@ impl ProcMacros {
{
let res = client.run(
&bridge::server::SameThread,
- S::make_server(call_site, def_site, mixed_site),
+ S::make_server(call_site, def_site, mixed_site, None, None),
+ macro_body,
+ cfg!(debug_assertions),
+ );
+ return res.map_err(crate::PanicMessage::from);
+ }
+ bridge::client::ProcMacro::Bang { name, client } if *name == macro_name => {
+ let res = client.run(
+ &bridge::server::SameThread,
+ S::make_server(call_site, def_site, mixed_site, None, None),
+ macro_body,
+ cfg!(debug_assertions),
+ );
+ return res.map_err(crate::PanicMessage::from);
+ }
+ bridge::client::ProcMacro::Attr { name, client } if *name == macro_name => {
+ let res = client.run(
+ &bridge::server::SameThread,
+ S::make_server(call_site, def_site, mixed_site, None, None),
+ parsed_attributes,
+ macro_body,
+ cfg!(debug_assertions),
+ );
+ return res.map_err(crate::PanicMessage::from);
+ }
+ _ => continue,
+ }
+ }
+
+ Err(bridge::PanicMessage::String(format!("proc-macro `{macro_name}` is missing")).into())
+ }
+
+ pub(crate) fn expand_with_channels<S: ProcMacroSrvSpan>(
+ &self,
+ macro_name: &str,
+ macro_body: TokenStream<S>,
+ attribute: Option<TokenStream<S>>,
+ def_site: S,
+ call_site: S,
+ mixed_site: S,
+ cli_to_server: crossbeam_channel::Receiver<crate::SubResponse>,
+ server_to_cli: crossbeam_channel::Sender<crate::SubRequest>,
+ ) -> Result<TokenStream<S>, crate::PanicMessage> {
+ let parsed_attributes = attribute.unwrap_or_default();
+
+ for proc_macro in &self.0 {
+ match proc_macro {
+ bridge::client::ProcMacro::CustomDerive { trait_name, client, .. }
+ if *trait_name == macro_name =>
+ {
+ let res = client.run(
+ &bridge::server::SameThread,
+ S::make_server(
+ call_site,
+ def_site,
+ mixed_site,
+ Some(cli_to_server),
+ Some(server_to_cli),
+ ),
macro_body,
cfg!(debug_assertions),
);
@@ -41,7 +98,13 @@ impl ProcMacros {
bridge::client::ProcMacro::Bang { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
- S::make_server(call_site, def_site, mixed_site),
+ S::make_server(
+ call_site,
+ def_site,
+ mixed_site,
+ Some(cli_to_server),
+ Some(server_to_cli),
+ ),
macro_body,
cfg!(debug_assertions),
);
@@ -50,7 +113,13 @@ impl ProcMacros {
bridge::client::ProcMacro::Attr { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
- S::make_server(call_site, def_site, mixed_site),
+ S::make_server(
+ call_site,
+ def_site,
+ mixed_site,
+ Some(cli_to_server),
+ Some(server_to_cli),
+ ),
parsed_attributes,
macro_body,
cfg!(debug_assertions),
diff --git a/crates/proc-macro-srv/src/lib.rs b/crates/proc-macro-srv/src/lib.rs
index 93319df824..f369ab93a2 100644
--- a/crates/proc-macro-srv/src/lib.rs
+++ b/crates/proc-macro-srv/src/lib.rs
@@ -47,7 +47,7 @@ use std::{
};
use paths::{Utf8Path, Utf8PathBuf};
-use span::Span;
+use span::{EditionedFileId, Span};
use temp_dir::TempDir;
pub use crate::server_impl::token_id::SpanId;
@@ -91,6 +91,14 @@ impl<'env> ProcMacroSrv<'env> {
}
}
+pub enum SubRequest {
+ SourceText { file_id: EditionedFileId, start: u32, end: u32 },
+}
+
+pub enum SubResponse {
+ SourceTextResult { text: Option<String> },
+}
+
const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024;
impl ProcMacroSrv<'_> {
@@ -133,6 +141,53 @@ impl ProcMacroSrv<'_> {
result
}
+ pub fn expand_with_channels<S: ProcMacroSrvSpan>(
+ &self,
+ lib: impl AsRef<Utf8Path>,
+ env: &[(String, String)],
+ current_dir: Option<impl AsRef<Path>>,
+ macro_name: &str,
+ macro_body: token_stream::TokenStream<S>,
+ attribute: Option<token_stream::TokenStream<S>>,
+ def_site: S,
+ call_site: S,
+ mixed_site: S,
+ cli_to_server: crossbeam_channel::Receiver<SubResponse>,
+ server_to_cli: crossbeam_channel::Sender<SubRequest>,
+ ) -> Result<token_stream::TokenStream<S>, PanicMessage> {
+ let snapped_env = self.env;
+ let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
+ message: Some(format!("failed to load macro: {err}")),
+ })?;
+
+ let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref));
+
+ let result = thread::scope(|s| {
+ let thread = thread::Builder::new()
+ .stack_size(EXPANDER_STACK_SIZE)
+ .name(macro_name.to_owned())
+ .spawn_scoped(s, move || {
+ expander.expand_with_channels(
+ macro_name,
+ macro_body,
+ attribute,
+ def_site,
+ call_site,
+ mixed_site,
+ cli_to_server,
+ server_to_cli,
+ )
+ });
+ match thread.unwrap().join() {
+ Ok(res) => res,
+ Err(e) => std::panic::resume_unwind(e),
+ }
+ });
+ prev_env.rollback();
+
+ result
+ }
+
pub fn list_macros(
&self,
dylib_path: &Utf8Path,
@@ -170,31 +225,54 @@ impl ProcMacroSrv<'_> {
pub trait ProcMacroSrvSpan: Copy + Send + Sync {
type Server: proc_macro::bridge::server::Server<TokenStream = crate::token_stream::TokenStream<Self>>;
- fn make_server(call_site: Self, def_site: Self, mixed_site: Self) -> Self::Server;
+ fn make_server(
+ call_site: Self,
+ def_site: Self,
+ mixed_site: Self,
+ cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
+ server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ ) -> Self::Server;
}
impl ProcMacroSrvSpan for SpanId {
type Server = server_impl::token_id::SpanIdServer;
- fn make_server(call_site: Self, def_site: Self, mixed_site: Self) -> Self::Server {
+ fn make_server(
+ call_site: Self,
+ def_site: Self,
+ mixed_site: Self,
+ cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
+ server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ ) -> Self::Server {
Self::Server {
call_site,
def_site,
mixed_site,
+ cli_to_server,
+ server_to_cli,
tracked_env_vars: Default::default(),
tracked_paths: Default::default(),
}
}
}
+
impl ProcMacroSrvSpan for Span {
type Server = server_impl::rust_analyzer_span::RaSpanServer;
- fn make_server(call_site: Self, def_site: Self, mixed_site: Self) -> Self::Server {
+ fn make_server(
+ call_site: Self,
+ def_site: Self,
+ mixed_site: Self,
+ cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
+ server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ ) -> Self::Server {
Self::Server {
call_site,
def_site,
mixed_site,
tracked_env_vars: Default::default(),
tracked_paths: Default::default(),
+ cli_to_server,
+ server_to_cli,
}
}
}
diff --git a/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs b/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs
index 7c685c2da7..1a8f6d6730 100644
--- a/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs
+++ b/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs
@@ -14,6 +14,7 @@ use proc_macro::bridge::server;
use span::{FIXUP_ERASED_FILE_AST_ID_MARKER, Span, TextRange, TextSize};
use crate::{
+ SubRequest, SubResponse,
bridge::{Diagnostic, ExpnGlobals, Literal, TokenTree},
server_impl::literal_from_str,
};
@@ -28,6 +29,8 @@ pub struct RaSpanServer {
pub call_site: Span,
pub def_site: Span,
pub mixed_site: Span,
+ pub cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
+ pub server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
}
impl server::Types for RaSpanServer {
@@ -149,9 +152,26 @@ impl server::Span for RaSpanServer {
///
/// See PR:
/// https://github.com/rust-lang/rust/pull/55780
- fn source_text(&mut self, _span: Self::Span) -> Option<String> {
+ fn source_text(&mut self, span: Self::Span) -> Option<String> {
// FIXME requires db, needs special handling wrt fixup spans
- None
+ if self.server_to_cli.is_some() && self.cli_to_server.is_some() {
+ let file_id = span.anchor.file_id;
+ let start: u32 = span.range.start().into();
+ let end: u32 = span.range.end().into();
+ let _ = self.server_to_cli.clone().unwrap().send(SubRequest::SourceText {
+ file_id,
+ start,
+ end,
+ });
+ self.cli_to_server
+ .clone()
+ .unwrap()
+ .recv()
+ .and_then(|SubResponse::SourceTextResult { text }| Ok(text))
+ .expect("REASON")
+ } else {
+ None
+ }
}
fn parent(&mut self, _span: Self::Span) -> Option<Self::Span> {
diff --git a/crates/proc-macro-srv/src/server_impl/token_id.rs b/crates/proc-macro-srv/src/server_impl/token_id.rs
index 5ac263b9d5..268042b3bc 100644
--- a/crates/proc-macro-srv/src/server_impl/token_id.rs
+++ b/crates/proc-macro-srv/src/server_impl/token_id.rs
@@ -9,6 +9,7 @@ use intern::Symbol;
use proc_macro::bridge::server;
use crate::{
+ SubRequest, SubResponse,
bridge::{Diagnostic, ExpnGlobals, Literal, TokenTree},
server_impl::literal_from_str,
};
@@ -34,6 +35,8 @@ pub struct SpanIdServer {
pub call_site: Span,
pub def_site: Span,
pub mixed_site: Span,
+ pub cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
+ pub server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
}
impl server::Types for SpanIdServer {
@@ -139,6 +142,7 @@ impl server::Span for SpanIdServer {
/// See PR:
/// https://github.com/rust-lang/rust/pull/55780
fn source_text(&mut self, _span: Self::Span) -> Option<String> {
+ // FIXME requires db, needs special handling wrt fixup spans
None
}
diff --git a/crates/test-fixture/src/lib.rs b/crates/test-fixture/src/lib.rs
index 5e8b250c24..e08a65c392 100644
--- a/crates/test-fixture/src/lib.rs
+++ b/crates/test-fixture/src/lib.rs
@@ -732,6 +732,7 @@ struct IdentityProcMacroExpander;
impl ProcMacroExpander for IdentityProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
subtree: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -754,6 +755,7 @@ struct Issue18089ProcMacroExpander;
impl ProcMacroExpander for Issue18089ProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
subtree: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -789,6 +791,7 @@ struct AttributeInputReplaceProcMacroExpander;
impl ProcMacroExpander for AttributeInputReplaceProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
_: &TopSubtree,
attrs: Option<&TopSubtree>,
_: &Env,
@@ -812,6 +815,7 @@ struct Issue18840ProcMacroExpander;
impl ProcMacroExpander for Issue18840ProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
fn_: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -847,6 +851,7 @@ struct MirrorProcMacroExpander;
impl ProcMacroExpander for MirrorProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
input: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -885,6 +890,7 @@ struct ShortenProcMacroExpander;
impl ProcMacroExpander for ShortenProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
input: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -927,6 +933,7 @@ struct Issue17479ProcMacroExpander;
impl ProcMacroExpander for Issue17479ProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
subtree: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -956,6 +963,7 @@ struct Issue18898ProcMacroExpander;
impl ProcMacroExpander for Issue18898ProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
subtree: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -1011,6 +1019,7 @@ struct DisallowCfgProcMacroExpander;
impl ProcMacroExpander for DisallowCfgProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
subtree: &TopSubtree,
_: Option<&TopSubtree>,
_: &Env,
@@ -1042,6 +1051,7 @@ struct GenerateSuffixedTypeProcMacroExpander;
impl ProcMacroExpander for GenerateSuffixedTypeProcMacroExpander {
fn expand(
&self,
+ _: &dyn SourceDatabase,
subtree: &TopSubtree,
_attrs: Option<&TopSubtree>,
_env: &Env,