Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/proc-macro-api/src/bidirectional_protocol.rs17
-rw-r--r--crates/proc-macro-api/src/bidirectional_protocol/msg.rs3
-rw-r--r--crates/proc-macro-srv-cli/src/main_loop.rs74
-rw-r--r--crates/proc-macro-srv/src/lib.rs54
4 files changed, 113 insertions, 35 deletions
diff --git a/crates/proc-macro-api/src/bidirectional_protocol.rs b/crates/proc-macro-api/src/bidirectional_protocol.rs
index 8311df23d7..ba59cb219b 100644
--- a/crates/proc-macro-api/src/bidirectional_protocol.rs
+++ b/crates/proc-macro-api/src/bidirectional_protocol.rs
@@ -2,6 +2,7 @@
use std::{
io::{self, BufRead, Write},
+ panic::{AssertUnwindSafe, catch_unwind},
sync::Arc,
};
@@ -55,9 +56,19 @@ pub fn run_conversation(
return Ok(BidirectionalMessage::Response(response));
}
BidirectionalMessage::SubRequest(sr) => {
- let resp = callback(sr)?;
- let reply = BidirectionalMessage::SubResponse(resp);
- let encoded = postcard::encode(&reply).map_err(wrap_encode)?;
+ // TODO: Avoid `AssertUnwindSafe` by making the callback `UnwindSafe` once `ExpandDatabase`
+ // becomes unwind-safe (currently blocked by `parking_lot::RwLock` in the VFS).
+ let resp = match catch_unwind(AssertUnwindSafe(|| callback(sr))) {
+ Ok(Ok(resp)) => BidirectionalMessage::SubResponse(resp),
+ Ok(Err(err)) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
+ reason: err.to_string(),
+ }),
+ Err(_) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
+ reason: "callback panicked or was cancelled".into(),
+ }),
+ };
+
+ let encoded = postcard::encode(&resp).map_err(wrap_encode)?;
postcard::write(writer, &encoded)
.map_err(wrap_io("failed to write sub-response"))?;
}
diff --git a/crates/proc-macro-api/src/bidirectional_protocol/msg.rs b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs
index 1df0c68379..3f0422dc5b 100644
--- a/crates/proc-macro-api/src/bidirectional_protocol/msg.rs
+++ b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs
@@ -42,6 +42,9 @@ pub enum SubResponse {
ByteRangeResult {
range: Range<usize>,
},
+ Cancel {
+ reason: String,
+ },
}
#[derive(Debug, Serialize, Deserialize)]
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs
index 758629fd1f..9be3199a38 100644
--- a/crates/proc-macro-srv-cli/src/main_loop.rs
+++ b/crates/proc-macro-srv-cli/src/main_loop.rs
@@ -3,6 +3,7 @@ use proc_macro_api::{
ProtocolFormat, bidirectional_protocol::msg as bidirectional, legacy_protocol::msg as legacy,
version::CURRENT_API_VERSION,
};
+use std::panic::{panic_any, resume_unwind};
use std::{
io::{self, BufRead, Write},
ops::Range,
@@ -10,7 +11,7 @@ use std::{
use legacy::Message;
-use proc_macro_srv::{EnvSnapshot, SpanId};
+use proc_macro_srv::{EnvSnapshot, ProcMacroClientError, ProcMacroPanicMarker, SpanId};
struct SpanTrans;
@@ -172,16 +173,43 @@ impl<'a> ProcMacroClientHandle<'a> {
fn roundtrip(
&mut self,
req: bidirectional::SubRequest,
- ) -> Option<bidirectional::BidirectionalMessage> {
+ ) -> Result<bidirectional::SubResponse, ProcMacroClientError> {
let msg = bidirectional::BidirectionalMessage::SubRequest(req);
- if msg.write(&mut *self.stdout).is_err() {
- return None;
+ msg.write(&mut *self.stdout).map_err(ProcMacroClientError::Io)?;
+
+ let msg = bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf)
+ .map_err(ProcMacroClientError::Io)?
+ .ok_or(ProcMacroClientError::Eof)?;
+
+ match msg {
+ bidirectional::BidirectionalMessage::SubResponse(resp) => match resp {
+ bidirectional::SubResponse::Cancel { reason } => {
+ Err(ProcMacroClientError::Cancelled { reason })
+ }
+ other => Ok(other),
+ },
+ other => {
+ Err(ProcMacroClientError::Protocol(format!("expected SubResponse, got {other:?}")))
+ }
}
+ }
+}
- match bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf) {
- Ok(Some(msg)) => Some(msg),
- _ => None,
+fn handle_failure(failure: Result<bidirectional::SubResponse, ProcMacroClientError>) -> ! {
+ match failure {
+ Err(ProcMacroClientError::Cancelled { reason }) => {
+ resume_unwind(Box::new(ProcMacroPanicMarker::Cancelled { reason }));
+ }
+ Err(err) => {
+ panic_any(ProcMacroPanicMarker::Internal {
+ reason: format!("proc-macro IPC error: {err:?}"),
+ });
+ }
+ Ok(other) => {
+ panic_any(ProcMacroPanicMarker::Internal {
+ reason: format!("unexpected SubResponse {other:?}"),
+ });
}
}
}
@@ -189,10 +217,8 @@ impl<'a> ProcMacroClientHandle<'a> {
impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
fn file(&mut self, file_id: proc_macro_srv::span::FileId) -> String {
match self.roundtrip(bidirectional::SubRequest::FilePath { file_id: file_id.index() }) {
- Some(bidirectional::BidirectionalMessage::SubResponse(
- bidirectional::SubResponse::FilePathResult { name },
- )) => name,
- _ => String::new(),
+ Ok(bidirectional::SubResponse::FilePathResult { name }) => name,
+ other => handle_failure(other),
}
}
@@ -206,20 +232,16 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
start: range.start().into(),
end: range.end().into(),
}) {
- Some(bidirectional::BidirectionalMessage::SubResponse(
- bidirectional::SubResponse::SourceTextResult { text },
- )) => text,
- _ => None,
+ Ok(bidirectional::SubResponse::SourceTextResult { text }) => text,
+ other => handle_failure(other),
}
}
fn local_file(&mut self, file_id: proc_macro_srv::span::FileId) -> Option<String> {
match self.roundtrip(bidirectional::SubRequest::LocalFilePath { file_id: file_id.index() })
{
- Some(bidirectional::BidirectionalMessage::SubResponse(
- bidirectional::SubResponse::LocalFilePathResult { name },
- )) => name,
- _ => None,
+ Ok(bidirectional::SubResponse::LocalFilePathResult { name }) => name,
+ other => handle_failure(other),
}
}
@@ -230,10 +252,10 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
ast_id: anchor.ast_id.into_raw(),
offset: range.start().into(),
}) {
- Some(bidirectional::BidirectionalMessage::SubResponse(
- bidirectional::SubResponse::LineColumnResult { line, column },
- )) => Some((line, column)),
- _ => None,
+ Ok(bidirectional::SubResponse::LineColumnResult { line, column }) => {
+ Some((line, column))
+ }
+ other => handle_failure(other),
}
}
@@ -247,10 +269,8 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
start: range.start().into(),
end: range.end().into(),
}) {
- Some(bidirectional::BidirectionalMessage::SubResponse(
- bidirectional::SubResponse::ByteRangeResult { range },
- )) => range,
- _ => Range { start: range.start().into(), end: range.end().into() },
+ Ok(bidirectional::SubResponse::ByteRangeResult { range }) => range,
+ other => handle_failure(other),
}
}
}
diff --git a/crates/proc-macro-srv/src/lib.rs b/crates/proc-macro-srv/src/lib.rs
index e04f744ae2..c548dc620a 100644
--- a/crates/proc-macro-srv/src/lib.rs
+++ b/crates/proc-macro-srv/src/lib.rs
@@ -96,6 +96,20 @@ impl<'env> ProcMacroSrv<'env> {
}
}
+#[derive(Debug)]
+pub enum ProcMacroClientError {
+ Cancelled { reason: String },
+ Io(std::io::Error),
+ Protocol(String),
+ Eof,
+}
+
+#[derive(Debug)]
+pub enum ProcMacroPanicMarker {
+ Cancelled { reason: String },
+ Internal { reason: String },
+}
+
pub type ProcMacroClientHandle<'a> = &'a mut (dyn ProcMacroClientInterface + Sync + Send);
pub trait ProcMacroClientInterface {
@@ -110,6 +124,22 @@ pub trait ProcMacroClientInterface {
const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024;
+pub enum ExpandError {
+ Panic(PanicMessage),
+ Cancelled { reason: Option<String> },
+ Internal { reason: Option<String> },
+}
+
+impl ExpandError {
+ pub fn into_string(self) -> Option<String> {
+ match self {
+ ExpandError::Panic(panic_message) => panic_message.into_string(),
+ ExpandError::Cancelled { reason } => reason,
+ ExpandError::Internal { reason } => reason,
+ }
+ }
+}
+
impl ProcMacroSrv<'_> {
pub fn expand<S: ProcMacroSrvSpan>(
&self,
@@ -123,10 +153,10 @@ impl ProcMacroSrv<'_> {
call_site: S,
mixed_site: S,
callback: Option<ProcMacroClientHandle<'_>>,
- ) -> Result<token_stream::TokenStream<S>, PanicMessage> {
+ ) -> Result<token_stream::TokenStream<S>, ExpandError> {
let snapped_env = self.env;
- let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
- message: Some(format!("failed to load macro: {err}")),
+ let expander = self.expander(lib.as_ref()).map_err(|err| ExpandError::Internal {
+ reason: Some(format!("failed to load macro: {err}")),
})?;
let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref));
@@ -144,8 +174,22 @@ impl ProcMacroSrv<'_> {
)
});
match thread.unwrap().join() {
- Ok(res) => res,
- Err(e) => std::panic::resume_unwind(e),
+ Ok(res) => res.map_err(ExpandError::Panic),
+
+ Err(payload) => {
+ if let Some(marker) = payload.downcast_ref::<ProcMacroPanicMarker>() {
+ return match marker {
+ ProcMacroPanicMarker::Cancelled { reason } => {
+ Err(ExpandError::Cancelled { reason: Some(reason.clone()) })
+ }
+ ProcMacroPanicMarker::Internal { reason } => {
+ Err(ExpandError::Internal { reason: Some(reason.clone()) })
+ }
+ };
+ }
+
+ std::panic::resume_unwind(payload)
+ }
}
});
prev_env.rollback();