Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/proc-macro-srv-cli/src/main_loop.rs')
| -rw-r--r-- | crates/proc-macro-srv-cli/src/main_loop.rs | 190 |
1 files changed, 123 insertions, 67 deletions
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs index b2f4b96bd2..9be3199a38 100644 --- a/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/crates/proc-macro-srv-cli/src/main_loop.rs @@ -1,18 +1,18 @@ //! The main loop of the proc-macro server. use proc_macro_api::{ - Codec, - bidirectional_protocol::msg as bidirectional, - legacy_protocol::msg as legacy, - transport::codec::{json::JsonProtocol, postcard::PostcardProtocol}, + ProtocolFormat, bidirectional_protocol::msg as bidirectional, legacy_protocol::msg as legacy, version::CURRENT_API_VERSION, }; -use std::io; +use std::panic::{panic_any, resume_unwind}; +use std::{ + io::{self, BufRead, Write}, + ops::Range, +}; use legacy::Message; -use proc_macro_srv::{EnvSnapshot, SpanId}; +use proc_macro_srv::{EnvSnapshot, ProcMacroClientError, ProcMacroPanicMarker, SpanId}; -use crate::ProtocolFormat; struct SpanTrans; impl legacy::SpanTransformer for SpanTrans { @@ -32,15 +32,21 @@ impl legacy::SpanTransformer for SpanTrans { } } -pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> { +pub fn run( + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), + format: ProtocolFormat, +) -> io::Result<()> { match format { - ProtocolFormat::JsonLegacy => run_::<JsonProtocol>(), - ProtocolFormat::PostcardLegacy => run_::<PostcardProtocol>(), - ProtocolFormat::BidirectionalPostcardPrototype => run_new::<PostcardProtocol>(), + ProtocolFormat::JsonLegacy => run_old(stdin, stdout), + ProtocolFormat::BidirectionalPostcardPrototype => run_new(stdin, stdout), } } -fn run_new<C: Codec>() -> io::Result<()> { +fn run_new( + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), +) -> io::Result<()> { fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { match kind { proc_macro_srv::ProcMacroKind::CustomDerive => { @@ -51,9 +57,7 @@ fn run_new<C: Codec>() -> io::Result<()> { } } - let mut buf = C::Buf::default(); - let mut stdin = io::stdin(); - let mut stdout = io::stdout(); + let mut buf = Vec::default(); let env_snapshot = EnvSnapshot::default(); let srv = proc_macro_srv::ProcMacroSrv::new(&env_snapshot); @@ -61,8 +65,7 @@ fn run_new<C: Codec>() -> io::Result<()> { let mut span_mode = legacy::SpanMode::Id; 'outer: loop { - let req_opt = - bidirectional::BidirectionalMessage::read::<_, C>(&mut stdin.lock(), &mut buf)?; + let req_opt = bidirectional::BidirectionalMessage::read(stdin, &mut buf)?; let Some(req) = req_opt else { break 'outer; }; @@ -77,22 +80,22 @@ fn run_new<C: Codec>() -> io::Result<()> { .collect() }); - send_response::<C>(&stdout, bidirectional::Response::ListMacros(res))?; + send_response(stdout, bidirectional::Response::ListMacros(res))?; } bidirectional::Request::ApiVersionCheck {} => { - send_response::<C>( - &stdout, + send_response( + stdout, bidirectional::Response::ApiVersionCheck(CURRENT_API_VERSION), )?; } bidirectional::Request::SetConfig(config) => { span_mode = config.span_mode; - send_response::<C>(&stdout, bidirectional::Response::SetConfig(config))?; + send_response(stdout, bidirectional::Response::SetConfig(config))?; } bidirectional::Request::ExpandMacro(task) => { - handle_expand::<C>(&srv, &mut stdin, &mut stdout, &mut buf, span_mode, *task)?; + handle_expand(&srv, stdin, stdout, &mut buf, span_mode, *task)?; } }, _ => continue, @@ -102,23 +105,23 @@ fn run_new<C: Codec>() -> io::Result<()> { Ok(()) } -fn handle_expand<C: Codec>( +fn handle_expand( srv: &proc_macro_srv::ProcMacroSrv<'_>, - stdin: &io::Stdin, - stdout: &io::Stdout, - buf: &mut C::Buf, + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), + buf: &mut Vec<u8>, span_mode: legacy::SpanMode, task: bidirectional::ExpandMacro, ) -> io::Result<()> { match span_mode { - legacy::SpanMode::Id => handle_expand_id::<C>(srv, stdout, task), - legacy::SpanMode::RustAnalyzer => handle_expand_ra::<C>(srv, stdin, stdout, buf, task), + legacy::SpanMode::Id => handle_expand_id(srv, stdout, task), + legacy::SpanMode::RustAnalyzer => handle_expand_ra(srv, stdin, stdout, buf, task), } } -fn handle_expand_id<C: Codec>( +fn handle_expand_id( srv: &proc_macro_srv::ProcMacroSrv<'_>, - stdout: &io::Stdout, + stdout: &mut dyn Write, task: bidirectional::ExpandMacro, ) -> io::Result<()> { let bidirectional::ExpandMacro { lib, env, current_dir, data } = task; @@ -157,40 +160,65 @@ fn handle_expand_id<C: Codec>( }) .map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default())); - send_response::<C>(&stdout, bidirectional::Response::ExpandMacro(res)) + send_response(stdout, bidirectional::Response::ExpandMacro(res)) } -struct ProcMacroClientHandle<'a, C: Codec> { - stdin: &'a io::Stdin, - stdout: &'a io::Stdout, - buf: &'a mut C::Buf, +struct ProcMacroClientHandle<'a> { + stdin: &'a mut (dyn BufRead + Send + Sync), + stdout: &'a mut (dyn Write + Send + Sync), + buf: &'a mut Vec<u8>, } -impl<'a, C: Codec> ProcMacroClientHandle<'a, C> { +impl<'a> ProcMacroClientHandle<'a> { fn roundtrip( &mut self, req: bidirectional::SubRequest, - ) -> Option<bidirectional::BidirectionalMessage> { + ) -> Result<bidirectional::SubResponse, ProcMacroClientError> { let msg = bidirectional::BidirectionalMessage::SubRequest(req); - if msg.write::<_, C>(&mut self.stdout.lock()).is_err() { - return None; + msg.write(&mut *self.stdout).map_err(ProcMacroClientError::Io)?; + + let msg = bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf) + .map_err(ProcMacroClientError::Io)? + .ok_or(ProcMacroClientError::Eof)?; + + match msg { + bidirectional::BidirectionalMessage::SubResponse(resp) => match resp { + bidirectional::SubResponse::Cancel { reason } => { + Err(ProcMacroClientError::Cancelled { reason }) + } + other => Ok(other), + }, + other => { + Err(ProcMacroClientError::Protocol(format!("expected SubResponse, got {other:?}"))) + } } + } +} - match bidirectional::BidirectionalMessage::read::<_, C>(&mut self.stdin.lock(), self.buf) { - Ok(Some(msg)) => Some(msg), - _ => None, +fn handle_failure(failure: Result<bidirectional::SubResponse, ProcMacroClientError>) -> ! { + match failure { + Err(ProcMacroClientError::Cancelled { reason }) => { + resume_unwind(Box::new(ProcMacroPanicMarker::Cancelled { reason })); + } + Err(err) => { + panic_any(ProcMacroPanicMarker::Internal { + reason: format!("proc-macro IPC error: {err:?}"), + }); + } + Ok(other) => { + panic_any(ProcMacroPanicMarker::Internal { + reason: format!("unexpected SubResponse {other:?}"), + }); } } } -impl<C: Codec> proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_, C> { +impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> { fn file(&mut self, file_id: proc_macro_srv::span::FileId) -> String { match self.roundtrip(bidirectional::SubRequest::FilePath { file_id: file_id.index() }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::FilePathResult { name }, - )) => name, - _ => String::new(), + Ok(bidirectional::SubResponse::FilePathResult { name }) => name, + other => handle_failure(other), } } @@ -204,29 +232,54 @@ impl<C: Codec> proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl start: range.start().into(), end: range.end().into(), }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::SourceTextResult { text }, - )) => text, - _ => None, + Ok(bidirectional::SubResponse::SourceTextResult { text }) => text, + other => handle_failure(other), } } fn local_file(&mut self, file_id: proc_macro_srv::span::FileId) -> Option<String> { match self.roundtrip(bidirectional::SubRequest::LocalFilePath { file_id: file_id.index() }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::LocalFilePathResult { name }, - )) => name, - _ => None, + Ok(bidirectional::SubResponse::LocalFilePathResult { name }) => name, + other => handle_failure(other), + } + } + + fn line_column(&mut self, span: proc_macro_srv::span::Span) -> Option<(u32, u32)> { + let proc_macro_srv::span::Span { range, anchor, ctx: _ } = span; + match self.roundtrip(bidirectional::SubRequest::LineColumn { + file_id: anchor.file_id.as_u32(), + ast_id: anchor.ast_id.into_raw(), + offset: range.start().into(), + }) { + Ok(bidirectional::SubResponse::LineColumnResult { line, column }) => { + Some((line, column)) + } + other => handle_failure(other), + } + } + + fn byte_range( + &mut self, + proc_macro_srv::span::Span { range, anchor, ctx: _ }: proc_macro_srv::span::Span, + ) -> Range<usize> { + match self.roundtrip(bidirectional::SubRequest::ByteRange { + file_id: anchor.file_id.as_u32(), + ast_id: anchor.ast_id.into_raw(), + start: range.start().into(), + end: range.end().into(), + }) { + Ok(bidirectional::SubResponse::ByteRangeResult { range }) => range, + other => handle_failure(other), } } } -fn handle_expand_ra<C: Codec>( +fn handle_expand_ra( srv: &proc_macro_srv::ProcMacroSrv<'_>, - stdin: &io::Stdin, - stdout: &io::Stdout, - buf: &mut C::Buf, + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), + buf: &mut Vec<u8>, task: bidirectional::ExpandMacro, ) -> io::Result<()> { let bidirectional::ExpandMacro { @@ -271,7 +324,7 @@ fn handle_expand_ra<C: Codec>( def_site, call_site, mixed_site, - Some(&mut ProcMacroClientHandle::<C> { stdin, stdout, buf }), + Some(&mut ProcMacroClientHandle { stdin, stdout, buf }), ) .map(|it| { ( @@ -287,10 +340,13 @@ fn handle_expand_ra<C: Codec>( .map(|(tree, span_data_table)| bidirectional::ExpandMacroExtended { tree, span_data_table }) .map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default())); - send_response::<C>(&stdout, bidirectional::Response::ExpandMacroExtended(res)) + send_response(stdout, bidirectional::Response::ExpandMacroExtended(res)) } -fn run_<C: Codec>() -> io::Result<()> { +fn run_old( + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), +) -> io::Result<()> { fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { match kind { proc_macro_srv::ProcMacroKind::CustomDerive => { @@ -301,9 +357,9 @@ fn run_<C: Codec>() -> io::Result<()> { } } - let mut buf = C::Buf::default(); - let mut read_request = || legacy::Request::read::<_, C>(&mut io::stdin().lock(), &mut buf); - let write_response = |msg: legacy::Response| msg.write::<_, C>(&mut io::stdout().lock()); + let mut buf = String::default(); + let mut read_request = || legacy::Request::read(stdin, &mut buf); + let mut write_response = |msg: legacy::Response| msg.write(stdout); let env = EnvSnapshot::default(); let srv = proc_macro_srv::ProcMacroSrv::new(&env); @@ -432,7 +488,7 @@ fn run_<C: Codec>() -> io::Result<()> { Ok(()) } -fn send_response<C: Codec>(stdout: &io::Stdout, resp: bidirectional::Response) -> io::Result<()> { +fn send_response(stdout: &mut dyn Write, resp: bidirectional::Response) -> io::Result<()> { let resp = bidirectional::BidirectionalMessage::Response(resp); - resp.write::<_, C>(&mut stdout.lock()) + resp.write(stdout) } |