Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-def/src/expr_store/expander.rs')
-rw-r--r--crates/hir-def/src/expr_store/expander.rs220
1 files changed, 220 insertions, 0 deletions
diff --git a/crates/hir-def/src/expr_store/expander.rs b/crates/hir-def/src/expr_store/expander.rs
new file mode 100644
index 0000000000..7eec913dd6
--- /dev/null
+++ b/crates/hir-def/src/expr_store/expander.rs
@@ -0,0 +1,220 @@
+//! Macro expansion utilities.
+
+use std::mem;
+
+use base_db::Crate;
+use drop_bomb::DropBomb;
+use hir_expand::{
+ ExpandError, ExpandErrorKind, ExpandResult, HirFileId, InFile, Lookup, MacroCallId,
+ attrs::RawAttrs, eager::EagerCallBackFn, mod_path::ModPath, span_map::SpanMap,
+};
+use span::{AstIdMap, Edition, SyntaxContext};
+use syntax::ast::HasAttrs;
+use syntax::{Parse, ast};
+use triomphe::Arc;
+use tt::TextRange;
+
+use crate::attr::Attrs;
+use crate::expr_store::HygieneId;
+use crate::nameres::DefMap;
+use crate::{AsMacroCall, MacroId, UnresolvedMacro, db::DefDatabase};
+
+#[derive(Debug)]
+pub(super) struct Expander {
+ span_map: SpanMap,
+ current_file_id: HirFileId,
+ ast_id_map: Arc<AstIdMap>,
+ /// `recursion_depth == usize::MAX` indicates that the recursion limit has been reached.
+ recursion_depth: u32,
+ recursion_limit: usize,
+}
+
+impl Expander {
+ pub(super) fn new(
+ db: &dyn DefDatabase,
+ current_file_id: HirFileId,
+ def_map: &DefMap,
+ ) -> Expander {
+ let recursion_limit = def_map.recursion_limit() as usize;
+ let recursion_limit = if cfg!(test) {
+ // Without this, `body::tests::your_stack_belongs_to_me` stack-overflows in debug
+ std::cmp::min(32, recursion_limit)
+ } else {
+ recursion_limit
+ };
+ Expander {
+ current_file_id,
+ recursion_depth: 0,
+ recursion_limit,
+ span_map: db.span_map(current_file_id),
+ ast_id_map: db.ast_id_map(current_file_id),
+ }
+ }
+
+ pub(super) fn ctx_for_range(&self, range: TextRange) -> SyntaxContext {
+ self.span_map.span_for_range(range).ctx
+ }
+
+ pub(super) fn hygiene_for_range(&self, db: &dyn DefDatabase, range: TextRange) -> HygieneId {
+ match self.span_map.as_ref() {
+ hir_expand::span_map::SpanMapRef::ExpansionSpanMap(span_map) => {
+ HygieneId::new(span_map.span_at(range.start()).ctx.opaque_and_semitransparent(db))
+ }
+ hir_expand::span_map::SpanMapRef::RealSpanMap(_) => HygieneId::ROOT,
+ }
+ }
+
+ pub(super) fn attrs(
+ &self,
+ db: &dyn DefDatabase,
+ krate: Crate,
+ has_attrs: &dyn HasAttrs,
+ ) -> Attrs {
+ Attrs::filter(db, krate, RawAttrs::new(db, has_attrs, self.span_map.as_ref()))
+ }
+
+ pub(super) fn is_cfg_enabled(
+ &self,
+ db: &dyn DefDatabase,
+ krate: Crate,
+ has_attrs: &dyn HasAttrs,
+ ) -> bool {
+ self.attrs(db, krate, has_attrs).is_cfg_enabled(krate.cfg_options(db))
+ }
+
+ pub(super) fn call_syntax_ctx(&self) -> SyntaxContext {
+ // FIXME:
+ SyntaxContext::root(Edition::CURRENT_FIXME)
+ }
+
+ pub(super) fn enter_expand<T: ast::AstNode>(
+ &mut self,
+ db: &dyn DefDatabase,
+ macro_call: ast::MacroCall,
+ krate: Crate,
+ resolver: impl Fn(&ModPath) -> Option<MacroId>,
+ eager_callback: EagerCallBackFn<'_>,
+ ) -> Result<ExpandResult<Option<(Mark, Option<Parse<T>>)>>, UnresolvedMacro> {
+ // FIXME: within_limit should support this, instead of us having to extract the error
+ let mut unresolved_macro_err = None;
+
+ let result = self.within_limit(db, |this| {
+ let macro_call = this.in_file(&macro_call);
+ match macro_call.as_call_id_with_errors(
+ db,
+ krate,
+ |path| resolver(path).map(|it| db.macro_def(it)),
+ eager_callback,
+ ) {
+ Ok(call_id) => call_id,
+ Err(resolve_err) => {
+ unresolved_macro_err = Some(resolve_err);
+ ExpandResult { value: None, err: None }
+ }
+ }
+ });
+
+ if let Some(err) = unresolved_macro_err { Err(err) } else { Ok(result) }
+ }
+
+ pub(super) fn enter_expand_id<T: ast::AstNode>(
+ &mut self,
+ db: &dyn DefDatabase,
+ call_id: MacroCallId,
+ ) -> ExpandResult<Option<(Mark, Option<Parse<T>>)>> {
+ self.within_limit(db, |_this| ExpandResult::ok(Some(call_id)))
+ }
+
+ pub(super) fn exit(&mut self, Mark { file_id, span_map, ast_id_map, mut bomb }: Mark) {
+ self.span_map = span_map;
+ self.current_file_id = file_id;
+ self.ast_id_map = ast_id_map;
+ if self.recursion_depth == u32::MAX {
+ // Recursion limit has been reached somewhere in the macro expansion tree. Reset the
+ // depth only when we get out of the tree.
+ if !self.current_file_id.is_macro() {
+ self.recursion_depth = 0;
+ }
+ } else {
+ self.recursion_depth -= 1;
+ }
+ bomb.defuse();
+ }
+
+ pub(super) fn in_file<T>(&self, value: T) -> InFile<T> {
+ InFile { file_id: self.current_file_id, value }
+ }
+
+ pub(super) fn current_file_id(&self) -> HirFileId {
+ self.current_file_id
+ }
+
+ fn within_limit<F, T: ast::AstNode>(
+ &mut self,
+ db: &dyn DefDatabase,
+ op: F,
+ ) -> ExpandResult<Option<(Mark, Option<Parse<T>>)>>
+ where
+ F: FnOnce(&mut Self) -> ExpandResult<Option<MacroCallId>>,
+ {
+ if self.recursion_depth == u32::MAX {
+ // Recursion limit has been reached somewhere in the macro expansion tree. We should
+ // stop expanding other macro calls in this tree, or else this may result in
+ // exponential number of macro expansions, leading to a hang.
+ //
+ // The overflow error should have been reported when it occurred (see the next branch),
+ // so don't return overflow error here to avoid diagnostics duplication.
+ cov_mark::hit!(overflow_but_not_me);
+ return ExpandResult::ok(None);
+ }
+
+ let ExpandResult { value, err } = op(self);
+ let Some(call_id) = value else {
+ return ExpandResult { value: None, err };
+ };
+ if self.recursion_depth as usize > self.recursion_limit {
+ self.recursion_depth = u32::MAX;
+ cov_mark::hit!(your_stack_belongs_to_me);
+ return ExpandResult::only_err(ExpandError::new(
+ db.macro_arg_considering_derives(call_id, &call_id.lookup(db).kind).2,
+ ExpandErrorKind::RecursionOverflow,
+ ));
+ }
+
+ let res = db.parse_macro_expansion(call_id);
+
+ let err = err.or(res.err);
+ ExpandResult {
+ value: {
+ let parse = res.value.0.cast::<T>();
+
+ self.recursion_depth += 1;
+ let old_file_id = std::mem::replace(&mut self.current_file_id, call_id.into());
+ let old_span_map =
+ std::mem::replace(&mut self.span_map, db.span_map(self.current_file_id));
+ let prev_ast_id_map =
+ mem::replace(&mut self.ast_id_map, db.ast_id_map(self.current_file_id));
+ let mark = Mark {
+ file_id: old_file_id,
+ span_map: old_span_map,
+ ast_id_map: prev_ast_id_map,
+ bomb: DropBomb::new("expansion mark dropped"),
+ };
+ Some((mark, parse))
+ },
+ err,
+ }
+ }
+
+ pub(super) fn ast_id_map(&self) -> &AstIdMap {
+ &self.ast_id_map
+ }
+}
+
+#[derive(Debug)]
+pub(super) struct Mark {
+ file_id: HirFileId,
+ span_map: SpanMap,
+ ast_id_map: Arc<AstIdMap>,
+ bomb: DropBomb,
+}