Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/proc-macro-api/src/bidirectional_protocol.rs')
| -rw-r--r-- | crates/proc-macro-api/src/bidirectional_protocol.rs | 85 |
1 files changed, 22 insertions, 63 deletions
diff --git a/crates/proc-macro-api/src/bidirectional_protocol.rs b/crates/proc-macro-api/src/bidirectional_protocol.rs index 4cb6a1d90f..bd74738bbd 100644 --- a/crates/proc-macro-api/src/bidirectional_protocol.rs +++ b/crates/proc-macro-api/src/bidirectional_protocol.rs @@ -5,9 +5,8 @@ use std::{ sync::Arc, }; -use base_db::SourceDatabase; use paths::AbsPath; -use span::{FileId, Span}; +use span::Span; use crate::{ Codec, ProcMacro, ProcMacroKind, ServerError, @@ -29,16 +28,14 @@ use crate::{ pub mod msg; -pub trait ClientCallbacks { - fn handle_sub_request(&mut self, req: SubRequest) -> Result<SubResponse, ServerError>; -} +pub type SubCallback<'a> = &'a mut dyn FnMut(SubRequest) -> Result<SubResponse, ServerError>; pub fn run_conversation<C: Codec>( writer: &mut dyn Write, reader: &mut dyn BufRead, buf: &mut C::Buf, msg: BidirectionalMessage, - callbacks: &mut dyn ClientCallbacks, + callback: SubCallback<'_>, ) -> Result<BidirectionalMessage, ServerError> { let encoded = C::encode(&msg).map_err(wrap_encode)?; C::write(writer, &encoded).map_err(wrap_io("failed to write initial request"))?; @@ -59,7 +56,7 @@ pub fn run_conversation<C: Codec>( return Ok(BidirectionalMessage::Response(response)); } BidirectionalMessage::SubRequest(sr) => { - let resp = callbacks.handle_sub_request(sr)?; + let resp = callback(sr)?; let reply = BidirectionalMessage::SubResponse(resp); let encoded = C::encode(&reply).map_err(wrap_encode)?; C::write(writer, &encoded).map_err(wrap_io("failed to write sub-response"))?; @@ -86,19 +83,13 @@ fn wrap_decode(err: io::Error) -> ServerError { ServerError { message: "failed to decode message".into(), io: Some(Arc::new(err)) } } -pub(crate) fn version_check(srv: &ProcMacroServerProcess) -> Result<u32, ServerError> { +pub(crate) fn version_check( + srv: &ProcMacroServerProcess, + callback: SubCallback<'_>, +) -> Result<u32, ServerError> { let request = BidirectionalMessage::Request(Request::ApiVersionCheck {}); - struct NoCallbacks; - impl ClientCallbacks for NoCallbacks { - fn handle_sub_request(&mut self, _req: SubRequest) -> Result<SubResponse, ServerError> { - Err(ServerError { message: "sub-request not supported here".into(), io: None }) - } - } - - let mut callbacks = NoCallbacks; - - let response_payload = run_request(srv, request, &mut callbacks)?; + let response_payload = run_request(srv, request, callback)?; match response_payload { BidirectionalMessage::Response(Response::ApiVersionCheck(version)) => Ok(version), @@ -111,21 +102,13 @@ pub(crate) fn version_check(srv: &ProcMacroServerProcess) -> Result<u32, ServerE /// Enable support for rust-analyzer span mode if the server supports it. pub(crate) fn enable_rust_analyzer_spans( srv: &ProcMacroServerProcess, + callback: SubCallback<'_>, ) -> Result<SpanMode, ServerError> { let request = BidirectionalMessage::Request(Request::SetConfig(ServerConfig { span_mode: SpanMode::RustAnalyzer, })); - struct NoCallbacks; - impl ClientCallbacks for NoCallbacks { - fn handle_sub_request(&mut self, _req: SubRequest) -> Result<SubResponse, ServerError> { - Err(ServerError { message: "sub-request not supported here".into(), io: None }) - } - } - - let mut callbacks = NoCallbacks; - - let response_payload = run_request(srv, request, &mut callbacks)?; + let response_payload = run_request(srv, request, callback)?; match response_payload { BidirectionalMessage::Response(Response::SetConfig(ServerConfig { span_mode })) => { @@ -139,21 +122,13 @@ pub(crate) fn enable_rust_analyzer_spans( pub(crate) fn find_proc_macros( srv: &ProcMacroServerProcess, dylib_path: &AbsPath, + callback: SubCallback<'_>, ) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> { let request = BidirectionalMessage::Request(Request::ListMacros { dylib_path: dylib_path.to_path_buf().into(), }); - struct NoCallbacks; - impl ClientCallbacks for NoCallbacks { - fn handle_sub_request(&mut self, _req: SubRequest) -> Result<SubResponse, ServerError> { - Err(ServerError { message: "sub-request not supported here".into(), io: None }) - } - } - - let mut callbacks = NoCallbacks; - - let response_payload = run_request(srv, request, &mut callbacks)?; + let response_payload = run_request(srv, request, callback)?; match response_payload { BidirectionalMessage::Response(Response::ListMacros(it)) => Ok(it), @@ -163,7 +138,6 @@ pub(crate) fn find_proc_macros( pub(crate) fn expand( proc_macro: &ProcMacro, - db: &dyn SourceDatabase, subtree: tt::SubtreeView<'_, Span>, attr: Option<tt::SubtreeView<'_, Span>>, env: Vec<(String, String)>, @@ -171,6 +145,7 @@ pub(crate) fn expand( call_site: Span, mixed_site: Span, current_dir: String, + callback: SubCallback<'_>, ) -> Result<Result<tt::TopSubtree<span::SpanData<span::SyntaxContext>>, String>, crate::ServerError> { let version = proc_macro.process.version(); @@ -201,27 +176,7 @@ pub(crate) fn expand( current_dir: Some(current_dir), }))); - struct Callbacks<'de> { - db: &'de dyn SourceDatabase, - } - impl<'db> ClientCallbacks for Callbacks<'db> { - fn handle_sub_request(&mut self, req: SubRequest) -> Result<SubResponse, ServerError> { - match req { - SubRequest::SourceText { file_id, start, end } => { - let file = FileId::from_raw(file_id); - let text = self.db.file_text(file).text(self.db); - - let slice = text.get(start as usize..end as usize).map(|s| s.to_owned()); - - Ok(SubResponse::SourceTextResult { text: slice }) - } - } - } - } - - let mut callbacks = Callbacks { db }; - - let response_payload = run_request(&proc_macro.process, task, &mut callbacks)?; + let response_payload = run_request(&proc_macro.process, task, callback)?; match response_payload { BidirectionalMessage::Response(Response::ExpandMacro(it)) => Ok(it @@ -253,15 +208,19 @@ pub(crate) fn expand( fn run_request( srv: &ProcMacroServerProcess, msg: BidirectionalMessage, - callbacks: &mut dyn ClientCallbacks, + callback: SubCallback<'_>, ) -> Result<BidirectionalMessage, ServerError> { if let Some(server_error) = srv.exited() { return Err(server_error.clone()); } if srv.use_postcard() { - srv.run_bidirectional::<PostcardProtocol>(msg, callbacks) + srv.run_bidirectional::<PostcardProtocol>(msg, callback) } else { - srv.run_bidirectional::<JsonProtocol>(msg, callbacks) + srv.run_bidirectional::<JsonProtocol>(msg, callback) } } + +pub fn reject_subrequests(req: SubRequest) -> Result<SubResponse, ServerError> { + Err(ServerError { message: format!("{req:?} sub-request not supported here"), io: None }) +} |