Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-def/src/expr_store/expander.rs')
| -rw-r--r-- | crates/hir-def/src/expr_store/expander.rs | 220 |
1 files changed, 220 insertions, 0 deletions
diff --git a/crates/hir-def/src/expr_store/expander.rs b/crates/hir-def/src/expr_store/expander.rs new file mode 100644 index 0000000000..7eec913dd6 --- /dev/null +++ b/crates/hir-def/src/expr_store/expander.rs @@ -0,0 +1,220 @@ +//! Macro expansion utilities. + +use std::mem; + +use base_db::Crate; +use drop_bomb::DropBomb; +use hir_expand::{ + ExpandError, ExpandErrorKind, ExpandResult, HirFileId, InFile, Lookup, MacroCallId, + attrs::RawAttrs, eager::EagerCallBackFn, mod_path::ModPath, span_map::SpanMap, +}; +use span::{AstIdMap, Edition, SyntaxContext}; +use syntax::ast::HasAttrs; +use syntax::{Parse, ast}; +use triomphe::Arc; +use tt::TextRange; + +use crate::attr::Attrs; +use crate::expr_store::HygieneId; +use crate::nameres::DefMap; +use crate::{AsMacroCall, MacroId, UnresolvedMacro, db::DefDatabase}; + +#[derive(Debug)] +pub(super) struct Expander { + span_map: SpanMap, + current_file_id: HirFileId, + ast_id_map: Arc<AstIdMap>, + /// `recursion_depth == usize::MAX` indicates that the recursion limit has been reached. + recursion_depth: u32, + recursion_limit: usize, +} + +impl Expander { + pub(super) fn new( + db: &dyn DefDatabase, + current_file_id: HirFileId, + def_map: &DefMap, + ) -> Expander { + let recursion_limit = def_map.recursion_limit() as usize; + let recursion_limit = if cfg!(test) { + // Without this, `body::tests::your_stack_belongs_to_me` stack-overflows in debug + std::cmp::min(32, recursion_limit) + } else { + recursion_limit + }; + Expander { + current_file_id, + recursion_depth: 0, + recursion_limit, + span_map: db.span_map(current_file_id), + ast_id_map: db.ast_id_map(current_file_id), + } + } + + pub(super) fn ctx_for_range(&self, range: TextRange) -> SyntaxContext { + self.span_map.span_for_range(range).ctx + } + + pub(super) fn hygiene_for_range(&self, db: &dyn DefDatabase, range: TextRange) -> HygieneId { + match self.span_map.as_ref() { + hir_expand::span_map::SpanMapRef::ExpansionSpanMap(span_map) => { + HygieneId::new(span_map.span_at(range.start()).ctx.opaque_and_semitransparent(db)) + } + hir_expand::span_map::SpanMapRef::RealSpanMap(_) => HygieneId::ROOT, + } + } + + pub(super) fn attrs( + &self, + db: &dyn DefDatabase, + krate: Crate, + has_attrs: &dyn HasAttrs, + ) -> Attrs { + Attrs::filter(db, krate, RawAttrs::new(db, has_attrs, self.span_map.as_ref())) + } + + pub(super) fn is_cfg_enabled( + &self, + db: &dyn DefDatabase, + krate: Crate, + has_attrs: &dyn HasAttrs, + ) -> bool { + self.attrs(db, krate, has_attrs).is_cfg_enabled(krate.cfg_options(db)) + } + + pub(super) fn call_syntax_ctx(&self) -> SyntaxContext { + // FIXME: + SyntaxContext::root(Edition::CURRENT_FIXME) + } + + pub(super) fn enter_expand<T: ast::AstNode>( + &mut self, + db: &dyn DefDatabase, + macro_call: ast::MacroCall, + krate: Crate, + resolver: impl Fn(&ModPath) -> Option<MacroId>, + eager_callback: EagerCallBackFn<'_>, + ) -> Result<ExpandResult<Option<(Mark, Option<Parse<T>>)>>, UnresolvedMacro> { + // FIXME: within_limit should support this, instead of us having to extract the error + let mut unresolved_macro_err = None; + + let result = self.within_limit(db, |this| { + let macro_call = this.in_file(¯o_call); + match macro_call.as_call_id_with_errors( + db, + krate, + |path| resolver(path).map(|it| db.macro_def(it)), + eager_callback, + ) { + Ok(call_id) => call_id, + Err(resolve_err) => { + unresolved_macro_err = Some(resolve_err); + ExpandResult { value: None, err: None } + } + } + }); + + if let Some(err) = unresolved_macro_err { Err(err) } else { Ok(result) } + } + + pub(super) fn enter_expand_id<T: ast::AstNode>( + &mut self, + db: &dyn DefDatabase, + call_id: MacroCallId, + ) -> ExpandResult<Option<(Mark, Option<Parse<T>>)>> { + self.within_limit(db, |_this| ExpandResult::ok(Some(call_id))) + } + + pub(super) fn exit(&mut self, Mark { file_id, span_map, ast_id_map, mut bomb }: Mark) { + self.span_map = span_map; + self.current_file_id = file_id; + self.ast_id_map = ast_id_map; + if self.recursion_depth == u32::MAX { + // Recursion limit has been reached somewhere in the macro expansion tree. Reset the + // depth only when we get out of the tree. + if !self.current_file_id.is_macro() { + self.recursion_depth = 0; + } + } else { + self.recursion_depth -= 1; + } + bomb.defuse(); + } + + pub(super) fn in_file<T>(&self, value: T) -> InFile<T> { + InFile { file_id: self.current_file_id, value } + } + + pub(super) fn current_file_id(&self) -> HirFileId { + self.current_file_id + } + + fn within_limit<F, T: ast::AstNode>( + &mut self, + db: &dyn DefDatabase, + op: F, + ) -> ExpandResult<Option<(Mark, Option<Parse<T>>)>> + where + F: FnOnce(&mut Self) -> ExpandResult<Option<MacroCallId>>, + { + if self.recursion_depth == u32::MAX { + // Recursion limit has been reached somewhere in the macro expansion tree. We should + // stop expanding other macro calls in this tree, or else this may result in + // exponential number of macro expansions, leading to a hang. + // + // The overflow error should have been reported when it occurred (see the next branch), + // so don't return overflow error here to avoid diagnostics duplication. + cov_mark::hit!(overflow_but_not_me); + return ExpandResult::ok(None); + } + + let ExpandResult { value, err } = op(self); + let Some(call_id) = value else { + return ExpandResult { value: None, err }; + }; + if self.recursion_depth as usize > self.recursion_limit { + self.recursion_depth = u32::MAX; + cov_mark::hit!(your_stack_belongs_to_me); + return ExpandResult::only_err(ExpandError::new( + db.macro_arg_considering_derives(call_id, &call_id.lookup(db).kind).2, + ExpandErrorKind::RecursionOverflow, + )); + } + + let res = db.parse_macro_expansion(call_id); + + let err = err.or(res.err); + ExpandResult { + value: { + let parse = res.value.0.cast::<T>(); + + self.recursion_depth += 1; + let old_file_id = std::mem::replace(&mut self.current_file_id, call_id.into()); + let old_span_map = + std::mem::replace(&mut self.span_map, db.span_map(self.current_file_id)); + let prev_ast_id_map = + mem::replace(&mut self.ast_id_map, db.ast_id_map(self.current_file_id)); + let mark = Mark { + file_id: old_file_id, + span_map: old_span_map, + ast_id_map: prev_ast_id_map, + bomb: DropBomb::new("expansion mark dropped"), + }; + Some((mark, parse)) + }, + err, + } + } + + pub(super) fn ast_id_map(&self) -> &AstIdMap { + &self.ast_id_map + } +} + +#[derive(Debug)] +pub(super) struct Mark { + file_id: HirFileId, + span_map: SpanMap, + ast_id_map: Arc<AstIdMap>, + bomb: DropBomb, +} |