Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-def/src/expr_store/lower/format_args.rs')
| -rw-r--r-- | crates/hir-def/src/expr_store/lower/format_args.rs | 461 |
1 files changed, 387 insertions, 74 deletions
diff --git a/crates/hir-def/src/expr_store/lower/format_args.rs b/crates/hir-def/src/expr_store/lower/format_args.rs index aa7bf0e4e2..7efc9a956c 100644 --- a/crates/hir-def/src/expr_store/lower/format_args.rs +++ b/crates/hir-def/src/expr_store/lower/format_args.rs @@ -19,6 +19,7 @@ use crate::{ FormatPlaceholder, FormatSign, FormatTrait, }, }, + lang_item::LangItemTarget, type_ref::{Mutability, Rawness}, }; @@ -94,6 +95,347 @@ impl<'db> ExprCollector<'db> { ), }; + let idx = if self.lang_items().FormatCount.is_none() { + self.collect_format_args_after_1_93_0_impl(syntax_ptr, fmt) + } else { + self.collect_format_args_before_1_93_0_impl(syntax_ptr, fmt) + }; + + self.store + .template_map + .get_or_insert_with(Default::default) + .format_args_to_captures + .insert(idx, (hygiene, mappings)); + idx + } + + fn collect_format_args_after_1_93_0_impl( + &mut self, + syntax_ptr: AstPtr<ast::Expr>, + fmt: FormatArgs, + ) -> ExprId { + let lang_items = self.lang_items(); + + // Create a list of all _unique_ (argument, format trait) combinations. + // E.g. "{0} {0:x} {0} {1}" -> [(0, Display), (0, LowerHex), (1, Display)] + // + // We use usize::MAX for arguments that don't exist, because that can never be a valid index + // into the arguments array. + let mut argmap = FxIndexSet::default(); + + let mut incomplete_lit = String::new(); + + let mut implicit_arg_index = 0; + + let mut bytecode = Vec::new(); + + let template = if fmt.template.is_empty() { + // Treat empty templates as a single literal piece (with an empty string), + // so we produce `from_str("")` for those. + &[FormatArgsPiece::Literal(sym::__empty)][..] + } else { + &fmt.template[..] + }; + + // See library/core/src/fmt/mod.rs for the format string encoding format. + + for (i, piece) in template.iter().enumerate() { + match piece { + FormatArgsPiece::Literal(sym) => { + // Coalesce adjacent literal pieces. + if let Some(FormatArgsPiece::Literal(_)) = template.get(i + 1) { + incomplete_lit.push_str(sym.as_str()); + continue; + } + let mut s = if incomplete_lit.is_empty() { + sym.as_str() + } else { + incomplete_lit.push_str(sym.as_str()); + &incomplete_lit + }; + + // If this is the last piece and was the only piece, that means + // there are no placeholders and the entire format string is just a literal. + // + // In that case, we can just use `from_str`. + if i + 1 == template.len() && bytecode.is_empty() { + // Generate: + // <core::fmt::Arguments>::from_str("meow") + let from_str = self.ty_rel_lang_path_desugared_expr( + lang_items.FormatArguments, + sym::from_str, + ); + let sym = + if incomplete_lit.is_empty() { sym.clone() } else { Symbol::intern(s) }; + let s = self.alloc_expr_desugared(Expr::Literal(Literal::String(sym))); + let from_str = self.alloc_expr( + Expr::Call { callee: from_str, args: Box::new([s]) }, + syntax_ptr, + ); + return if !fmt.arguments.arguments.is_empty() { + // With an incomplete format string (e.g. only an opening `{`), it's possible for `arguments` + // to be non-empty when reaching this code path. + self.alloc_expr( + Expr::Block { + id: None, + statements: fmt + .arguments + .arguments + .iter() + .map(|arg| Statement::Expr { + expr: arg.expr, + has_semi: true, + }) + .collect(), + tail: Some(from_str), + label: None, + }, + syntax_ptr, + ) + } else { + from_str + }; + } + + // Encode the literal in chunks of up to u16::MAX bytes, split at utf-8 boundaries. + while !s.is_empty() { + let len = s.floor_char_boundary(usize::from(u16::MAX)); + if len < 0x80 { + bytecode.push(len as u8); + } else { + bytecode.push(0x80); + bytecode.extend_from_slice(&(len as u16).to_le_bytes()); + } + bytecode.extend(&s.as_bytes()[..len]); + s = &s[len..]; + } + + incomplete_lit.clear(); + } + FormatArgsPiece::Placeholder(p) => { + // Push the start byte and remember its index so we can set the option bits later. + let i = bytecode.len(); + bytecode.push(0xC0); + + let position = match &p.argument.index { + &Ok(it) => it, + Err(_) => usize::MAX, + }; + let position = argmap + .insert_full((position, ArgumentType::Format(p.format_trait))) + .0 as u64; + + // This needs to match the constants in library/core/src/fmt/mod.rs. + let o = &p.format_options; + let align = match o.alignment { + Some(FormatAlignment::Left) => 0, + Some(FormatAlignment::Right) => 1, + Some(FormatAlignment::Center) => 2, + None => 3, + }; + let default_flags = 0x6000_0020; + let flags: u32 = o.fill.unwrap_or(' ') as u32 + | ((o.sign == Some(FormatSign::Plus)) as u32) << 21 + | ((o.sign == Some(FormatSign::Minus)) as u32) << 22 + | (o.alternate as u32) << 23 + | (o.zero_pad as u32) << 24 + | ((o.debug_hex == Some(FormatDebugHex::Lower)) as u32) << 25 + | ((o.debug_hex == Some(FormatDebugHex::Upper)) as u32) << 26 + | (o.width.is_some() as u32) << 27 + | (o.precision.is_some() as u32) << 28 + | align << 29; + if flags != default_flags { + bytecode[i] |= 1; + bytecode.extend_from_slice(&flags.to_le_bytes()); + if let Some(val) = &o.width { + let (indirect, val) = self.make_count_after_1_93_0(val, &mut argmap); + // Only encode if nonzero; zero is the default. + if indirect || val != 0 { + bytecode[i] |= 1 << 1 | (indirect as u8) << 4; + bytecode.extend_from_slice(&val.to_le_bytes()); + } + } + if let Some(val) = &o.precision { + let (indirect, val) = self.make_count_after_1_93_0(val, &mut argmap); + // Only encode if nonzero; zero is the default. + if indirect || val != 0 { + bytecode[i] |= 1 << 2 | (indirect as u8) << 5; + bytecode.extend_from_slice(&val.to_le_bytes()); + } + } + } + if implicit_arg_index != position { + bytecode[i] |= 1 << 3; + bytecode.extend_from_slice(&(position as u16).to_le_bytes()); + } + implicit_arg_index = position + 1; + } + } + } + + assert!(incomplete_lit.is_empty()); + + // Zero terminator. + bytecode.push(0); + + // Ensure all argument indexes actually fit in 16 bits, as we truncated them to 16 bits before. + if argmap.len() > u16::MAX as usize { + // FIXME: Emit an error. + // ctx.dcx().span_err(macsp, "too many format arguments"); + } + + let arguments = &fmt.arguments.arguments[..]; + + let (mut statements, args) = if arguments.is_empty() { + // Generate: + // [] + ( + Vec::new(), + self.alloc_expr_desugared(Expr::Array(Array::ElementList { + elements: Box::new([]), + })), + ) + } else { + // Generate: + // super let args = (&arg0, &arg1, &…); + let args_name = self.generate_new_name(); + let args_path = Path::from(args_name.clone()); + let args_binding = self.alloc_binding( + args_name.clone(), + BindingAnnotation::Unannotated, + HygieneId::ROOT, + ); + let args_pat = self.alloc_pat_desugared(Pat::Bind { id: args_binding, subpat: None }); + self.add_definition_to_binding(args_binding, args_pat); + let elements = arguments + .iter() + .map(|arg| { + self.alloc_expr_desugared(Expr::Ref { + expr: arg.expr, + rawness: Rawness::Ref, + mutability: Mutability::Shared, + }) + }) + .collect(); + let args_tuple = self.alloc_expr_desugared(Expr::Tuple { exprs: elements }); + // FIXME: Make this a `super let` when we have this statement. + let let_statement_1 = Statement::Let { + pat: args_pat, + type_ref: None, + initializer: Some(args_tuple), + else_branch: None, + }; + + // Generate: + // super let args = [ + // <core::fmt::Argument>::new_display(args.0), + // <core::fmt::Argument>::new_lower_hex(args.1), + // <core::fmt::Argument>::new_debug(args.0), + // … + // ]; + let args = argmap + .iter() + .map(|&(arg_index, ty)| { + let args_ident_expr = self.alloc_expr_desugared(Expr::Path(args_path.clone())); + let arg = self.alloc_expr_desugared(Expr::Field { + expr: args_ident_expr, + name: Name::new_tuple_field(arg_index), + }); + self.make_argument(arg, ty) + }) + .collect(); + let args = + self.alloc_expr_desugared(Expr::Array(Array::ElementList { elements: args })); + let args_binding = + self.alloc_binding(args_name, BindingAnnotation::Unannotated, HygieneId::ROOT); + let args_pat = self.alloc_pat_desugared(Pat::Bind { id: args_binding, subpat: None }); + self.add_definition_to_binding(args_binding, args_pat); + // FIXME: Make this a `super let` when we have this statement. + let let_statement_2 = Statement::Let { + pat: args_pat, + type_ref: None, + initializer: Some(args), + else_branch: None, + }; + ( + vec![let_statement_1, let_statement_2], + self.alloc_expr_desugared(Expr::Path(args_path)), + ) + }; + + // Generate: + // unsafe { + // <core::fmt::Arguments>::new(b"…", &args) + // } + let template = self + .alloc_expr_desugared(Expr::Literal(Literal::ByteString(bytecode.into_boxed_slice()))); + let call = { + let new = self.ty_rel_lang_path_desugared_expr(lang_items.FormatArguments, sym::new); + let args = self.alloc_expr_desugared(Expr::Ref { + expr: args, + rawness: Rawness::Ref, + mutability: Mutability::Shared, + }); + self.alloc_expr_desugared(Expr::Call { callee: new, args: Box::new([template, args]) }) + }; + let call = self.alloc_expr( + Expr::Unsafe { id: None, statements: Box::new([]), tail: Some(call) }, + syntax_ptr, + ); + + // We collect the unused expressions here so that we still infer them instead of + // dropping them out of the expression tree. We cannot store them in the `Unsafe` + // block because then unsafe blocks within them will get a false "unused unsafe" + // diagnostic (rustc has a notion of builtin unsafe blocks, but we don't). + statements + .extend(fmt.orphans.into_iter().map(|expr| Statement::Expr { expr, has_semi: true })); + + if !statements.is_empty() { + // Generate: + // { + // super let … + // super let … + // <core::fmt::Arguments>::new(…) + // } + self.alloc_expr( + Expr::Block { + id: None, + statements: statements.into_boxed_slice(), + tail: Some(call), + label: None, + }, + syntax_ptr, + ) + } else { + call + } + } + + /// Get the value for a `width` or `precision` field. + /// + /// Returns the value and whether it is indirect (an indexed argument) or not. + fn make_count_after_1_93_0( + &self, + count: &FormatCount, + argmap: &mut FxIndexSet<(usize, ArgumentType)>, + ) -> (bool, u16) { + match count { + FormatCount::Literal(n) => (false, *n), + FormatCount::Argument(arg) => { + let index = match &arg.index { + &Ok(it) => it, + Err(_) => usize::MAX, + }; + (true, argmap.insert_full((index, ArgumentType::Usize)).0 as u16) + } + } + } + + fn collect_format_args_before_1_93_0_impl( + &mut self, + syntax_ptr: AstPtr<ast::Expr>, + fmt: FormatArgs, + ) -> ExprId { // Create a list of all _unique_ (argument, format trait) combinations. // E.g. "{0} {0:x} {0} {1}" -> [(0, Display), (0, LowerHex), (1, Display)] let mut argmap = FxIndexSet::default(); @@ -160,8 +502,14 @@ impl<'db> ExprCollector<'db> { let fmt_unsafe_arg = lang_items.FormatUnsafeArg; let use_format_args_since_1_89_0 = fmt_args.is_some() && fmt_unsafe_arg.is_none(); - let idx = if use_format_args_since_1_89_0 { - self.collect_format_args_impl(syntax_ptr, fmt, argmap, lit_pieces, format_options) + if use_format_args_since_1_89_0 { + self.collect_format_args_after_1_89_0_impl( + syntax_ptr, + fmt, + argmap, + lit_pieces, + format_options, + ) } else { self.collect_format_args_before_1_89_0_impl( syntax_ptr, @@ -170,14 +518,7 @@ impl<'db> ExprCollector<'db> { lit_pieces, format_options, ) - }; - - self.store - .template_map - .get_or_insert_with(Default::default) - .format_args_to_captures - .insert(idx, (hygiene, mappings)); - idx + } } /// `format_args!` expansion implementation for rustc versions < `1.89.0` @@ -238,17 +579,10 @@ impl<'db> ExprCollector<'db> { // ) let lang_items = self.lang_items(); - let new_v1_formatted = self.ty_rel_lang_path( - lang_items.FormatArguments, - Name::new_symbol_root(sym::new_v1_formatted), - ); - let unsafe_arg_new = - self.ty_rel_lang_path(lang_items.FormatUnsafeArg, Name::new_symbol_root(sym::new)); let new_v1_formatted = - self.alloc_expr_desugared(new_v1_formatted.map_or(Expr::Missing, Expr::Path)); - + self.ty_rel_lang_path_desugared_expr(lang_items.FormatArguments, sym::new_v1_formatted); let unsafe_arg_new = - self.alloc_expr_desugared(unsafe_arg_new.map_or(Expr::Missing, Expr::Path)); + self.ty_rel_lang_path_desugared_expr(lang_items.FormatUnsafeArg, sym::new); let unsafe_arg_new = self.alloc_expr_desugared(Expr::Call { callee: unsafe_arg_new, args: Box::default() }); let mut unsafe_arg_new = self.alloc_expr_desugared(Expr::Unsafe { @@ -284,7 +618,7 @@ impl<'db> ExprCollector<'db> { /// `format_args!` expansion implementation for rustc versions >= `1.89.0`, /// especially since [this PR](https://github.com/rust-lang/rust/pull/140748) - fn collect_format_args_impl( + fn collect_format_args_after_1_89_0_impl( &mut self, syntax_ptr: AstPtr<ast::Expr>, fmt: FormatArgs, @@ -422,12 +756,10 @@ impl<'db> ExprCollector<'db> { // ) // } - let new_v1_formatted = self.ty_rel_lang_path( + let new_v1_formatted = self.ty_rel_lang_path_desugared_expr( self.lang_items().FormatArguments, - Name::new_symbol_root(sym::new_v1_formatted), + sym::new_v1_formatted, ); - let new_v1_formatted = - self.alloc_expr_desugared(new_v1_formatted.map_or(Expr::Missing, Expr::Path)); let args = [lit_pieces, args, format_options]; let call = self .alloc_expr_desugared(Expr::Call { callee: new_v1_formatted, args: args.into() }); @@ -499,8 +831,8 @@ impl<'db> ExprCollector<'db> { debug_hex, } = &placeholder.format_options; - let precision_expr = self.make_count(precision, argmap); - let width_expr = self.make_count(width, argmap); + let precision_expr = self.make_count_before_1_93_0(precision, argmap); + let width_expr = self.make_count_before_1_93_0(width, argmap); if self.krate.workspace_data(self.db).is_atleast_187() { // These need to match the constants in library/core/src/fmt/rt.rs. @@ -542,16 +874,8 @@ impl<'db> ExprCollector<'db> { spread: None, }) } else { - let format_placeholder_new = { - let format_placeholder_new = self.ty_rel_lang_path( - lang_items.FormatPlaceholder, - Name::new_symbol_root(sym::new), - ); - match format_placeholder_new { - Some(path) => self.alloc_expr_desugared(Expr::Path(path)), - None => self.missing_expr(), - } - }; + let format_placeholder_new = + self.ty_rel_lang_path_desugared_expr(lang_items.FormatPlaceholder, sym::new); // This needs to match `Flag` in library/core/src/fmt/rt.rs. let flags: u32 = ((sign == Some(FormatSign::Plus)) as u32) | (((sign == Some(FormatSign::Minus)) as u32) << 1) @@ -564,21 +888,15 @@ impl<'db> ExprCollector<'db> { Some(BuiltinUint::U32), ))); let fill = self.alloc_expr_desugared(Expr::Literal(Literal::Char(fill.unwrap_or(' ')))); - let align = { - let align = self.ty_rel_lang_path( - lang_items.FormatAlignment, - match alignment { - Some(FormatAlignment::Left) => Name::new_symbol_root(sym::Left), - Some(FormatAlignment::Right) => Name::new_symbol_root(sym::Right), - Some(FormatAlignment::Center) => Name::new_symbol_root(sym::Center), - None => Name::new_symbol_root(sym::Unknown), - }, - ); - match align { - Some(path) => self.alloc_expr_desugared(Expr::Path(path)), - None => self.missing_expr(), - } - }; + let align = self.ty_rel_lang_path_desugared_expr( + lang_items.FormatAlignment, + match alignment { + Some(FormatAlignment::Left) => sym::Left, + Some(FormatAlignment::Right) => sym::Right, + Some(FormatAlignment::Center) => sym::Center, + None => sym::Unknown, + }, + ); self.alloc_expr_desugared(Expr::Call { callee: format_placeholder_new, args: Box::new([position, fill, align, flags, precision_expr, width_expr]), @@ -605,7 +923,7 @@ impl<'db> ExprCollector<'db> { /// ```text /// <core::fmt::rt::Count>::Implied /// ``` - fn make_count( + fn make_count_before_1_93_0( &mut self, count: &Option<FormatCount>, argmap: &mut FxIndexSet<(usize, ArgumentType)>, @@ -618,12 +936,8 @@ impl<'db> ExprCollector<'db> { // FIXME: Change this to Some(BuiltinUint::U16) once we drop support for toolchains < 1.88 None, ))); - let count_is = match self - .ty_rel_lang_path(lang_items.FormatCount, Name::new_symbol_root(sym::Is)) - { - Some(count_is) => self.alloc_expr_desugared(Expr::Path(count_is)), - None => self.missing_expr(), - }; + let count_is = + self.ty_rel_lang_path_desugared_expr(lang_items.FormatCount, sym::Is); self.alloc_expr_desugared(Expr::Call { callee: count_is, args: Box::new([args]) }) } Some(FormatCount::Argument(arg)) => { @@ -634,12 +948,8 @@ impl<'db> ExprCollector<'db> { i as u128, Some(BuiltinUint::Usize), ))); - let count_param = match self - .ty_rel_lang_path(lang_items.FormatCount, Name::new_symbol_root(sym::Param)) - { - Some(count_param) => self.alloc_expr_desugared(Expr::Path(count_param)), - None => self.missing_expr(), - }; + let count_param = + self.ty_rel_lang_path_desugared_expr(lang_items.FormatCount, sym::Param); self.alloc_expr_desugared(Expr::Call { callee: count_param, args: Box::new([args]), @@ -650,9 +960,7 @@ impl<'db> ExprCollector<'db> { self.missing_expr() } } - None => match self - .ty_rel_lang_path(lang_items.FormatCount, Name::new_symbol_root(sym::Implied)) - { + None => match self.ty_rel_lang_path(lang_items.FormatCount, sym::Implied) { Some(count_param) => self.alloc_expr_desugared(Expr::Path(count_param)), None => self.missing_expr(), }, @@ -670,9 +978,9 @@ impl<'db> ExprCollector<'db> { use ArgumentType::*; use FormatTrait::*; - let new_fn = match self.ty_rel_lang_path( + let new_fn = self.ty_rel_lang_path_desugared_expr( self.lang_items().FormatArgument, - Name::new_symbol_root(match ty { + match ty { Format(Display) => sym::new_display, Format(Debug) => sym::new_debug, Format(LowerExp) => sym::new_lower_exp, @@ -683,13 +991,18 @@ impl<'db> ExprCollector<'db> { Format(LowerHex) => sym::new_lower_hex, Format(UpperHex) => sym::new_upper_hex, Usize => sym::from_usize, - }), - ) { - Some(new_fn) => self.alloc_expr_desugared(Expr::Path(new_fn)), - None => self.missing_expr(), - }; + }, + ); self.alloc_expr_desugared(Expr::Call { callee: new_fn, args: Box::new([arg]) }) } + + fn ty_rel_lang_path_desugared_expr( + &mut self, + lang: Option<impl Into<LangItemTarget>>, + relative_name: Symbol, + ) -> ExprId { + self.alloc_expr_desugared(self.ty_rel_lang_path_expr(lang, relative_name)) + } } #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] |