Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_function.rs')
| -rw-r--r-- | crates/ide-assists/src/handlers/extract_function.rs | 284 |
1 files changed, 253 insertions, 31 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index c1e2f19ab1..e04a1dabb2 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -11,7 +11,9 @@ use ide_db::{ helpers::mod_path_to_ast, imports::insert_use::{insert_use, ImportScope}, search::{FileReference, ReferenceCategory, SearchScope}, - syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr}, + syntax_helpers::node_ext::{ + for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr, + }, FxIndexSet, RootDatabase, }; use itertools::Itertools; @@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op }; let body = extraction_target(&node, range)?; - let container_info = body.analyze_container(&ctx.sema)?; + let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema)?; let (locals_used, self_param) = body.analyze(&ctx.sema); @@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op ret_ty, body, outliving_locals, + contains_tail_expr, mods: container_info, }; @@ -245,6 +248,8 @@ struct Function { ret_ty: RetType, body: FunctionBody, outliving_locals: Vec<OutlivedLocal>, + /// Whether at least one of the container's tail expr is contained in the range we're extracting. + contains_tail_expr: bool, mods: ContainerInfo, } @@ -265,7 +270,7 @@ enum ParamKind { MutRef, } -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug)] enum FunType { Unit, Single(hir::Type), @@ -294,7 +299,6 @@ struct ControlFlow { #[derive(Clone, Debug)] struct ContainerInfo { is_const: bool, - is_in_tail: bool, parent_loop: Option<SyntaxNode>, /// The function's return type, const's type etc. ret_type: Option<hir::Type>, @@ -584,7 +588,7 @@ impl FunctionBody { FunctionBody::Expr(expr) => Some(expr.clone()), FunctionBody::Span { parent, text_range } => { let tail_expr = parent.tail_expr()?; - text_range.contains_range(tail_expr.syntax().text_range()).then(|| tail_expr) + text_range.contains_range(tail_expr.syntax().text_range()).then_some(tail_expr) } } } @@ -743,7 +747,10 @@ impl FunctionBody { (res, self_param) } - fn analyze_container(&self, sema: &Semantics<'_, RootDatabase>) -> Option<ContainerInfo> { + fn analyze_container( + &self, + sema: &Semantics<'_, RootDatabase>, + ) -> Option<(ContainerInfo, bool)> { let mut ancestors = self.parent()?.ancestors(); let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted); let mut parent_loop = None; @@ -815,28 +822,36 @@ impl FunctionBody { } }; }; - let container_tail = match expr? { - ast::Expr::BlockExpr(block) => block.tail_expr(), - expr => Some(expr), - }; - let is_in_tail = - container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| { - container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range()) + + let expr = expr?; + let contains_tail_expr = if let Some(body_tail) = self.tail_expr() { + let mut contains_tail_expr = false; + let tail_expr_range = body_tail.syntax().text_range(); + for_each_tail_expr(&expr, &mut |e| { + if tail_expr_range.contains_range(e.syntax().text_range()) { + contains_tail_expr = true; + } }); + contains_tail_expr + } else { + false + }; let parent = self.parent()?; let parents = generic_parents(&parent); let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect(); let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect(); - Some(ContainerInfo { - is_in_tail, - is_const, - parent_loop, - ret_type: ty, - generic_param_lists, - where_clauses, - }) + Some(( + ContainerInfo { + is_const, + parent_loop, + ret_type: ty, + generic_param_lists, + where_clauses, + }, + contains_tail_expr, + )) } fn return_ty(&self, ctx: &AssistContext<'_>) -> Option<RetType> { @@ -1368,7 +1383,7 @@ impl FlowHandler { None => FlowHandler::None, Some(flow_kind) => { let action = flow_kind.clone(); - if *ret_ty == FunType::Unit { + if let FunType::Unit = ret_ty { match flow_kind { FlowKind::Return(None) | FlowKind::Break(_, None) @@ -1633,7 +1648,7 @@ impl Function { fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> { let fun_ty = self.return_type(ctx); - let handler = if self.mods.is_in_tail { + let handler = if self.contains_tail_expr { FlowHandler::None } else { FlowHandler::from_ret_ty(self, &fun_ty) @@ -1707,7 +1722,7 @@ fn make_body( fun: &Function, ) -> ast::BlockExpr { let ret_ty = fun.return_type(ctx); - let handler = if fun.mods.is_in_tail { + let handler = if fun.contains_tail_expr { FlowHandler::None } else { FlowHandler::from_ret_ty(fun, &ret_ty) @@ -1785,7 +1800,7 @@ fn make_body( .collect::<Vec<SyntaxElement>>(); let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent)); - make::hacky_block_expr_with_comments(elements, tail_expr) + make::hacky_block_expr(elements, tail_expr) } }; @@ -1845,9 +1860,29 @@ fn with_default_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::B } fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { - let stmt_tail = block.tail_expr().map(|expr| make::expr_stmt(expr).into()); - let stmts = block.statements().chain(stmt_tail); - make::block_expr(stmts, Some(tail_expr)) + let stmt_tail_opt: Option<ast::Stmt> = + block.tail_expr().map(|expr| make::expr_stmt(expr).into()); + + let mut elements: Vec<SyntaxElement> = vec![]; + + block.statements().for_each(|stmt| { + elements.push(syntax::NodeOrToken::Node(stmt.syntax().clone())); + }); + + if let Some(stmt_list) = block.stmt_list() { + stmt_list.syntax().children_with_tokens().for_each(|node_or_token| { + match &node_or_token { + syntax::NodeOrToken::Token(_) => elements.push(node_or_token), + _ => (), + }; + }); + } + + if let Some(stmt_tail) = stmt_tail_opt { + elements.push(syntax::NodeOrToken::Node(stmt_tail.syntax().clone())); + } + + make::hacky_block_expr(elements, Some(tail_expr)) } fn format_type(ty: &hir::Type, ctx: &AssistContext<'_>, module: hir::Module) -> String { @@ -1946,7 +1981,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) { if nested_scope.is_none() { if let Some(expr) = ast::Expr::cast(e.clone()) { match expr { - ast::Expr::ReturnExpr(return_expr) if nested_scope.is_none() => { + ast::Expr::ReturnExpr(return_expr) => { let expr = return_expr.expr(); if let Some(replacement) = make_rewritten_flow(handler, expr) { ted::replace(return_expr.syntax(), replacement.syntax()) @@ -4944,9 +4979,8 @@ fn $0fun_name() { ); } - // FIXME: we do want to preserve whitespace #[test] - fn extract_function_does_not_preserve_whitespace() { + fn extract_function_does_preserve_whitespace() { check_assist( extract_function, r#" @@ -4965,6 +4999,7 @@ fn func() { fn $0fun_name() { let a = 0; + let x = 0; } "#, @@ -5585,4 +5620,191 @@ fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> { "#, ); } + + #[test] + fn non_tail_expr_of_tail_expr_loop() { + check_assist( + extract_function, + r#" +pub fn f() { + loop { + $0if true { + continue; + }$0 + + if false { + break; + } + } +} +"#, + r#" +pub fn f() { + loop { + if let ControlFlow::Break(_) = fun_name() { + continue; + } + + if false { + break; + } + } +} + +fn $0fun_name() -> ControlFlow<()> { + if true { + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) +} +"#, + ); + } + + #[test] + fn non_tail_expr_of_tail_if_block() { + // FIXME: double semicolon + check_assist( + extract_function, + r#" +//- minicore: option, try +impl<T> core::ops::Try for Option<T> { + type Output = T; + type Residual = Option<!>; +} +impl<T> core::ops::FromResidual for Option<T> {} + +fn f() -> Option<()> { + if true { + let a = $0if true { + Some(())? + } else { + () + }$0; + Some(a) + } else { + None + } +} +"#, + r#" +impl<T> core::ops::Try for Option<T> { + type Output = T; + type Residual = Option<!>; +} +impl<T> core::ops::FromResidual for Option<T> {} + +fn f() -> Option<()> { + if true { + let a = fun_name()?;; + Some(a) + } else { + None + } +} + +fn $0fun_name() -> Option<()> { + Some(if true { + Some(())? + } else { + () + }) +} +"#, + ); + } + + #[test] + fn tail_expr_of_tail_block_nested() { + check_assist( + extract_function, + r#" +//- minicore: option, try +impl<T> core::ops::Try for Option<T> { + type Output = T; + type Residual = Option<!>; +} +impl<T> core::ops::FromResidual for Option<T> {} + +fn f() -> Option<()> { + if true { + $0{ + let a = if true { + Some(())? + } else { + () + }; + Some(a) + }$0 + } else { + None + } +} +"#, + r#" +impl<T> core::ops::Try for Option<T> { + type Output = T; + type Residual = Option<!>; +} +impl<T> core::ops::FromResidual for Option<T> {} + +fn f() -> Option<()> { + if true { + fun_name()? + } else { + None + } +} + +fn $0fun_name() -> Option<()> { + let a = if true { + Some(())? + } else { + () + }; + Some(a) +} +"#, + ); + } + + #[test] + fn non_tail_expr_with_comment_of_tail_expr_loop() { + check_assist( + extract_function, + r#" +pub fn f() { + loop { + $0// A comment + if true { + continue; + }$0 + if false { + break; + } + } +} +"#, + r#" +pub fn f() { + loop { + if let ControlFlow::Break(_) = fun_name() { + continue; + } + if false { + break; + } + } +} + +fn $0fun_name() -> ControlFlow<()> { + // A comment + if true { + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) +} +"#, + ); + } } |