Unnamed repository; edit this file 'description' to name the repository.
Abstract proc-macro-srv input and output away
| -rw-r--r-- | crates/proc-macro-api/src/legacy_protocol.rs | 4 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/legacy_protocol/msg.rs | 4 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/lib.rs | 41 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/process.rs | 138 | ||||
| -rw-r--r-- | crates/proc-macro-srv-cli/src/main.rs | 12 | ||||
| -rw-r--r-- | crates/proc-macro-srv-cli/src/main_loop.rs | 72 |
6 files changed, 193 insertions, 78 deletions
diff --git a/crates/proc-macro-api/src/legacy_protocol.rs b/crates/proc-macro-api/src/legacy_protocol.rs index 22a7d9868e..4524d1b66b 100644 --- a/crates/proc-macro-api/src/legacy_protocol.rs +++ b/crates/proc-macro-api/src/legacy_protocol.rs @@ -162,11 +162,11 @@ fn send_request<P: Codec>( req: Request, buf: &mut P::Buf, ) -> Result<Option<Response>, ServerError> { - req.write::<_, P>(&mut writer).map_err(|err| ServerError { + req.write::<P>(&mut writer).map_err(|err| ServerError { message: "failed to write request".into(), io: Some(Arc::new(err)), })?; - let res = Response::read::<_, P>(&mut reader, buf).map_err(|err| ServerError { + let res = Response::read::<P>(&mut reader, buf).map_err(|err| ServerError { message: "failed to read response".into(), io: Some(Arc::new(err)), })?; diff --git a/crates/proc-macro-api/src/legacy_protocol/msg.rs b/crates/proc-macro-api/src/legacy_protocol/msg.rs index 4146b619ec..1b65906933 100644 --- a/crates/proc-macro-api/src/legacy_protocol/msg.rs +++ b/crates/proc-macro-api/src/legacy_protocol/msg.rs @@ -155,13 +155,13 @@ impl ExpnGlobals { } pub trait Message: serde::Serialize + DeserializeOwned { - fn read<R: BufRead, C: Codec>(inp: &mut R, buf: &mut C::Buf) -> io::Result<Option<Self>> { + fn read<C: Codec>(inp: &mut dyn BufRead, buf: &mut C::Buf) -> io::Result<Option<Self>> { Ok(match C::read(inp, buf)? { None => None, Some(buf) => Some(C::decode(buf)?), }) } - fn write<W: Write, C: Codec>(self, out: &mut W) -> io::Result<()> { + fn write<C: Codec>(self, out: &mut dyn Write) -> io::Result<()> { let value = C::encode(&self)?; C::write(out, &value) } diff --git a/crates/proc-macro-api/src/lib.rs b/crates/proc-macro-api/src/lib.rs index f5fcc99f14..98ee6817c2 100644 --- a/crates/proc-macro-api/src/lib.rs +++ b/crates/proc-macro-api/src/lib.rs @@ -18,7 +18,7 @@ extern crate rustc_driver as _; pub mod bidirectional_protocol; pub mod legacy_protocol; -mod process; +pub mod process; pub mod transport; use paths::{AbsPath, AbsPathBuf}; @@ -44,6 +44,25 @@ pub mod version { pub const CURRENT_API_VERSION: u32 = HASHED_AST_ID; } +#[derive(Copy, Clone)] +pub enum ProtocolFormat { + JsonLegacy, + PostcardLegacy, + BidirectionalPostcardPrototype, +} + +impl fmt::Display for ProtocolFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ProtocolFormat::JsonLegacy => write!(f, "json-legacy"), + ProtocolFormat::PostcardLegacy => write!(f, "postcard-legacy"), + ProtocolFormat::BidirectionalPostcardPrototype => { + write!(f, "bidirectional-postcard-prototype") + } + } + } +} + /// Represents different kinds of procedural macros that can be expanded by the external server. #[derive(Copy, Clone, Eq, PartialEq, Debug, serde_derive::Serialize, serde_derive::Deserialize)] pub enum ProcMacroKind { @@ -132,7 +151,25 @@ impl ProcMacroClient { > + Clone, version: Option<&Version>, ) -> io::Result<ProcMacroClient> { - let process = ProcMacroServerProcess::run(process_path, env, version)?; + let process = ProcMacroServerProcess::spawn(process_path, env, version)?; + Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() }) + } + + /// Invokes `spawn` and returns a client connected to the resulting read and write handles. + /// + /// The `process_path` is used for `Self::server_path`. This function is mainly used for testing. + pub fn with_io_channels( + process_path: &AbsPath, + spawn: impl Fn( + Option<ProtocolFormat>, + ) -> io::Result<( + Box<dyn process::ProcessExit>, + Box<dyn io::Write + Send + Sync>, + Box<dyn io::BufRead + Send + Sync>, + )>, + version: Option<&Version>, + ) -> io::Result<ProcMacroClient> { + let process = ProcMacroServerProcess::run(spawn, version, || "<unknown>".to_owned())?; Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() }) } diff --git a/crates/proc-macro-api/src/process.rs b/crates/proc-macro-api/src/process.rs index f6a656e3ce..4f87621587 100644 --- a/crates/proc-macro-api/src/process.rs +++ b/crates/proc-macro-api/src/process.rs @@ -13,14 +13,13 @@ use span::Span; use stdx::JodChild; use crate::{ - Codec, ProcMacro, ProcMacroKind, ServerError, + Codec, ProcMacro, ProcMacroKind, ProtocolFormat, ServerError, bidirectional_protocol::{self, SubCallback, msg::BidirectionalMessage, reject_subrequests}, legacy_protocol::{self, SpanMode}, version, }; /// Represents a process handling proc-macro communication. -#[derive(Debug)] pub(crate) struct ProcMacroServerProcess { /// The state of the proc-macro server process, the protocol is currently strictly sequential /// hence the lock on the state. @@ -31,6 +30,16 @@ pub(crate) struct ProcMacroServerProcess { exited: OnceLock<AssertUnwindSafe<ServerError>>, } +impl std::fmt::Debug for ProcMacroServerProcess { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProcMacroServerProcess") + .field("version", &self.version) + .field("protocol", &self.protocol) + .field("exited", &self.exited) + .finish() + } +} + #[derive(Debug, Clone)] pub(crate) enum Protocol { LegacyJson { mode: SpanMode }, @@ -38,23 +47,84 @@ pub(crate) enum Protocol { BidirectionalPostcardPrototype { mode: SpanMode }, } +pub trait ProcessExit: Send + Sync { + fn exit_err(&mut self) -> Option<ServerError>; +} + +impl ProcessExit for Process { + fn exit_err(&mut self) -> Option<ServerError> { + match self.child.try_wait() { + Ok(None) | Err(_) => None, + Ok(Some(status)) => { + let mut msg = String::new(); + if !status.success() + && let Some(stderr) = self.child.stderr.as_mut() + { + _ = stderr.read_to_string(&mut msg); + } + Some(ServerError { + message: format!( + "proc-macro server exited with {status}{}{msg}", + if msg.is_empty() { "" } else { ": " } + ), + io: None, + }) + } + } + } +} + /// Maintains the state of the proc-macro server process. -#[derive(Debug)] struct ProcessSrvState { - process: Process, - stdin: ChildStdin, - stdout: BufReader<ChildStdout>, + process: Box<dyn ProcessExit>, + stdin: Box<dyn Write + Send + Sync>, + stdout: Box<dyn BufRead + Send + Sync>, } impl ProcMacroServerProcess { /// Starts the proc-macro server and performs a version check - pub(crate) fn run<'a>( + pub(crate) fn spawn<'a>( process_path: &AbsPath, env: impl IntoIterator< Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>), > + Clone, version: Option<&Version>, ) -> io::Result<ProcMacroServerProcess> { + Self::run( + |format| { + let mut process = Process::run( + process_path, + env.clone(), + format.map(|format| format.to_string()).as_deref(), + )?; + let (stdin, stdout) = process.stdio().expect("couldn't access child stdio"); + + Ok((Box::new(process), Box::new(stdin), Box::new(stdout))) + }, + version, + || { + #[expect(clippy::disallowed_methods)] + Command::new(process_path) + .arg("--version") + .output() + .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_owned()) + .unwrap_or_else(|_| "unknown version".to_owned()) + }, + ) + } + + /// Invokes `spawn` and performs a version check. + pub(crate) fn run( + spawn: impl Fn( + Option<ProtocolFormat>, + ) -> io::Result<( + Box<dyn ProcessExit>, + Box<dyn Write + Send + Sync>, + Box<dyn BufRead + Send + Sync>, + )>, + version: Option<&Version>, + binary_server_version: impl Fn() -> String, + ) -> io::Result<ProcMacroServerProcess> { const VERSION: Version = Version::new(1, 93, 0); // we do `>` for nightly as this started working in the middle of the 1.93 nightly release, so we dont want to break on half of the nightlies let has_working_format_flag = version.map_or(false, |v| { @@ -65,27 +135,33 @@ impl ProcMacroServerProcess { && has_working_format_flag { &[ - ( - Some("bidirectional-postcard-prototype"), - Protocol::BidirectionalPostcardPrototype { mode: SpanMode::Id }, - ), - (Some("postcard-legacy"), Protocol::LegacyPostcard { mode: SpanMode::Id }), - (Some("json-legacy"), Protocol::LegacyJson { mode: SpanMode::Id }), + Some(ProtocolFormat::BidirectionalPostcardPrototype), + Some(ProtocolFormat::PostcardLegacy), + Some(ProtocolFormat::JsonLegacy), ] } else { - &[(None, Protocol::LegacyJson { mode: SpanMode::Id })] + &[None] }; let mut err = None; - for &(format, ref protocol) in formats { + for &format in formats { let create_srv = || { - let mut process = Process::run(process_path, env.clone(), format)?; - let (stdin, stdout) = process.stdio().expect("couldn't access child stdio"); + let (process, stdin, stdout) = spawn(format)?; io::Result::Ok(ProcMacroServerProcess { state: Mutex::new(ProcessSrvState { process, stdin, stdout }), version: 0, - protocol: protocol.clone(), + protocol: match format { + Some(ProtocolFormat::BidirectionalPostcardPrototype) => { + Protocol::BidirectionalPostcardPrototype { mode: SpanMode::Id } + } + Some(ProtocolFormat::PostcardLegacy) => { + Protocol::LegacyPostcard { mode: SpanMode::Id } + } + Some(ProtocolFormat::JsonLegacy) | None => { + Protocol::LegacyJson { mode: SpanMode::Id } + } + }, exited: OnceLock::new(), }) }; @@ -93,12 +169,7 @@ impl ProcMacroServerProcess { tracing::info!("sending proc-macro server version check"); match srv.version_check(Some(&mut reject_subrequests)) { Ok(v) if v > version::CURRENT_API_VERSION => { - #[allow(clippy::disallowed_methods)] - let process_version = Command::new(process_path) - .arg("--version") - .output() - .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_owned()) - .unwrap_or_else(|_| "unknown version".to_owned()); + let process_version = binary_server_version(); err = Some(io::Error::other(format!( "Your installed proc-macro server is too new for your rust-analyzer. API version: {}, server version: {process_version}. \ This will prevent proc-macro expansion from working. Please consider updating your rust-analyzer to ensure compatibility with your current toolchain.", @@ -275,22 +346,9 @@ impl ProcMacroServerProcess { f(&mut state.stdin, &mut state.stdout, &mut buf).map_err(|e| { if e.io.as_ref().map(|it| it.kind()) == Some(io::ErrorKind::BrokenPipe) { - match state.process.child.try_wait() { - Ok(None) | Err(_) => e, - Ok(Some(status)) => { - let mut msg = String::new(); - if !status.success() - && let Some(stderr) = state.process.child.stderr.as_mut() - { - _ = stderr.read_to_string(&mut msg); - } - let server_error = ServerError { - message: format!( - "proc-macro server exited with {status}{}{msg}", - if msg.is_empty() { "" } else { ": " } - ), - io: None, - }; + match state.process.exit_err() { + None => e, + Some(server_error) => { self.exited.get_or_init(|| AssertUnwindSafe(server_error)).0.clone() } } diff --git a/crates/proc-macro-srv-cli/src/main.rs b/crates/proc-macro-srv-cli/src/main.rs index bdfdb50002..189a1eea5c 100644 --- a/crates/proc-macro-srv-cli/src/main.rs +++ b/crates/proc-macro-srv-cli/src/main.rs @@ -45,7 +45,11 @@ fn main() -> std::io::Result<()> { } let &format = matches.get_one::<ProtocolFormat>("format").expect("format value should always be present"); - run(format) + + let mut stdin = std::io::BufReader::new(std::io::stdin()); + let mut stdout = std::io::stdout(); + + run(&mut stdin, &mut stdout, format) } #[derive(Copy, Clone)] @@ -88,7 +92,11 @@ impl ValueEnum for ProtocolFormat { } #[cfg(not(feature = "sysroot-abi"))] -fn run(_: ProtocolFormat) -> std::io::Result<()> { +fn run( + _: &mut std::io::BufReader<std::io::Stdin>, + _: &mut std::io::Stdout, + _: ProtocolFormat, +) -> std::io::Result<()> { Err(std::io::Error::new( std::io::ErrorKind::Unsupported, "proc-macro-srv-cli needs to be compiled with the `sysroot-abi` feature to function" diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs index 22536a4e52..0c651d22b4 100644 --- a/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/crates/proc-macro-srv-cli/src/main_loop.rs @@ -6,7 +6,7 @@ use proc_macro_api::{ transport::codec::{json::JsonProtocol, postcard::PostcardProtocol}, version::CURRENT_API_VERSION, }; -use std::io; +use std::io::{self, BufRead, Write}; use legacy::Message; @@ -32,15 +32,24 @@ impl legacy::SpanTransformer for SpanTrans { } } -pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> { +pub(crate) fn run( + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), + format: ProtocolFormat, +) -> io::Result<()> { match format { - ProtocolFormat::JsonLegacy => run_::<JsonProtocol>(), - ProtocolFormat::PostcardLegacy => run_::<PostcardProtocol>(), - ProtocolFormat::BidirectionalPostcardPrototype => run_new::<PostcardProtocol>(), + ProtocolFormat::JsonLegacy => run_old::<JsonProtocol>(stdin, stdout), + ProtocolFormat::PostcardLegacy => run_old::<PostcardProtocol>(stdin, stdout), + ProtocolFormat::BidirectionalPostcardPrototype => { + run_new::<PostcardProtocol>(stdin, stdout) + } } } -fn run_new<C: Codec>() -> io::Result<()> { +fn run_new<C: Codec>( + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), +) -> io::Result<()> { fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { match kind { proc_macro_srv::ProcMacroKind::CustomDerive => { @@ -52,8 +61,6 @@ fn run_new<C: Codec>() -> io::Result<()> { } let mut buf = C::Buf::default(); - let mut stdin = io::stdin(); - let mut stdout = io::stdout(); let env_snapshot = EnvSnapshot::default(); let srv = proc_macro_srv::ProcMacroSrv::new(&env_snapshot); @@ -61,8 +68,7 @@ fn run_new<C: Codec>() -> io::Result<()> { let mut span_mode = legacy::SpanMode::Id; 'outer: loop { - let req_opt = - bidirectional::BidirectionalMessage::read::<_, C>(&mut stdin.lock(), &mut buf)?; + let req_opt = bidirectional::BidirectionalMessage::read::<C>(stdin, &mut buf)?; let Some(req) = req_opt else { break 'outer; }; @@ -77,22 +83,22 @@ fn run_new<C: Codec>() -> io::Result<()> { .collect() }); - send_response::<C>(&stdout, bidirectional::Response::ListMacros(res))?; + send_response::<C>(stdout, bidirectional::Response::ListMacros(res))?; } bidirectional::Request::ApiVersionCheck {} => { send_response::<C>( - &stdout, + stdout, bidirectional::Response::ApiVersionCheck(CURRENT_API_VERSION), )?; } bidirectional::Request::SetConfig(config) => { span_mode = config.span_mode; - send_response::<C>(&stdout, bidirectional::Response::SetConfig(config))?; + send_response::<C>(stdout, bidirectional::Response::SetConfig(config))?; } bidirectional::Request::ExpandMacro(task) => { - handle_expand::<C>(&srv, &mut stdin, &mut stdout, &mut buf, span_mode, *task)?; + handle_expand::<C>(&srv, stdin, stdout, &mut buf, span_mode, *task)?; } }, _ => continue, @@ -104,8 +110,8 @@ fn run_new<C: Codec>() -> io::Result<()> { fn handle_expand<C: Codec>( srv: &proc_macro_srv::ProcMacroSrv<'_>, - stdin: &io::Stdin, - stdout: &io::Stdout, + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), buf: &mut C::Buf, span_mode: legacy::SpanMode, task: bidirectional::ExpandMacro, @@ -118,7 +124,7 @@ fn handle_expand<C: Codec>( fn handle_expand_id<C: Codec>( srv: &proc_macro_srv::ProcMacroSrv<'_>, - stdout: &io::Stdout, + stdout: &mut dyn Write, task: bidirectional::ExpandMacro, ) -> io::Result<()> { let bidirectional::ExpandMacro { lib, env, current_dir, data } = task; @@ -157,12 +163,12 @@ fn handle_expand_id<C: Codec>( }) .map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default())); - send_response::<C>(&stdout, bidirectional::Response::ExpandMacro(res)) + send_response::<C>(stdout, bidirectional::Response::ExpandMacro(res)) } struct ProcMacroClientHandle<'a, C: Codec> { - stdin: &'a io::Stdin, - stdout: &'a io::Stdout, + stdin: &'a mut (dyn BufRead + Send + Sync), + stdout: &'a mut (dyn Write + Send + Sync), buf: &'a mut C::Buf, } @@ -173,11 +179,11 @@ impl<'a, C: Codec> ProcMacroClientHandle<'a, C> { ) -> Option<bidirectional::BidirectionalMessage> { let msg = bidirectional::BidirectionalMessage::SubRequest(req); - if msg.write::<_, C>(&mut self.stdout.lock()).is_err() { + if msg.write::<C>(&mut *self.stdout).is_err() { return None; } - match bidirectional::BidirectionalMessage::read::<_, C>(&mut self.stdin.lock(), self.buf) { + match bidirectional::BidirectionalMessage::read::<C>(&mut *self.stdin, self.buf) { Ok(Some(msg)) => Some(msg), _ => None, } @@ -238,8 +244,8 @@ impl<C: Codec> proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl fn handle_expand_ra<C: Codec>( srv: &proc_macro_srv::ProcMacroSrv<'_>, - stdin: &io::Stdin, - stdout: &io::Stdout, + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), buf: &mut C::Buf, task: bidirectional::ExpandMacro, ) -> io::Result<()> { @@ -301,10 +307,13 @@ fn handle_expand_ra<C: Codec>( .map(|(tree, span_data_table)| bidirectional::ExpandMacroExtended { tree, span_data_table }) .map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default())); - send_response::<C>(&stdout, bidirectional::Response::ExpandMacroExtended(res)) + send_response::<C>(stdout, bidirectional::Response::ExpandMacroExtended(res)) } -fn run_<C: Codec>() -> io::Result<()> { +fn run_old<C: Codec>( + stdin: &mut (dyn BufRead + Send + Sync), + stdout: &mut (dyn Write + Send + Sync), +) -> io::Result<()> { fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { match kind { proc_macro_srv::ProcMacroKind::CustomDerive => { @@ -316,8 +325,8 @@ fn run_<C: Codec>() -> io::Result<()> { } let mut buf = C::Buf::default(); - let mut read_request = || legacy::Request::read::<_, C>(&mut io::stdin().lock(), &mut buf); - let write_response = |msg: legacy::Response| msg.write::<_, C>(&mut io::stdout().lock()); + let mut read_request = || legacy::Request::read::<C>(stdin, &mut buf); + let mut write_response = |msg: legacy::Response| msg.write::<C>(stdout); let env = EnvSnapshot::default(); let srv = proc_macro_srv::ProcMacroSrv::new(&env); @@ -446,7 +455,10 @@ fn run_<C: Codec>() -> io::Result<()> { Ok(()) } -fn send_response<C: Codec>(stdout: &io::Stdout, resp: bidirectional::Response) -> io::Result<()> { +fn send_response<C: Codec>( + stdout: &mut dyn Write, + resp: bidirectional::Response, +) -> io::Result<()> { let resp = bidirectional::BidirectionalMessage::Response(resp); - resp.write::<_, C>(&mut stdout.lock()) + resp.write::<C>(stdout) } |