Unnamed repository; edit this file 'description' to name the repository.
Auto merge of #16226 - Veykril:lsp-server, r=Veykril
internal: Expose whether a channel has been dropped in lsp-server errors
Not the best way to expose this, but this should allow us to give somewhat better errors when the initialization request is malformed, as currently that just results in a channel disconnected error instead of the deserialization error. cc https://github.com/rust-lang/rust-analyzer/issues/15859
| -rw-r--r-- | Cargo.lock | 14 | ||||
| -rw-r--r-- | Cargo.toml | 2 | ||||
| -rw-r--r-- | crates/rust-analyzer/src/bin/main.rs | 17 | ||||
| -rw-r--r-- | lib/lsp-server/Cargo.toml | 4 | ||||
| -rw-r--r-- | lib/lsp-server/examples/goto_def.rs | 10 | ||||
| -rw-r--r-- | lib/lsp-server/src/error.rs | 17 | ||||
| -rw-r--r-- | lib/lsp-server/src/lib.rs | 49 | ||||
| -rw-r--r-- | lib/lsp-server/src/msg.rs | 4 | ||||
| -rw-r--r-- | lib/lsp-server/src/stdio.rs | 3 |
9 files changed, 79 insertions, 41 deletions
diff --git a/Cargo.lock b/Cargo.lock index 7310ecc858..c7d110eafb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -945,24 +945,24 @@ checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" [[package]] name = "lsp-server" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b52dccdf3302eefab8c8a1273047f0a3c3dca4b527c8458d00c09484c8371928" +version = "0.7.6" dependencies = [ "crossbeam-channel", + "ctrlc", "log", + "lsp-types", "serde", "serde_json", ] [[package]] name = "lsp-server" -version = "0.7.5" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248f65b78f6db5d8e1b1604b4098a28b43d21a8eb1deeca22b1c421b276c7095" dependencies = [ "crossbeam-channel", - "ctrlc", "log", - "lsp-types", "serde", "serde_json", ] @@ -1526,7 +1526,7 @@ dependencies = [ "ide-ssr", "itertools", "load-cargo", - "lsp-server 0.7.4", + "lsp-server 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)", "lsp-types", "mbe", "mimalloc", diff --git a/Cargo.toml b/Cargo.toml index d4cff420bc..e82a14d16e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,7 +88,7 @@ test-utils = { path = "./crates/test-utils" } # In-tree crates that are published separately and follow semver. See lib/README.md line-index = { version = "0.1.1" } la-arena = { version = "0.3.1" } -lsp-server = { version = "0.7.4" } +lsp-server = { version = "0.7.6" } # non-local crates anyhow = "1.0.75" diff --git a/crates/rust-analyzer/src/bin/main.rs b/crates/rust-analyzer/src/bin/main.rs index 8472e49de9..6f40a4c88e 100644 --- a/crates/rust-analyzer/src/bin/main.rs +++ b/crates/rust-analyzer/src/bin/main.rs @@ -172,7 +172,15 @@ fn run_server() -> anyhow::Result<()> { let (connection, io_threads) = Connection::stdio(); - let (initialize_id, initialize_params) = connection.initialize_start()?; + let (initialize_id, initialize_params) = match connection.initialize_start() { + Ok(it) => it, + Err(e) => { + if e.channel_is_disconnected() { + io_threads.join()?; + } + return Err(e.into()); + } + }; tracing::info!("InitializeParams: {}", initialize_params); let lsp_types::InitializeParams { root_uri, @@ -240,7 +248,12 @@ fn run_server() -> anyhow::Result<()> { let initialize_result = serde_json::to_value(initialize_result).unwrap(); - connection.initialize_finish(initialize_id, initialize_result)?; + if let Err(e) = connection.initialize_finish(initialize_id, initialize_result) { + if e.channel_is_disconnected() { + io_threads.join()?; + } + return Err(e.into()); + } if !config.has_linked_projects() && config.detached_files().is_empty() { config.rediscover_workspaces(); diff --git a/lib/lsp-server/Cargo.toml b/lib/lsp-server/Cargo.toml index e802bf185b..116b376b0b 100644 --- a/lib/lsp-server/Cargo.toml +++ b/lib/lsp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lsp-server" -version = "0.7.5" +version = "0.7.6" description = "Generic LSP server scaffold." license = "MIT OR Apache-2.0" repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/lsp-server" @@ -10,7 +10,7 @@ edition = "2021" log = "0.4.17" serde_json = "1.0.108" serde = { version = "1.0.192", features = ["derive"] } -crossbeam-channel = "0.5.6" +crossbeam-channel = "0.5.8" [dev-dependencies] lsp-types = "=0.95" diff --git a/lib/lsp-server/examples/goto_def.rs b/lib/lsp-server/examples/goto_def.rs index 2f270afbbf..71f6625406 100644 --- a/lib/lsp-server/examples/goto_def.rs +++ b/lib/lsp-server/examples/goto_def.rs @@ -64,7 +64,15 @@ fn main() -> Result<(), Box<dyn Error + Sync + Send>> { ..Default::default() }) .unwrap(); - let initialization_params = connection.initialize(server_capabilities)?; + let initialization_params = match connection.initialize(server_capabilities) { + Ok(it) => it, + Err(e) => { + if e.channel_is_disconnected() { + io_threads.join()?; + } + return Err(e.into()); + } + }; main_loop(connection, initialization_params)?; io_threads.join()?; diff --git a/lib/lsp-server/src/error.rs b/lib/lsp-server/src/error.rs index 755b3fd959..ebdd153b5b 100644 --- a/lib/lsp-server/src/error.rs +++ b/lib/lsp-server/src/error.rs @@ -3,7 +3,22 @@ use std::fmt; use crate::{Notification, Request}; #[derive(Debug, Clone, PartialEq)] -pub struct ProtocolError(pub(crate) String); +pub struct ProtocolError(String, bool); + +impl ProtocolError { + pub(crate) fn new(msg: impl Into<String>) -> Self { + ProtocolError(msg.into(), false) + } + + pub(crate) fn disconnected() -> ProtocolError { + ProtocolError("disconnected channel".into(), true) + } + + /// Whether this error occured due to a disconnected channel. + pub fn channel_is_disconnected(&self) -> bool { + self.1 + } +} impl std::error::Error for ProtocolError {} diff --git a/lib/lsp-server/src/lib.rs b/lib/lsp-server/src/lib.rs index 2797a6b60d..6b732d4702 100644 --- a/lib/lsp-server/src/lib.rs +++ b/lib/lsp-server/src/lib.rs @@ -17,7 +17,7 @@ use std::{ net::{TcpListener, TcpStream, ToSocketAddrs}, }; -use crossbeam_channel::{Receiver, RecvTimeoutError, Sender}; +use crossbeam_channel::{Receiver, RecvError, RecvTimeoutError, Sender}; pub use crate::{ error::{ExtractError, ProtocolError}, @@ -158,11 +158,7 @@ impl Connection { Err(RecvTimeoutError::Timeout) => { continue; } - Err(e) => { - return Err(ProtocolError(format!( - "expected initialize request, got error: {e}" - ))) - } + Err(RecvTimeoutError::Disconnected) => return Err(ProtocolError::disconnected()), }; match msg { @@ -181,12 +177,14 @@ impl Connection { continue; } msg => { - return Err(ProtocolError(format!("expected initialize request, got {msg:?}"))); + return Err(ProtocolError::new(format!( + "expected initialize request, got {msg:?}" + ))); } }; } - return Err(ProtocolError(String::from( + return Err(ProtocolError::new(String::from( "Initialization has been aborted during initialization", ))); } @@ -201,12 +199,10 @@ impl Connection { self.sender.send(resp.into()).unwrap(); match &self.receiver.recv() { Ok(Message::Notification(n)) if n.is_initialized() => Ok(()), - Ok(msg) => { - Err(ProtocolError(format!(r#"expected initialized notification, got: {msg:?}"#))) - } - Err(e) => { - Err(ProtocolError(format!("expected initialized notification, got error: {e}",))) - } + Ok(msg) => Err(ProtocolError::new(format!( + r#"expected initialized notification, got: {msg:?}"# + ))), + Err(RecvError) => Err(ProtocolError::disconnected()), } } @@ -231,10 +227,8 @@ impl Connection { Err(RecvTimeoutError::Timeout) => { continue; } - Err(e) => { - return Err(ProtocolError(format!( - "expected initialized notification, got error: {e}", - ))); + Err(RecvTimeoutError::Disconnected) => { + return Err(ProtocolError::disconnected()); } }; @@ -243,14 +237,14 @@ impl Connection { return Ok(()); } msg => { - return Err(ProtocolError(format!( + return Err(ProtocolError::new(format!( r#"expected initialized notification, got: {msg:?}"# ))); } } } - return Err(ProtocolError(String::from( + return Err(ProtocolError::new(String::from( "Initialization has been aborted during initialization", ))); } @@ -359,9 +353,18 @@ impl Connection { match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) { Ok(Message::Notification(n)) if n.is_exit() => (), Ok(msg) => { - return Err(ProtocolError(format!("unexpected message during shutdown: {msg:?}"))) + return Err(ProtocolError::new(format!( + "unexpected message during shutdown: {msg:?}" + ))) + } + Err(RecvTimeoutError::Timeout) => { + return Err(ProtocolError::new(format!("timed out waiting for exit notification"))) + } + Err(RecvTimeoutError::Disconnected) => { + return Err(ProtocolError::new(format!( + "channel disconnected waiting for exit notification" + ))) } - Err(e) => return Err(ProtocolError(format!("unexpected error during shutdown: {e}"))), } Ok(true) } @@ -426,7 +429,7 @@ mod tests { initialize_start_test(TestCase { test_messages: vec![notification_msg.clone()], - expected_resp: Err(ProtocolError(format!( + expected_resp: Err(ProtocolError::new(format!( "expected initialize request, got {:?}", notification_msg ))), diff --git a/lib/lsp-server/src/msg.rs b/lib/lsp-server/src/msg.rs index 730ad51f42..ba318dd169 100644 --- a/lib/lsp-server/src/msg.rs +++ b/lib/lsp-server/src/msg.rs @@ -264,12 +264,12 @@ fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> { let mut parts = buf.splitn(2, ": "); let header_name = parts.next().unwrap(); let header_value = - parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; + parts.next().ok_or_else(|| invalid_data(format!("malformed header: {:?}", buf)))?; if header_name.eq_ignore_ascii_case("Content-Length") { size = Some(header_value.parse::<usize>().map_err(invalid_data)?); } } - let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; + let size: usize = size.ok_or_else(|| invalid_data("no Content-Length".to_string()))?; let mut buf = buf.into_bytes(); buf.resize(size, 0); inp.read_exact(&mut buf)?; diff --git a/lib/lsp-server/src/stdio.rs b/lib/lsp-server/src/stdio.rs index e487b9b462..cea199d029 100644 --- a/lib/lsp-server/src/stdio.rs +++ b/lib/lsp-server/src/stdio.rs @@ -15,8 +15,7 @@ pub(crate) fn stdio_transport() -> (Sender<Message>, Receiver<Message>, IoThread let writer = thread::spawn(move || { let stdout = stdout(); let mut stdout = stdout.lock(); - writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))?; - Ok(()) + writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout)) }); let (reader_sender, reader_receiver) = bounded::<Message>(0); let reader = thread::spawn(move || { |