Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_function.rs')
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs284
1 files changed, 253 insertions, 31 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index c1e2f19ab1..e04a1dabb2 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -11,7 +11,9 @@ use ide_db::{
helpers::mod_path_to_ast,
imports::insert_use::{insert_use, ImportScope},
search::{FileReference, ReferenceCategory, SearchScope},
- syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
+ syntax_helpers::node_ext::{
+ for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
+ },
FxIndexSet, RootDatabase,
};
use itertools::Itertools;
@@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
};
let body = extraction_target(&node, range)?;
- let container_info = body.analyze_container(&ctx.sema)?;
+ let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema)?;
let (locals_used, self_param) = body.analyze(&ctx.sema);
@@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
ret_ty,
body,
outliving_locals,
+ contains_tail_expr,
mods: container_info,
};
@@ -245,6 +248,8 @@ struct Function {
ret_ty: RetType,
body: FunctionBody,
outliving_locals: Vec<OutlivedLocal>,
+ /// Whether at least one of the container's tail expr is contained in the range we're extracting.
+ contains_tail_expr: bool,
mods: ContainerInfo,
}
@@ -265,7 +270,7 @@ enum ParamKind {
MutRef,
}
-#[derive(Debug, Eq, PartialEq)]
+#[derive(Debug)]
enum FunType {
Unit,
Single(hir::Type),
@@ -294,7 +299,6 @@ struct ControlFlow {
#[derive(Clone, Debug)]
struct ContainerInfo {
is_const: bool,
- is_in_tail: bool,
parent_loop: Option<SyntaxNode>,
/// The function's return type, const's type etc.
ret_type: Option<hir::Type>,
@@ -584,7 +588,7 @@ impl FunctionBody {
FunctionBody::Expr(expr) => Some(expr.clone()),
FunctionBody::Span { parent, text_range } => {
let tail_expr = parent.tail_expr()?;
- text_range.contains_range(tail_expr.syntax().text_range()).then(|| tail_expr)
+ text_range.contains_range(tail_expr.syntax().text_range()).then_some(tail_expr)
}
}
}
@@ -743,7 +747,10 @@ impl FunctionBody {
(res, self_param)
}
- fn analyze_container(&self, sema: &Semantics<'_, RootDatabase>) -> Option<ContainerInfo> {
+ fn analyze_container(
+ &self,
+ sema: &Semantics<'_, RootDatabase>,
+ ) -> Option<(ContainerInfo, bool)> {
let mut ancestors = self.parent()?.ancestors();
let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted);
let mut parent_loop = None;
@@ -815,28 +822,36 @@ impl FunctionBody {
}
};
};
- let container_tail = match expr? {
- ast::Expr::BlockExpr(block) => block.tail_expr(),
- expr => Some(expr),
- };
- let is_in_tail =
- container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
- container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
+
+ let expr = expr?;
+ let contains_tail_expr = if let Some(body_tail) = self.tail_expr() {
+ let mut contains_tail_expr = false;
+ let tail_expr_range = body_tail.syntax().text_range();
+ for_each_tail_expr(&expr, &mut |e| {
+ if tail_expr_range.contains_range(e.syntax().text_range()) {
+ contains_tail_expr = true;
+ }
});
+ contains_tail_expr
+ } else {
+ false
+ };
let parent = self.parent()?;
let parents = generic_parents(&parent);
let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();
- Some(ContainerInfo {
- is_in_tail,
- is_const,
- parent_loop,
- ret_type: ty,
- generic_param_lists,
- where_clauses,
- })
+ Some((
+ ContainerInfo {
+ is_const,
+ parent_loop,
+ ret_type: ty,
+ generic_param_lists,
+ where_clauses,
+ },
+ contains_tail_expr,
+ ))
}
fn return_ty(&self, ctx: &AssistContext<'_>) -> Option<RetType> {
@@ -1368,7 +1383,7 @@ impl FlowHandler {
None => FlowHandler::None,
Some(flow_kind) => {
let action = flow_kind.clone();
- if *ret_ty == FunType::Unit {
+ if let FunType::Unit = ret_ty {
match flow_kind {
FlowKind::Return(None)
| FlowKind::Break(_, None)
@@ -1633,7 +1648,7 @@ impl Function {
fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> {
let fun_ty = self.return_type(ctx);
- let handler = if self.mods.is_in_tail {
+ let handler = if self.contains_tail_expr {
FlowHandler::None
} else {
FlowHandler::from_ret_ty(self, &fun_ty)
@@ -1707,7 +1722,7 @@ fn make_body(
fun: &Function,
) -> ast::BlockExpr {
let ret_ty = fun.return_type(ctx);
- let handler = if fun.mods.is_in_tail {
+ let handler = if fun.contains_tail_expr {
FlowHandler::None
} else {
FlowHandler::from_ret_ty(fun, &ret_ty)
@@ -1785,7 +1800,7 @@ fn make_body(
.collect::<Vec<SyntaxElement>>();
let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent));
- make::hacky_block_expr_with_comments(elements, tail_expr)
+ make::hacky_block_expr(elements, tail_expr)
}
};
@@ -1845,9 +1860,29 @@ fn with_default_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::B
}
fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr {
- let stmt_tail = block.tail_expr().map(|expr| make::expr_stmt(expr).into());
- let stmts = block.statements().chain(stmt_tail);
- make::block_expr(stmts, Some(tail_expr))
+ let stmt_tail_opt: Option<ast::Stmt> =
+ block.tail_expr().map(|expr| make::expr_stmt(expr).into());
+
+ let mut elements: Vec<SyntaxElement> = vec![];
+
+ block.statements().for_each(|stmt| {
+ elements.push(syntax::NodeOrToken::Node(stmt.syntax().clone()));
+ });
+
+ if let Some(stmt_list) = block.stmt_list() {
+ stmt_list.syntax().children_with_tokens().for_each(|node_or_token| {
+ match &node_or_token {
+ syntax::NodeOrToken::Token(_) => elements.push(node_or_token),
+ _ => (),
+ };
+ });
+ }
+
+ if let Some(stmt_tail) = stmt_tail_opt {
+ elements.push(syntax::NodeOrToken::Node(stmt_tail.syntax().clone()));
+ }
+
+ make::hacky_block_expr(elements, Some(tail_expr))
}
fn format_type(ty: &hir::Type, ctx: &AssistContext<'_>, module: hir::Module) -> String {
@@ -1946,7 +1981,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) {
if nested_scope.is_none() {
if let Some(expr) = ast::Expr::cast(e.clone()) {
match expr {
- ast::Expr::ReturnExpr(return_expr) if nested_scope.is_none() => {
+ ast::Expr::ReturnExpr(return_expr) => {
let expr = return_expr.expr();
if let Some(replacement) = make_rewritten_flow(handler, expr) {
ted::replace(return_expr.syntax(), replacement.syntax())
@@ -4944,9 +4979,8 @@ fn $0fun_name() {
);
}
- // FIXME: we do want to preserve whitespace
#[test]
- fn extract_function_does_not_preserve_whitespace() {
+ fn extract_function_does_preserve_whitespace() {
check_assist(
extract_function,
r#"
@@ -4965,6 +4999,7 @@ fn func() {
fn $0fun_name() {
let a = 0;
+
let x = 0;
}
"#,
@@ -5585,4 +5620,191 @@ fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
"#,
);
}
+
+ #[test]
+ fn non_tail_expr_of_tail_expr_loop() {
+ check_assist(
+ extract_function,
+ r#"
+pub fn f() {
+ loop {
+ $0if true {
+ continue;
+ }$0
+
+ if false {
+ break;
+ }
+ }
+}
+"#,
+ r#"
+pub fn f() {
+ loop {
+ if let ControlFlow::Break(_) = fun_name() {
+ continue;
+ }
+
+ if false {
+ break;
+ }
+ }
+}
+
+fn $0fun_name() -> ControlFlow<()> {
+ if true {
+ return ControlFlow::Break(());
+ }
+ ControlFlow::Continue(())
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn non_tail_expr_of_tail_if_block() {
+ // FIXME: double semicolon
+ check_assist(
+ extract_function,
+ r#"
+//- minicore: option, try
+impl<T> core::ops::Try for Option<T> {
+ type Output = T;
+ type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+ if true {
+ let a = $0if true {
+ Some(())?
+ } else {
+ ()
+ }$0;
+ Some(a)
+ } else {
+ None
+ }
+}
+"#,
+ r#"
+impl<T> core::ops::Try for Option<T> {
+ type Output = T;
+ type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+ if true {
+ let a = fun_name()?;;
+ Some(a)
+ } else {
+ None
+ }
+}
+
+fn $0fun_name() -> Option<()> {
+ Some(if true {
+ Some(())?
+ } else {
+ ()
+ })
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn tail_expr_of_tail_block_nested() {
+ check_assist(
+ extract_function,
+ r#"
+//- minicore: option, try
+impl<T> core::ops::Try for Option<T> {
+ type Output = T;
+ type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+ if true {
+ $0{
+ let a = if true {
+ Some(())?
+ } else {
+ ()
+ };
+ Some(a)
+ }$0
+ } else {
+ None
+ }
+}
+"#,
+ r#"
+impl<T> core::ops::Try for Option<T> {
+ type Output = T;
+ type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+ if true {
+ fun_name()?
+ } else {
+ None
+ }
+}
+
+fn $0fun_name() -> Option<()> {
+ let a = if true {
+ Some(())?
+ } else {
+ ()
+ };
+ Some(a)
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn non_tail_expr_with_comment_of_tail_expr_loop() {
+ check_assist(
+ extract_function,
+ r#"
+pub fn f() {
+ loop {
+ $0// A comment
+ if true {
+ continue;
+ }$0
+ if false {
+ break;
+ }
+ }
+}
+"#,
+ r#"
+pub fn f() {
+ loop {
+ if let ControlFlow::Break(_) = fun_name() {
+ continue;
+ }
+ if false {
+ break;
+ }
+ }
+}
+
+fn $0fun_name() -> ControlFlow<()> {
+ // A comment
+ if true {
+ return ControlFlow::Break(());
+ }
+ ControlFlow::Continue(())
+}
+"#,
+ );
+ }
}