Unnamed repository; edit this file 'description' to name the repository.
Abstract proc-macro-srv input and output away
Lukas Wirth 3 months ago
parent 9c0cfc5 · commit 4341269
-rw-r--r--crates/proc-macro-api/src/legacy_protocol.rs4
-rw-r--r--crates/proc-macro-api/src/legacy_protocol/msg.rs4
-rw-r--r--crates/proc-macro-api/src/lib.rs41
-rw-r--r--crates/proc-macro-api/src/process.rs138
-rw-r--r--crates/proc-macro-srv-cli/src/main.rs12
-rw-r--r--crates/proc-macro-srv-cli/src/main_loop.rs72
6 files changed, 193 insertions, 78 deletions
diff --git a/crates/proc-macro-api/src/legacy_protocol.rs b/crates/proc-macro-api/src/legacy_protocol.rs
index 22a7d9868e..4524d1b66b 100644
--- a/crates/proc-macro-api/src/legacy_protocol.rs
+++ b/crates/proc-macro-api/src/legacy_protocol.rs
@@ -162,11 +162,11 @@ fn send_request<P: Codec>(
req: Request,
buf: &mut P::Buf,
) -> Result<Option<Response>, ServerError> {
- req.write::<_, P>(&mut writer).map_err(|err| ServerError {
+ req.write::<P>(&mut writer).map_err(|err| ServerError {
message: "failed to write request".into(),
io: Some(Arc::new(err)),
})?;
- let res = Response::read::<_, P>(&mut reader, buf).map_err(|err| ServerError {
+ let res = Response::read::<P>(&mut reader, buf).map_err(|err| ServerError {
message: "failed to read response".into(),
io: Some(Arc::new(err)),
})?;
diff --git a/crates/proc-macro-api/src/legacy_protocol/msg.rs b/crates/proc-macro-api/src/legacy_protocol/msg.rs
index 4146b619ec..1b65906933 100644
--- a/crates/proc-macro-api/src/legacy_protocol/msg.rs
+++ b/crates/proc-macro-api/src/legacy_protocol/msg.rs
@@ -155,13 +155,13 @@ impl ExpnGlobals {
}
pub trait Message: serde::Serialize + DeserializeOwned {
- fn read<R: BufRead, C: Codec>(inp: &mut R, buf: &mut C::Buf) -> io::Result<Option<Self>> {
+ fn read<C: Codec>(inp: &mut dyn BufRead, buf: &mut C::Buf) -> io::Result<Option<Self>> {
Ok(match C::read(inp, buf)? {
None => None,
Some(buf) => Some(C::decode(buf)?),
})
}
- fn write<W: Write, C: Codec>(self, out: &mut W) -> io::Result<()> {
+ fn write<C: Codec>(self, out: &mut dyn Write) -> io::Result<()> {
let value = C::encode(&self)?;
C::write(out, &value)
}
diff --git a/crates/proc-macro-api/src/lib.rs b/crates/proc-macro-api/src/lib.rs
index f5fcc99f14..98ee6817c2 100644
--- a/crates/proc-macro-api/src/lib.rs
+++ b/crates/proc-macro-api/src/lib.rs
@@ -18,7 +18,7 @@ extern crate rustc_driver as _;
pub mod bidirectional_protocol;
pub mod legacy_protocol;
-mod process;
+pub mod process;
pub mod transport;
use paths::{AbsPath, AbsPathBuf};
@@ -44,6 +44,25 @@ pub mod version {
pub const CURRENT_API_VERSION: u32 = HASHED_AST_ID;
}
+#[derive(Copy, Clone)]
+pub enum ProtocolFormat {
+ JsonLegacy,
+ PostcardLegacy,
+ BidirectionalPostcardPrototype,
+}
+
+impl fmt::Display for ProtocolFormat {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ ProtocolFormat::JsonLegacy => write!(f, "json-legacy"),
+ ProtocolFormat::PostcardLegacy => write!(f, "postcard-legacy"),
+ ProtocolFormat::BidirectionalPostcardPrototype => {
+ write!(f, "bidirectional-postcard-prototype")
+ }
+ }
+ }
+}
+
/// Represents different kinds of procedural macros that can be expanded by the external server.
#[derive(Copy, Clone, Eq, PartialEq, Debug, serde_derive::Serialize, serde_derive::Deserialize)]
pub enum ProcMacroKind {
@@ -132,7 +151,25 @@ impl ProcMacroClient {
> + Clone,
version: Option<&Version>,
) -> io::Result<ProcMacroClient> {
- let process = ProcMacroServerProcess::run(process_path, env, version)?;
+ let process = ProcMacroServerProcess::spawn(process_path, env, version)?;
+ Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() })
+ }
+
+ /// Invokes `spawn` and returns a client connected to the resulting read and write handles.
+ ///
+ /// The `process_path` is used for `Self::server_path`. This function is mainly used for testing.
+ pub fn with_io_channels(
+ process_path: &AbsPath,
+ spawn: impl Fn(
+ Option<ProtocolFormat>,
+ ) -> io::Result<(
+ Box<dyn process::ProcessExit>,
+ Box<dyn io::Write + Send + Sync>,
+ Box<dyn io::BufRead + Send + Sync>,
+ )>,
+ version: Option<&Version>,
+ ) -> io::Result<ProcMacroClient> {
+ let process = ProcMacroServerProcess::run(spawn, version, || "<unknown>".to_owned())?;
Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() })
}
diff --git a/crates/proc-macro-api/src/process.rs b/crates/proc-macro-api/src/process.rs
index f6a656e3ce..4f87621587 100644
--- a/crates/proc-macro-api/src/process.rs
+++ b/crates/proc-macro-api/src/process.rs
@@ -13,14 +13,13 @@ use span::Span;
use stdx::JodChild;
use crate::{
- Codec, ProcMacro, ProcMacroKind, ServerError,
+ Codec, ProcMacro, ProcMacroKind, ProtocolFormat, ServerError,
bidirectional_protocol::{self, SubCallback, msg::BidirectionalMessage, reject_subrequests},
legacy_protocol::{self, SpanMode},
version,
};
/// Represents a process handling proc-macro communication.
-#[derive(Debug)]
pub(crate) struct ProcMacroServerProcess {
/// The state of the proc-macro server process, the protocol is currently strictly sequential
/// hence the lock on the state.
@@ -31,6 +30,16 @@ pub(crate) struct ProcMacroServerProcess {
exited: OnceLock<AssertUnwindSafe<ServerError>>,
}
+impl std::fmt::Debug for ProcMacroServerProcess {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("ProcMacroServerProcess")
+ .field("version", &self.version)
+ .field("protocol", &self.protocol)
+ .field("exited", &self.exited)
+ .finish()
+ }
+}
+
#[derive(Debug, Clone)]
pub(crate) enum Protocol {
LegacyJson { mode: SpanMode },
@@ -38,23 +47,84 @@ pub(crate) enum Protocol {
BidirectionalPostcardPrototype { mode: SpanMode },
}
+pub trait ProcessExit: Send + Sync {
+ fn exit_err(&mut self) -> Option<ServerError>;
+}
+
+impl ProcessExit for Process {
+ fn exit_err(&mut self) -> Option<ServerError> {
+ match self.child.try_wait() {
+ Ok(None) | Err(_) => None,
+ Ok(Some(status)) => {
+ let mut msg = String::new();
+ if !status.success()
+ && let Some(stderr) = self.child.stderr.as_mut()
+ {
+ _ = stderr.read_to_string(&mut msg);
+ }
+ Some(ServerError {
+ message: format!(
+ "proc-macro server exited with {status}{}{msg}",
+ if msg.is_empty() { "" } else { ": " }
+ ),
+ io: None,
+ })
+ }
+ }
+ }
+}
+
/// Maintains the state of the proc-macro server process.
-#[derive(Debug)]
struct ProcessSrvState {
- process: Process,
- stdin: ChildStdin,
- stdout: BufReader<ChildStdout>,
+ process: Box<dyn ProcessExit>,
+ stdin: Box<dyn Write + Send + Sync>,
+ stdout: Box<dyn BufRead + Send + Sync>,
}
impl ProcMacroServerProcess {
/// Starts the proc-macro server and performs a version check
- pub(crate) fn run<'a>(
+ pub(crate) fn spawn<'a>(
process_path: &AbsPath,
env: impl IntoIterator<
Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
> + Clone,
version: Option<&Version>,
) -> io::Result<ProcMacroServerProcess> {
+ Self::run(
+ |format| {
+ let mut process = Process::run(
+ process_path,
+ env.clone(),
+ format.map(|format| format.to_string()).as_deref(),
+ )?;
+ let (stdin, stdout) = process.stdio().expect("couldn't access child stdio");
+
+ Ok((Box::new(process), Box::new(stdin), Box::new(stdout)))
+ },
+ version,
+ || {
+ #[expect(clippy::disallowed_methods)]
+ Command::new(process_path)
+ .arg("--version")
+ .output()
+ .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_owned())
+ .unwrap_or_else(|_| "unknown version".to_owned())
+ },
+ )
+ }
+
+ /// Invokes `spawn` and performs a version check.
+ pub(crate) fn run(
+ spawn: impl Fn(
+ Option<ProtocolFormat>,
+ ) -> io::Result<(
+ Box<dyn ProcessExit>,
+ Box<dyn Write + Send + Sync>,
+ Box<dyn BufRead + Send + Sync>,
+ )>,
+ version: Option<&Version>,
+ binary_server_version: impl Fn() -> String,
+ ) -> io::Result<ProcMacroServerProcess> {
const VERSION: Version = Version::new(1, 93, 0);
// we do `>` for nightly as this started working in the middle of the 1.93 nightly release, so we dont want to break on half of the nightlies
let has_working_format_flag = version.map_or(false, |v| {
@@ -65,27 +135,33 @@ impl ProcMacroServerProcess {
&& has_working_format_flag
{
&[
- (
- Some("bidirectional-postcard-prototype"),
- Protocol::BidirectionalPostcardPrototype { mode: SpanMode::Id },
- ),
- (Some("postcard-legacy"), Protocol::LegacyPostcard { mode: SpanMode::Id }),
- (Some("json-legacy"), Protocol::LegacyJson { mode: SpanMode::Id }),
+ Some(ProtocolFormat::BidirectionalPostcardPrototype),
+ Some(ProtocolFormat::PostcardLegacy),
+ Some(ProtocolFormat::JsonLegacy),
]
} else {
- &[(None, Protocol::LegacyJson { mode: SpanMode::Id })]
+ &[None]
};
let mut err = None;
- for &(format, ref protocol) in formats {
+ for &format in formats {
let create_srv = || {
- let mut process = Process::run(process_path, env.clone(), format)?;
- let (stdin, stdout) = process.stdio().expect("couldn't access child stdio");
+ let (process, stdin, stdout) = spawn(format)?;
io::Result::Ok(ProcMacroServerProcess {
state: Mutex::new(ProcessSrvState { process, stdin, stdout }),
version: 0,
- protocol: protocol.clone(),
+ protocol: match format {
+ Some(ProtocolFormat::BidirectionalPostcardPrototype) => {
+ Protocol::BidirectionalPostcardPrototype { mode: SpanMode::Id }
+ }
+ Some(ProtocolFormat::PostcardLegacy) => {
+ Protocol::LegacyPostcard { mode: SpanMode::Id }
+ }
+ Some(ProtocolFormat::JsonLegacy) | None => {
+ Protocol::LegacyJson { mode: SpanMode::Id }
+ }
+ },
exited: OnceLock::new(),
})
};
@@ -93,12 +169,7 @@ impl ProcMacroServerProcess {
tracing::info!("sending proc-macro server version check");
match srv.version_check(Some(&mut reject_subrequests)) {
Ok(v) if v > version::CURRENT_API_VERSION => {
- #[allow(clippy::disallowed_methods)]
- let process_version = Command::new(process_path)
- .arg("--version")
- .output()
- .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_owned())
- .unwrap_or_else(|_| "unknown version".to_owned());
+ let process_version = binary_server_version();
err = Some(io::Error::other(format!(
"Your installed proc-macro server is too new for your rust-analyzer. API version: {}, server version: {process_version}. \
This will prevent proc-macro expansion from working. Please consider updating your rust-analyzer to ensure compatibility with your current toolchain.",
@@ -275,22 +346,9 @@ impl ProcMacroServerProcess {
f(&mut state.stdin, &mut state.stdout, &mut buf).map_err(|e| {
if e.io.as_ref().map(|it| it.kind()) == Some(io::ErrorKind::BrokenPipe) {
- match state.process.child.try_wait() {
- Ok(None) | Err(_) => e,
- Ok(Some(status)) => {
- let mut msg = String::new();
- if !status.success()
- && let Some(stderr) = state.process.child.stderr.as_mut()
- {
- _ = stderr.read_to_string(&mut msg);
- }
- let server_error = ServerError {
- message: format!(
- "proc-macro server exited with {status}{}{msg}",
- if msg.is_empty() { "" } else { ": " }
- ),
- io: None,
- };
+ match state.process.exit_err() {
+ None => e,
+ Some(server_error) => {
self.exited.get_or_init(|| AssertUnwindSafe(server_error)).0.clone()
}
}
diff --git a/crates/proc-macro-srv-cli/src/main.rs b/crates/proc-macro-srv-cli/src/main.rs
index bdfdb50002..189a1eea5c 100644
--- a/crates/proc-macro-srv-cli/src/main.rs
+++ b/crates/proc-macro-srv-cli/src/main.rs
@@ -45,7 +45,11 @@ fn main() -> std::io::Result<()> {
}
let &format =
matches.get_one::<ProtocolFormat>("format").expect("format value should always be present");
- run(format)
+
+ let mut stdin = std::io::BufReader::new(std::io::stdin());
+ let mut stdout = std::io::stdout();
+
+ run(&mut stdin, &mut stdout, format)
}
#[derive(Copy, Clone)]
@@ -88,7 +92,11 @@ impl ValueEnum for ProtocolFormat {
}
#[cfg(not(feature = "sysroot-abi"))]
-fn run(_: ProtocolFormat) -> std::io::Result<()> {
+fn run(
+ _: &mut std::io::BufReader<std::io::Stdin>,
+ _: &mut std::io::Stdout,
+ _: ProtocolFormat,
+) -> std::io::Result<()> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"proc-macro-srv-cli needs to be compiled with the `sysroot-abi` feature to function"
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs
index 22536a4e52..0c651d22b4 100644
--- a/crates/proc-macro-srv-cli/src/main_loop.rs
+++ b/crates/proc-macro-srv-cli/src/main_loop.rs
@@ -6,7 +6,7 @@ use proc_macro_api::{
transport::codec::{json::JsonProtocol, postcard::PostcardProtocol},
version::CURRENT_API_VERSION,
};
-use std::io;
+use std::io::{self, BufRead, Write};
use legacy::Message;
@@ -32,15 +32,24 @@ impl legacy::SpanTransformer for SpanTrans {
}
}
-pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> {
+pub(crate) fn run(
+ stdin: &mut (dyn BufRead + Send + Sync),
+ stdout: &mut (dyn Write + Send + Sync),
+ format: ProtocolFormat,
+) -> io::Result<()> {
match format {
- ProtocolFormat::JsonLegacy => run_::<JsonProtocol>(),
- ProtocolFormat::PostcardLegacy => run_::<PostcardProtocol>(),
- ProtocolFormat::BidirectionalPostcardPrototype => run_new::<PostcardProtocol>(),
+ ProtocolFormat::JsonLegacy => run_old::<JsonProtocol>(stdin, stdout),
+ ProtocolFormat::PostcardLegacy => run_old::<PostcardProtocol>(stdin, stdout),
+ ProtocolFormat::BidirectionalPostcardPrototype => {
+ run_new::<PostcardProtocol>(stdin, stdout)
+ }
}
}
-fn run_new<C: Codec>() -> io::Result<()> {
+fn run_new<C: Codec>(
+ stdin: &mut (dyn BufRead + Send + Sync),
+ stdout: &mut (dyn Write + Send + Sync),
+) -> io::Result<()> {
fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
match kind {
proc_macro_srv::ProcMacroKind::CustomDerive => {
@@ -52,8 +61,6 @@ fn run_new<C: Codec>() -> io::Result<()> {
}
let mut buf = C::Buf::default();
- let mut stdin = io::stdin();
- let mut stdout = io::stdout();
let env_snapshot = EnvSnapshot::default();
let srv = proc_macro_srv::ProcMacroSrv::new(&env_snapshot);
@@ -61,8 +68,7 @@ fn run_new<C: Codec>() -> io::Result<()> {
let mut span_mode = legacy::SpanMode::Id;
'outer: loop {
- let req_opt =
- bidirectional::BidirectionalMessage::read::<_, C>(&mut stdin.lock(), &mut buf)?;
+ let req_opt = bidirectional::BidirectionalMessage::read::<C>(stdin, &mut buf)?;
let Some(req) = req_opt else {
break 'outer;
};
@@ -77,22 +83,22 @@ fn run_new<C: Codec>() -> io::Result<()> {
.collect()
});
- send_response::<C>(&stdout, bidirectional::Response::ListMacros(res))?;
+ send_response::<C>(stdout, bidirectional::Response::ListMacros(res))?;
}
bidirectional::Request::ApiVersionCheck {} => {
send_response::<C>(
- &stdout,
+ stdout,
bidirectional::Response::ApiVersionCheck(CURRENT_API_VERSION),
)?;
}
bidirectional::Request::SetConfig(config) => {
span_mode = config.span_mode;
- send_response::<C>(&stdout, bidirectional::Response::SetConfig(config))?;
+ send_response::<C>(stdout, bidirectional::Response::SetConfig(config))?;
}
bidirectional::Request::ExpandMacro(task) => {
- handle_expand::<C>(&srv, &mut stdin, &mut stdout, &mut buf, span_mode, *task)?;
+ handle_expand::<C>(&srv, stdin, stdout, &mut buf, span_mode, *task)?;
}
},
_ => continue,
@@ -104,8 +110,8 @@ fn run_new<C: Codec>() -> io::Result<()> {
fn handle_expand<C: Codec>(
srv: &proc_macro_srv::ProcMacroSrv<'_>,
- stdin: &io::Stdin,
- stdout: &io::Stdout,
+ stdin: &mut (dyn BufRead + Send + Sync),
+ stdout: &mut (dyn Write + Send + Sync),
buf: &mut C::Buf,
span_mode: legacy::SpanMode,
task: bidirectional::ExpandMacro,
@@ -118,7 +124,7 @@ fn handle_expand<C: Codec>(
fn handle_expand_id<C: Codec>(
srv: &proc_macro_srv::ProcMacroSrv<'_>,
- stdout: &io::Stdout,
+ stdout: &mut dyn Write,
task: bidirectional::ExpandMacro,
) -> io::Result<()> {
let bidirectional::ExpandMacro { lib, env, current_dir, data } = task;
@@ -157,12 +163,12 @@ fn handle_expand_id<C: Codec>(
})
.map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default()));
- send_response::<C>(&stdout, bidirectional::Response::ExpandMacro(res))
+ send_response::<C>(stdout, bidirectional::Response::ExpandMacro(res))
}
struct ProcMacroClientHandle<'a, C: Codec> {
- stdin: &'a io::Stdin,
- stdout: &'a io::Stdout,
+ stdin: &'a mut (dyn BufRead + Send + Sync),
+ stdout: &'a mut (dyn Write + Send + Sync),
buf: &'a mut C::Buf,
}
@@ -173,11 +179,11 @@ impl<'a, C: Codec> ProcMacroClientHandle<'a, C> {
) -> Option<bidirectional::BidirectionalMessage> {
let msg = bidirectional::BidirectionalMessage::SubRequest(req);
- if msg.write::<_, C>(&mut self.stdout.lock()).is_err() {
+ if msg.write::<C>(&mut *self.stdout).is_err() {
return None;
}
- match bidirectional::BidirectionalMessage::read::<_, C>(&mut self.stdin.lock(), self.buf) {
+ match bidirectional::BidirectionalMessage::read::<C>(&mut *self.stdin, self.buf) {
Ok(Some(msg)) => Some(msg),
_ => None,
}
@@ -238,8 +244,8 @@ impl<C: Codec> proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl
fn handle_expand_ra<C: Codec>(
srv: &proc_macro_srv::ProcMacroSrv<'_>,
- stdin: &io::Stdin,
- stdout: &io::Stdout,
+ stdin: &mut (dyn BufRead + Send + Sync),
+ stdout: &mut (dyn Write + Send + Sync),
buf: &mut C::Buf,
task: bidirectional::ExpandMacro,
) -> io::Result<()> {
@@ -301,10 +307,13 @@ fn handle_expand_ra<C: Codec>(
.map(|(tree, span_data_table)| bidirectional::ExpandMacroExtended { tree, span_data_table })
.map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default()));
- send_response::<C>(&stdout, bidirectional::Response::ExpandMacroExtended(res))
+ send_response::<C>(stdout, bidirectional::Response::ExpandMacroExtended(res))
}
-fn run_<C: Codec>() -> io::Result<()> {
+fn run_old<C: Codec>(
+ stdin: &mut (dyn BufRead + Send + Sync),
+ stdout: &mut (dyn Write + Send + Sync),
+) -> io::Result<()> {
fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
match kind {
proc_macro_srv::ProcMacroKind::CustomDerive => {
@@ -316,8 +325,8 @@ fn run_<C: Codec>() -> io::Result<()> {
}
let mut buf = C::Buf::default();
- let mut read_request = || legacy::Request::read::<_, C>(&mut io::stdin().lock(), &mut buf);
- let write_response = |msg: legacy::Response| msg.write::<_, C>(&mut io::stdout().lock());
+ let mut read_request = || legacy::Request::read::<C>(stdin, &mut buf);
+ let mut write_response = |msg: legacy::Response| msg.write::<C>(stdout);
let env = EnvSnapshot::default();
let srv = proc_macro_srv::ProcMacroSrv::new(&env);
@@ -446,7 +455,10 @@ fn run_<C: Codec>() -> io::Result<()> {
Ok(())
}
-fn send_response<C: Codec>(stdout: &io::Stdout, resp: bidirectional::Response) -> io::Result<()> {
+fn send_response<C: Codec>(
+ stdout: &mut dyn Write,
+ resp: bidirectional::Response,
+) -> io::Result<()> {
let resp = bidirectional::BidirectionalMessage::Response(resp);
- resp.write::<_, C>(&mut stdout.lock())
+ resp.write::<C>(stdout)
}