Unnamed repository; edit this file 'description' to name the repository.
add codec and framing to abstract encoding and decoding logic from run
| -rw-r--r-- | crates/proc-macro-api/src/codec.rs | 12 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/framing.rs | 14 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/legacy_protocol.rs | 33 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/legacy_protocol/json.rs | 74 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/legacy_protocol/msg.rs | 56 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/legacy_protocol/postcard.rs | 49 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/lib.rs | 11 | ||||
| -rw-r--r-- | crates/proc-macro-api/src/process.rs | 13 | ||||
| -rw-r--r-- | crates/proc-macro-srv-cli/Cargo.toml | 2 | ||||
| -rw-r--r-- | crates/proc-macro-srv-cli/src/main_loop.rs | 148 |
10 files changed, 142 insertions, 270 deletions
diff --git a/crates/proc-macro-api/src/codec.rs b/crates/proc-macro-api/src/codec.rs new file mode 100644 index 0000000000..baccaa6be4 --- /dev/null +++ b/crates/proc-macro-api/src/codec.rs @@ -0,0 +1,12 @@ +//! Protocol codec + +use std::io; + +use serde::de::DeserializeOwned; + +use crate::framing::Framing; + +pub trait Codec: Framing { + fn encode<T: serde::Serialize>(msg: &T) -> io::Result<Self::Buf>; + fn decode<T: DeserializeOwned>(buf: &mut Self::Buf) -> io::Result<T>; +} diff --git a/crates/proc-macro-api/src/framing.rs b/crates/proc-macro-api/src/framing.rs new file mode 100644 index 0000000000..a1e6fc05ca --- /dev/null +++ b/crates/proc-macro-api/src/framing.rs @@ -0,0 +1,14 @@ +//! Protocol framing + +use std::io::{self, BufRead, Write}; + +pub trait Framing { + type Buf: Default; + + fn read<'a, R: BufRead>( + inp: &mut R, + buf: &'a mut Self::Buf, + ) -> io::Result<Option<&'a mut Self::Buf>>; + + fn write<W: Write>(out: &mut W, buf: &Self::Buf) -> io::Result<()>; +} diff --git a/crates/proc-macro-api/src/legacy_protocol.rs b/crates/proc-macro-api/src/legacy_protocol.rs index 6d521d00cd..c2b132ddcc 100644 --- a/crates/proc-macro-api/src/legacy_protocol.rs +++ b/crates/proc-macro-api/src/legacy_protocol.rs @@ -14,14 +14,15 @@ use span::Span; use crate::{ ProcMacro, ProcMacroKind, ServerError, + codec::Codec, legacy_protocol::{ - json::{read_json, write_json}, + json::JsonProtocol, msg::{ ExpandMacro, ExpandMacroData, ExpnGlobals, FlatTree, Message, Request, Response, ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map, flat::serialize_span_data_index_map, }, - postcard::{read_postcard, write_postcard}, + postcard::PostcardProtocol, }, process::ProcMacroServerProcess, version, @@ -154,42 +155,26 @@ fn send_task(srv: &ProcMacroServerProcess, req: Request) -> Result<Response, Ser } if srv.use_postcard() { - srv.send_task(send_request_postcard, req) + srv.send_task(send_request::<PostcardProtocol>, req) } else { - srv.send_task(send_request, req) + srv.send_task(send_request::<JsonProtocol>, req) } } /// Sends a request to the server and reads the response. -fn send_request( +fn send_request<P: Codec>( mut writer: &mut dyn Write, mut reader: &mut dyn BufRead, req: Request, - buf: &mut String, + buf: &mut P::Buf, ) -> Result<Option<Response>, ServerError> { - req.write(write_json, &mut writer).map_err(|err| ServerError { + req.write::<_, P>(&mut writer).map_err(|err| ServerError { message: "failed to write request".into(), io: Some(Arc::new(err)), })?; - let res = Response::read(read_json, &mut reader, buf).map_err(|err| ServerError { + let res = Response::read::<_, P>(&mut reader, buf).map_err(|err| ServerError { message: "failed to read response".into(), io: Some(Arc::new(err)), })?; Ok(res) } - -fn send_request_postcard( - mut writer: &mut dyn Write, - mut reader: &mut dyn BufRead, - req: Request, - buf: &mut Vec<u8>, -) -> Result<Option<Response>, ServerError> { - req.write_postcard(write_postcard, &mut writer).map_err(|err| ServerError { - message: "failed to write request".into(), - io: Some(Arc::new(err)), - })?; - let res = Response::read_postcard(read_postcard, &mut reader, buf).map_err(|err| { - ServerError { message: "failed to read response".into(), io: Some(Arc::new(err)) } - })?; - Ok(res) -} diff --git a/crates/proc-macro-api/src/legacy_protocol/json.rs b/crates/proc-macro-api/src/legacy_protocol/json.rs index cf8535f77d..1359c05684 100644 --- a/crates/proc-macro-api/src/legacy_protocol/json.rs +++ b/crates/proc-macro-api/src/legacy_protocol/json.rs @@ -1,36 +1,58 @@ //! Protocol functions for json. use std::io::{self, BufRead, Write}; -/// Reads a JSON message from the input stream. -pub fn read_json<'a>( - inp: &mut impl BufRead, - buf: &'a mut String, -) -> io::Result<Option<&'a mut String>> { - loop { - buf.clear(); - - inp.read_line(buf)?; - buf.pop(); // Remove trailing '\n' - - if buf.is_empty() { - return Ok(None); - } +use serde::{Serialize, de::DeserializeOwned}; + +use crate::{codec::Codec, framing::Framing}; + +pub struct JsonProtocol; + +impl Framing for JsonProtocol { + type Buf = String; + + fn read<'a, R: BufRead>( + inp: &mut R, + buf: &'a mut String, + ) -> io::Result<Option<&'a mut String>> { + loop { + buf.clear(); + + inp.read_line(buf)?; + buf.pop(); // Remove trailing '\n' - // Some ill behaved macro try to use stdout for debugging - // We ignore it here - if !buf.starts_with('{') { - tracing::error!("proc-macro tried to print : {}", buf); - continue; + if buf.is_empty() { + return Ok(None); + } + + // Some ill behaved macro try to use stdout for debugging + // We ignore it here + if !buf.starts_with('{') { + tracing::error!("proc-macro tried to print : {}", buf); + continue; + } + + return Ok(Some(buf)); } + } - return Ok(Some(buf)); + fn write<W: Write>(out: &mut W, buf: &String) -> io::Result<()> { + tracing::debug!("> {}", buf); + out.write_all(buf.as_bytes())?; + out.write_all(b"\n")?; + out.flush() } } -/// Writes a JSON message to the output stream. -pub fn write_json(out: &mut impl Write, msg: &String) -> io::Result<()> { - tracing::debug!("> {}", msg); - out.write_all(msg.as_bytes())?; - out.write_all(b"\n")?; - out.flush() +impl Codec for JsonProtocol { + fn encode<T: Serialize>(msg: &T) -> io::Result<String> { + Ok(serde_json::to_string(msg)?) + } + + fn decode<T: DeserializeOwned>(buf: &mut String) -> io::Result<T> { + let mut deserializer = serde_json::Deserializer::from_str(buf); + // Note that some proc-macro generate very deep syntax tree + // We have to disable the current limit of serde here + deserializer.disable_recursion_limit(); + Ok(T::deserialize(&mut deserializer)?) + } } diff --git a/crates/proc-macro-api/src/legacy_protocol/msg.rs b/crates/proc-macro-api/src/legacy_protocol/msg.rs index 6df184630d..1c77863aac 100644 --- a/crates/proc-macro-api/src/legacy_protocol/msg.rs +++ b/crates/proc-macro-api/src/legacy_protocol/msg.rs @@ -8,10 +8,7 @@ use paths::Utf8PathBuf; use serde::de::DeserializeOwned; use serde_derive::{Deserialize, Serialize}; -use crate::{ - ProcMacroKind, - legacy_protocol::postcard::{decode_cobs, encode_cobs}, -}; +use crate::{ProcMacroKind, codec::Codec}; /// Represents requests sent from the client to the proc-macro-srv. #[derive(Debug, Serialize, Deserialize)] @@ -152,60 +149,21 @@ impl ExpnGlobals { } pub trait Message: serde::Serialize + DeserializeOwned { - fn read<R: BufRead>( - from_proto: ProtocolRead<R, String>, - inp: &mut R, - buf: &mut String, - ) -> io::Result<Option<Self>> { - Ok(match from_proto(inp, buf)? { + fn read<R: BufRead, C: Codec>(inp: &mut R, buf: &mut C::Buf) -> io::Result<Option<Self>> { + Ok(match C::read(inp, buf)? { None => None, - Some(text) => { - let mut deserializer = serde_json::Deserializer::from_str(text); - // Note that some proc-macro generate very deep syntax tree - // We have to disable the current limit of serde here - deserializer.disable_recursion_limit(); - Some(Self::deserialize(&mut deserializer)?) - } - }) - } - fn write<W: Write>(self, to_proto: ProtocolWrite<W, String>, out: &mut W) -> io::Result<()> { - let text = serde_json::to_string(&self)?; - to_proto(out, &text) - } - - fn read_postcard<R: BufRead>( - from_proto: ProtocolRead<R, Vec<u8>>, - inp: &mut R, - buf: &mut Vec<u8>, - ) -> io::Result<Option<Self>> { - Ok(match from_proto(inp, buf)? { - None => None, - Some(buf) => Some(decode_cobs(buf)?), + Some(buf) => C::decode(buf)?, }) } - - fn write_postcard<W: Write>( - self, - to_proto: ProtocolWrite<W, Vec<u8>>, - out: &mut W, - ) -> io::Result<()> { - let buf = encode_cobs(&self)?; - to_proto(out, &buf) + fn write<W: Write, C: Codec>(self, out: &mut W) -> io::Result<()> { + let value = C::encode(&self)?; + C::write(out, &value) } } impl Message for Request {} impl Message for Response {} -/// Type alias for a function that reads protocol messages from a buffered input stream. -#[allow(type_alias_bounds)] -type ProtocolRead<R: BufRead, Buf> = - for<'i, 'buf> fn(inp: &'i mut R, buf: &'buf mut Buf) -> io::Result<Option<&'buf mut Buf>>; -/// Type alias for a function that writes protocol messages to an output stream. -#[allow(type_alias_bounds)] -type ProtocolWrite<W: Write, Buf> = - for<'o, 'msg> fn(out: &'o mut W, msg: &'msg Buf) -> io::Result<()>; - #[cfg(test)] mod tests { use intern::{Symbol, sym}; diff --git a/crates/proc-macro-api/src/legacy_protocol/postcard.rs b/crates/proc-macro-api/src/legacy_protocol/postcard.rs index 305e4de934..c28a9bfe3a 100644 --- a/crates/proc-macro-api/src/legacy_protocol/postcard.rs +++ b/crates/proc-macro-api/src/legacy_protocol/postcard.rs @@ -2,28 +2,39 @@ use std::io::{self, BufRead, Write}; -pub fn read_postcard<'a>( - input: &mut impl BufRead, - buf: &'a mut Vec<u8>, -) -> io::Result<Option<&'a mut Vec<u8>>> { - buf.clear(); - let n = input.read_until(0, buf)?; - if n == 0 { - return Ok(None); +use serde::{Serialize, de::DeserializeOwned}; + +use crate::{codec::Codec, framing::Framing}; + +pub struct PostcardProtocol; + +impl Framing for PostcardProtocol { + type Buf = Vec<u8>; + + fn read<'a, R: BufRead>( + inp: &mut R, + buf: &'a mut Vec<u8>, + ) -> io::Result<Option<&'a mut Vec<u8>>> { + buf.clear(); + let n = inp.read_until(0, buf)?; + if n == 0 { + return Ok(None); + } + Ok(Some(buf)) } - Ok(Some(buf)) -} -#[allow(clippy::ptr_arg)] -pub fn write_postcard(out: &mut impl Write, msg: &Vec<u8>) -> io::Result<()> { - out.write_all(msg)?; - out.flush() + fn write<W: Write>(out: &mut W, buf: &Vec<u8>) -> io::Result<()> { + out.write_all(buf)?; + out.flush() + } } -pub fn encode_cobs<T: serde::Serialize>(value: &T) -> io::Result<Vec<u8>> { - postcard::to_allocvec_cobs(value).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) -} +impl Codec for PostcardProtocol { + fn encode<T: Serialize>(msg: &T) -> io::Result<Vec<u8>> { + postcard::to_allocvec_cobs(msg).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + } -pub fn decode_cobs<T: serde::de::DeserializeOwned>(bytes: &mut [u8]) -> io::Result<T> { - postcard::from_bytes_cobs(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + fn decode<T: DeserializeOwned>(buf: &mut Self::Buf) -> io::Result<T> { + postcard::from_bytes_cobs(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + } } diff --git a/crates/proc-macro-api/src/lib.rs b/crates/proc-macro-api/src/lib.rs index 2cdb33ff81..a725b94f04 100644 --- a/crates/proc-macro-api/src/lib.rs +++ b/crates/proc-macro-api/src/lib.rs @@ -12,6 +12,8 @@ )] #![allow(internal_features)] +mod codec; +mod framing; pub mod legacy_protocol; mod process; @@ -19,7 +21,8 @@ use paths::{AbsPath, AbsPathBuf}; use span::{ErasedFileAstId, FIXUP_ERASED_FILE_AST_ID_MARKER, Span}; use std::{fmt, io, sync::Arc, time::SystemTime}; -use crate::process::ProcMacroServerProcess; +pub use crate::codec::Codec; +use crate::{legacy_protocol::SpanMode, process::ProcMacroServerProcess}; /// The versions of the server protocol pub mod version { @@ -123,7 +126,11 @@ impl ProcMacroClient { Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>), > + Clone, ) -> io::Result<ProcMacroClient> { - let process = ProcMacroServerProcess::run(process_path, env, process::Protocol::default())?; + let process = ProcMacroServerProcess::run( + process_path, + env, + process::Protocol::Postcard { mode: SpanMode::Id }, + )?; Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() }) } diff --git a/crates/proc-macro-api/src/process.rs b/crates/proc-macro-api/src/process.rs index 7f0cd05c80..1365245f98 100644 --- a/crates/proc-macro-api/src/process.rs +++ b/crates/proc-macro-api/src/process.rs @@ -34,12 +34,6 @@ pub(crate) enum Protocol { Postcard { mode: SpanMode }, } -impl Default for Protocol { - fn default() -> Self { - Protocol::Postcard { mode: SpanMode::Id } - } -} - /// Maintains the state of the proc-macro server process. #[derive(Debug)] struct ProcessSrvState { @@ -122,11 +116,10 @@ impl ProcMacroServerProcess { srv.version = version; if version >= version::RUST_ANALYZER_SPAN_SUPPORT - && let Ok(mode) = srv.enable_rust_analyzer_spans() + && let Ok(new_mode) = srv.enable_rust_analyzer_spans() { - srv.protocol = match protocol { - Protocol::Postcard { .. } => Protocol::Postcard { mode }, - Protocol::LegacyJson { .. } => Protocol::LegacyJson { mode }, + match &mut srv.protocol { + Protocol::Postcard { mode } | Protocol::LegacyJson { mode } => *mode = new_mode, }; } diff --git a/crates/proc-macro-srv-cli/Cargo.toml b/crates/proc-macro-srv-cli/Cargo.toml index f6022cf2c7..aa153897fa 100644 --- a/crates/proc-macro-srv-cli/Cargo.toml +++ b/crates/proc-macro-srv-cli/Cargo.toml @@ -18,7 +18,7 @@ postcard.workspace = true clap = {version = "4.5.42", default-features = false, features = ["std"]} [features] -default = ["postcard"] +default = [] sysroot-abi = ["proc-macro-srv/sysroot-abi", "proc-macro-api/sysroot-abi"] in-rust-tree = ["proc-macro-srv/in-rust-tree", "sysroot-abi"] diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs index b0e7108d20..029ab6eca9 100644 --- a/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/crates/proc-macro-srv-cli/src/main_loop.rs @@ -2,13 +2,14 @@ use std::io; use proc_macro_api::{ + Codec, legacy_protocol::{ - json::{read_json, write_json}, + json::JsonProtocol, msg::{ self, ExpandMacroData, ExpnGlobals, Message, SpanMode, SpanTransformer, deserialize_span_data_index_map, serialize_span_data_index_map, }, - postcard::{read_postcard, write_postcard}, + postcard::PostcardProtocol, }, version::CURRENT_API_VERSION, }; @@ -36,12 +37,12 @@ impl SpanTransformer for SpanTrans { pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> { match format { - ProtocolFormat::Json => run_json(), - ProtocolFormat::Postcard => run_postcard(), + ProtocolFormat::Json => run_::<JsonProtocol>(), + ProtocolFormat::Postcard => run_::<PostcardProtocol>(), } } -fn run_json() -> io::Result<()> { +fn run_<C: Codec>() -> io::Result<()> { fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { match kind { proc_macro_srv::ProcMacroKind::CustomDerive => { @@ -52,9 +53,9 @@ fn run_json() -> io::Result<()> { } } - let mut buf = String::new(); - let mut read_request = || msg::Request::read(read_json, &mut io::stdin().lock(), &mut buf); - let write_response = |msg: msg::Response| msg.write(write_json, &mut io::stdout().lock()); + let mut buf = C::Buf::default(); + let mut read_request = || msg::Request::read::<_, C>(&mut io::stdin().lock(), &mut buf); + let write_response = |msg: msg::Response| msg.write::<_, C>(&mut io::stdout().lock()); let env = EnvSnapshot::default(); let srv = proc_macro_srv::ProcMacroSrv::new(&env); @@ -170,134 +171,3 @@ fn run_json() -> io::Result<()> { Ok(()) } - -fn run_postcard() -> io::Result<()> { - fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { - match kind { - proc_macro_srv::ProcMacroKind::CustomDerive => { - proc_macro_api::ProcMacroKind::CustomDerive - } - proc_macro_srv::ProcMacroKind::Bang => proc_macro_api::ProcMacroKind::Bang, - proc_macro_srv::ProcMacroKind::Attr => proc_macro_api::ProcMacroKind::Attr, - } - } - - let mut buf = Vec::new(); - let mut read_request = - || msg::Request::read_postcard(read_postcard, &mut io::stdin().lock(), &mut buf); - let write_response = - |msg: msg::Response| msg.write_postcard(write_postcard, &mut io::stdout().lock()); - - let env = proc_macro_srv::EnvSnapshot::default(); - let srv = proc_macro_srv::ProcMacroSrv::new(&env); - - let mut span_mode = msg::SpanMode::Id; - - while let Some(req) = read_request()? { - let res = match req { - msg::Request::ListMacros { dylib_path } => { - msg::Response::ListMacros(srv.list_macros(&dylib_path).map(|macros| { - macros.into_iter().map(|(name, kind)| (name, macro_kind_to_api(kind))).collect() - })) - } - msg::Request::ExpandMacro(task) => { - let msg::ExpandMacro { - lib, - env, - current_dir, - data: - msg::ExpandMacroData { - macro_body, - macro_name, - attributes, - has_global_spans: - msg::ExpnGlobals { serialize: _, def_site, call_site, mixed_site }, - span_data_table, - }, - } = *task; - match span_mode { - msg::SpanMode::Id => msg::Response::ExpandMacro({ - let def_site = proc_macro_srv::SpanId(def_site as u32); - let call_site = proc_macro_srv::SpanId(call_site as u32); - let mixed_site = proc_macro_srv::SpanId(mixed_site as u32); - - let macro_body = - macro_body.to_subtree_unresolved::<SpanTrans>(CURRENT_API_VERSION); - let attributes = attributes - .map(|it| it.to_subtree_unresolved::<SpanTrans>(CURRENT_API_VERSION)); - - srv.expand( - lib, - &env, - current_dir, - ¯o_name, - macro_body, - attributes, - def_site, - call_site, - mixed_site, - ) - .map(|it| { - msg::FlatTree::new_raw::<SpanTrans>( - tt::SubtreeView::new(&it), - CURRENT_API_VERSION, - ) - }) - .map_err(|e| e.into_string().unwrap_or_default()) - .map_err(msg::PanicMessage) - }), - msg::SpanMode::RustAnalyzer => msg::Response::ExpandMacroExtended({ - let mut span_data_table = - msg::deserialize_span_data_index_map(&span_data_table); - - let def_site = span_data_table[def_site]; - let call_site = span_data_table[call_site]; - let mixed_site = span_data_table[mixed_site]; - - let macro_body = - macro_body.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table); - let attributes = attributes.map(|it| { - it.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table) - }); - srv.expand( - lib, - &env, - current_dir, - ¯o_name, - macro_body, - attributes, - def_site, - call_site, - mixed_site, - ) - .map(|it| { - ( - msg::FlatTree::new( - tt::SubtreeView::new(&it), - CURRENT_API_VERSION, - &mut span_data_table, - ), - msg::serialize_span_data_index_map(&span_data_table), - ) - }) - .map(|(tree, span_data_table)| msg::ExpandMacroExtended { - tree, - span_data_table, - }) - .map_err(|e| e.into_string().unwrap_or_default()) - .map_err(msg::PanicMessage) - }), - } - } - msg::Request::ApiVersionCheck {} => msg::Response::ApiVersionCheck(CURRENT_API_VERSION), - msg::Request::SetConfig(config) => { - span_mode = config.span_mode; - msg::Response::SetConfig(config) - } - }; - - write_response(res)?; - } - - Ok(()) -} |