//! Macro expansion utilities. use std::mem; use base_db::Crate; use cfg::CfgOptions; use drop_bomb::DropBomb; use hir_expand::AstId; use hir_expand::span_map::SpanMapRef; use hir_expand::{ ExpandError, ExpandErrorKind, ExpandResult, HirFileId, InFile, Lookup, MacroCallId, eager::EagerCallBackFn, mod_path::ModPath, span_map::SpanMap, }; use span::{AstIdMap, SyntaxContext}; use syntax::ast::HasAttrs; use syntax::{AstNode, Parse, ast}; use triomphe::Arc; use tt::TextRange; use crate::{ MacroId, UnresolvedMacro, attrs::AttrFlags, db::DefDatabase, expr_store::HygieneId, macro_call_as_call_id, nameres::DefMap, }; #[derive(Debug)] pub(super) struct Expander { span_map: SpanMap, current_file_id: HirFileId, ast_id_map: Arc, /// `recursion_depth == usize::MAX` indicates that the recursion limit has been reached. recursion_depth: u32, recursion_limit: usize, } impl Expander { pub(super) fn new( db: &dyn DefDatabase, current_file_id: HirFileId, def_map: &DefMap, ) -> Expander { let recursion_limit = def_map.recursion_limit() as usize; let recursion_limit = if cfg!(test) { // Without this, `body::tests::your_stack_belongs_to_me` stack-overflows in debug std::cmp::min(32, recursion_limit) } else { recursion_limit }; Expander { current_file_id, recursion_depth: 0, recursion_limit, span_map: db.span_map(current_file_id), ast_id_map: db.ast_id_map(current_file_id), } } pub(super) fn ctx_for_range(&self, range: TextRange) -> SyntaxContext { self.span_map.span_for_range(range).ctx } pub(super) fn hygiene_for_range(&self, db: &dyn DefDatabase, range: TextRange) -> HygieneId { match self.span_map.as_ref() { hir_expand::span_map::SpanMapRef::ExpansionSpanMap(span_map) => { HygieneId::new(span_map.span_at(range.start()).ctx.opaque_and_semiopaque(db)) } hir_expand::span_map::SpanMapRef::RealSpanMap(_) => HygieneId::ROOT, } } pub(super) fn is_cfg_enabled( &self, owner: &dyn HasAttrs, cfg_options: &CfgOptions, ) -> Result<(), cfg::CfgExpr> { AttrFlags::is_cfg_enabled_for(owner, cfg_options) } pub(super) fn enter_expand( &mut self, db: &dyn DefDatabase, macro_call: ast::MacroCall, krate: Crate, resolver: impl Fn(&ModPath) -> Option, eager_callback: EagerCallBackFn<'_>, ) -> Result>)>>, UnresolvedMacro> { // FIXME: within_limit should support this, instead of us having to extract the error let mut unresolved_macro_err = None; let result = self.within_limit(db, |this| { let macro_call = this.in_file(¯o_call); let expands_to = hir_expand::ExpandTo::from_call_site(macro_call.value); let ast_id = AstId::new(macro_call.file_id, this.ast_id_map().ast_id(macro_call.value)); let path = macro_call.value.path().and_then(|path| { let range = path.syntax().text_range(); let mod_path = ModPath::from_src(db, path, &mut |range| { this.span_map.span_for_range(range).ctx })?; let call_site = this.span_map.span_for_range(range); Some((call_site, mod_path)) }); let Some((call_site, path)) = path else { return ExpandResult::only_err(ExpandError::other( this.span_map.span_for_range(macro_call.value.syntax().text_range()), "malformed macro invocation", )); }; match macro_call_as_call_id( db, ast_id, &path, call_site.ctx, expands_to, krate, |path| resolver(path).map(|it| db.macro_def(it)), eager_callback, ) { Ok(call_id) => call_id, Err(resolve_err) => { unresolved_macro_err = Some(resolve_err); ExpandResult { value: None, err: None } } } }); if let Some(err) = unresolved_macro_err { Err(err) } else { Ok(result) } } pub(super) fn enter_expand_id( &mut self, db: &dyn DefDatabase, call_id: MacroCallId, ) -> ExpandResult>)>> { self.within_limit(db, |_this| ExpandResult::ok(Some(call_id))) } pub(super) fn exit(&mut self, Mark { file_id, span_map, ast_id_map, mut bomb }: Mark) { self.span_map = span_map; self.current_file_id = file_id; self.ast_id_map = ast_id_map; if self.recursion_depth == u32::MAX { // Recursion limit has been reached somewhere in the macro expansion tree. Reset the // depth only when we get out of the tree. if !self.current_file_id.is_macro() { self.recursion_depth = 0; } } else { self.recursion_depth -= 1; } bomb.defuse(); } pub(super) fn in_file(&self, value: T) -> InFile { InFile { file_id: self.current_file_id, value } } pub(super) fn current_file_id(&self) -> HirFileId { self.current_file_id } fn within_limit( &mut self, db: &dyn DefDatabase, op: F, ) -> ExpandResult>)>> where F: FnOnce(&mut Self) -> ExpandResult>, { if self.recursion_depth == u32::MAX { // Recursion limit has been reached somewhere in the macro expansion tree. We should // stop expanding other macro calls in this tree, or else this may result in // exponential number of macro expansions, leading to a hang. // // The overflow error should have been reported when it occurred (see the next branch), // so don't return overflow error here to avoid diagnostics duplication. cov_mark::hit!(overflow_but_not_me); return ExpandResult::ok(None); } let ExpandResult { value, err } = op(self); let Some(call_id) = value else { return ExpandResult { value: None, err }; }; if self.recursion_depth as usize > self.recursion_limit { self.recursion_depth = u32::MAX; cov_mark::hit!(your_stack_belongs_to_me); return ExpandResult::only_err(ExpandError::new( db.macro_arg_considering_derives(call_id, &call_id.lookup(db).kind).2, ExpandErrorKind::RecursionOverflow, )); } let res = db.parse_macro_expansion(call_id); let err = err.or(res.err); ExpandResult { value: { let parse = res.value.0.cast::(); self.recursion_depth += 1; let old_file_id = std::mem::replace(&mut self.current_file_id, call_id.into()); let old_span_map = std::mem::replace(&mut self.span_map, db.span_map(self.current_file_id)); let prev_ast_id_map = mem::replace(&mut self.ast_id_map, db.ast_id_map(self.current_file_id)); let mark = Mark { file_id: old_file_id, span_map: old_span_map, ast_id_map: prev_ast_id_map, bomb: DropBomb::new("expansion mark dropped"), }; Some((mark, parse)) }, err, } } #[inline] pub(super) fn ast_id_map(&self) -> &AstIdMap { &self.ast_id_map } #[inline] pub(super) fn span_map(&self) -> SpanMapRef<'_> { self.span_map.as_ref() } } #[derive(Debug)] pub(super) struct Mark { file_id: HirFileId, span_map: SpanMap, ast_id_map: Arc, bomb: DropBomb, }