Unnamed repository; edit this file 'description' to name the repository.
refactor the main loop in proc_macro-srv-cli
bit-aloo 4 months ago
parent 28a3b80 · commit 57fdf52
-rw-r--r--crates/proc-macro-srv-cli/src/main_loop.rs458
1 files changed, 247 insertions, 211 deletions
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs
index e543260964..aec971c776 100644
--- a/crates/proc-macro-srv-cli/src/main_loop.rs
+++ b/crates/proc-macro-srv-cli/src/main_loop.rs
@@ -80,236 +80,254 @@ fn run_new<C: Codec>() -> io::Result<()> {
.collect()
});
- let resp_env = Envelope {
- id: req.id,
- kind: Kind::Response,
- payload: Payload::Response(
- proc_macro_api::bidirectional_protocol::msg::Response::ListMacros(res),
- ),
- };
-
- resp_env.write::<_, C>(&mut stdout)?;
+ send_response::<_, C>(
+ &mut stdout,
+ req.id,
+ proc_macro_api::bidirectional_protocol::msg::Response::ListMacros(res),
+ )?;
}
Request::ApiVersionCheck {} => {
- let resp_env = Envelope {
- id: req.id,
- kind: Kind::Response,
- payload: Payload::Response(
- proc_macro_api::bidirectional_protocol::msg::Response::ApiVersionCheck(
- CURRENT_API_VERSION,
- ),
+ send_response::<_, C>(
+ &mut stdout,
+ req.id,
+ proc_macro_api::bidirectional_protocol::msg::Response::ApiVersionCheck(
+ CURRENT_API_VERSION,
),
- };
- resp_env.write::<_, C>(&mut stdout)?;
+ )?;
}
Request::SetConfig(config) => {
span_mode = config.span_mode;
- let resp_env = Envelope {
- id: req.id,
- kind: Kind::Response,
- payload: Payload::Response(
- proc_macro_api::bidirectional_protocol::msg::Response::SetConfig(
- config,
- ),
- ),
- };
- resp_env.write::<_, C>(&mut stdout)?;
+ send_response::<_, C>(
+ &mut stdout,
+ req.id,
+ proc_macro_api::bidirectional_protocol::msg::Response::SetConfig(config),
+ )?;
}
-
Request::ExpandMacro(task) => {
- let proc_macro_api::bidirectional_protocol::msg::ExpandMacro {
+ handle_expand::<_, _, C>(
+ &srv,
+ &mut stdin,
+ &mut stdout,
+ &mut buf,
+ req.id,
+ span_mode,
+ *task,
+ )?;
+ }
+ },
+ _ => continue,
+ }
+ }
+
+ Ok(())
+}
+
+fn handle_expand<W: std::io::Write, R: std::io::BufRead, C: Codec>(
+ srv: &proc_macro_srv::ProcMacroSrv<'_>,
+ stdin: &mut R,
+ stdout: &mut W,
+ buf: &mut C::Buf,
+ req_id: u64,
+ span_mode: SpanMode,
+ task: proc_macro_api::bidirectional_protocol::msg::ExpandMacro,
+) -> io::Result<()> {
+ match span_mode {
+ SpanMode::Id => handle_expand_id::<_, C>(srv, stdout, req_id, task),
+ SpanMode::RustAnalyzer => {
+ handle_expand_ra::<_, _, C>(srv, stdin, stdout, buf, req_id, task)
+ }
+ }
+}
+
+fn handle_expand_id<W: std::io::Write, C: Codec>(
+ srv: &proc_macro_srv::ProcMacroSrv<'_>,
+ stdout: &mut W,
+ req_id: u64,
+ task: proc_macro_api::bidirectional_protocol::msg::ExpandMacro,
+) -> io::Result<()> {
+ let proc_macro_api::bidirectional_protocol::msg::ExpandMacro { lib, env, current_dir, data } =
+ task;
+ let proc_macro_api::bidirectional_protocol::msg::ExpandMacroData {
+ macro_body,
+ macro_name,
+ attributes,
+ has_global_spans:
+ proc_macro_api::bidirectional_protocol::msg::ExpnGlobals {
+ def_site,
+ call_site,
+ mixed_site,
+ ..
+ },
+ ..
+ } = data;
+
+ let def_site = SpanId(def_site as u32);
+ let call_site = SpanId(call_site as u32);
+ let mixed_site = SpanId(mixed_site as u32);
+
+ let macro_body =
+ macro_body.to_tokenstream_unresolved::<SpanTrans>(CURRENT_API_VERSION, |_, b| b);
+ let attributes = attributes
+ .map(|it| it.to_tokenstream_unresolved::<SpanTrans>(CURRENT_API_VERSION, |_, b| b));
+
+ let res = srv
+ .expand(
+ lib,
+ &env,
+ current_dir,
+ &macro_name,
+ macro_body,
+ attributes,
+ def_site,
+ call_site,
+ mixed_site,
+ )
+ .map(|it| {
+ msg::FlatTree::from_tokenstream_raw::<SpanTrans>(it, call_site, CURRENT_API_VERSION)
+ })
+ .map_err(|e| msg::PanicMessage(e.into_string().unwrap_or_default()));
+
+ send_response::<_, C>(
+ stdout,
+ req_id,
+ proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacro(res),
+ )
+}
+
+fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>(
+ srv: &proc_macro_srv::ProcMacroSrv<'_>,
+ stdin: &mut R,
+ stdout: &mut W,
+ buf: &mut C::Buf,
+ req_id: u64,
+ task: proc_macro_api::bidirectional_protocol::msg::ExpandMacro,
+) -> io::Result<()> {
+ let proc_macro_api::bidirectional_protocol::msg::ExpandMacro {
+ lib,
+ env,
+ current_dir,
+ data:
+ proc_macro_api::bidirectional_protocol::msg::ExpandMacroData {
+ macro_body,
+ macro_name,
+ attributes,
+ has_global_spans:
+ proc_macro_api::bidirectional_protocol::msg::ExpnGlobals {
+ serialize: _,
+ def_site,
+ call_site,
+ mixed_site,
+ },
+ span_data_table,
+ },
+ } = task;
+
+ let mut span_data_table = deserialize_span_data_index_map(&span_data_table);
+
+ let def_site_span = span_data_table[def_site];
+ let call_site_span = span_data_table[call_site];
+ let mixed_site_span = span_data_table[mixed_site];
+
+ let macro_body_ts =
+ macro_body.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| {
+ srv.join_spans(a, b).unwrap_or(b)
+ });
+ let attributes_ts = attributes.map(|it| {
+ it.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| {
+ srv.join_spans(a, b).unwrap_or(b)
+ })
+ });
+
+ let (subreq_tx, subreq_rx) = unbounded::<proc_macro_srv::SubRequest>();
+ let (subresp_tx, subresp_rx) = unbounded::<proc_macro_srv::SubResponse>();
+ let (result_tx, result_rx) = crossbeam_channel::bounded(1);
+
+ std::thread::scope(|scope| {
+ let srv_ref = &srv;
+
+ scope.spawn({
+ let lib = lib.clone();
+ let env = env.clone();
+ let current_dir = current_dir.clone();
+ let macro_name = macro_name.clone();
+ move || {
+ let res = srv_ref
+ .expand_with_channels(
lib,
- env,
+ &env,
current_dir,
- data:
- proc_macro_api::bidirectional_protocol::msg::ExpandMacroData {
- macro_body,
- macro_name,
- attributes,
- has_global_spans:
- proc_macro_api::bidirectional_protocol::msg::ExpnGlobals {
- serialize: _,
- def_site,
- call_site,
- mixed_site,
- },
- span_data_table,
- },
- } = *task;
-
- match span_mode {
- SpanMode::Id => {
- let def_site = SpanId(def_site as u32);
- let call_site = SpanId(call_site as u32);
- let mixed_site = SpanId(mixed_site as u32);
-
- let macro_body = macro_body.to_tokenstream_unresolved::<SpanTrans>(
+ &macro_name,
+ macro_body_ts,
+ attributes_ts,
+ def_site_span,
+ call_site_span,
+ mixed_site_span,
+ subresp_rx,
+ subreq_tx,
+ )
+ .map(|it| {
+ (
+ msg::FlatTree::from_tokenstream(
+ it,
CURRENT_API_VERSION,
- |_, b| b,
- );
- let attributes = attributes.map(|it| {
- it.to_tokenstream_unresolved::<SpanTrans>(
- CURRENT_API_VERSION,
- |_, b| b,
- )
- });
-
- let res = srv
- .expand(
- lib,
- &env,
- current_dir,
- &macro_name,
- macro_body,
- attributes,
- def_site,
- call_site,
- mixed_site,
- )
- .map(|it| {
- msg::FlatTree::from_tokenstream_raw::<SpanTrans>(
- it,
- call_site,
- CURRENT_API_VERSION,
- )
- })
- .map_err(|e| e.into_string().unwrap_or_default())
- .map_err(msg::PanicMessage);
-
- let resp_env = Envelope {
- id: req.id,
- kind: Kind::Response,
- payload: Payload::Response(proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacro(res)),
- };
-
- resp_env.write::<_, C>(&mut stdout)?;
+ call_site_span,
+ &mut span_data_table,
+ ),
+ serialize_span_data_index_map(&span_data_table),
+ )
+ })
+ .map(|(tree, span_data_table)| {
+ proc_macro_api::bidirectional_protocol::msg::ExpandMacroExtended {
+ tree,
+ span_data_table,
}
+ })
+ .map_err(|e| e.into_string().unwrap_or_default())
+ .map_err(msg::PanicMessage);
+ let _ = result_tx.send(res);
+ }
+ });
+
+ loop {
+ if let Ok(res) = result_rx.try_recv() {
+ send_response::<_, C>(
+ stdout,
+ req_id,
+ proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacroExtended(res),
+ )
+ .unwrap();
+ break;
+ }
- SpanMode::RustAnalyzer => {
- let mut span_data_table =
- deserialize_span_data_index_map(&span_data_table);
+ let subreq = match subreq_rx.recv() {
+ Ok(r) => r,
+ Err(_) => {
+ break;
+ }
+ };
- let def_site_span = span_data_table[def_site];
- let call_site_span = span_data_table[call_site];
- let mixed_site_span = span_data_table[mixed_site];
+ send_subrequest::<_, C>(stdout, req_id, from_srv_req(subreq)).unwrap();
- let macro_body_ts = macro_body.to_tokenstream_resolved(
- CURRENT_API_VERSION,
- &span_data_table,
- |a, b| srv.join_spans(a, b).unwrap_or(b),
- );
- let attributes_ts = attributes.map(|it| {
- it.to_tokenstream_resolved(
- CURRENT_API_VERSION,
- &span_data_table,
- |a, b| srv.join_spans(a, b).unwrap_or(b),
- )
- });
-
- let (subreq_tx, subreq_rx) = unbounded::<proc_macro_srv::SubRequest>();
- let (subresp_tx, subresp_rx) =
- unbounded::<proc_macro_srv::SubResponse>();
- let (result_tx, result_rx) = crossbeam_channel::bounded(1);
-
- std::thread::scope(|scope| {
- let srv_ref = &srv;
-
- scope.spawn({
- let lib = lib.clone();
- let env = env.clone();
- let current_dir = current_dir.clone();
- let macro_name = macro_name.clone();
- move || {
- let res = srv_ref
- .expand_with_channels(
- lib,
- &env,
- current_dir,
- &macro_name,
- macro_body_ts,
- attributes_ts,
- def_site_span,
- call_site_span,
- mixed_site_span,
- subresp_rx,
- subreq_tx,
- )
- .map(|it| {
- (
- msg::FlatTree::from_tokenstream(
- it,
- CURRENT_API_VERSION,
- call_site_span,
- &mut span_data_table,
- ),
- serialize_span_data_index_map(&span_data_table),
- )
- })
- .map(|(tree, span_data_table)| {
- proc_macro_api::bidirectional_protocol::msg::ExpandMacroExtended { tree, span_data_table }
- })
- .map_err(|e| e.into_string().unwrap_or_default())
- .map_err(msg::PanicMessage);
- let _ = result_tx.send(res);
- }
- });
-
- loop {
- if let Ok(res) = result_rx.try_recv() {
- let resp_env = Envelope {
- id: req.id,
- kind: Kind::Response,
- payload: Payload::Response(
- proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacroExtended(res),
- ),
- };
- resp_env.write::<_, C>(&mut stdout).unwrap();
- break;
- }
-
- let subreq = match subreq_rx.recv() {
- Ok(r) => r,
- Err(_) => {
- break;
- }
- };
-
- let sub_env = Envelope {
- id: req.id,
- kind: Kind::SubRequest,
- payload: Payload::SubRequest(from_srv_req(subreq)),
- };
- sub_env.write::<_, C>(&mut stdout).unwrap();
-
- let resp_opt =
- Envelope::read::<_, C>(&mut stdin, &mut buf).unwrap();
- let resp = match resp_opt {
- Some(env) => env,
- None => {
- break;
- }
- };
-
- match (resp.kind, resp.payload) {
- (Kind::SubResponse, Payload::SubResponse(subresp)) => {
- let _ = subresp_tx.send(from_client_res(subresp));
- }
- _ => {
- break;
- }
- }
- }
- });
- }
- }
+ let resp_opt = Envelope::read::<_, C>(stdin, buf).unwrap();
+ let resp = match resp_opt {
+ Some(env) => env,
+ None => {
+ break;
}
- },
- _ => {}
- }
- }
+ };
+ match (resp.kind, resp.payload) {
+ (Kind::SubResponse, Payload::SubResponse(subresp)) => {
+ let _ = subresp_tx.send(from_client_res(subresp));
+ }
+ _ => {
+ break;
+ }
+ }
+ }
+ });
Ok(())
}
@@ -473,3 +491,21 @@ fn from_client_res(
}
}
}
+
+fn send_response<W: std::io::Write, C: Codec>(
+ stdout: &mut W,
+ id: u64,
+ resp: proc_macro_api::bidirectional_protocol::msg::Response,
+) -> io::Result<()> {
+ let resp = Envelope { id, kind: Kind::Response, payload: Payload::Response(resp) };
+ resp.write::<W, C>(stdout)
+}
+
+fn send_subrequest<W: std::io::Write, C: Codec>(
+ stdout: &mut W,
+ id: u64,
+ resp: proc_macro_api::bidirectional_protocol::msg::SubRequest,
+) -> io::Result<()> {
+ let resp = Envelope { id, kind: Kind::SubRequest, payload: Payload::SubRequest(resp) };
+ resp.write::<W, C>(stdout)
+}