Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-lsp/src/lib.rs')
| -rw-r--r-- | helix-lsp/src/lib.rs | 141 |
1 files changed, 84 insertions, 57 deletions
diff --git a/helix-lsp/src/lib.rs b/helix-lsp/src/lib.rs index 262bc1b9..04f018bc 100644 --- a/helix-lsp/src/lib.rs +++ b/helix-lsp/src/lib.rs @@ -17,6 +17,7 @@ use helix_core::syntax::{ LanguageConfiguration, LanguageServerConfiguration, LanguageServerFeatures, }; use helix_stdx::path; +use slotmap::SlotMap; use tokio::sync::mpsc::UnboundedReceiver; use std::{ @@ -28,8 +29,9 @@ use std::{ use thiserror::Error; use tokio_stream::wrappers::UnboundedReceiverStream; -pub type Result<T> = core::result::Result<T, Error>; +pub type Result<T, E = Error> = core::result::Result<T, E>; pub type LanguageServerName = String; +pub use helix_core::diagnostic::LanguageServerId; #[derive(Error, Debug)] pub enum Error { @@ -651,38 +653,42 @@ impl Notification { #[derive(Debug)] pub struct Registry { - inner: HashMap<LanguageServerName, Vec<Arc<Client>>>, + inner: SlotMap<LanguageServerId, Arc<Client>>, + inner_by_name: HashMap<LanguageServerName, Vec<Arc<Client>>>, syn_loader: Arc<ArcSwap<helix_core::syntax::Loader>>, - counter: usize, - pub incoming: SelectAll<UnboundedReceiverStream<(usize, Call)>>, + pub incoming: SelectAll<UnboundedReceiverStream<(LanguageServerId, Call)>>, pub file_event_handler: file_event::Handler, } impl Registry { pub fn new(syn_loader: Arc<ArcSwap<helix_core::syntax::Loader>>) -> Self { Self { - inner: HashMap::new(), + inner: SlotMap::with_key(), + inner_by_name: HashMap::new(), syn_loader, - counter: 0, incoming: SelectAll::new(), file_event_handler: file_event::Handler::new(), } } - pub fn get_by_id(&self, id: usize) -> Option<&Client> { - self.inner - .values() - .flatten() - .find(|client| client.id() == id) - .map(|client| &**client) + pub fn get_by_id(&self, id: LanguageServerId) -> Option<&Arc<Client>> { + self.inner.get(id) } - pub fn remove_by_id(&mut self, id: usize) { + pub fn remove_by_id(&mut self, id: LanguageServerId) { + let Some(client) = self.inner.remove(id) else { + log::error!("client was already removed"); + return + }; self.file_event_handler.remove_client(id); - self.inner.retain(|_, language_servers| { - language_servers.retain(|ls| id != ls.id()); - !language_servers.is_empty() - }); + let instances = self + .inner_by_name + .get_mut(client.name()) + .expect("inner and inner_by_name must be synced"); + instances.retain(|ls| id != ls.id()); + if instances.is_empty() { + self.inner_by_name.remove(client.name()); + } } fn start_client( @@ -692,28 +698,28 @@ impl Registry { doc_path: Option<&std::path::PathBuf>, root_dirs: &[PathBuf], enable_snippets: bool, - ) -> Result<Option<Arc<Client>>> { + ) -> Result<Arc<Client>, StartupError> { let syn_loader = self.syn_loader.load(); let config = syn_loader .language_server_configs() .get(&name) .ok_or_else(|| anyhow::anyhow!("Language server '{name}' not defined"))?; - let id = self.counter; - self.counter += 1; - if let Some(NewClient(client, incoming)) = start_client( - id, - name, - ls_config, - config, - doc_path, - root_dirs, - enable_snippets, - )? { - self.incoming.push(UnboundedReceiverStream::new(incoming)); - Ok(Some(client)) - } else { - Ok(None) - } + let id = self.inner.try_insert_with_key(|id| { + start_client( + id, + name, + ls_config, + config, + doc_path, + root_dirs, + enable_snippets, + ) + .map(|client| { + self.incoming.push(UnboundedReceiverStream::new(client.1)); + client.0 + }) + })?; + Ok(self.inner[id].clone()) } /// If this method is called, all documents that have a reference to language servers used by the language config have to refresh their language servers, @@ -730,7 +736,7 @@ impl Registry { .language_servers .iter() .filter_map(|LanguageServerFeatures { name, .. }| { - if self.inner.contains_key(name) { + if self.inner_by_name.contains_key(name) { let client = match self.start_client( name.clone(), language_config, @@ -738,16 +744,18 @@ impl Registry { root_dirs, enable_snippets, ) { - Ok(client) => client?, - Err(error) => return Some(Err(error)), + Ok(client) => client, + Err(StartupError::NoRequiredRootFound) => return None, + Err(StartupError::Error(err)) => return Some(Err(err)), }; let old_clients = self - .inner + .inner_by_name .insert(name.clone(), vec![client.clone()]) .unwrap(); for old_client in old_clients { self.file_event_handler.remove_client(old_client.id()); + self.inner.remove(client.id()); tokio::spawn(async move { let _ = old_client.force_shutdown().await; }); @@ -762,9 +770,10 @@ impl Registry { } pub fn stop(&mut self, name: &str) { - if let Some(clients) = self.inner.remove(name) { + if let Some(clients) = self.inner_by_name.remove(name) { for client in clients { self.file_event_handler.remove_client(client.id()); + self.inner.remove(client.id()); tokio::spawn(async move { let _ = client.force_shutdown().await; }); @@ -781,7 +790,7 @@ impl Registry { ) -> impl Iterator<Item = (LanguageServerName, Result<Arc<Client>>)> + 'a { language_config.language_servers.iter().filter_map( move |LanguageServerFeatures { name, .. }| { - if let Some(clients) = self.inner.get(name) { + if let Some(clients) = self.inner_by_name.get(name) { if let Some((_, client)) = clients.iter().enumerate().find(|(i, client)| { client.try_add_doc(&language_config.roots, root_dirs, doc_path, *i == 0) }) { @@ -796,21 +805,21 @@ impl Registry { enable_snippets, ) { Ok(client) => { - let client = client?; - self.inner + self.inner_by_name .entry(name.to_owned()) .or_default() .push(client.clone()); Some((name.clone(), Ok(client))) } - Err(err) => Some((name.to_owned(), Err(err))), + Err(StartupError::NoRequiredRootFound) => None, + Err(StartupError::Error(err)) => Some((name.to_owned(), Err(err))), } }, ) } pub fn iter_clients(&self) -> impl Iterator<Item = &Arc<Client>> { - self.inner.values().flatten() + self.inner.values() } } @@ -833,7 +842,7 @@ impl ProgressStatus { /// Acts as a container for progress reported by language servers. Each server /// has a unique id assigned at creation through [`Registry`]. This id is then used /// to store the progress in this map. -pub struct LspProgressMap(HashMap<usize, HashMap<lsp::ProgressToken, ProgressStatus>>); +pub struct LspProgressMap(HashMap<LanguageServerId, HashMap<lsp::ProgressToken, ProgressStatus>>); impl LspProgressMap { pub fn new() -> Self { @@ -841,28 +850,35 @@ impl LspProgressMap { } /// Returns a map of all tokens corresponding to the language server with `id`. - pub fn progress_map(&self, id: usize) -> Option<&HashMap<lsp::ProgressToken, ProgressStatus>> { + pub fn progress_map( + &self, + id: LanguageServerId, + ) -> Option<&HashMap<lsp::ProgressToken, ProgressStatus>> { self.0.get(&id) } - pub fn is_progressing(&self, id: usize) -> bool { + pub fn is_progressing(&self, id: LanguageServerId) -> bool { self.0.get(&id).map(|it| !it.is_empty()).unwrap_or_default() } /// Returns last progress status for a given server with `id` and `token`. - pub fn progress(&self, id: usize, token: &lsp::ProgressToken) -> Option<&ProgressStatus> { + pub fn progress( + &self, + id: LanguageServerId, + token: &lsp::ProgressToken, + ) -> Option<&ProgressStatus> { self.0.get(&id).and_then(|values| values.get(token)) } /// Checks if progress `token` for server with `id` is created. - pub fn is_created(&mut self, id: usize, token: &lsp::ProgressToken) -> bool { + pub fn is_created(&mut self, id: LanguageServerId, token: &lsp::ProgressToken) -> bool { self.0 .get(&id) .map(|values| values.get(token).is_some()) .unwrap_or_default() } - pub fn create(&mut self, id: usize, token: lsp::ProgressToken) { + pub fn create(&mut self, id: LanguageServerId, token: lsp::ProgressToken) { self.0 .entry(id) .or_default() @@ -872,7 +888,7 @@ impl LspProgressMap { /// Ends the progress by removing the `token` from server with `id`, if removed returns the value. pub fn end_progress( &mut self, - id: usize, + id: LanguageServerId, token: &lsp::ProgressToken, ) -> Option<ProgressStatus> { self.0.get_mut(&id).and_then(|vals| vals.remove(token)) @@ -881,7 +897,7 @@ impl LspProgressMap { /// Updates the progress of `token` for server with `id` to `status`, returns the value replaced or `None`. pub fn update( &mut self, - id: usize, + id: LanguageServerId, token: lsp::ProgressToken, status: lsp::WorkDoneProgress, ) -> Option<ProgressStatus> { @@ -892,19 +908,30 @@ impl LspProgressMap { } } -struct NewClient(Arc<Client>, UnboundedReceiver<(usize, Call)>); +struct NewClient(Arc<Client>, UnboundedReceiver<(LanguageServerId, Call)>); + +enum StartupError { + NoRequiredRootFound, + Error(Error), +} + +impl<T: Into<Error>> From<T> for StartupError { + fn from(value: T) -> Self { + StartupError::Error(value.into()) + } +} /// start_client takes both a LanguageConfiguration and a LanguageServerConfiguration to ensure that /// it is only called when it makes sense. fn start_client( - id: usize, + id: LanguageServerId, name: String, config: &LanguageConfiguration, ls_config: &LanguageServerConfiguration, doc_path: Option<&std::path::PathBuf>, root_dirs: &[PathBuf], enable_snippets: bool, -) -> Result<Option<NewClient>> { +) -> Result<NewClient, StartupError> { let (workspace, workspace_is_cwd) = helix_loader::find_workspace(); let workspace = path::normalize(workspace); let root = find_lsp_workspace( @@ -929,7 +956,7 @@ fn start_client( .map(|entry| entry.file_name()) .any(|entry| globset.is_match(entry)) { - return Ok(None); + return Err(StartupError::NoRequiredRootFound); } } @@ -981,7 +1008,7 @@ fn start_client( initialize_notify.notify_one(); }); - Ok(Some(NewClient(client, incoming))) + Ok(NewClient(client, incoming)) } /// Find an LSP workspace of a file using the following mechanism: |