//! Bidirectional protocol methods
use std::{
io::{self, BufRead, Write},
sync::Arc,
};
use paths::AbsPath;
use span::Span;
use crate::{
Codec, ProcMacro, ProcMacroKind, ServerError,
bidirectional_protocol::msg::{
BidirectionalMessage, ExpandMacro, ExpandMacroData, ExpnGlobals, Request, Response,
SubRequest, SubResponse,
},
legacy_protocol::{
SpanMode,
msg::{
FlatTree, ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
serialize_span_data_index_map,
},
},
process::ProcMacroServerProcess,
transport::codec::postcard::PostcardProtocol,
version,
};
pub mod msg;
pub type SubCallback<'a> = &'a mut dyn FnMut(SubRequest) -> Result<SubResponse, ServerError>;
pub fn run_conversation<C: Codec>(
writer: &mut dyn Write,
reader: &mut dyn BufRead,
buf: &mut C::Buf,
msg: BidirectionalMessage,
callback: SubCallback<'_>,
) -> Result<BidirectionalMessage, ServerError> {
let encoded = C::encode(&msg).map_err(wrap_encode)?;
C::write(writer, &encoded).map_err(wrap_io("failed to write initial request"))?;
loop {
let maybe_buf = C::read(reader, buf).map_err(wrap_io("failed to read message"))?;
let Some(b) = maybe_buf else {
return Err(ServerError {
message: "proc-macro server closed the stream".into(),
io: Some(Arc::new(io::Error::new(io::ErrorKind::UnexpectedEof, "closed"))),
});
};
let msg: BidirectionalMessage = C::decode(b).map_err(wrap_decode)?;
match msg {
BidirectionalMessage::Response(response) => {
return Ok(BidirectionalMessage::Response(response));
}
BidirectionalMessage::SubRequest(sr) => {
let resp = callback(sr)?;
let reply = BidirectionalMessage::SubResponse(resp);
let encoded = C::encode(&reply).map_err(wrap_encode)?;
C::write(writer, &encoded).map_err(wrap_io("failed to write sub-response"))?;
}
_ => {
return Err(ServerError {
message: format!("unexpected message {:?}", msg),
io: None,
});
}
}
}
}
fn wrap_io(msg: &'static str) -> impl Fn(io::Error) -> ServerError {
move |err| ServerError { message: msg.into(), io: Some(Arc::new(err)) }
}
fn wrap_encode(err: io::Error) -> ServerError {
ServerError { message: "failed to encode message".into(), io: Some(Arc::new(err)) }
}
fn wrap_decode(err: io::Error) -> ServerError {
ServerError { message: "failed to decode message".into(), io: Some(Arc::new(err)) }
}
pub(crate) fn version_check(
srv: &ProcMacroServerProcess,
callback: SubCallback<'_>,
) -> Result<u32, ServerError> {
let request = BidirectionalMessage::Request(Request::ApiVersionCheck {});
let response_payload = run_request(srv, request, callback)?;
match response_payload {
BidirectionalMessage::Response(Response::ApiVersionCheck(version)) => Ok(version),
other => {
Err(ServerError { message: format!("unexpected response: {:?}", other), io: None })
}
}
}
/// Enable support for rust-analyzer span mode if the server supports it.
pub(crate) fn enable_rust_analyzer_spans(
srv: &ProcMacroServerProcess,
callback: SubCallback<'_>,
) -> Result<SpanMode, ServerError> {
let request = BidirectionalMessage::Request(Request::SetConfig(ServerConfig {
span_mode: SpanMode::RustAnalyzer,
}));
let response_payload = run_request(srv, request, callback)?;
match response_payload {
BidirectionalMessage::Response(Response::SetConfig(ServerConfig { span_mode })) => {
Ok(span_mode)
}
_ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
}
}
/// Finds proc-macros in a given dynamic library.
pub(crate) fn find_proc_macros(
srv: &ProcMacroServerProcess,
dylib_path: &AbsPath,
callback: SubCallback<'_>,
) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
let request = BidirectionalMessage::Request(Request::ListMacros {
dylib_path: dylib_path.to_path_buf().into(),
});
let response_payload = run_request(srv, request, callback)?;
match response_payload {
BidirectionalMessage::Response(Response::ListMacros(it)) => Ok(it),
_ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
}
}
pub(crate) fn expand(
proc_macro: &ProcMacro,
subtree: tt::SubtreeView<'_>,
attr: Option<tt::SubtreeView<'_>>,
env: Vec<(String, String)>,
def_site: Span,
call_site: Span,
mixed_site: Span,
current_dir: String,
callback: SubCallback<'_>,
) -> Result<Result<tt::TopSubtree, String>, crate::ServerError> {
let version = proc_macro.process.version();
let mut span_data_table = SpanDataIndexMap::default();
let def_site = span_data_table.insert_full(def_site).0;
let call_site = span_data_table.insert_full(call_site).0;
let mixed_site = span_data_table.insert_full(mixed_site).0;
let task = BidirectionalMessage::Request(Request::ExpandMacro(Box::new(ExpandMacro {
data: ExpandMacroData {
macro_body: FlatTree::from_subtree(subtree, version, &mut span_data_table),
macro_name: proc_macro.name.to_string(),
attributes: attr
.map(|subtree| FlatTree::from_subtree(subtree, version, &mut span_data_table)),
has_global_spans: ExpnGlobals {
serialize: version >= version::HAS_GLOBAL_SPANS,
def_site,
call_site,
mixed_site,
},
span_data_table: if proc_macro.process.rust_analyzer_spans() {
serialize_span_data_index_map(&span_data_table)
} else {
Vec::new()
},
},
lib: proc_macro.dylib_path.to_path_buf().into(),
env,
current_dir: Some(current_dir),
})));
let response_payload = run_request(&proc_macro.process, task, callback)?;
match response_payload {
BidirectionalMessage::Response(Response::ExpandMacro(it)) => Ok(it
.map(|tree| {
let mut expanded = FlatTree::to_subtree_resolved(tree, version, &span_data_table);
if proc_macro.needs_fixup_change() {
proc_macro.change_fixup_to_match_old_server(&mut expanded);
}
expanded
})
.map_err(|msg| msg.0)),
BidirectionalMessage::Response(Response::ExpandMacroExtended(it)) => Ok(it
.map(|resp| {
let mut expanded = FlatTree::to_subtree_resolved(
resp.tree,
version,
&deserialize_span_data_index_map(&resp.span_data_table),
);
if proc_macro.needs_fixup_change() {
proc_macro.change_fixup_to_match_old_server(&mut expanded);
}
expanded
})
.map_err(|msg| msg.0)),
_ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
}
}
fn run_request(
srv: &ProcMacroServerProcess,
msg: BidirectionalMessage,
callback: SubCallback<'_>,
) -> Result<BidirectionalMessage, ServerError> {
if let Some(err) = srv.exited() {
return Err(err.clone());
}
match srv.use_postcard() {
true => srv.run_bidirectional::<PostcardProtocol>(msg, callback),
false => Err(ServerError {
message: "bidirectional messaging does not support JSON".to_owned(),
io: None,
}),
}
}
pub fn reject_subrequests(req: SubRequest) -> Result<SubResponse, ServerError> {
Err(ServerError { message: format!("{req:?} sub-request not supported here"), io: None })
}