Unnamed repository; edit this file 'description' to name the repository.
refactor the main loop in proc_macro-srv-cli
| -rw-r--r-- | crates/proc-macro-srv-cli/src/main_loop.rs | 458 |
1 files changed, 247 insertions, 211 deletions
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs index e543260964..aec971c776 100644 --- a/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/crates/proc-macro-srv-cli/src/main_loop.rs @@ -80,236 +80,254 @@ fn run_new<C: Codec>() -> io::Result<()> { .collect() }); - let resp_env = Envelope { - id: req.id, - kind: Kind::Response, - payload: Payload::Response( - proc_macro_api::bidirectional_protocol::msg::Response::ListMacros(res), - ), - }; - - resp_env.write::<_, C>(&mut stdout)?; + send_response::<_, C>( + &mut stdout, + req.id, + proc_macro_api::bidirectional_protocol::msg::Response::ListMacros(res), + )?; } Request::ApiVersionCheck {} => { - let resp_env = Envelope { - id: req.id, - kind: Kind::Response, - payload: Payload::Response( - proc_macro_api::bidirectional_protocol::msg::Response::ApiVersionCheck( - CURRENT_API_VERSION, - ), + send_response::<_, C>( + &mut stdout, + req.id, + proc_macro_api::bidirectional_protocol::msg::Response::ApiVersionCheck( + CURRENT_API_VERSION, ), - }; - resp_env.write::<_, C>(&mut stdout)?; + )?; } Request::SetConfig(config) => { span_mode = config.span_mode; - let resp_env = Envelope { - id: req.id, - kind: Kind::Response, - payload: Payload::Response( - proc_macro_api::bidirectional_protocol::msg::Response::SetConfig( - config, - ), - ), - }; - resp_env.write::<_, C>(&mut stdout)?; + send_response::<_, C>( + &mut stdout, + req.id, + proc_macro_api::bidirectional_protocol::msg::Response::SetConfig(config), + )?; } - Request::ExpandMacro(task) => { - let proc_macro_api::bidirectional_protocol::msg::ExpandMacro { + handle_expand::<_, _, C>( + &srv, + &mut stdin, + &mut stdout, + &mut buf, + req.id, + span_mode, + *task, + )?; + } + }, + _ => continue, + } + } + + Ok(()) +} + +fn handle_expand<W: std::io::Write, R: std::io::BufRead, C: Codec>( + srv: &proc_macro_srv::ProcMacroSrv<'_>, + stdin: &mut R, + stdout: &mut W, + buf: &mut C::Buf, + req_id: u64, + span_mode: SpanMode, + task: proc_macro_api::bidirectional_protocol::msg::ExpandMacro, +) -> io::Result<()> { + match span_mode { + SpanMode::Id => handle_expand_id::<_, C>(srv, stdout, req_id, task), + SpanMode::RustAnalyzer => { + handle_expand_ra::<_, _, C>(srv, stdin, stdout, buf, req_id, task) + } + } +} + +fn handle_expand_id<W: std::io::Write, C: Codec>( + srv: &proc_macro_srv::ProcMacroSrv<'_>, + stdout: &mut W, + req_id: u64, + task: proc_macro_api::bidirectional_protocol::msg::ExpandMacro, +) -> io::Result<()> { + let proc_macro_api::bidirectional_protocol::msg::ExpandMacro { lib, env, current_dir, data } = + task; + let proc_macro_api::bidirectional_protocol::msg::ExpandMacroData { + macro_body, + macro_name, + attributes, + has_global_spans: + proc_macro_api::bidirectional_protocol::msg::ExpnGlobals { + def_site, + call_site, + mixed_site, + .. + }, + .. + } = data; + + let def_site = SpanId(def_site as u32); + let call_site = SpanId(call_site as u32); + let mixed_site = SpanId(mixed_site as u32); + + let macro_body = + macro_body.to_tokenstream_unresolved::<SpanTrans>(CURRENT_API_VERSION, |_, b| b); + let attributes = attributes + .map(|it| it.to_tokenstream_unresolved::<SpanTrans>(CURRENT_API_VERSION, |_, b| b)); + + let res = srv + .expand( + lib, + &env, + current_dir, + ¯o_name, + macro_body, + attributes, + def_site, + call_site, + mixed_site, + ) + .map(|it| { + msg::FlatTree::from_tokenstream_raw::<SpanTrans>(it, call_site, CURRENT_API_VERSION) + }) + .map_err(|e| msg::PanicMessage(e.into_string().unwrap_or_default())); + + send_response::<_, C>( + stdout, + req_id, + proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacro(res), + ) +} + +fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>( + srv: &proc_macro_srv::ProcMacroSrv<'_>, + stdin: &mut R, + stdout: &mut W, + buf: &mut C::Buf, + req_id: u64, + task: proc_macro_api::bidirectional_protocol::msg::ExpandMacro, +) -> io::Result<()> { + let proc_macro_api::bidirectional_protocol::msg::ExpandMacro { + lib, + env, + current_dir, + data: + proc_macro_api::bidirectional_protocol::msg::ExpandMacroData { + macro_body, + macro_name, + attributes, + has_global_spans: + proc_macro_api::bidirectional_protocol::msg::ExpnGlobals { + serialize: _, + def_site, + call_site, + mixed_site, + }, + span_data_table, + }, + } = task; + + let mut span_data_table = deserialize_span_data_index_map(&span_data_table); + + let def_site_span = span_data_table[def_site]; + let call_site_span = span_data_table[call_site]; + let mixed_site_span = span_data_table[mixed_site]; + + let macro_body_ts = + macro_body.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| { + srv.join_spans(a, b).unwrap_or(b) + }); + let attributes_ts = attributes.map(|it| { + it.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| { + srv.join_spans(a, b).unwrap_or(b) + }) + }); + + let (subreq_tx, subreq_rx) = unbounded::<proc_macro_srv::SubRequest>(); + let (subresp_tx, subresp_rx) = unbounded::<proc_macro_srv::SubResponse>(); + let (result_tx, result_rx) = crossbeam_channel::bounded(1); + + std::thread::scope(|scope| { + let srv_ref = &srv; + + scope.spawn({ + let lib = lib.clone(); + let env = env.clone(); + let current_dir = current_dir.clone(); + let macro_name = macro_name.clone(); + move || { + let res = srv_ref + .expand_with_channels( lib, - env, + &env, current_dir, - data: - proc_macro_api::bidirectional_protocol::msg::ExpandMacroData { - macro_body, - macro_name, - attributes, - has_global_spans: - proc_macro_api::bidirectional_protocol::msg::ExpnGlobals { - serialize: _, - def_site, - call_site, - mixed_site, - }, - span_data_table, - }, - } = *task; - - match span_mode { - SpanMode::Id => { - let def_site = SpanId(def_site as u32); - let call_site = SpanId(call_site as u32); - let mixed_site = SpanId(mixed_site as u32); - - let macro_body = macro_body.to_tokenstream_unresolved::<SpanTrans>( + ¯o_name, + macro_body_ts, + attributes_ts, + def_site_span, + call_site_span, + mixed_site_span, + subresp_rx, + subreq_tx, + ) + .map(|it| { + ( + msg::FlatTree::from_tokenstream( + it, CURRENT_API_VERSION, - |_, b| b, - ); - let attributes = attributes.map(|it| { - it.to_tokenstream_unresolved::<SpanTrans>( - CURRENT_API_VERSION, - |_, b| b, - ) - }); - - let res = srv - .expand( - lib, - &env, - current_dir, - ¯o_name, - macro_body, - attributes, - def_site, - call_site, - mixed_site, - ) - .map(|it| { - msg::FlatTree::from_tokenstream_raw::<SpanTrans>( - it, - call_site, - CURRENT_API_VERSION, - ) - }) - .map_err(|e| e.into_string().unwrap_or_default()) - .map_err(msg::PanicMessage); - - let resp_env = Envelope { - id: req.id, - kind: Kind::Response, - payload: Payload::Response(proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacro(res)), - }; - - resp_env.write::<_, C>(&mut stdout)?; + call_site_span, + &mut span_data_table, + ), + serialize_span_data_index_map(&span_data_table), + ) + }) + .map(|(tree, span_data_table)| { + proc_macro_api::bidirectional_protocol::msg::ExpandMacroExtended { + tree, + span_data_table, } + }) + .map_err(|e| e.into_string().unwrap_or_default()) + .map_err(msg::PanicMessage); + let _ = result_tx.send(res); + } + }); + + loop { + if let Ok(res) = result_rx.try_recv() { + send_response::<_, C>( + stdout, + req_id, + proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacroExtended(res), + ) + .unwrap(); + break; + } - SpanMode::RustAnalyzer => { - let mut span_data_table = - deserialize_span_data_index_map(&span_data_table); + let subreq = match subreq_rx.recv() { + Ok(r) => r, + Err(_) => { + break; + } + }; - let def_site_span = span_data_table[def_site]; - let call_site_span = span_data_table[call_site]; - let mixed_site_span = span_data_table[mixed_site]; + send_subrequest::<_, C>(stdout, req_id, from_srv_req(subreq)).unwrap(); - let macro_body_ts = macro_body.to_tokenstream_resolved( - CURRENT_API_VERSION, - &span_data_table, - |a, b| srv.join_spans(a, b).unwrap_or(b), - ); - let attributes_ts = attributes.map(|it| { - it.to_tokenstream_resolved( - CURRENT_API_VERSION, - &span_data_table, - |a, b| srv.join_spans(a, b).unwrap_or(b), - ) - }); - - let (subreq_tx, subreq_rx) = unbounded::<proc_macro_srv::SubRequest>(); - let (subresp_tx, subresp_rx) = - unbounded::<proc_macro_srv::SubResponse>(); - let (result_tx, result_rx) = crossbeam_channel::bounded(1); - - std::thread::scope(|scope| { - let srv_ref = &srv; - - scope.spawn({ - let lib = lib.clone(); - let env = env.clone(); - let current_dir = current_dir.clone(); - let macro_name = macro_name.clone(); - move || { - let res = srv_ref - .expand_with_channels( - lib, - &env, - current_dir, - ¯o_name, - macro_body_ts, - attributes_ts, - def_site_span, - call_site_span, - mixed_site_span, - subresp_rx, - subreq_tx, - ) - .map(|it| { - ( - msg::FlatTree::from_tokenstream( - it, - CURRENT_API_VERSION, - call_site_span, - &mut span_data_table, - ), - serialize_span_data_index_map(&span_data_table), - ) - }) - .map(|(tree, span_data_table)| { - proc_macro_api::bidirectional_protocol::msg::ExpandMacroExtended { tree, span_data_table } - }) - .map_err(|e| e.into_string().unwrap_or_default()) - .map_err(msg::PanicMessage); - let _ = result_tx.send(res); - } - }); - - loop { - if let Ok(res) = result_rx.try_recv() { - let resp_env = Envelope { - id: req.id, - kind: Kind::Response, - payload: Payload::Response( - proc_macro_api::bidirectional_protocol::msg::Response::ExpandMacroExtended(res), - ), - }; - resp_env.write::<_, C>(&mut stdout).unwrap(); - break; - } - - let subreq = match subreq_rx.recv() { - Ok(r) => r, - Err(_) => { - break; - } - }; - - let sub_env = Envelope { - id: req.id, - kind: Kind::SubRequest, - payload: Payload::SubRequest(from_srv_req(subreq)), - }; - sub_env.write::<_, C>(&mut stdout).unwrap(); - - let resp_opt = - Envelope::read::<_, C>(&mut stdin, &mut buf).unwrap(); - let resp = match resp_opt { - Some(env) => env, - None => { - break; - } - }; - - match (resp.kind, resp.payload) { - (Kind::SubResponse, Payload::SubResponse(subresp)) => { - let _ = subresp_tx.send(from_client_res(subresp)); - } - _ => { - break; - } - } - } - }); - } - } + let resp_opt = Envelope::read::<_, C>(stdin, buf).unwrap(); + let resp = match resp_opt { + Some(env) => env, + None => { + break; } - }, - _ => {} - } - } + }; + match (resp.kind, resp.payload) { + (Kind::SubResponse, Payload::SubResponse(subresp)) => { + let _ = subresp_tx.send(from_client_res(subresp)); + } + _ => { + break; + } + } + } + }); Ok(()) } @@ -473,3 +491,21 @@ fn from_client_res( } } } + +fn send_response<W: std::io::Write, C: Codec>( + stdout: &mut W, + id: u64, + resp: proc_macro_api::bidirectional_protocol::msg::Response, +) -> io::Result<()> { + let resp = Envelope { id, kind: Kind::Response, payload: Payload::Response(resp) }; + resp.write::<W, C>(stdout) +} + +fn send_subrequest<W: std::io::Write, C: Codec>( + stdout: &mut W, + id: u64, + resp: proc_macro_api::bidirectional_protocol::msg::SubRequest, +) -> io::Result<()> { + let resp = Envelope { id, kind: Kind::SubRequest, payload: Payload::SubRequest(resp) }; + resp.write::<W, C>(stdout) +} |