Unnamed repository; edit this file 'description' to name the repository.
remove channels with callbacks in proc-macro-srv
bit-aloo 4 months ago
parent 336f025 · commit 1f64a69
-rw-r--r--crates/proc-macro-srv-cli/src/main_loop.rs135
-rw-r--r--crates/proc-macro-srv/src/dylib.rs31
-rw-r--r--crates/proc-macro-srv/src/dylib/proc_macros.rs80
-rw-r--r--crates/proc-macro-srv/src/lib.rs69
-rw-r--r--crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs31
-rw-r--r--crates/proc-macro-srv/src/server_impl/token_id.rs5
-rw-r--r--crates/proc-macro-srv/src/tests/utils.rs8
7 files changed, 99 insertions, 260 deletions
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs
index 8666c13677..99e3d79ef2 100644
--- a/crates/proc-macro-srv-cli/src/main_loop.rs
+++ b/crates/proc-macro-srv-cli/src/main_loop.rs
@@ -1,7 +1,6 @@
//! The main loop of the proc-macro server.
use std::io;
-use crossbeam_channel::unbounded;
use proc_macro_api::{
Codec,
bidirectional_protocol::msg as bidirectional,
@@ -82,6 +81,7 @@ fn run_new<C: Codec>() -> io::Result<()> {
}
bidirectional::Request::ApiVersionCheck {} => {
+ // bidirectional::Response::ApiVersionCheck(CURRENT_API_VERSION).write::<_, C>(stdout)
send_response::<_, C>(
&mut stdout,
bidirectional::Response::ApiVersionCheck(CURRENT_API_VERSION),
@@ -160,6 +160,7 @@ fn handle_expand_id<W: std::io::Write, C: Codec>(
def_site,
call_site,
mixed_site,
+ None,
)
.map(|it| {
legacy::FlatTree::from_tokenstream_raw::<SpanTrans>(it, call_site, CURRENT_API_VERSION)
@@ -169,7 +170,7 @@ fn handle_expand_id<W: std::io::Write, C: Codec>(
send_response::<_, C>(stdout, bidirectional::Response::ExpandMacro(res))
}
-fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>(
+fn handle_expand_ra<W: io::Write, R: io::BufRead, C: Codec>(
srv: &proc_macro_srv::ProcMacroSrv<'_>,
stdin: &mut R,
stdout: &mut W,
@@ -185,74 +186,69 @@ fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>(
macro_body,
macro_name,
attributes,
- has_global_spans:
- bidirectional::ExpnGlobals { serialize: _, def_site, call_site, mixed_site },
+ has_global_spans: bidirectional::ExpnGlobals { def_site, call_site, mixed_site, .. },
span_data_table,
},
} = task;
let mut span_data_table = legacy::deserialize_span_data_index_map(&span_data_table);
- let def_site_span = span_data_table[def_site];
- let call_site_span = span_data_table[call_site];
- let mixed_site_span = span_data_table[mixed_site];
+ let def_site = span_data_table[def_site];
+ let call_site = span_data_table[call_site];
+ let mixed_site = span_data_table[mixed_site];
- let macro_body_ts =
+ let macro_body =
macro_body.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| {
srv.join_spans(a, b).unwrap_or(b)
});
- let attributes_ts = attributes.map(|it| {
+ let attributes = attributes.map(|it| {
it.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| {
srv.join_spans(a, b).unwrap_or(b)
})
});
- let (subreq_tx, subreq_rx) = unbounded::<proc_macro_srv::SubRequest>();
- let (subresp_tx, subresp_rx) = unbounded::<proc_macro_srv::SubResponse>();
+ let (subreq_tx, subreq_rx) = crossbeam_channel::unbounded();
+ let (subresp_tx, subresp_rx) = crossbeam_channel::unbounded();
let (result_tx, result_rx) = crossbeam_channel::bounded(1);
std::thread::scope(|scope| {
- let srv_ref = &srv;
-
- scope.spawn({
- let lib = lib.clone();
- let env = env.clone();
- let current_dir = current_dir.clone();
- let macro_name = macro_name.clone();
- move || {
- let res = srv_ref
- .expand_with_channels(
- lib,
- &env,
- current_dir,
- &macro_name,
- macro_body_ts,
- attributes_ts,
- def_site_span,
- call_site_span,
- mixed_site_span,
- subresp_rx,
- subreq_tx,
+ scope.spawn(|| {
+ let callback = Box::new(move |req: proc_macro_srv::SubRequest| {
+ subreq_tx.send(req).unwrap();
+ subresp_rx.recv().unwrap()
+ });
+
+ let res = srv
+ .expand(
+ lib,
+ &env,
+ current_dir,
+ &macro_name,
+ macro_body,
+ attributes,
+ def_site,
+ call_site,
+ mixed_site,
+ Some(callback),
+ )
+ .map(|it| {
+ (
+ legacy::FlatTree::from_tokenstream(
+ it,
+ CURRENT_API_VERSION,
+ call_site,
+ &mut span_data_table,
+ ),
+ legacy::serialize_span_data_index_map(&span_data_table),
)
- .map(|it| {
- (
- legacy::FlatTree::from_tokenstream(
- it,
- CURRENT_API_VERSION,
- call_site_span,
- &mut span_data_table,
- ),
- legacy::serialize_span_data_index_map(&span_data_table),
- )
- })
- .map(|(tree, span_data_table)| bidirectional::ExpandMacroExtended {
- tree,
- span_data_table,
- })
- .map_err(|e| e.into_string().unwrap_or_default())
- .map_err(legacy::PanicMessage);
- let _ = result_tx.send(res);
- }
+ })
+ .map(|(tree, span_data_table)| bidirectional::ExpandMacroExtended {
+ tree,
+ span_data_table,
+ })
+ .map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default()));
+
+ let _ = result_tx.send(res);
});
loop {
@@ -264,31 +260,26 @@ fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>(
let subreq = match subreq_rx.recv() {
Ok(r) => r,
- Err(_) => {
- break;
- }
+ Err(_) => break,
};
- send_subrequest::<_, C>(stdout, from_srv_req(subreq)).unwrap();
+ let api_req = from_srv_req(subreq);
+ bidirectional::BidirectionalMessage::SubRequest(api_req).write::<_, C>(stdout).unwrap();
- let resp_opt = bidirectional::BidirectionalMessage::read::<_, C>(stdin, buf).unwrap();
- let resp = match resp_opt {
- Some(env) => env,
- None => {
- break;
- }
- };
+ let resp = bidirectional::BidirectionalMessage::read::<_, C>(stdin, buf)
+ .unwrap()
+ .expect("client closed connection");
match resp {
- bidirectional::BidirectionalMessage::SubResponse(subresp) => {
- let _ = subresp_tx.send(from_client_res(subresp));
- }
- _ => {
- break;
+ bidirectional::BidirectionalMessage::SubResponse(api_resp) => {
+ let srv_resp = from_client_res(api_resp);
+ subresp_tx.send(srv_resp).unwrap();
}
+ other => panic!("expected SubResponse, got {other:?}"),
}
}
});
+
Ok(())
}
@@ -356,6 +347,7 @@ fn run_<C: Codec>() -> io::Result<()> {
def_site,
call_site,
mixed_site,
+ None,
)
.map(|it| {
legacy::FlatTree::from_tokenstream_raw::<SpanTrans>(
@@ -397,6 +389,7 @@ fn run_<C: Codec>() -> io::Result<()> {
def_site,
call_site,
mixed_site,
+ None,
)
.map(|it| {
(
@@ -455,11 +448,3 @@ fn send_response<W: std::io::Write, C: Codec>(
let resp = bidirectional::BidirectionalMessage::Response(resp);
resp.write::<W, C>(stdout)
}
-
-fn send_subrequest<W: std::io::Write, C: Codec>(
- stdout: &mut W,
- resp: bidirectional::SubRequest,
-) -> io::Result<()> {
- let resp = bidirectional::BidirectionalMessage::SubRequest(resp);
- resp.write::<W, C>(stdout)
-}
diff --git a/crates/proc-macro-srv/src/dylib.rs b/crates/proc-macro-srv/src/dylib.rs
index ba089c9549..082a1d77b5 100644
--- a/crates/proc-macro-srv/src/dylib.rs
+++ b/crates/proc-macro-srv/src/dylib.rs
@@ -12,7 +12,7 @@ use object::Object;
use paths::{Utf8Path, Utf8PathBuf};
use crate::{
- PanicMessage, ProcMacroKind, ProcMacroSrvSpan, dylib::proc_macros::ProcMacros,
+ PanicMessage, ProcMacroKind, ProcMacroSrvSpan, SubCallback, dylib::proc_macros::ProcMacros,
token_stream::TokenStream,
};
@@ -45,39 +45,14 @@ impl Expander {
def_site: S,
call_site: S,
mixed_site: S,
+ callback: Option<SubCallback>,
) -> Result<TokenStream<S>, PanicMessage>
where
<S::Server as bridge::server::Types>::TokenStream: Default,
{
self.inner
.proc_macros
- .expand(macro_name, macro_body, attribute, def_site, call_site, mixed_site)
- }
-
- pub(crate) fn expand_with_channels<S: ProcMacroSrvSpan>(
- &self,
- macro_name: &str,
- macro_body: crate::token_stream::TokenStream<S>,
- attribute: Option<crate::token_stream::TokenStream<S>>,
- def_site: S,
- call_site: S,
- mixed_site: S,
- cli_to_server: crossbeam_channel::Receiver<crate::SubResponse>,
- server_to_cli: crossbeam_channel::Sender<crate::SubRequest>,
- ) -> Result<crate::token_stream::TokenStream<S>, crate::PanicMessage>
- where
- <S::Server as proc_macro::bridge::server::Types>::TokenStream: Default,
- {
- self.inner.proc_macros.expand_with_channels(
- macro_name,
- macro_body,
- attribute,
- def_site,
- call_site,
- mixed_site,
- cli_to_server,
- server_to_cli,
- )
+ .expand(macro_name, macro_body, attribute, def_site, call_site, mixed_site, callback)
}
pub(crate) fn list_macros(&self) -> impl Iterator<Item = (&str, ProcMacroKind)> {
diff --git a/crates/proc-macro-srv/src/dylib/proc_macros.rs b/crates/proc-macro-srv/src/dylib/proc_macros.rs
index 5b6f1cf2f3..6f6bd086de 100644
--- a/crates/proc-macro-srv/src/dylib/proc_macros.rs
+++ b/crates/proc-macro-srv/src/dylib/proc_macros.rs
@@ -1,8 +1,7 @@
//! Proc macro ABI
+use crate::{ProcMacroKind, ProcMacroSrvSpan, SubCallback, token_stream::TokenStream};
use proc_macro::bridge;
-use crate::{ProcMacroKind, ProcMacroSrvSpan, token_stream::TokenStream};
-
#[repr(transparent)]
pub(crate) struct ProcMacros([bridge::client::ProcMacro]);
@@ -21,6 +20,7 @@ impl ProcMacros {
def_site: S,
call_site: S,
mixed_site: S,
+ callback: Option<SubCallback>,
) -> Result<TokenStream<S>, crate::PanicMessage> {
let parsed_attributes = attribute.unwrap_or_default();
@@ -31,65 +31,7 @@ impl ProcMacros {
{
let res = client.run(
&bridge::server::SameThread,
- S::make_server(call_site, def_site, mixed_site, None, None),
- macro_body,
- cfg!(debug_assertions),
- );
- return res.map_err(crate::PanicMessage::from);
- }
- bridge::client::ProcMacro::Bang { name, client } if *name == macro_name => {
- let res = client.run(
- &bridge::server::SameThread,
- S::make_server(call_site, def_site, mixed_site, None, None),
- macro_body,
- cfg!(debug_assertions),
- );
- return res.map_err(crate::PanicMessage::from);
- }
- bridge::client::ProcMacro::Attr { name, client } if *name == macro_name => {
- let res = client.run(
- &bridge::server::SameThread,
- S::make_server(call_site, def_site, mixed_site, None, None),
- parsed_attributes,
- macro_body,
- cfg!(debug_assertions),
- );
- return res.map_err(crate::PanicMessage::from);
- }
- _ => continue,
- }
- }
-
- Err(bridge::PanicMessage::String(format!("proc-macro `{macro_name}` is missing")).into())
- }
-
- pub(crate) fn expand_with_channels<S: ProcMacroSrvSpan>(
- &self,
- macro_name: &str,
- macro_body: TokenStream<S>,
- attribute: Option<TokenStream<S>>,
- def_site: S,
- call_site: S,
- mixed_site: S,
- cli_to_server: crossbeam_channel::Receiver<crate::SubResponse>,
- server_to_cli: crossbeam_channel::Sender<crate::SubRequest>,
- ) -> Result<TokenStream<S>, crate::PanicMessage> {
- let parsed_attributes = attribute.unwrap_or_default();
-
- for proc_macro in &self.0 {
- match proc_macro {
- bridge::client::ProcMacro::CustomDerive { trait_name, client, .. }
- if *trait_name == macro_name =>
- {
- let res = client.run(
- &bridge::server::SameThread,
- S::make_server(
- call_site,
- def_site,
- mixed_site,
- Some(cli_to_server),
- Some(server_to_cli),
- ),
+ S::make_server(call_site, def_site, mixed_site, callback),
macro_body,
cfg!(debug_assertions),
);
@@ -98,13 +40,7 @@ impl ProcMacros {
bridge::client::ProcMacro::Bang { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
- S::make_server(
- call_site,
- def_site,
- mixed_site,
- Some(cli_to_server),
- Some(server_to_cli),
- ),
+ S::make_server(call_site, def_site, mixed_site, callback),
macro_body,
cfg!(debug_assertions),
);
@@ -113,13 +49,7 @@ impl ProcMacros {
bridge::client::ProcMacro::Attr { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
- S::make_server(
- call_site,
- def_site,
- mixed_site,
- Some(cli_to_server),
- Some(server_to_cli),
- ),
+ S::make_server(call_site, def_site, mixed_site, callback),
parsed_attributes,
macro_body,
cfg!(debug_assertions),
diff --git a/crates/proc-macro-srv/src/lib.rs b/crates/proc-macro-srv/src/lib.rs
index f369ab93a2..705ac930ed 100644
--- a/crates/proc-macro-srv/src/lib.rs
+++ b/crates/proc-macro-srv/src/lib.rs
@@ -91,6 +91,8 @@ impl<'env> ProcMacroSrv<'env> {
}
}
+pub type SubCallback = Box<dyn Fn(SubRequest) -> SubResponse + Send + Sync + 'static>;
+
pub enum SubRequest {
SourceText { file_id: EditionedFileId, start: u32, end: u32 },
}
@@ -113,6 +115,7 @@ impl ProcMacroSrv<'_> {
def_site: S,
call_site: S,
mixed_site: S,
+ callback: Option<SubCallback>,
) -> Result<token_stream::TokenStream<S>, PanicMessage> {
let snapped_env = self.env;
let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
@@ -128,54 +131,9 @@ impl ProcMacroSrv<'_> {
.stack_size(EXPANDER_STACK_SIZE)
.name(macro_name.to_owned())
.spawn_scoped(s, move || {
- expander
- .expand(macro_name, macro_body, attribute, def_site, call_site, mixed_site)
- });
- match thread.unwrap().join() {
- Ok(res) => res,
- Err(e) => std::panic::resume_unwind(e),
- }
- });
- prev_env.rollback();
-
- result
- }
-
- pub fn expand_with_channels<S: ProcMacroSrvSpan>(
- &self,
- lib: impl AsRef<Utf8Path>,
- env: &[(String, String)],
- current_dir: Option<impl AsRef<Path>>,
- macro_name: &str,
- macro_body: token_stream::TokenStream<S>,
- attribute: Option<token_stream::TokenStream<S>>,
- def_site: S,
- call_site: S,
- mixed_site: S,
- cli_to_server: crossbeam_channel::Receiver<SubResponse>,
- server_to_cli: crossbeam_channel::Sender<SubRequest>,
- ) -> Result<token_stream::TokenStream<S>, PanicMessage> {
- let snapped_env = self.env;
- let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
- message: Some(format!("failed to load macro: {err}")),
- })?;
-
- let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref));
-
- let result = thread::scope(|s| {
- let thread = thread::Builder::new()
- .stack_size(EXPANDER_STACK_SIZE)
- .name(macro_name.to_owned())
- .spawn_scoped(s, move || {
- expander.expand_with_channels(
- macro_name,
- macro_body,
- attribute,
- def_site,
- call_site,
- mixed_site,
- cli_to_server,
- server_to_cli,
+ expander.expand(
+ macro_name, macro_body, attribute, def_site, call_site, mixed_site,
+ callback,
)
});
match thread.unwrap().join() {
@@ -229,8 +187,7 @@ pub trait ProcMacroSrvSpan: Copy + Send + Sync {
call_site: Self,
def_site: Self,
mixed_site: Self,
- cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
- server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ callback: Option<SubCallback>,
) -> Self::Server;
}
@@ -241,15 +198,13 @@ impl ProcMacroSrvSpan for SpanId {
call_site: Self,
def_site: Self,
mixed_site: Self,
- cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
- server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ callback: Option<SubCallback>,
) -> Self::Server {
Self::Server {
call_site,
def_site,
mixed_site,
- cli_to_server,
- server_to_cli,
+ callback,
tracked_env_vars: Default::default(),
tracked_paths: Default::default(),
}
@@ -262,17 +217,15 @@ impl ProcMacroSrvSpan for Span {
call_site: Self,
def_site: Self,
mixed_site: Self,
- cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
- server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ callback: Option<SubCallback>,
) -> Self::Server {
Self::Server {
call_site,
def_site,
mixed_site,
+ callback,
tracked_env_vars: Default::default(),
tracked_paths: Default::default(),
- cli_to_server,
- server_to_cli,
}
}
}
diff --git a/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs b/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs
index cccb74429d..0bce67fcd9 100644
--- a/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs
+++ b/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs
@@ -14,7 +14,7 @@ use proc_macro::bridge::server;
use span::{FIXUP_ERASED_FILE_AST_ID_MARKER, Span, TextRange, TextSize};
use crate::{
- SubRequest, SubResponse,
+ SubCallback, SubRequest, SubResponse,
bridge::{Diagnostic, ExpnGlobals, Literal, TokenTree},
server_impl::literal_from_str,
};
@@ -29,8 +29,7 @@ pub struct RaSpanServer {
pub call_site: Span,
pub def_site: Span,
pub mixed_site: Span,
- pub cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
- pub server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ pub callback: Option<SubCallback>,
}
impl server::Types for RaSpanServer {
@@ -153,21 +152,17 @@ impl server::Span for RaSpanServer {
/// See PR:
/// https://github.com/rust-lang/rust/pull/55780
fn source_text(&mut self, span: Self::Span) -> Option<String> {
- // FIXME requires db, needs special handling wrt fixup spans
- if self.server_to_cli.is_some() && self.cli_to_server.is_some() {
- let file_id = span.anchor.file_id;
- let start: u32 = span.range.start().into();
- let end: u32 = span.range.end().into();
- let _ = self.server_to_cli.clone().unwrap().send(SubRequest::SourceText {
- file_id,
- start,
- end,
- });
- match self.cli_to_server.as_ref()?.recv().ok()? {
- SubResponse::SourceTextResult { text } => text,
- }
- } else {
- None
+ let file_id = span.anchor.file_id;
+ let start: u32 = span.range.start().into();
+ let end: u32 = span.range.end().into();
+
+ let req = SubRequest::SourceText { file_id, start, end };
+
+ let cb = self.callback.as_mut()?;
+ let response = cb(req);
+
+ match response {
+ SubResponse::SourceTextResult { text } => text,
}
}
diff --git a/crates/proc-macro-srv/src/server_impl/token_id.rs b/crates/proc-macro-srv/src/server_impl/token_id.rs
index 268042b3bc..3b12644ec3 100644
--- a/crates/proc-macro-srv/src/server_impl/token_id.rs
+++ b/crates/proc-macro-srv/src/server_impl/token_id.rs
@@ -9,7 +9,7 @@ use intern::Symbol;
use proc_macro::bridge::server;
use crate::{
- SubRequest, SubResponse,
+ SubCallback,
bridge::{Diagnostic, ExpnGlobals, Literal, TokenTree},
server_impl::literal_from_str,
};
@@ -35,8 +35,7 @@ pub struct SpanIdServer {
pub call_site: Span,
pub def_site: Span,
pub mixed_site: Span,
- pub cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
- pub server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
+ pub callback: Option<SubCallback>,
}
impl server::Types for SpanIdServer {
diff --git a/crates/proc-macro-srv/src/tests/utils.rs b/crates/proc-macro-srv/src/tests/utils.rs
index 1b12308ad6..61fcd810b1 100644
--- a/crates/proc-macro-srv/src/tests/utils.rs
+++ b/crates/proc-macro-srv/src/tests/utils.rs
@@ -59,8 +59,9 @@ fn assert_expand_impl(
let input_ts_string = format!("{input_ts:?}");
let attr_ts_string = attr_ts.as_ref().map(|it| format!("{it:?}"));
- let res =
- expander.expand(macro_name, input_ts, attr_ts, def_site, call_site, mixed_site).unwrap();
+ let res = expander
+ .expand(macro_name, input_ts, attr_ts, def_site, call_site, mixed_site, None)
+ .unwrap();
expect.assert_eq(&format!(
"{input_ts_string}{}{}{}",
if attr_ts_string.is_some() { "\n\n" } else { "" },
@@ -91,7 +92,8 @@ fn assert_expand_impl(
let fixture_string = format!("{fixture:?}");
let attr_string = attr.as_ref().map(|it| format!("{it:?}"));
- let res = expander.expand(macro_name, fixture, attr, def_site, call_site, mixed_site).unwrap();
+ let res =
+ expander.expand(macro_name, fixture, attr, def_site, call_site, mixed_site, None).unwrap();
expect_spanned.assert_eq(&format!(
"{fixture_string}{}{}{}",
if attr_string.is_some() { "\n\n" } else { "" },