Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-diagnostics/src/handlers/missing_unsafe.rs')
-rw-r--r--crates/ide-diagnostics/src/handlers/missing_unsafe.rs39
1 files changed, 38 insertions, 1 deletions
diff --git a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs
index 029ed18a4d..df1cd76cf7 100644
--- a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs
+++ b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs
@@ -50,7 +50,12 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr)?;
- let replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text());
+ let mut replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text());
+ if let Some(expr) = ast::Expr::cast(node_to_add_unsafe_block.clone())
+ && needs_parentheses(&expr)
+ {
+ replacement = format!("({replacement})");
+ }
let edit = TextEdit::replace(node_to_add_unsafe_block.text_range(), replacement);
let source_change = SourceChange::from_text_edit(
d.node.file_id.original_file(ctx.sema.db).file_id(ctx.sema.db),
@@ -112,6 +117,17 @@ fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> Option<SyntaxN
None
}
+fn needs_parentheses(expr: &ast::Expr) -> bool {
+ let node = expr.syntax();
+ node.ancestors()
+ .skip(1)
+ .take_while(|it| it.text_range().start() == node.text_range().start())
+ .map_while(ast::Expr::cast)
+ .last()
+ .and_then(|it| Some(it.syntax().parent()?.kind()))
+ .is_some_and(|kind| ast::ExprStmt::can_cast(kind) || ast::StmtList::can_cast(kind))
+}
+
#[cfg(test)]
mod tests {
use crate::tests::{check_diagnostics, check_fix, check_no_fix};
@@ -571,6 +587,27 @@ fn main() {
}
#[test]
+ fn needs_parentheses_for_unambiguous() {
+ check_fix(
+ r#"
+//- minicore: copy
+static mut STATIC_MUT: u8 = 0;
+
+fn foo() -> u8 {
+ STATIC_MUT$0 * 2
+}
+"#,
+ r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn foo() -> u8 {
+ (unsafe { STATIC_MUT }) * 2
+}
+"#,
+ )
+ }
+
+ #[test]
fn ref_to_unsafe_expr() {
check_fix(
r#"