Unnamed repository; edit this file 'description' to name the repository.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
1
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
//! Various helper functions to work with SyntaxNodes.
use std::ops::ControlFlow;

use itertools::Itertools;
use parser::T;
use span::Edition;
use syntax::{
    AstNode, AstToken, Preorder, RustLanguage, WalkEvent,
    ast::{self, HasLoopBody, MacroCall, PathSegmentKind, VisibilityKind},
};

pub fn expr_as_name_ref(expr: &ast::Expr) -> Option<ast::NameRef> {
    if let ast::Expr::PathExpr(expr) = expr {
        let path = expr.path()?;
        path.as_single_name_ref()
    } else {
        None
    }
}

pub fn full_path_of_name_ref(name_ref: &ast::NameRef) -> Option<ast::Path> {
    let mut ancestors = name_ref.syntax().ancestors();
    let _ = ancestors.next()?; // skip self
    let _ = ancestors.next().filter(|it| ast::PathSegment::can_cast(it.kind()))?; // skip self
    ancestors.take_while(|it| ast::Path::can_cast(it.kind())).last().and_then(ast::Path::cast)
}

pub fn block_as_lone_tail(block: &ast::BlockExpr) -> Option<ast::Expr> {
    block.statements().next().is_none().then(|| block.tail_expr()).flatten()
}

/// Preorder walk all the expression's child expressions.
pub fn walk_expr(expr: &ast::Expr, cb: &mut dyn FnMut(ast::Expr)) {
    preorder_expr(expr, &mut |ev| {
        if let WalkEvent::Enter(expr) = ev {
            cb(expr);
        }
        false
    })
}

pub fn is_closure_or_blk_with_modif(expr: &ast::Expr) -> bool {
    match expr {
        ast::Expr::BlockExpr(block_expr) => {
            matches!(
                block_expr.modifier(),
                Some(
                    ast::BlockModifier::Async(_)
                        | ast::BlockModifier::Try(_)
                        | ast::BlockModifier::Const(_)
                )
            )
        }
        ast::Expr::ClosureExpr(_) => true,
        _ => false,
    }
}

/// Preorder walk all the expression's child expressions preserving events.
/// If the callback returns true on an [`WalkEvent::Enter`], the subtree of the expression will be skipped.
/// Note that the subtree may already be skipped due to the context analysis this function does.
pub fn preorder_expr(start: &ast::Expr, cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool) {
    preorder_expr_with_ctx_checker(start, &is_closure_or_blk_with_modif, cb);
}

pub fn preorder_expr_with_ctx_checker(
    start: &ast::Expr,
    check_ctx: &dyn Fn(&ast::Expr) -> bool,
    cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool,
) {
    let mut preorder = start.syntax().preorder();
    while let Some(event) = preorder.next() {
        let node = match event {
            WalkEvent::Enter(node) => node,
            WalkEvent::Leave(node) => {
                if let Some(expr) = ast::Expr::cast(node) {
                    cb(WalkEvent::Leave(expr));
                }
                continue;
            }
        };
        if let Some(let_stmt) = node.parent().and_then(ast::LetStmt::cast) {
            if let_stmt.initializer().map(|it| it.syntax() != &node).unwrap_or(true)
                && let_stmt.let_else().map(|it| it.syntax() != &node).unwrap_or(true)
            {
                // skipping potential const pat expressions in  let statements
                preorder.skip_subtree();
                continue;
            }
        }

        match ast::Stmt::cast(node.clone()) {
            // Don't skip subtree since we want to process the expression child next
            Some(ast::Stmt::ExprStmt(_)) | Some(ast::Stmt::LetStmt(_)) => (),
            // skip inner items which might have their own expressions
            Some(ast::Stmt::Item(_)) => preorder.skip_subtree(),
            None => {
                // skip const args, those expressions are a different context
                if ast::GenericArg::can_cast(node.kind()) {
                    preorder.skip_subtree();
                } else if let Some(expr) = ast::Expr::cast(node) {
                    let is_different_context = check_ctx(&expr) && expr.syntax() != start.syntax();
                    let skip = cb(WalkEvent::Enter(expr));
                    if skip || is_different_context {
                        preorder.skip_subtree();
                    }
                }
            }
        }
    }
}

/// Preorder walk all the expression's child patterns.
pub fn walk_patterns_in_expr(start: &ast::Expr, cb: &mut dyn FnMut(ast::Pat)) {
    let mut preorder = start.syntax().preorder();
    while let Some(event) = preorder.next() {
        let node = match event {
            WalkEvent::Enter(node) => node,
            WalkEvent::Leave(_) => continue,
        };
        match ast::Stmt::cast(node.clone()) {
            Some(ast::Stmt::LetStmt(l)) => {
                if let Some(pat) = l.pat() {
                    _ = walk_pat(&pat, &mut |pat| {
                        cb(pat);
                        ControlFlow::<(), ()>::Continue(())
                    });
                }
                if let Some(expr) = l.initializer() {
                    walk_patterns_in_expr(&expr, cb);
                }
                preorder.skip_subtree();
            }
            // Don't skip subtree since we want to process the expression child next
            Some(ast::Stmt::ExprStmt(_)) => (),
            // skip inner items which might have their own patterns
            Some(ast::Stmt::Item(_)) => preorder.skip_subtree(),
            None => {
                // skip const args, those are a different context
                if ast::GenericArg::can_cast(node.kind()) {
                    preorder.skip_subtree();
                } else if let Some(expr) = ast::Expr::cast(node.clone()) {
                    let is_different_context = match &expr {
                        ast::Expr::BlockExpr(block_expr) => {
                            matches!(
                                block_expr.modifier(),
                                Some(
                                    ast::BlockModifier::Async(_)
                                        | ast::BlockModifier::Try(_)
                                        | ast::BlockModifier::Const(_)
                                )
                            )
                        }
                        ast::Expr::ClosureExpr(_) => true,
                        _ => false,
                    } && expr.syntax() != start.syntax();
                    if is_different_context {
                        preorder.skip_subtree();
                    }
                } else if let Some(pat) = ast::Pat::cast(node) {
                    preorder.skip_subtree();
                    _ = walk_pat(&pat, &mut |pat| {
                        cb(pat);
                        ControlFlow::<(), ()>::Continue(())
                    });
                }
            }
        }
    }
}

/// Preorder walk all the pattern's sub patterns.
pub fn walk_pat<T>(
    pat: &ast::Pat,
    cb: &mut dyn FnMut(ast::Pat) -> ControlFlow<T>,
) -> ControlFlow<T> {
    let mut preorder = pat.syntax().preorder();
    while let Some(event) = preorder.next() {
        let node = match event {
            WalkEvent::Enter(node) => node,
            WalkEvent::Leave(_) => continue,
        };
        let kind = node.kind();
        match ast::Pat::cast(node) {
            Some(pat @ ast::Pat::ConstBlockPat(_)) => {
                preorder.skip_subtree();
                cb(pat)?;
            }
            Some(pat) => {
                cb(pat)?;
            }
            // skip const args
            None if ast::GenericArg::can_cast(kind) => {
                preorder.skip_subtree();
            }
            None => (),
        }
    }
    ControlFlow::Continue(())
}

/// Preorder walk all the type's sub types.
// FIXME: Make the control flow more proper
pub fn walk_ty(ty: &ast::Type, cb: &mut dyn FnMut(ast::Type) -> bool) {
    let mut preorder = ty.syntax().preorder();
    while let Some(event) = preorder.next() {
        let node = match event {
            WalkEvent::Enter(node) => node,
            WalkEvent::Leave(_) => continue,
        };
        let kind = node.kind();
        match ast::Type::cast(node) {
            Some(ty @ ast::Type::MacroType(_)) => {
                preorder.skip_subtree();
                cb(ty);
            }
            Some(ty) => {
                if cb(ty) {
                    preorder.skip_subtree();
                }
            }
            // skip const args
            None if ast::ConstArg::can_cast(kind) => {
                preorder.skip_subtree();
            }
            None => (),
        }
    }
}

pub fn vis_eq(this: &ast::Visibility, other: &ast::Visibility) -> bool {
    match (this.kind(), other.kind()) {
        (VisibilityKind::In(this), VisibilityKind::In(other)) => {
            stdx::iter_eq_by(this.segments(), other.segments(), |lhs, rhs| {
                lhs.kind().zip(rhs.kind()).is_some_and(|it| match it {
                    (PathSegmentKind::CrateKw, PathSegmentKind::CrateKw)
                    | (PathSegmentKind::SelfKw, PathSegmentKind::SelfKw)
                    | (PathSegmentKind::SuperKw, PathSegmentKind::SuperKw) => true,
                    (PathSegmentKind::Name(lhs), PathSegmentKind::Name(rhs)) => {
                        lhs.text() == rhs.text()
                    }
                    _ => false,
                })
            })
        }
        (VisibilityKind::PubSelf, VisibilityKind::PubSelf)
        | (VisibilityKind::PubSuper, VisibilityKind::PubSuper)
        | (VisibilityKind::PubCrate, VisibilityKind::PubCrate)
        | (VisibilityKind::Pub, VisibilityKind::Pub) => true,
        _ => false,
    }
}

/// Returns the `let` only if there is exactly one (that is, `let pat = expr`
/// or `((let pat = expr))`, but not `let pat = expr && expr` or `non_let_expr`).
pub fn single_let(expr: ast::Expr) -> Option<ast::LetExpr> {
    match expr {
        ast::Expr::ParenExpr(expr) => expr.expr().and_then(single_let),
        ast::Expr::LetExpr(expr) => Some(expr),
        _ => None,
    }
}

pub fn is_pattern_cond(expr: ast::Expr) -> bool {
    match expr {
        ast::Expr::BinExpr(expr)
            if expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And)) =>
        {
            expr.lhs()
                .map(is_pattern_cond)
                .or_else(|| expr.rhs().map(is_pattern_cond))
                .unwrap_or(false)
        }
        ast::Expr::ParenExpr(expr) => expr.expr().is_some_and(is_pattern_cond),
        ast::Expr::LetExpr(_) => true,
        _ => false,
    }
}

/// Calls `cb` on each expression inside `expr` that is at "tail position".
/// Does not walk into `break` or `return` expressions.
/// Note that modifying the tree while iterating it will cause undefined iteration which might
/// potentially results in an out of bounds panic.
pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) {
    let walk_loop = |cb: &mut dyn FnMut(&ast::Expr), label, body: Option<ast::BlockExpr>| {
        for_each_break_expr(label, body.and_then(|it| it.stmt_list()), &mut |b| {
            cb(&ast::Expr::BreakExpr(b))
        })
    };
    match expr {
        ast::Expr::BlockExpr(b) => {
            match b.modifier() {
                Some(
                    ast::BlockModifier::Async(_)
                    | ast::BlockModifier::Try(_)
                    | ast::BlockModifier::Const(_),
                ) => return cb(expr),

                Some(ast::BlockModifier::Label(label)) => {
                    for_each_break_expr(Some(label), b.stmt_list(), &mut |b| {
                        cb(&ast::Expr::BreakExpr(b))
                    });
                }
                Some(ast::BlockModifier::Unsafe(_)) => (),
                Some(ast::BlockModifier::Gen(_)) => (),
                Some(ast::BlockModifier::AsyncGen(_)) => (),
                None => (),
            }
            if let Some(stmt_list) = b.stmt_list() {
                if let Some(e) = stmt_list.tail_expr() {
                    for_each_tail_expr(&e, cb);
                }
            }
        }
        ast::Expr::IfExpr(if_) => {
            let mut if_ = if_.clone();
            loop {
                if let Some(block) = if_.then_branch() {
                    for_each_tail_expr(&ast::Expr::BlockExpr(block), cb);
                }
                match if_.else_branch() {
                    Some(ast::ElseBranch::IfExpr(it)) => if_ = it,
                    Some(ast::ElseBranch::Block(block)) => {
                        for_each_tail_expr(&ast::Expr::BlockExpr(block), cb);
                        break;
                    }
                    None => break,
                }
            }
        }
        ast::Expr::LoopExpr(l) => walk_loop(cb, l.label(), l.loop_body()),
        ast::Expr::WhileExpr(w) => walk_loop(cb, w.label(), w.loop_body()),
        ast::Expr::ForExpr(f) => walk_loop(cb, f.label(), f.loop_body()),
        ast::Expr::MatchExpr(m) => {
            if let Some(arms) = m.match_arm_list() {
                arms.arms().filter_map(|arm| arm.expr()).for_each(|e| for_each_tail_expr(&e, cb));
            }
        }
        ast::Expr::ArrayExpr(_)
        | ast::Expr::AwaitExpr(_)
        | ast::Expr::BinExpr(_)
        | ast::Expr::BreakExpr(_)
        | ast::Expr::CallExpr(_)
        | ast::Expr::CastExpr(_)
        | ast::Expr::ClosureExpr(_)
        | ast::Expr::ContinueExpr(_)
        | ast::Expr::FieldExpr(_)
        | ast::Expr::IndexExpr(_)
        | ast::Expr::Literal(_)
        | ast::Expr::MacroExpr(_)
        | ast::Expr::MethodCallExpr(_)
        | ast::Expr::ParenExpr(_)
        | ast::Expr::PathExpr(_)
        | ast::Expr::PrefixExpr(_)
        | ast::Expr::RangeExpr(_)
        | ast::Expr::RecordExpr(_)
        | ast::Expr::RefExpr(_)
        | ast::Expr::ReturnExpr(_)
        | ast::Expr::BecomeExpr(_)
        | ast::Expr::TryExpr(_)
        | ast::Expr::TupleExpr(_)
        | ast::Expr::LetExpr(_)
        | ast::Expr::UnderscoreExpr(_)
        | ast::Expr::YieldExpr(_)
        | ast::Expr::YeetExpr(_)
        | ast::Expr::OffsetOfExpr(_)
        | ast::Expr::FormatArgsExpr(_)
        | ast::Expr::AsmExpr(_) => cb(expr),
    }
}

pub fn for_each_break_and_continue_expr(
    label: Option<ast::Label>,
    body: Option<ast::StmtList>,
    cb: &mut dyn FnMut(ast::Expr),
) {
    let label = label.and_then(|lbl| lbl.lifetime());
    if let Some(b) = body {
        let tree_depth_iterator = TreeWithDepthIterator::new(b);
        for (expr, depth) in tree_depth_iterator {
            match expr {
                ast::Expr::BreakExpr(b)
                    if (depth == 0 && b.lifetime().is_none())
                        || eq_label_lt(&label, &b.lifetime()) =>
                {
                    cb(ast::Expr::BreakExpr(b));
                }
                ast::Expr::ContinueExpr(c)
                    if (depth == 0 && c.lifetime().is_none())
                        || eq_label_lt(&label, &c.lifetime()) =>
                {
                    cb(ast::Expr::ContinueExpr(c));
                }
                _ => (),
            }
        }
    }
}

fn for_each_break_expr(
    label: Option<ast::Label>,
    body: Option<ast::StmtList>,
    cb: &mut dyn FnMut(ast::BreakExpr),
) {
    let label = label.and_then(|lbl| lbl.lifetime());
    if let Some(b) = body {
        let tree_depth_iterator = TreeWithDepthIterator::new(b);
        for (expr, depth) in tree_depth_iterator {
            match expr {
                ast::Expr::BreakExpr(b)
                    if (depth == 0 && b.lifetime().is_none())
                        || eq_label_lt(&label, &b.lifetime()) =>
                {
                    cb(b);
                }
                _ => (),
            }
        }
    }
}

pub fn eq_label_lt(lt1: &Option<ast::Lifetime>, lt2: &Option<ast::Lifetime>) -> bool {
    lt1.as_ref().zip(lt2.as_ref()).is_some_and(|(lt, lbl)| lt.text() == lbl.text())
}

struct TreeWithDepthIterator {
    preorder: Preorder<RustLanguage>,
    depth: u32,
}

impl TreeWithDepthIterator {
    fn new(body: ast::StmtList) -> Self {
        let preorder = body.syntax().preorder();
        Self { preorder, depth: 0 }
    }
}

impl Iterator for TreeWithDepthIterator {
    type Item = (ast::Expr, u32);

    fn next(&mut self) -> Option<Self::Item> {
        while let Some(event) = self.preorder.find_map(|ev| match ev {
            WalkEvent::Enter(it) => ast::Expr::cast(it).map(WalkEvent::Enter),
            WalkEvent::Leave(it) => ast::Expr::cast(it).map(WalkEvent::Leave),
        }) {
            match event {
                WalkEvent::Enter(
                    ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_),
                ) => {
                    self.depth += 1;
                }
                WalkEvent::Leave(
                    ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_),
                ) => {
                    self.depth -= 1;
                }
                WalkEvent::Enter(ast::Expr::BlockExpr(e)) if e.label().is_some() => {
                    self.depth += 1;
                }
                WalkEvent::Leave(ast::Expr::BlockExpr(e)) if e.label().is_some() => {
                    self.depth -= 1;
                }
                WalkEvent::Enter(expr) => return Some((expr, self.depth)),
                _ => (),
            }
        }
        None
    }
}

/// Parses the input token tree as comma separated plain paths.
pub fn parse_tt_as_comma_sep_paths(
    input: ast::TokenTree,
    edition: Edition,
) -> Option<Vec<ast::Path>> {
    let r_paren = input.r_paren_token();
    let tokens =
        input.syntax().children_with_tokens().skip(1).map_while(|it| match it.into_token() {
            // seeing a keyword means the attribute is unclosed so stop parsing here
            Some(tok) if tok.kind().is_keyword(edition) => None,
            // don't include the right token tree parenthesis if it exists
            tok @ Some(_) if tok == r_paren => None,
            // only nodes that we can find are other TokenTrees, those are unexpected in this parse though
            None => None,
            Some(tok) => Some(tok),
        });
    let input_expressions = tokens.chunk_by(|tok| tok.kind() == T![,]);
    let paths = input_expressions
        .into_iter()
        .filter_map(|(is_sep, group)| (!is_sep).then_some(group))
        .filter_map(|mut tokens| {
            syntax::hacks::parse_expr_from_str(&tokens.join(""), Edition::CURRENT).and_then(
                |expr| match expr {
                    ast::Expr::PathExpr(it) => it.path(),
                    _ => None,
                },
            )
        })
        .collect();
    Some(paths)
}

pub fn macro_call_for_string_token(string: &ast::String) -> Option<MacroCall> {
    let macro_call = string.syntax().parent_ancestors().find_map(ast::MacroCall::cast)?;
    Some(macro_call)
}