Unnamed repository; edit this file 'description' to name the repository.
Merge ref '0e95a0f4c677' from rust-lang/rust
Pull recent changes from https://github.com/rust-lang/rust via Josh.
Upstream ref: rust-lang/rust@0e95a0f4c677002a5d4ac5bc59d97885e6f51f71
Filtered ref: rust-lang/compiler-builtins@84dcb0e1aec6e783710ea190820a56590e23eef5
Upstream diff: https://github.com/rust-lang/rust/compare/db3e99bbab28c6ca778b13222becdea54533d908...0e95a0f4c677002a5d4ac5bc59d97885e6f51f71
This merge was created using https://github.com/rust-lang/josh-sync.
400 files changed, 14325 insertions, 6010 deletions
diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000000..68eacb7d08 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,10 @@ +coverage: + range: 40...60 + status: + patch: off + project: + default: + informational: true + +# Don't leave comments on PRs +comment: false diff --git a/.github/workflows/autopublish.yaml b/.github/workflows/autopublish.yaml index 6e2be7fd3d..abb9b521f1 100644 --- a/.github/workflows/autopublish.yaml +++ b/.github/workflows/autopublish.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1a0deee564..c27d84fb0b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -30,8 +30,8 @@ jobs: outputs: typescript: ${{ steps.filter.outputs.typescript }} steps: - - uses: actions/checkout@v4 - - uses: dorny/paths-filter@1441771bbfdd59dcd748680ee64ebd8faab1a242 + - uses: actions/checkout@v6 + - uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1 id: filter with: filters: | @@ -45,12 +45,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: ref: ${{ github.event.pull_request.head.sha }} - name: Install rustup-toolchain-install-master - run: cargo install [email protected] + run: cargo install [email protected] # Install a pinned rustc commit to avoid surprises - name: Install Rust toolchain @@ -88,7 +88,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: ref: ${{ github.event.pull_request.head.sha }} @@ -96,7 +96,7 @@ jobs: run: | rustup update --no-self-update stable rustup default stable - rustup component add --toolchain stable rust-src clippy rustfmt + rustup component add --toolchain stable rust-src rustfmt # We also install a nightly rustfmt, because we use `--file-lines` in # a test. rustup toolchain install nightly --profile minimal --component rustfmt @@ -128,10 +128,6 @@ jobs: - name: Run cargo-machete run: cargo machete - - name: Run Clippy - if: matrix.os == 'macos-latest' - run: cargo clippy --all-targets -- -D clippy::disallowed_macros -D clippy::dbg_macro -D clippy::todo -D clippy::print_stdout -D clippy::print_stderr - analysis-stats: if: github.repository == 'rust-lang/rust-analyzer' runs-on: ubuntu-latest @@ -140,7 +136,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install Rust toolchain run: | @@ -168,7 +164,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install Rust toolchain run: | @@ -178,18 +174,43 @@ jobs: - run: cargo fmt -- --check + clippy: + if: github.repository == 'rust-lang/rust-analyzer' + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + # Note that clippy output is currently dependent on whether rust-src is installed, + # https://github.com/rust-lang/rust-clippy/issues/14625 + - name: Install Rust toolchain + run: | + rustup update --no-self-update stable + rustup default stable + rustup component add --toolchain stable rust-src clippy + + # https://github.com/actions-rust-lang/setup-rust-toolchain/blob/main/rust.json + - name: Install Rust Problem Matcher + run: echo "::add-matcher::.github/rust.json" + + - run: cargo clippy --all-targets -- -D clippy::disallowed_macros -D clippy::dbg_macro -D clippy::todo -D clippy::print_stdout -D clippy::print_stderr + miri: if: github.repository == 'rust-lang/rust-analyzer' runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install Rust toolchain run: | - rustup update --no-self-update nightly - rustup default nightly + # FIXME: Pin nightly due to a regression in miri on nightly-2026-02-12. + # See https://github.com/rust-lang/miri/issues/4855. + # Revert to plain `nightly` once this is fixed upstream. + rustup toolchain install nightly-2026-02-10 + rustup default nightly-2026-02-10 rustup component add miri # - name: Cache Dependencies @@ -214,7 +235,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install Rust toolchain run: | @@ -242,10 +263,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install Nodejs - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: 22 @@ -301,7 +322,7 @@ jobs: run: curl -LsSf https://github.com/crate-ci/typos/releases/download/$TYPOS_VERSION/typos-$TYPOS_VERSION-x86_64-unknown-linux-musl.tar.gz | tar zxf - -C ${CARGO_HOME:-~/.cargo}/bin - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: ref: ${{ github.event.pull_request.head.sha }} @@ -309,7 +330,18 @@ jobs: run: typos conclusion: - needs: [rust, rust-cross, typescript, typo-check, proc-macro-srv, miri, rustfmt, analysis-stats] + needs: + [ + rust, + rust-cross, + typescript, + typo-check, + proc-macro-srv, + miri, + rustfmt, + clippy, + analysis-stats, + ] # We need to ensure this job does *not* get skipped if its dependencies fail, # because a skipped job is considered a success by GitHub. So we have to # overwrite `if:`. We use `!cancelled()` to ensure the job does still not get run diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml new file mode 100644 index 0000000000..9460c6a3c7 --- /dev/null +++ b/.github/workflows/coverage.yaml @@ -0,0 +1,44 @@ +name: Coverage + +on: [pull_request, push] + +env: + CARGO_INCREMENTAL: 0 + CARGO_NET_RETRY: 10 + CI: 1 + RUST_BACKTRACE: short + RUSTUP_MAX_RETRIES: 10 + +jobs: + coverage: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - name: Install Rust toolchain + run: | + rustup update --no-self-update nightly + rustup default nightly + rustup component add --toolchain nightly rust-src rustc-dev rustfmt + # We also install a nightly rustfmt, because we use `--file-lines` in + # a test. + rustup toolchain install nightly --profile minimal --component rustfmt + + rustup toolchain install nightly --component llvm-tools-preview + + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + + - name: Install nextest + uses: taiki-e/install-action@nextest + + - name: Generate code coverage + run: cargo llvm-cov --workspace --lcov --output-path lcov.info + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 + with: + files: lcov.info + fail_ci_if_error: false + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index 7acfcbe351..af0e03598e 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -2,10 +2,10 @@ name: Fuzz on: schedule: # Once a week - - cron: '0 0 * * 0' + - cron: "0 0 * * 0" push: paths: - - '.github/workflows/fuzz.yml' + - ".github/workflows/fuzz.yml" # Allow manual trigger workflow_dispatch: @@ -27,7 +27,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 1 diff --git a/.github/workflows/metrics.yaml b/.github/workflows/metrics.yaml index 860837dd7f..a482235105 100644 --- a/.github/workflows/metrics.yaml +++ b/.github/workflows/metrics.yaml @@ -23,7 +23,7 @@ jobs: rustup component add --toolchain beta rust-src - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Cache cargo uses: actions/cache@v4 @@ -45,7 +45,7 @@ jobs: key: ${{ runner.os }}-target-${{ github.sha }} - name: Upload build metrics - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: build-${{ github.sha }} path: target/build.json @@ -66,7 +66,7 @@ jobs: rustup component add --toolchain beta rust-src - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Restore target cache uses: actions/cache@v4 @@ -78,7 +78,7 @@ jobs: run: cargo xtask metrics "${{ matrix.names }}" - name: Upload metrics - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: ${{ matrix.names }}-${{ github.sha }} path: target/${{ matrix.names }}.json @@ -89,35 +89,35 @@ jobs: needs: [build_metrics, other_metrics] steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Download build metrics - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: build-${{ github.sha }} - name: Download self metrics - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: self-${{ github.sha }} - name: Download ripgrep-13.0.0 metrics - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: ripgrep-13.0.0-${{ github.sha }} - name: Download webrender-2022 metrics - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: webrender-2022-${{ github.sha }} - name: Download diesel-1.4.8 metrics - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: diesel-1.4.8-${{ github.sha }} - name: Download hyper-0.14.18 metrics - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: hyper-0.14.18-${{ github.sha }} diff --git a/.github/workflows/publish-libs.yaml b/.github/workflows/publish-libs.yaml index f2c8b6365b..762b7bda87 100644 --- a/.github/workflows/publish-libs.yaml +++ b/.github/workflows/publish-libs.yaml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 28914118de..b35614f91b 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -70,12 +70,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: ${{ env.FETCH_DEPTH }} - name: Install Node.js toolchain - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: 22 @@ -143,7 +143,7 @@ jobs: run: target/${{ matrix.target }}/release/rust-analyzer analysis-stats --with-deps --no-sysroot --no-test $(rustc --print sysroot)/lib/rustlib/src/rust/library/std -q - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: dist-${{ matrix.target }} path: ./dist @@ -166,7 +166,7 @@ jobs: run: apk add --no-cache git clang lld musl-dev nodejs npm - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: ${{ env.FETCH_DEPTH }} @@ -189,7 +189,7 @@ jobs: - run: rm -rf editors/code/server - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: dist-x86_64-unknown-linux-musl path: ./dist @@ -201,7 +201,7 @@ jobs: needs: ["dist", "dist-x86_64-unknown-linux-musl"] steps: - name: Install Nodejs - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: 22 @@ -212,46 +212,46 @@ jobs: - run: 'echo "TAG: $TAG"' - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: ${{ env.FETCH_DEPTH }} - run: echo "HEAD_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV - run: 'echo "HEAD_SHA: $HEAD_SHA"' - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-aarch64-apple-darwin path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-x86_64-apple-darwin path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-x86_64-unknown-linux-gnu path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-x86_64-unknown-linux-musl path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-aarch64-unknown-linux-gnu path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-arm-unknown-linux-gnueabihf path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-x86_64-pc-windows-msvc path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-i686-pc-windows-msvc path: dist - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v8 with: name: dist-aarch64-pc-windows-msvc path: dist diff --git a/.github/workflows/rustdoc.yaml b/.github/workflows/rustdoc.yaml index 0cc7ce77dd..03fd083175 100644 --- a/.github/workflows/rustdoc.yaml +++ b/.github/workflows/rustdoc.yaml @@ -1,8 +1,8 @@ name: rustdoc on: push: - branches: - - master + branches: + - master env: CARGO_INCREMENTAL: 0 @@ -10,6 +10,7 @@ env: RUSTFLAGS: "-D warnings -W unreachable-pub" RUSTDOCFLAGS: "-D warnings" RUSTUP_MAX_RETRIES: 10 + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true jobs: rustdoc: @@ -17,19 +18,19 @@ jobs: runs-on: ubuntu-latest steps: - - name: Checkout repository - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v6 - - name: Install Rust toolchain - run: rustup update --no-self-update stable + - name: Install Rust toolchain + run: rustup update --no-self-update stable - - name: Build Documentation - run: cargo doc --all --no-deps --document-private-items + - name: Build Documentation + run: cargo doc --all --no-deps --document-private-items - - name: Deploy Docs - uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_branch: gh-pages - publish_dir: ./target/doc - force_orphan: true + - name: Deploy Docs + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_branch: gh-pages + publish_dir: ./target/doc + force_orphan: true diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 0000000000..681311eb9c --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md
\ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..e8f699d928 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,40 @@ +**Reminder: All AI usage must be disclosed in commit messages, see +CONTRIBUTING.md for more details.** + +## Build Commands + +```bash +cargo build # Build all crates +cargo test # Run all tests +cargo test -p <crate> # Run tests for a specific crate (e.g., cargo test -p hir-ty) +cargo lint # Run clippy on all targets +cargo xtask codegen # Run code generation +cargo xtask tidy # Run tidy checks +UPDATE_EXPECT=1 cargo test # Update test expectations (snapshot tests) +RUN_SLOW_TESTS=1 cargo test # Run heavy/slow tests +``` + +## Key Architectural Invariants + +- Typing in a function body never invalidates global derived data +- Parser/syntax tree is built per-file to enable parallel parsing +- The server is stateless (HTTP-like); context must be re-created from request parameters +- Cancellation uses salsa's cancellation mechanism; computations panic with a `Cancelled` payload + +### Code Generation + +Generated code is committed to the repo. Grammar and AST are generated from `ungrammar`. Run `cargo test -p xtask` after adding inline parser tests (`// test test_name` comments). + +## Testing + +Tests are snapshot-based using `expect-test`. Test fixtures use a mini-language: +- `$0` marks cursor position +- `// ^^^^` labels attach to the line above +- `//- minicore: sized, fn` includes parts of minicore (minimal core library) +- `//- /path/to/file.rs crate:name deps:dep1,dep2` declares files/crates + +## Style Notes + +- Use `stdx::never!` and `stdx::always!` instead of `assert!` for recoverable invariants +- Use `T![fn]` macro instead of `SyntaxKind::FN_KW` +- Use keyword name mangling over underscore prefixing for identifiers: `crate` → `krate`, `fn` → `func`, `struct` → `strukt`, `type` → `ty` diff --git a/Cargo.lock b/Cargo.lock index 2cf3e37a43..5370127ddc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1614,9 +1614,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-traits" @@ -1845,6 +1845,7 @@ dependencies = [ "paths", "postcard", "proc-macro-srv", + "rayon", "rustc-hash 2.1.1", "semver", "serde", @@ -2314,6 +2315,7 @@ dependencies = [ "ide-db", "ide-ssr", "indexmap", + "intern", "itertools 0.14.0", "load-cargo", "lsp-server 0.7.9 (registry+https://github.com/rust-lang/crates.io-index)", @@ -2634,7 +2636,7 @@ dependencies = [ [[package]] name = "smol_str" -version = "0.3.5" +version = "0.3.6" dependencies = [ "arbitrary", "borsh", @@ -2914,9 +2916,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.44" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", @@ -2924,22 +2926,22 @@ dependencies = [ "num-conv", "num_threads", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" dependencies = [ "num-conv", "time-core", diff --git a/Cargo.toml b/Cargo.toml index 2288933a96..9f31e1903a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -192,8 +192,6 @@ unused_lifetimes = "warn" unreachable_pub = "warn" [workspace.lints.clippy] -# FIXME Remove the tidy test once the lint table is stable - ## lint groups complexity = { level = "warn", priority = -1 } correctness = { level = "deny", priority = -1 } @@ -20,6 +20,8 @@ analyzing Rust code. See [Architecture](https://rust-analyzer.github.io/book/contributing/architecture.html) in the manual. +[](https://app.codecov.io/github/rust-lang/rust-analyzer/tree/master) + ## Quick Start https://rust-analyzer.github.io/book/installation.html diff --git a/bench_data/glorious_old_parser b/bench_data/glorious_old_parser index f593f2b295..8136daa832 100644 --- a/bench_data/glorious_old_parser +++ b/bench_data/glorious_old_parser @@ -724,7 +724,7 @@ impl<'a> Parser<'a> { // {foo(bar {}} // - ^ // | | - // | help: `)` may belong here (FIXME: #58270) + // | help: `)` may belong here // | // unclosed delimiter if let Some(sp) = unmatched.unclosed_span { @@ -3217,7 +3217,6 @@ impl<'a> Parser<'a> { } _ => { - // FIXME Could factor this out into non_fatal_unexpected or something. let actual = self.this_token_to_string(); self.span_err(self.span, &format!("unexpected token: `{}`", actual)); } @@ -5250,7 +5249,6 @@ impl<'a> Parser<'a> { } } } else { - // FIXME: Bad copy of attrs let old_directory_ownership = mem::replace(&mut self.directory.ownership, DirectoryOwnership::UnownedViaBlock); let item = self.parse_item_(attrs.clone(), false, true)?; @@ -5953,23 +5951,14 @@ impl<'a> Parser<'a> { }); assoc_ty_bindings.push(span); } else if self.check_const_arg() { - // FIXME(const_generics): to distinguish between idents for types and consts, - // we should introduce a GenericArg::Ident in the AST and distinguish when - // lowering to the HIR. For now, idents for const args are not permitted. - // Parse const argument. let expr = if let token::OpenDelim(token::Brace) = self.token { self.parse_block_expr(None, self.span, BlockCheckMode::Default, ThinVec::new())? } else if self.token.is_ident() { - // FIXME(const_generics): to distinguish between idents for types and consts, - // we should introduce a GenericArg::Ident in the AST and distinguish when - // lowering to the HIR. For now, idents for const args are not permitted. return Err( self.fatal("identifiers may currently not be used for const generics") ); } else { - // FIXME(const_generics): this currently conflicts with emplacement syntax - // with negative integer literals. self.parse_literal_maybe_minus()? }; let value = AnonConst { @@ -5991,9 +5980,6 @@ impl<'a> Parser<'a> { } } - // FIXME: we would like to report this in ast_validation instead, but we currently do not - // preserve ordering of generic parameters with respect to associated type binding, so we - // lose that information after parsing. if misplaced_assoc_ty_bindings.len() > 0 { let mut err = self.struct_span_err( args_lo.to(self.prev_span), @@ -6079,8 +6065,6 @@ impl<'a> Parser<'a> { bounds, } )); - // FIXME: Decide what should be used here, `=` or `==`. - // FIXME: We are just dropping the binders in lifetime_defs on the floor here. } else if self.eat(&token::Eq) || self.eat(&token::EqEq) { let rhs_ty = self.parse_ty()?; where_clause.predicates.push(ast::WherePredicate::EqPredicate( diff --git a/crates/base-db/src/editioned_file_id.rs b/crates/base-db/src/editioned_file_id.rs index 13fb05d565..8721f3a0ff 100644 --- a/crates/base-db/src/editioned_file_id.rs +++ b/crates/base-db/src/editioned_file_id.rs @@ -1,305 +1,46 @@ //! Defines [`EditionedFileId`], an interned wrapper around [`span::EditionedFileId`] that -//! is interned (so queries can take it) and remembers its crate. +//! is interned (so queries can take it) and stores only the underlying `span::EditionedFileId`. -use core::fmt; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; +use salsa::Database; use span::Edition; use vfs::FileId; -use crate::{Crate, RootQueryDb}; - -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct EditionedFileId( - salsa::Id, - std::marker::PhantomData<&'static salsa::plumbing::interned::Value<EditionedFileId>>, -); - -const _: () = { - use salsa::plumbing as zalsa_; - use zalsa_::interned as zalsa_struct_; - type Configuration_ = EditionedFileId; - - #[derive(Debug, Clone, PartialEq, Eq)] - pub struct EditionedFileIdData { - editioned_file_id: span::EditionedFileId, - krate: Crate, - } - - // FIXME: This poses an invalidation problem, if one constructs an `EditionedFileId` with a - // different crate then whatever the input of a memo used, it will invalidate the memo causing - // it to recompute even if the crate is not really used. - /// We like to include the origin crate in an `EditionedFileId` (for use in the item tree), - /// but this poses us a problem. - /// - /// Spans contain `EditionedFileId`s, and we don't want to make them store the crate too - /// because that will increase their size, which will increase memory usage significantly. - /// Furthermore, things using spans do not generally need the crate: they are using the - /// file id for queries like `ast_id_map` or `parse`, which do not care about the crate. - /// - /// To solve this, we hash **only the `span::EditionedFileId`**, but on still compare - /// the crate in equality check. This preserves the invariant of `Hash` and `Eq` - - /// although same hashes can be used for different items, same file ids used for multiple - /// crates is a rare thing, and different items always have different hashes. Then, - /// when we only have a `span::EditionedFileId`, we use the `intern()` method to - /// reuse existing file ids, and create new one only if needed. See [`from_span_guess_origin`]. - /// - /// See this for more info: https://rust-lang.zulipchat.com/#narrow/channel/185405-t-compiler.2Frust-analyzer/topic/Letting.20EditionedFileId.20know.20its.20crate/near/530189401 - /// - /// [`from_span_guess_origin`]: EditionedFileId::from_span_guess_origin - #[derive(Hash, PartialEq, Eq)] - struct WithoutCrate { - editioned_file_id: span::EditionedFileId, - } - - impl Hash for EditionedFileIdData { - #[inline] - fn hash<H: Hasher>(&self, state: &mut H) { - let EditionedFileIdData { editioned_file_id, krate: _ } = *self; - editioned_file_id.hash(state); - } - } - - impl zalsa_struct_::HashEqLike<WithoutCrate> for EditionedFileIdData { - #[inline] - fn hash<H: Hasher>(&self, state: &mut H) { - Hash::hash(self, state); - } - - #[inline] - fn eq(&self, data: &WithoutCrate) -> bool { - let EditionedFileIdData { editioned_file_id, krate: _ } = *self; - editioned_file_id == data.editioned_file_id - } - } - - impl zalsa_::HasJar for EditionedFileId { - type Jar = zalsa_struct_::JarImpl<EditionedFileId>; - const KIND: zalsa_::JarKind = zalsa_::JarKind::Struct; - } - - zalsa_::register_jar! { - zalsa_::ErasedJar::erase::<EditionedFileId>() - } - - impl zalsa_struct_::Configuration for EditionedFileId { - const LOCATION: salsa::plumbing::Location = - salsa::plumbing::Location { file: file!(), line: line!() }; - const DEBUG_NAME: &'static str = "EditionedFileId"; - const REVISIONS: std::num::NonZeroUsize = std::num::NonZeroUsize::MAX; - const PERSIST: bool = false; - - type Fields<'a> = EditionedFileIdData; - type Struct<'db> = EditionedFileId; - - fn serialize<S>(_: &Self::Fields<'_>, _: S) -> Result<S::Ok, S::Error> - where - S: zalsa_::serde::Serializer, - { - unimplemented!("attempted to serialize value that set `PERSIST` to false") - } - - fn deserialize<'de, D>(_: D) -> Result<Self::Fields<'static>, D::Error> - where - D: zalsa_::serde::Deserializer<'de>, - { - unimplemented!("attempted to deserialize value that cannot set `PERSIST` to false"); - } - } - - impl Configuration_ { - pub fn ingredient(zalsa: &zalsa_::Zalsa) -> &zalsa_struct_::IngredientImpl<Self> { - static CACHE: zalsa_::IngredientCache<zalsa_struct_::IngredientImpl<EditionedFileId>> = - zalsa_::IngredientCache::new(); - - // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only - // ingredient created by our jar is the struct ingredient. - unsafe { - CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<zalsa_struct_::JarImpl<EditionedFileId>>() - }) - } - } - } - - impl zalsa_::AsId for EditionedFileId { - fn as_id(&self) -> salsa::Id { - self.0.as_id() - } - } - impl zalsa_::FromId for EditionedFileId { - fn from_id(id: salsa::Id) -> Self { - Self(<salsa::Id>::from_id(id), std::marker::PhantomData) - } - } - - unsafe impl Send for EditionedFileId {} - unsafe impl Sync for EditionedFileId {} - - impl std::fmt::Debug for EditionedFileId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Self::default_debug_fmt(*self, f) - } - } - - impl zalsa_::SalsaStructInDb for EditionedFileId { - type MemoIngredientMap = salsa::plumbing::MemoIngredientSingletonIndex; - - fn lookup_ingredient_index(aux: &zalsa_::Zalsa) -> salsa::plumbing::IngredientIndices { - aux.lookup_jar_by_type::<zalsa_struct_::JarImpl<EditionedFileId>>().into() - } - - fn entries(zalsa: &zalsa_::Zalsa) -> impl Iterator<Item = zalsa_::DatabaseKeyIndex> + '_ { - let _ingredient_index = - zalsa.lookup_jar_by_type::<zalsa_struct_::JarImpl<EditionedFileId>>(); - <EditionedFileId>::ingredient(zalsa).entries(zalsa).map(|entry| entry.key()) - } - - #[inline] - fn cast(id: salsa::Id, type_id: std::any::TypeId) -> Option<Self> { - if type_id == std::any::TypeId::of::<EditionedFileId>() { - Some(<Self as salsa::plumbing::FromId>::from_id(id)) - } else { - None - } - } - - #[inline] - unsafe fn memo_table( - zalsa: &zalsa_::Zalsa, - id: zalsa_::Id, - current_revision: zalsa_::Revision, - ) -> zalsa_::MemoTableWithTypes<'_> { - // SAFETY: Guaranteed by caller. - unsafe { - zalsa.table().memos::<zalsa_struct_::Value<EditionedFileId>>(id, current_revision) - } - } - } - - unsafe impl zalsa_::Update for EditionedFileId { - unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { - if unsafe { *old_pointer } != new_value { - unsafe { *old_pointer = new_value }; - true - } else { - false - } - } - } - - impl EditionedFileId { - pub fn from_span( - db: &(impl salsa::Database + ?Sized), - editioned_file_id: span::EditionedFileId, - krate: Crate, - ) -> Self { - let (zalsa, zalsa_local) = db.zalsas(); - Configuration_::ingredient(zalsa).intern( - zalsa, - zalsa_local, - EditionedFileIdData { editioned_file_id, krate }, - |_, data| data, - ) - } - - /// Guesses the crate for the file. - /// - /// Only use this if you cannot precisely determine the origin. This can happen in one of two cases: - /// - /// 1. The file is not in the module tree. - /// 2. You are latency sensitive and cannot afford calling the def map to precisely compute the origin - /// (e.g. on enter feature, folding, etc.). - pub fn from_span_guess_origin( - db: &dyn RootQueryDb, - editioned_file_id: span::EditionedFileId, - ) -> Self { - let (zalsa, zalsa_local) = db.zalsas(); - Configuration_::ingredient(zalsa).intern( - zalsa, - zalsa_local, - WithoutCrate { editioned_file_id }, - |_, _| { - // FileId not in the database. - let krate = db - .relevant_crates(editioned_file_id.file_id()) - .first() - .copied() - .or_else(|| db.all_crates().first().copied()) - .unwrap_or_else(|| { - // What we're doing here is a bit fishy. We rely on the fact that we only need - // the crate in the item tree, and we should not create an `EditionedFileId` - // without a crate except in cases where it does not matter. The chances that - // `all_crates()` will be empty are also very slim, but it can occur during startup. - // In the very unlikely case that there is a bug and we'll use this crate, Salsa - // will panic. - - // SAFETY: 0 is less than `Id::MAX_U32`. - salsa::plumbing::FromId::from_id(unsafe { salsa::Id::from_index(0) }) - }); - EditionedFileIdData { editioned_file_id, krate } - }, - ) - } - - pub fn editioned_file_id(self, db: &dyn salsa::Database) -> span::EditionedFileId { - let zalsa = db.zalsa(); - let fields = Configuration_::ingredient(zalsa).fields(zalsa, self); - fields.editioned_file_id - } - - pub fn krate(self, db: &dyn salsa::Database) -> Crate { - let zalsa = db.zalsa(); - let fields = Configuration_::ingredient(zalsa).fields(zalsa, self); - fields.krate - } +#[salsa::interned(debug, constructor = from_span_file_id, no_lifetime)] +#[derive(PartialOrd, Ord)] +pub struct EditionedFileId { + field: span::EditionedFileId, +} - /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) - pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - zalsa_::with_attached_database(|db| { - let zalsa = db.zalsa(); - let fields = Configuration_::ingredient(zalsa).fields(zalsa, this); - fmt::Debug::fmt(fields, f) - }) - .unwrap_or_else(|| { - f.debug_tuple("EditionedFileId").field(&zalsa_::AsId::as_id(&this)).finish() - }) - } +impl EditionedFileId { + #[inline] + pub fn new(db: &dyn Database, file_id: FileId, edition: Edition) -> Self { + Self::from_span_file_id(db, span::EditionedFileId::new(file_id, edition)) } -}; -impl EditionedFileId { #[inline] - pub fn new(db: &dyn salsa::Database, file_id: FileId, edition: Edition, krate: Crate) -> Self { - EditionedFileId::from_span(db, span::EditionedFileId::new(file_id, edition), krate) + pub fn current_edition(db: &dyn Database, file_id: FileId) -> Self { + Self::from_span_file_id(db, span::EditionedFileId::current_edition(file_id)) } - /// Attaches the current edition and guesses the crate for the file. - /// - /// Only use this if you cannot precisely determine the origin. This can happen in one of two cases: - /// - /// 1. The file is not in the module tree. - /// 2. You are latency sensitive and cannot afford calling the def map to precisely compute the origin - /// (e.g. on enter feature, folding, etc.). #[inline] - pub fn current_edition_guess_origin(db: &dyn RootQueryDb, file_id: FileId) -> Self { - Self::from_span_guess_origin(db, span::EditionedFileId::current_edition(file_id)) + pub fn file_id(self, db: &dyn Database) -> vfs::FileId { + self.field(db).file_id() } #[inline] - pub fn file_id(self, db: &dyn salsa::Database) -> vfs::FileId { - let id = self.editioned_file_id(db); - id.file_id() + pub fn span_file_id(self, db: &dyn Database) -> span::EditionedFileId { + self.field(db) } #[inline] - pub fn unpack(self, db: &dyn salsa::Database) -> (vfs::FileId, span::Edition) { - let id = self.editioned_file_id(db); - (id.file_id(), id.edition()) + pub fn unpack(self, db: &dyn Database) -> (vfs::FileId, span::Edition) { + self.field(db).unpack() } #[inline] - pub fn edition(self, db: &dyn salsa::Database) -> Edition { - self.editioned_file_id(db).edition() + pub fn edition(self, db: &dyn Database) -> Edition { + self.field(db).edition() } } diff --git a/crates/base-db/src/input.rs b/crates/base-db/src/input.rs index 94793a3618..246c57edc2 100644 --- a/crates/base-db/src/input.rs +++ b/crates/base-db/src/input.rs @@ -742,7 +742,7 @@ impl CrateGraphBuilder { deps.into_iter() } - /// Returns all crates in the graph, sorted in topological order (ie. dependencies of a crate + /// Returns all crates in the graph, sorted in topological order (i.e. dependencies of a crate /// come before the crate itself). fn crates_in_topological_order(&self) -> Vec<CrateBuilderId> { let mut res = Vec::new(); @@ -870,7 +870,7 @@ impl CrateGraphBuilder { impl Crate { pub fn root_file_id(self, db: &dyn salsa::Database) -> EditionedFileId { let data = self.data(db); - EditionedFileId::new(db, data.root_file_id, data.edition, self) + EditionedFileId::new(db, data.root_file_id, data.edition) } } diff --git a/crates/base-db/src/lib.rs b/crates/base-db/src/lib.rs index 24f6dd59a9..5baf4ce6f9 100644 --- a/crates/base-db/src/lib.rs +++ b/crates/base-db/src/lib.rs @@ -32,7 +32,7 @@ pub use crate::{ }, }; use dashmap::{DashMap, mapref::entry::Entry}; -pub use query_group::{self}; +pub use query_group; use rustc_hash::{FxHashSet, FxHasher}; use salsa::{Durability, Setter}; pub use semver::{BuildMetadata, Prerelease, Version, VersionReq}; diff --git a/crates/edition/src/lib.rs b/crates/edition/src/lib.rs index f1a1fe5964..eb4cec39dc 100644 --- a/crates/edition/src/lib.rs +++ b/crates/edition/src/lib.rs @@ -16,8 +16,6 @@ impl Edition { pub const DEFAULT: Edition = Edition::Edition2015; pub const LATEST: Edition = Edition::Edition2024; pub const CURRENT: Edition = Edition::Edition2024; - /// The current latest stable edition, note this is usually not the right choice in code. - pub const CURRENT_FIXME: Edition = Edition::Edition2024; pub fn from_u32(u32: u32) -> Edition { match u32 { diff --git a/crates/hir-def/src/attrs.rs b/crates/hir-def/src/attrs.rs index 0b8f656872..e3e1aac709 100644 --- a/crates/hir-def/src/attrs.rs +++ b/crates/hir-def/src/attrs.rs @@ -894,7 +894,7 @@ impl AttrFlags { def: GenericDefId, ) -> &(ArenaMap<LocalLifetimeParamId, AttrFlags>, ArenaMap<LocalTypeOrConstParamId, AttrFlags>) { - let generic_params = GenericParams::new(db, def); + let generic_params = GenericParams::of(db, def); let params_count_excluding_self = generic_params.len() - usize::from(generic_params.trait_self_param().is_some()); if params_count_excluding_self == 0 { diff --git a/crates/hir-def/src/db.rs b/crates/hir-def/src/db.rs index ccd4bc9be8..5d5d435398 100644 --- a/crates/hir-def/src/db.rs +++ b/crates/hir-def/src/db.rs @@ -4,28 +4,18 @@ use hir_expand::{ EditionedFileId, HirFileId, InFile, Lookup, MacroCallId, MacroDefId, MacroDefKind, db::ExpandDatabase, }; -use la_arena::ArenaMap; use triomphe::Arc; use crate::{ - AssocItemId, AttrDefId, BlockId, BlockLoc, ConstId, ConstLoc, DefWithBodyId, EnumId, EnumLoc, - EnumVariantId, EnumVariantLoc, ExternBlockId, ExternBlockLoc, ExternCrateId, ExternCrateLoc, - FunctionId, FunctionLoc, GenericDefId, ImplId, ImplLoc, LocalFieldId, Macro2Id, Macro2Loc, - MacroExpander, MacroId, MacroRulesId, MacroRulesLoc, MacroRulesLocFlags, ProcMacroId, - ProcMacroLoc, StaticId, StaticLoc, StructId, StructLoc, TraitId, TraitLoc, TypeAliasId, - TypeAliasLoc, UnionId, UnionLoc, UseId, UseLoc, VariantId, + AnonConstId, AnonConstLoc, AssocItemId, AttrDefId, BlockId, BlockLoc, ConstId, ConstLoc, + EnumId, EnumLoc, EnumVariantId, EnumVariantLoc, ExternBlockId, ExternBlockLoc, ExternCrateId, + ExternCrateLoc, FunctionId, FunctionLoc, ImplId, ImplLoc, Macro2Id, Macro2Loc, MacroExpander, + MacroId, MacroRulesId, MacroRulesLoc, MacroRulesLocFlags, ProcMacroId, ProcMacroLoc, StaticId, + StaticLoc, StructId, StructLoc, TraitId, TraitLoc, TypeAliasId, TypeAliasLoc, UnionId, + UnionLoc, UseId, UseLoc, attrs::AttrFlags, - expr_store::{ - Body, BodySourceMap, ExpressionStore, ExpressionStoreSourceMap, scope::ExprScopes, - }, - hir::generics::GenericParams, - import_map::ImportMap, item_tree::{ItemTree, file_item_tree_query}, nameres::crate_def_map, - signatures::{ - ConstSignature, EnumSignature, FunctionSignature, ImplSignature, StaticSignature, - StructSignature, TraitSignature, TypeAliasSignature, UnionSignature, - }, visibility::{self, Visibility}, }; @@ -62,6 +52,9 @@ pub trait InternDatabase: RootQueryDb { fn intern_static(&self, loc: StaticLoc) -> StaticId; #[salsa::interned] + fn intern_anon_const(&self, loc: AnonConstLoc) -> AnonConstId; + + #[salsa::interned] fn intern_trait(&self, loc: TraitLoc) -> TraitId; #[salsa::interned] @@ -96,151 +89,14 @@ pub trait DefDatabase: InternDatabase + ExpandDatabase + SourceDatabase { /// Computes an [`ItemTree`] for the given file or macro expansion. #[salsa::invoke(file_item_tree_query)] #[salsa::transparent] - fn file_item_tree(&self, file_id: HirFileId) -> &ItemTree; + fn file_item_tree(&self, file_id: HirFileId, krate: Crate) -> &ItemTree; /// Turns a MacroId into a MacroDefId, describing the macro's definition post name resolution. #[salsa::invoke(macro_def)] fn macro_def(&self, m: MacroId) -> MacroDefId; - // region:data - - #[salsa::tracked] - fn trait_signature(&self, trait_: TraitId) -> Arc<TraitSignature> { - self.trait_signature_with_source_map(trait_).0 - } - - #[salsa::tracked] - fn impl_signature(&self, impl_: ImplId) -> Arc<ImplSignature> { - self.impl_signature_with_source_map(impl_).0 - } - - #[salsa::tracked] - fn struct_signature(&self, struct_: StructId) -> Arc<StructSignature> { - self.struct_signature_with_source_map(struct_).0 - } - - #[salsa::tracked] - fn union_signature(&self, union_: UnionId) -> Arc<UnionSignature> { - self.union_signature_with_source_map(union_).0 - } - - #[salsa::tracked] - fn enum_signature(&self, e: EnumId) -> Arc<EnumSignature> { - self.enum_signature_with_source_map(e).0 - } - - #[salsa::tracked] - fn const_signature(&self, e: ConstId) -> Arc<ConstSignature> { - self.const_signature_with_source_map(e).0 - } - - #[salsa::tracked] - fn static_signature(&self, e: StaticId) -> Arc<StaticSignature> { - self.static_signature_with_source_map(e).0 - } - - #[salsa::tracked] - fn function_signature(&self, e: FunctionId) -> Arc<FunctionSignature> { - self.function_signature_with_source_map(e).0 - } - - #[salsa::tracked] - fn type_alias_signature(&self, e: TypeAliasId) -> Arc<TypeAliasSignature> { - self.type_alias_signature_with_source_map(e).0 - } - - #[salsa::invoke(TraitSignature::query)] - fn trait_signature_with_source_map( - &self, - trait_: TraitId, - ) -> (Arc<TraitSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(ImplSignature::query)] - fn impl_signature_with_source_map( - &self, - impl_: ImplId, - ) -> (Arc<ImplSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(StructSignature::query)] - fn struct_signature_with_source_map( - &self, - struct_: StructId, - ) -> (Arc<StructSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(UnionSignature::query)] - fn union_signature_with_source_map( - &self, - union_: UnionId, - ) -> (Arc<UnionSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(EnumSignature::query)] - fn enum_signature_with_source_map( - &self, - e: EnumId, - ) -> (Arc<EnumSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(ConstSignature::query)] - fn const_signature_with_source_map( - &self, - e: ConstId, - ) -> (Arc<ConstSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(StaticSignature::query)] - fn static_signature_with_source_map( - &self, - e: StaticId, - ) -> (Arc<StaticSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(FunctionSignature::query)] - fn function_signature_with_source_map( - &self, - e: FunctionId, - ) -> (Arc<FunctionSignature>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(TypeAliasSignature::query)] - fn type_alias_signature_with_source_map( - &self, - e: TypeAliasId, - ) -> (Arc<TypeAliasSignature>, Arc<ExpressionStoreSourceMap>); - - // endregion:data - - #[salsa::invoke(Body::body_with_source_map_query)] - #[salsa::lru(512)] - fn body_with_source_map(&self, def: DefWithBodyId) -> (Arc<Body>, Arc<BodySourceMap>); - - #[salsa::invoke(Body::body_query)] - fn body(&self, def: DefWithBodyId) -> Arc<Body>; - - #[salsa::invoke(ExprScopes::expr_scopes_query)] - fn expr_scopes(&self, def: DefWithBodyId) -> Arc<ExprScopes>; - - #[salsa::transparent] - #[salsa::invoke(GenericParams::new)] - fn generic_params(&self, def: GenericDefId) -> Arc<GenericParams>; - - #[salsa::transparent] - #[salsa::invoke(GenericParams::generic_params_and_store)] - fn generic_params_and_store( - &self, - def: GenericDefId, - ) -> (Arc<GenericParams>, Arc<ExpressionStore>); - - #[salsa::transparent] - #[salsa::invoke(GenericParams::generic_params_and_store_and_source_map)] - fn generic_params_and_store_and_source_map( - &self, - def: GenericDefId, - ) -> (Arc<GenericParams>, Arc<ExpressionStore>, Arc<ExpressionStoreSourceMap>); - - #[salsa::invoke(ImportMap::import_map_query)] - fn import_map(&self, krate: Crate) -> Arc<ImportMap>; - // region:visibilities - #[salsa::invoke(visibility::field_visibilities_query)] - fn field_visibilities(&self, var: VariantId) -> Arc<ArenaMap<LocalFieldId, Visibility>>; - #[salsa::invoke(visibility::assoc_visibility_query)] fn assoc_visibility(&self, def: AssocItemId) -> Visibility; diff --git a/crates/hir-def/src/expr_store.rs b/crates/hir-def/src/expr_store.rs index 1ce4c881e7..ca523622ec 100644 --- a/crates/hir-def/src/expr_store.rs +++ b/crates/hir-def/src/expr_store.rs @@ -9,10 +9,7 @@ pub mod scope; #[cfg(test)] mod tests; -use std::{ - ops::{Deref, Index}, - sync::LazyLock, -}; +use std::ops::{Deref, Index}; use cfg::{CfgExpr, CfgOptions}; use either::Either; @@ -23,11 +20,10 @@ use smallvec::SmallVec; use span::{Edition, SyntaxContext}; use syntax::{AstPtr, SyntaxNodePtr, ast}; use thin_vec::ThinVec; -use triomphe::Arc; use tt::TextRange; use crate::{ - BlockId, SyntheticSyntax, + AdtId, BlockId, ExpressionStoreOwnerId, GenericDefId, SyntheticSyntax, db::DefDatabase, expr_store::path::Path, hir::{ @@ -35,6 +31,7 @@ use crate::{ PatId, RecordFieldPat, RecordSpread, Statement, }, nameres::{DefMap, block_def_map}, + signatures::VariantFields, type_ref::{LifetimeRef, LifetimeRefId, PathId, TypeRef, TypeRefId}, }; @@ -94,9 +91,26 @@ pub type TypeSource = InFile<TypePtr>; pub type LifetimePtr = AstPtr<ast::Lifetime>; pub type LifetimeSource = InFile<LifetimePtr>; +/// Describes where a const expression originated from. +/// +/// Used by signature/body inference to determine the expected type for each +/// const expression root. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RootExprOrigin { + /// Array length expression: `[T; <expr>]` — expected type is `usize`. + ArrayLength, + /// Const parameter default value: `const N: usize = <expr>`. + ConstParam(crate::hir::generics::LocalTypeOrConstParamId), + /// Const generic argument in a path: `SomeType::<{ <expr> }>` or `some_fn::<{ <expr> }>()`. + /// Determining the expected type requires path resolution, so it is deferred. + GenericArgsPath, + /// The root expression of a body. + BodyRoot, +} + // We split the store into types-only and expressions, because most stores (e.g. generics) // don't store any expressions and this saves memory. Same thing for the source map. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] struct ExpressionOnlyStore { exprs: Arena<Expr>, pats: Arena<Pat>, @@ -113,9 +127,12 @@ struct ExpressionOnlyStore { /// Expressions (and destructuing patterns) that can be recorded here are single segment path, although not all single segments path refer /// to variables and have hygiene (some refer to items, we don't know at this stage). ident_hygiene: FxHashMap<ExprOrPatId, HygieneId>, + + /// Maps expression roots to their origin. + expr_roots: SmallVec<[(ExprId, RootExprOrigin); 1]>, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ExpressionStore { expr_only: Option<Box<ExpressionOnlyStore>>, pub types: Arena<TypeRef>, @@ -226,6 +243,7 @@ pub struct ExpressionStoreBuilder { pub types: Arena<TypeRef>, block_scopes: Vec<BlockId>, ident_hygiene: FxHashMap<ExprOrPatId, HygieneId>, + pub inference_roots: Option<SmallVec<[(ExprId, RootExprOrigin); 1]>>, // AST expressions can create patterns in destructuring assignments. Therefore, `ExprSource` can also map // to `PatId`, and `PatId` can also map to `ExprSource` (the other way around is unaffected). @@ -297,6 +315,7 @@ impl ExpressionStoreBuilder { mut bindings, mut binding_owners, mut ident_hygiene, + inference_roots: mut expr_roots, mut types, mut lifetimes, @@ -356,6 +375,9 @@ impl ExpressionStoreBuilder { let store = { let expr_only = if has_exprs { + if let Some(const_expr_origins) = &mut expr_roots { + const_expr_origins.shrink_to_fit(); + } Some(Box::new(ExpressionOnlyStore { exprs, pats, @@ -364,6 +386,7 @@ impl ExpressionStoreBuilder { binding_owners, block_scopes: block_scopes.into_boxed_slice(), ident_hygiene, + expr_roots: expr_roots.unwrap_or_default(), })) } else { None @@ -404,13 +427,108 @@ impl ExpressionStoreBuilder { } impl ExpressionStore { - pub fn empty_singleton() -> (Arc<ExpressionStore>, Arc<ExpressionStoreSourceMap>) { - static EMPTY: LazyLock<(Arc<ExpressionStore>, Arc<ExpressionStoreSourceMap>)> = - LazyLock::new(|| { - let (store, source_map) = ExpressionStoreBuilder::default().finish(); - (Arc::new(store), Arc::new(source_map)) - }); - EMPTY.clone() + pub fn of(db: &dyn DefDatabase, def: ExpressionStoreOwnerId) -> &ExpressionStore { + match def { + ExpressionStoreOwnerId::Signature(def) => { + use crate::signatures::{ + ConstSignature, EnumSignature, FunctionSignature, ImplSignature, + StaticSignature, StructSignature, TraitSignature, TypeAliasSignature, + UnionSignature, + }; + match def { + GenericDefId::AdtId(AdtId::EnumId(id)) => &EnumSignature::of(db, id).store, + GenericDefId::AdtId(AdtId::StructId(id)) => &StructSignature::of(db, id).store, + GenericDefId::AdtId(AdtId::UnionId(id)) => &UnionSignature::of(db, id).store, + GenericDefId::ConstId(id) => &ConstSignature::of(db, id).store, + GenericDefId::FunctionId(id) => &FunctionSignature::of(db, id).store, + GenericDefId::ImplId(id) => &ImplSignature::of(db, id).store, + GenericDefId::StaticId(id) => &StaticSignature::of(db, id).store, + GenericDefId::TraitId(id) => &TraitSignature::of(db, id).store, + GenericDefId::TypeAliasId(id) => &TypeAliasSignature::of(db, id).store, + } + } + ExpressionStoreOwnerId::Body(body) => &Body::of(db, body).store, + ExpressionStoreOwnerId::VariantFields(variant_id) => { + &VariantFields::of(db, variant_id).store + } + } + } + + pub fn with_source_map( + db: &dyn DefDatabase, + def: ExpressionStoreOwnerId, + ) -> (&ExpressionStore, &ExpressionStoreSourceMap) { + match def { + ExpressionStoreOwnerId::Signature(def) => { + use crate::signatures::{ + ConstSignature, EnumSignature, FunctionSignature, ImplSignature, + StaticSignature, StructSignature, TraitSignature, TypeAliasSignature, + UnionSignature, + }; + match def { + GenericDefId::AdtId(AdtId::EnumId(id)) => { + let sig = EnumSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::AdtId(AdtId::StructId(id)) => { + let sig = StructSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::AdtId(AdtId::UnionId(id)) => { + let sig = UnionSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::ConstId(id) => { + let sig = ConstSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::FunctionId(id) => { + let sig = FunctionSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::ImplId(id) => { + let sig = ImplSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::StaticId(id) => { + let sig = StaticSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::TraitId(id) => { + let sig = TraitSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + GenericDefId::TypeAliasId(id) => { + let sig = TypeAliasSignature::with_source_map(db, id); + (&sig.0.store, &sig.1) + } + } + } + ExpressionStoreOwnerId::Body(body) => { + let (store, sm) = Body::with_source_map(db, body); + (&store.store, &sm.store) + } + ExpressionStoreOwnerId::VariantFields(variant_id) => { + let (store, sm) = VariantFields::with_source_map(db, variant_id); + (&store.store, sm) + } + } + } + + /// Returns all expression root `ExprId`s found in this store. + pub fn expr_roots(&self) -> impl Iterator<Item = ExprId> { + self.const_expr_origins().iter().map(|&(id, _)| id) + } + + /// Like [`Self::signature_const_expr_roots`], but also returns the origin + /// of each expression. + pub fn expr_roots_with_origins(&self) -> impl Iterator<Item = (ExprId, RootExprOrigin)> { + self.const_expr_origins().iter().map(|&(id, origin)| (id, origin)) + } + + /// Returns the map of const expression roots to their origins. + pub fn const_expr_origins(&self) -> &[(ExprId, RootExprOrigin)] { + self.expr_only.as_ref().map_or(&[], |it| &it.expr_roots) } /// Returns an iterator over all block expressions in this store that define inner items. diff --git a/crates/hir-def/src/expr_store/body.rs b/crates/hir-def/src/expr_store/body.rs index c955393b9c..0c8320369f 100644 --- a/crates/hir-def/src/expr_store/body.rs +++ b/crates/hir-def/src/expr_store/body.rs @@ -29,8 +29,6 @@ pub struct Body { /// empty. pub params: Box<[PatId]>, pub self_param: Option<BindingId>, - /// The `ExprId` of the actual body expression. - pub body_expr: ExprId, } impl ops::Deref for Body { @@ -68,11 +66,10 @@ impl ops::Deref for BodySourceMap { } } +#[salsa::tracked] impl Body { - pub(crate) fn body_with_source_map_query( - db: &dyn DefDatabase, - def: DefWithBodyId, - ) -> (Arc<Body>, Arc<BodySourceMap>) { + #[salsa::tracked(lru = 512, returns(ref))] + pub fn with_source_map(db: &dyn DefDatabase, def: DefWithBodyId) -> (Arc<Body>, BodySourceMap) { let _p = tracing::info_span!("body_with_source_map_query").entered(); let mut params = None; @@ -99,18 +96,25 @@ impl Body { DefWithBodyId::VariantId(v) => { let s = v.lookup(db); let src = s.source(db); - src.map(|it| it.expr()) + src.map(|it| it.const_arg()?.expr()) } } }; let module = def.module(db); let (body, source_map) = lower_body(db, def, file_id, module, params, body, is_async_fn); - (Arc::new(body), Arc::new(source_map)) + (Arc::new(body), source_map) } - pub(crate) fn body_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<Body> { - db.body_with_source_map(def).0 + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<Body> { + Self::with_source_map(db, def).0.clone() + } +} + +impl Body { + pub fn root_expr(&self) -> ExprId { + self.store.expr_roots().next().unwrap() } pub fn pretty_print( diff --git a/crates/hir-def/src/expr_store/expander.rs b/crates/hir-def/src/expr_store/expander.rs index d34ec9bbc1..2fffa02c13 100644 --- a/crates/hir-def/src/expr_store/expander.rs +++ b/crates/hir-def/src/expr_store/expander.rs @@ -14,7 +14,6 @@ use hir_expand::{ use span::{AstIdMap, SyntaxContext}; use syntax::ast::HasAttrs; use syntax::{AstNode, Parse, ast}; -use triomphe::Arc; use tt::TextRange; use crate::{ @@ -23,21 +22,21 @@ use crate::{ }; #[derive(Debug)] -pub(super) struct Expander { +pub(super) struct Expander<'db> { span_map: SpanMap, current_file_id: HirFileId, - ast_id_map: Arc<AstIdMap>, + ast_id_map: &'db AstIdMap, /// `recursion_depth == usize::MAX` indicates that the recursion limit has been reached. recursion_depth: u32, recursion_limit: usize, } -impl Expander { +impl<'db> Expander<'db> { pub(super) fn new( - db: &dyn DefDatabase, + db: &'db dyn DefDatabase, current_file_id: HirFileId, - def_map: &DefMap, - ) -> Expander { + def_map: &'db DefMap, + ) -> Expander<'db> { let recursion_limit = def_map.recursion_limit() as usize; let recursion_limit = if cfg!(test) { // Without this, `body::tests::your_stack_belongs_to_me` stack-overflows in debug @@ -77,12 +76,12 @@ impl Expander { pub(super) fn enter_expand<T: ast::AstNode>( &mut self, - db: &dyn DefDatabase, + db: &'db dyn DefDatabase, macro_call: ast::MacroCall, krate: Crate, resolver: impl Fn(&ModPath) -> Option<MacroId>, eager_callback: EagerCallBackFn<'_>, - ) -> Result<ExpandResult<Option<(Mark, Option<Parse<T>>)>>, UnresolvedMacro> { + ) -> Result<ExpandResult<Option<(Mark<'db>, Option<Parse<T>>)>>, UnresolvedMacro> { // FIXME: within_limit should support this, instead of us having to extract the error let mut unresolved_macro_err = None; @@ -130,13 +129,13 @@ impl Expander { pub(super) fn enter_expand_id<T: ast::AstNode>( &mut self, - db: &dyn DefDatabase, + db: &'db dyn DefDatabase, call_id: MacroCallId, - ) -> ExpandResult<Option<(Mark, Option<Parse<T>>)>> { + ) -> ExpandResult<Option<(Mark<'db>, Option<Parse<T>>)>> { self.within_limit(db, |_this| ExpandResult::ok(Some(call_id))) } - pub(super) fn exit(&mut self, Mark { file_id, span_map, ast_id_map, mut bomb }: Mark) { + pub(super) fn exit(&mut self, Mark { file_id, span_map, ast_id_map, mut bomb }: Mark<'db>) { self.span_map = span_map; self.current_file_id = file_id; self.ast_id_map = ast_id_map; @@ -162,9 +161,9 @@ impl Expander { fn within_limit<F, T: ast::AstNode>( &mut self, - db: &dyn DefDatabase, + db: &'db dyn DefDatabase, op: F, - ) -> ExpandResult<Option<(Mark, Option<Parse<T>>)>> + ) -> ExpandResult<Option<(Mark<'db>, Option<Parse<T>>)>> where F: FnOnce(&mut Self) -> ExpandResult<Option<MacroCallId>>, { @@ -219,7 +218,7 @@ impl Expander { #[inline] pub(super) fn ast_id_map(&self) -> &AstIdMap { - &self.ast_id_map + self.ast_id_map } #[inline] @@ -229,9 +228,9 @@ impl Expander { } #[derive(Debug)] -pub(super) struct Mark { +pub(super) struct Mark<'db> { file_id: HirFileId, span_map: SpanMap, - ast_id_map: Arc<AstIdMap>, + ast_id_map: &'db AstIdMap, bomb: DropBomb, } diff --git a/crates/hir-def/src/expr_store/lower.rs b/crates/hir-def/src/expr_store/lower.rs index 4fbf6d9517..74006c6037 100644 --- a/crates/hir-def/src/expr_store/lower.rs +++ b/crates/hir-def/src/expr_store/lower.rs @@ -18,6 +18,7 @@ use hir_expand::{ }; use intern::{Symbol, sym}; use rustc_hash::FxHashMap; +use smallvec::smallvec; use stdx::never; use syntax::{ AstNode, AstPtr, SyntaxNodePtr, @@ -28,18 +29,17 @@ use syntax::{ }, }; use thin_vec::ThinVec; -use triomphe::Arc; use tt::TextRange; use crate::{ - AdtId, BlockId, BlockLoc, DefWithBodyId, FunctionId, GenericDefId, ImplId, MacroId, - ModuleDefId, ModuleId, TraitId, TypeAliasId, UnresolvedMacro, + AdtId, BlockId, BlockLoc, DefWithBodyId, FunctionId, GenericDefId, ImplId, ItemContainerId, + MacroId, ModuleDefId, ModuleId, TraitId, TypeAliasId, UnresolvedMacro, attrs::AttrFlags, db::DefDatabase, expr_store::{ Body, BodySourceMap, ExprPtr, ExpressionStore, ExpressionStoreBuilder, ExpressionStoreDiagnostics, ExpressionStoreSourceMap, HygieneId, LabelPtr, LifetimePtr, - PatPtr, TypePtr, + PatPtr, RootExprOrigin, TypePtr, expander::Expander, lower::generics::ImplTraitLowerFn, path::{AssociatedTypeBinding, GenericArg, GenericArgs, GenericArgsParentheses, Path}, @@ -53,6 +53,7 @@ use crate::{ item_tree::FieldsShape, lang_item::{LangItemTarget, LangItems}, nameres::{DefMap, LocalDefMap, MacroSubNs, block_def_map}, + signatures::StructSignature, type_ref::{ ArrayType, ConstRef, FnType, LifetimeRef, LifetimeRefId, Mutability, PathId, Rawness, RefType, TraitBoundModifier, TraitRef, TypeBound, TypeRef, TypeRefId, UseArgRef, @@ -79,7 +80,7 @@ pub(super) fn lower_body( let mut self_param = None; let mut source_map_self_param = None; let mut params = vec![]; - let mut collector = ExprCollector::new(db, module, current_file_id); + let mut collector = ExprCollector::body(db, module, current_file_id); let skip_body = AttrFlags::query( db, @@ -117,9 +118,10 @@ pub(super) fn lower_body( params = (0..count).map(|_| collector.missing_pat()).collect(); }; let body_expr = collector.missing_expr(); + collector.store.inference_roots = Some(smallvec![(body_expr, RootExprOrigin::BodyRoot)]); let (store, source_map) = collector.store.finish(); return ( - Body { store, params: params.into_boxed_slice(), self_param, body_expr }, + Body { store, params: params.into_boxed_slice(), self_param }, BodySourceMap { self_param: source_map_self_param, store: source_map }, ); } @@ -141,9 +143,19 @@ pub(super) fn lower_body( source_map_self_param = Some(collector.expander.in_file(AstPtr::new(&self_param_syn))); } + let is_extern = matches!( + owner, + DefWithBodyId::FunctionId(id) + if matches!(id.loc(db).container, ItemContainerId::ExternBlockId(_)), + ); + for param in param_list.params() { if collector.check_cfg(¶m) { - let param_pat = collector.collect_pat_top(param.pat()); + let param_pat = if is_extern { + collector.collect_extern_fn_param(param.pat()) + } else { + collector.collect_pat_top(param.pat()) + }; params.push(param_pat); } } @@ -163,10 +175,11 @@ pub(super) fn lower_body( } }, ); + collector.store.inference_roots = Some(smallvec![(body_expr, RootExprOrigin::BodyRoot)]); let (store, source_map) = collector.store.finish(); ( - Body { store, params: params.into_boxed_slice(), self_param, body_expr }, + Body { store, params: params.into_boxed_slice(), self_param }, BodySourceMap { self_param: source_map_self_param, store: source_map }, ) } @@ -176,7 +189,7 @@ pub(crate) fn lower_type_ref( module: ModuleId, type_ref: InFile<Option<ast::Type>>, ) -> (ExpressionStore, ExpressionStoreSourceMap, TypeRefId) { - let mut expr_collector = ExprCollector::new(db, module, type_ref.file_id); + let mut expr_collector = ExprCollector::signature(db, module, type_ref.file_id); let type_ref = expr_collector.lower_type_ref_opt(type_ref.value, &mut ExprCollector::impl_trait_allocator); let (store, source_map) = expr_collector.store.finish(); @@ -190,13 +203,13 @@ pub(crate) fn lower_generic_params( file_id: HirFileId, param_list: Option<ast::GenericParamList>, where_clause: Option<ast::WhereClause>, -) -> (Arc<ExpressionStore>, Arc<GenericParams>, ExpressionStoreSourceMap) { - let mut expr_collector = ExprCollector::new(db, module, file_id); +) -> (ExpressionStore, GenericParams, ExpressionStoreSourceMap) { + let mut expr_collector = ExprCollector::signature(db, module, file_id); let mut collector = generics::GenericParamsCollector::new(def); collector.lower(&mut expr_collector, param_list, where_clause); let params = collector.finish(); let (store, source_map) = expr_collector.store.finish(); - (Arc::new(store), params, source_map) + (store, params, source_map) } pub(crate) fn lower_impl( @@ -204,8 +217,8 @@ pub(crate) fn lower_impl( module: ModuleId, impl_syntax: InFile<ast::Impl>, impl_id: ImplId, -) -> (ExpressionStore, ExpressionStoreSourceMap, TypeRefId, Option<TraitRef>, Arc<GenericParams>) { - let mut expr_collector = ExprCollector::new(db, module, impl_syntax.file_id); +) -> (ExpressionStore, ExpressionStoreSourceMap, TypeRefId, Option<TraitRef>, GenericParams) { + let mut expr_collector = ExprCollector::signature(db, module, impl_syntax.file_id); let self_ty = expr_collector.lower_type_ref_opt_disallow_impl_trait(impl_syntax.value.self_ty()); let trait_ = impl_syntax.value.trait_().and_then(|it| match &it { @@ -232,8 +245,8 @@ pub(crate) fn lower_trait( module: ModuleId, trait_syntax: InFile<ast::Trait>, trait_id: TraitId, -) -> (ExpressionStore, ExpressionStoreSourceMap, Arc<GenericParams>) { - let mut expr_collector = ExprCollector::new(db, module, trait_syntax.file_id); +) -> (ExpressionStore, ExpressionStoreSourceMap, GenericParams) { + let mut expr_collector = ExprCollector::signature(db, module, trait_syntax.file_id); let mut collector = generics::GenericParamsCollector::with_self_param( &mut expr_collector, trait_id.into(), @@ -254,14 +267,9 @@ pub(crate) fn lower_type_alias( module: ModuleId, alias: InFile<ast::TypeAlias>, type_alias_id: TypeAliasId, -) -> ( - ExpressionStore, - ExpressionStoreSourceMap, - Arc<GenericParams>, - Box<[TypeBound]>, - Option<TypeRefId>, -) { - let mut expr_collector = ExprCollector::new(db, module, alias.file_id); +) -> (ExpressionStore, ExpressionStoreSourceMap, GenericParams, Box<[TypeBound]>, Option<TypeRefId>) +{ + let mut expr_collector = ExprCollector::signature(db, module, alias.file_id); let bounds = alias .value .type_bound_list() @@ -297,13 +305,13 @@ pub(crate) fn lower_function( ) -> ( ExpressionStore, ExpressionStoreSourceMap, - Arc<GenericParams>, + GenericParams, Box<[TypeRefId]>, Option<TypeRefId>, bool, bool, ) { - let mut expr_collector = ExprCollector::new(db, module, fn_.file_id); + let mut expr_collector = ExprCollector::signature(db, module, fn_.file_id); let mut collector = generics::GenericParamsCollector::new(function_id.into()); collector.lower(&mut expr_collector, fn_.value.generic_param_list(), fn_.value.where_clause()); let mut params = vec![]; @@ -409,7 +417,7 @@ pub(crate) fn lower_function( pub struct ExprCollector<'db> { db: &'db dyn DefDatabase, cfg_options: &'db CfgOptions, - expander: Expander, + expander: Expander<'db>, def_map: &'db DefMap, local_def_map: &'db LocalDefMap, module: ModuleId, @@ -426,7 +434,7 @@ pub struct ExprCollector<'db> { /// and we need to find the current definition. So we track the number of definitions we saw. current_block_legacy_macro_defs_count: FxHashMap<Name, usize>, - current_try_block_label: Option<LabelId>, + current_try_block: Option<TryBlock>, label_ribs: Vec<LabelRib>, unowned_bindings: Vec<BindingId>, @@ -472,6 +480,13 @@ enum Awaitable { No(&'static str), } +enum TryBlock { + // `try { ... }` + Homogeneous { label: LabelId }, + // `try bikeshed Ty { ... }` + Heterogeneous { label: LabelId }, +} + #[derive(Debug, Default)] struct BindingList { map: FxHashMap<(Name, HygieneId), BindingId>, @@ -515,7 +530,20 @@ impl BindingList { } impl<'db> ExprCollector<'db> { - pub fn new( + /// Creates a collector for a signature store, this will populate `const_expr_origins` to any + /// top level const arg roots. + pub fn signature( + db: &dyn DefDatabase, + module: ModuleId, + current_file_id: HirFileId, + ) -> ExprCollector<'_> { + let mut this = Self::body(db, module, current_file_id); + this.store.inference_roots = Some(Default::default()); + this + } + + /// Creates a collector for a bidy store. + pub fn body( db: &dyn DefDatabase, module: ModuleId, current_file_id: HirFileId, @@ -532,7 +560,7 @@ impl<'db> ExprCollector<'db> { lang_items: OnceCell::new(), store: ExpressionStoreBuilder::default(), expander, - current_try_block_label: None, + current_try_block: None, is_lowering_coroutine: false, label_ribs: Vec::new(), unowned_bindings: Vec::new(), @@ -560,7 +588,10 @@ impl<'db> ExprCollector<'db> { self.expander.span_map() } - pub fn lower_lifetime_ref(&mut self, lifetime: ast::Lifetime) -> LifetimeRefId { + pub(in crate::expr_store) fn lower_lifetime_ref( + &mut self, + lifetime: ast::Lifetime, + ) -> LifetimeRefId { // FIXME: Keyword check? let lifetime_ref = match &*lifetime.text() { "" | "'" => LifetimeRef::Error, @@ -571,7 +602,10 @@ impl<'db> ExprCollector<'db> { self.alloc_lifetime_ref(lifetime_ref, AstPtr::new(&lifetime)) } - pub fn lower_lifetime_ref_opt(&mut self, lifetime: Option<ast::Lifetime>) -> LifetimeRefId { + pub(in crate::expr_store) fn lower_lifetime_ref_opt( + &mut self, + lifetime: Option<ast::Lifetime>, + ) -> LifetimeRefId { match lifetime { Some(lifetime) => self.lower_lifetime_ref(lifetime), None => self.alloc_lifetime_ref_desugared(LifetimeRef::Placeholder), @@ -579,7 +613,7 @@ impl<'db> ExprCollector<'db> { } /// Converts an `ast::TypeRef` to a `hir::TypeRef`. - pub fn lower_type_ref( + pub(in crate::expr_store) fn lower_type_ref( &mut self, node: ast::Type, impl_trait_lower_fn: ImplTraitLowerFn<'_>, @@ -604,6 +638,9 @@ impl<'db> ExprCollector<'db> { } ast::Type::ArrayType(inner) => { let len = self.lower_const_arg_opt(inner.const_arg()); + if let Some(const_expr_origins) = &mut self.store.inference_roots { + const_expr_origins.push((len.expr, RootExprOrigin::ArrayLength)); + } TypeRef::Array(ArrayType { ty: self.lower_type_ref_opt(inner.ty(), impl_trait_lower_fn), len, @@ -793,7 +830,7 @@ impl<'db> ExprCollector<'db> { /// Collect `GenericArgs` from the parts of a fn-like path, i.e. `Fn(X, Y) /// -> Z` (which desugars to `Fn<(X, Y), Output=Z>`). - pub fn lower_generic_args_from_fn_path( + pub(in crate::expr_store) fn lower_generic_args_from_fn_path( &mut self, args: Option<ast::ParenthesizedArgList>, ret_type: Option<ast::RetType>, @@ -888,6 +925,9 @@ impl<'db> ExprCollector<'db> { } ast::GenericArg::ConstArg(arg) => { let arg = self.lower_const_arg(arg); + if let Some(const_expr_origins) = &mut self.store.inference_roots { + const_expr_origins.push((arg.expr, RootExprOrigin::GenericArgsPath)); + } args.push(GenericArg::Const(arg)) } } @@ -1028,17 +1068,30 @@ impl<'db> ExprCollector<'db> { } fn lower_const_arg_opt(&mut self, arg: Option<ast::ConstArg>) -> ConstRef { - ConstRef { expr: self.collect_expr_opt(arg.and_then(|it| it.expr())) } + let const_expr_origins = self.store.inference_roots.take(); + let r = ConstRef { expr: self.collect_expr_opt(arg.and_then(|it| it.expr())) }; + self.store.inference_roots = const_expr_origins; + r } - fn lower_const_arg(&mut self, arg: ast::ConstArg) -> ConstRef { - ConstRef { expr: self.collect_expr_opt(arg.expr()) } + pub fn lower_const_arg(&mut self, arg: ast::ConstArg) -> ConstRef { + let const_expr_origins = self.store.inference_roots.take(); + let r = ConstRef { expr: self.collect_expr_opt(arg.expr()) }; + self.store.inference_roots = const_expr_origins; + r } fn collect_expr(&mut self, expr: ast::Expr) -> ExprId { self.maybe_collect_expr(expr).unwrap_or_else(|| self.missing_expr()) } + pub(in crate::expr_store) fn collect_expr_opt(&mut self, expr: Option<ast::Expr>) -> ExprId { + match expr { + Some(expr) => self.collect_expr(expr), + None => self.missing_expr(), + } + } + /// Returns `None` if and only if the expression is `#[cfg]`d out. fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> { let syntax_ptr = AstPtr::new(&expr); @@ -1069,7 +1122,9 @@ impl<'db> ExprCollector<'db> { self.alloc_expr(Expr::Let { pat, expr }, syntax_ptr) } ast::Expr::BlockExpr(e) => match e.modifier() { - Some(ast::BlockModifier::Try(_)) => self.desugar_try_block(e), + Some(ast::BlockModifier::Try { try_token: _, bikeshed_token: _, result_type }) => { + self.desugar_try_block(e, result_type) + } Some(ast::BlockModifier::Unsafe(_)) => { self.collect_block_(e, |id, statements, tail| Expr::Unsafe { id, @@ -1344,7 +1399,7 @@ impl<'db> ExprCollector<'db> { .map(|it| this.lower_type_ref_disallow_impl_trait(it)); let prev_is_lowering_coroutine = mem::take(&mut this.is_lowering_coroutine); - let prev_try_block_label = this.current_try_block_label.take(); + let prev_try_block = this.current_try_block.take(); let awaitable = if e.async_token().is_some() { Awaitable::Yes @@ -1369,7 +1424,7 @@ impl<'db> ExprCollector<'db> { let capture_by = if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref }; this.is_lowering_coroutine = prev_is_lowering_coroutine; - this.current_try_block_label = prev_try_block_label; + this.current_try_block = prev_try_block; this.alloc_expr( Expr::Closure { args: args.into(), @@ -1686,11 +1741,15 @@ impl<'db> ExprCollector<'db> { /// Desugar `try { <stmts>; <expr> }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(<expr>) }`, /// `try { <stmts>; }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(()) }` /// and save the `<new_label>` to use it as a break target for desugaring of the `?` operator. - fn desugar_try_block(&mut self, e: BlockExpr) -> ExprId { + fn desugar_try_block(&mut self, e: BlockExpr, result_type: Option<ast::Type>) -> ExprId { let try_from_output = self.lang_path(self.lang_items().TryTraitFromOutput); let label = self.generate_new_name(); let label = self.alloc_label_desugared(Label { name: label }, AstPtr::new(&e).wrap_right()); - let old_label = self.current_try_block_label.replace(label); + let try_block_info = match result_type { + Some(_) => TryBlock::Heterogeneous { label }, + None => TryBlock::Homogeneous { label }, + }; + let old_try_block = self.current_try_block.replace(try_block_info); let ptr = AstPtr::new(&e).upcast(); let (btail, expr_id) = self.with_labeled_rib(label, HygieneId::ROOT, |this| { @@ -1720,8 +1779,38 @@ impl<'db> ExprCollector<'db> { unreachable!("block was lowered to non-block"); }; *tail = Some(next_tail); - self.current_try_block_label = old_label; - expr_id + self.current_try_block = old_try_block; + match result_type { + Some(ty) => { + // `{ let <name>: <ty> = <expr>; <name> }` + let name = self.generate_new_name(); + let type_ref = self.lower_type_ref_disallow_impl_trait(ty); + let binding = self.alloc_binding( + name.clone(), + BindingAnnotation::Unannotated, + HygieneId::ROOT, + ); + let pat = self.alloc_pat_desugared(Pat::Bind { id: binding, subpat: None }); + self.add_definition_to_binding(binding, pat); + let tail_expr = + self.alloc_expr_desugared_with_ptr(Expr::Path(Path::from(name)), ptr); + self.alloc_expr_desugared_with_ptr( + Expr::Block { + id: None, + statements: Box::new([Statement::Let { + pat, + type_ref: Some(type_ref), + initializer: Some(expr_id), + else_branch: None, + }]), + tail: Some(tail_expr), + label: None, + }, + ptr, + ) + } + None => expr_id, + } } /// Desugar `ast::WhileExpr` from: `[opt_ident]: while <cond> <body>` into: @@ -1863,6 +1952,8 @@ impl<'db> ExprCollector<'db> { /// ControlFlow::Continue(val) => val, /// ControlFlow::Break(residual) => /// // If there is an enclosing `try {...}`: + /// break 'catch_target Residual::into_try_type(residual), + /// // If there is an enclosing `try bikeshed Ty {...}`: /// break 'catch_target Try::from_residual(residual), /// // Otherwise: /// return Try::from_residual(residual), @@ -1873,7 +1964,6 @@ impl<'db> ExprCollector<'db> { let try_branch = self.lang_path(lang_items.TryTraitBranch); let cf_continue = self.lang_path(lang_items.ControlFlowContinue); let cf_break = self.lang_path(lang_items.ControlFlowBreak); - let try_from_residual = self.lang_path(lang_items.TryTraitFromResidual); let operand = self.collect_expr_opt(e.expr()); let try_branch = self.alloc_expr(try_branch.map_or(Expr::Missing, Expr::Path), syntax_ptr); let expr = self @@ -1910,13 +2000,23 @@ impl<'db> ExprCollector<'db> { guard: None, expr: { let it = self.alloc_expr(Expr::Path(Path::from(break_name)), syntax_ptr); - let callee = self - .alloc_expr(try_from_residual.map_or(Expr::Missing, Expr::Path), syntax_ptr); + let convert_fn = match self.current_try_block { + Some(TryBlock::Homogeneous { .. }) => { + self.lang_path(lang_items.ResidualIntoTryType) + } + Some(TryBlock::Heterogeneous { .. }) | None => { + self.lang_path(lang_items.TryTraitFromResidual) + } + }; + let callee = + self.alloc_expr(convert_fn.map_or(Expr::Missing, Expr::Path), syntax_ptr); let result = self.alloc_expr(Expr::Call { callee, args: Box::new([it]) }, syntax_ptr); self.alloc_expr( - match self.current_try_block_label { - Some(label) => Expr::Break { expr: Some(result), label: Some(label) }, + match self.current_try_block { + Some( + TryBlock::Heterogeneous { label } | TryBlock::Homogeneous { label }, + ) => Expr::Break { expr: Some(result), label: Some(label) }, None => Expr::Return { expr: Some(result) }, }, syntax_ptr, @@ -2001,13 +2101,6 @@ impl<'db> ExprCollector<'db> { } } - pub fn collect_expr_opt(&mut self, expr: Option<ast::Expr>) -> ExprId { - match expr { - Some(expr) => self.collect_expr(expr), - None => self.missing_expr(), - } - } - fn collect_macro_as_stmt( &mut self, statements: &mut Vec<Statement>, @@ -2194,6 +2287,32 @@ impl<'db> ExprCollector<'db> { } } + fn collect_extern_fn_param(&mut self, pat: Option<ast::Pat>) -> PatId { + // `extern` functions cannot have pattern-matched parameters, and furthermore, the identifiers + // in their parameters are always interpreted as bindings, even if in a normal function they + // won't be, because they would refer to a path pattern. + let Some(pat) = pat else { return self.missing_pat() }; + + match &pat { + ast::Pat::IdentPat(bp) => { + // FIXME: Emit an error if `!bp.is_simple_ident()`. + + let name = bp.name().map(|nr| nr.as_name()).unwrap_or_else(Name::missing); + let hygiene = bp + .name() + .map(|name| self.hygiene_id_for(name.syntax().text_range())) + .unwrap_or(HygieneId::ROOT); + let binding = self.alloc_binding(name, BindingAnnotation::Unannotated, hygiene); + let pat = + self.alloc_pat(Pat::Bind { id: binding, subpat: None }, AstPtr::new(&pat)); + self.add_definition_to_binding(binding, pat); + pat + } + // FIXME: Emit an error. + _ => self.missing_pat(), + } + } + // region: patterns fn collect_pat_top(&mut self, pat: Option<ast::Pat>) -> PatId { @@ -2242,7 +2361,7 @@ impl<'db> ExprCollector<'db> { } Some(ModuleDefId::AdtId(AdtId::StructId(s))) // FIXME: This can cause a cycle if the user is writing invalid code - if self.db.struct_signature(s).shape != FieldsShape::Record => + if StructSignature::of(self.db, s).shape != FieldsShape::Record => { (None, Pat::Path(name.into())) } diff --git a/crates/hir-def/src/expr_store/lower/generics.rs b/crates/hir-def/src/expr_store/lower/generics.rs index c570df42b2..5ffc4f5851 100644 --- a/crates/hir-def/src/expr_store/lower/generics.rs +++ b/crates/hir-def/src/expr_store/lower/generics.rs @@ -3,15 +3,12 @@ //! generic parameters. See also the `Generics` type and the `generics_of` query //! in rustc. -use std::sync::LazyLock; - use either::Either; use hir_expand::name::{AsName, Name}; use intern::sym; use la_arena::Arena; use syntax::ast::{self, HasName, HasTypeBounds}; use thin_vec::ThinVec; -use triomphe::Arc; use crate::{ GenericDefId, TypeOrConstParamId, TypeParamId, @@ -84,28 +81,16 @@ impl GenericParamsCollector { ) } - pub(crate) fn finish(self) -> Arc<GenericParams> { - let Self { mut lifetimes, mut type_or_consts, mut where_predicates, parent: _ } = self; - - if lifetimes.is_empty() && type_or_consts.is_empty() && where_predicates.is_empty() { - static EMPTY: LazyLock<Arc<GenericParams>> = LazyLock::new(|| { - Arc::new(GenericParams { - lifetimes: Arena::new(), - type_or_consts: Arena::new(), - where_predicates: Box::default(), - }) - }); - return Arc::clone(&EMPTY); - } + pub(crate) fn finish(self) -> GenericParams { + let Self { mut lifetimes, mut type_or_consts, where_predicates, parent: _ } = self; lifetimes.shrink_to_fit(); type_or_consts.shrink_to_fit(); - where_predicates.shrink_to_fit(); - Arc::new(GenericParams { + GenericParams { type_or_consts, lifetimes, where_predicates: where_predicates.into_boxed_slice(), - }) + } } fn lower_param_list(&mut self, ec: &mut ExprCollector<'_>, params: ast::GenericParamList) { @@ -141,12 +126,17 @@ impl GenericParamsCollector { const_param.ty(), &mut ExprCollector::impl_trait_error_allocator, ); - let param = ConstParamData { - name, - ty, - default: const_param.default_val().map(|it| ec.lower_const_arg(it)), - }; - let _idx = self.type_or_consts.alloc(param.into()); + let default = const_param.default_val().map(|it| ec.lower_const_arg(it)); + let param = ConstParamData { name, ty, default }; + let idx = self.type_or_consts.alloc(param.into()); + if let Some(default) = default + && let Some(const_expr_origins) = &mut ec.store.inference_roots + { + const_expr_origins.push(( + default.expr, + crate::expr_store::RootExprOrigin::ConstParam(idx), + )); + } } ast::GenericParam::LifetimeParam(lifetime_param) => { let lifetime = ec.lower_lifetime_ref_opt(lifetime_param.lifetime()); diff --git a/crates/hir-def/src/expr_store/lower/path/tests.rs b/crates/hir-def/src/expr_store/lower/path/tests.rs index f507841a91..6819eb3deb 100644 --- a/crates/hir-def/src/expr_store/lower/path/tests.rs +++ b/crates/hir-def/src/expr_store/lower/path/tests.rs @@ -21,7 +21,7 @@ fn lower_path(path: ast::Path) -> (TestDB, ExpressionStore, Option<Path>) { let (db, file_id) = TestDB::with_single_file(""); let krate = db.fetch_test_crate(); let mut ctx = - ExprCollector::new(&db, crate_def_map(&db, krate).root_module_id(), file_id.into()); + ExprCollector::signature(&db, crate_def_map(&db, krate).root_module_id(), file_id.into()); let lowered_path = ctx.lower_path(path, &mut ExprCollector::impl_trait_allocator); let (store, _) = ctx.store.finish(); (db, store, lowered_path) diff --git a/crates/hir-def/src/expr_store/pretty.rs b/crates/hir-def/src/expr_store/pretty.rs index 35f3cd114e..9c9c4db3b2 100644 --- a/crates/hir-def/src/expr_store/pretty.rs +++ b/crates/hir-def/src/expr_store/pretty.rs @@ -105,7 +105,7 @@ pub fn print_body_hir( p.buf.push(')'); p.buf.push(' '); } - p.print_expr(body.body_expr); + p.print_expr(body.root_expr()); if matches!(owner, DefWithBodyId::StaticId(_) | DefWithBodyId::ConstId(_)) { p.buf.push(';'); } @@ -168,8 +168,8 @@ pub fn print_signature(db: &dyn DefDatabase, owner: GenericDefId, edition: Editi match owner { GenericDefId::AdtId(id) => match id { AdtId::StructId(id) => { - let signature = db.struct_signature(id); - print_struct(db, id, &signature, edition) + let signature = StructSignature::of(db, id); + print_struct(db, id, signature, edition) } AdtId::UnionId(id) => { format!("unimplemented {id:?}") @@ -180,8 +180,8 @@ pub fn print_signature(db: &dyn DefDatabase, owner: GenericDefId, edition: Editi }, GenericDefId::ConstId(id) => format!("unimplemented {id:?}"), GenericDefId::FunctionId(id) => { - let signature = db.function_signature(id); - print_function(db, id, &signature, edition) + let signature = FunctionSignature::of(db, id); + print_function(db, id, signature, edition) } GenericDefId::ImplId(id) => format!("unimplemented {id:?}"), GenericDefId::StaticId(id) => format!("unimplemented {id:?}"), @@ -1212,7 +1212,7 @@ impl Printer<'_> { } pub(crate) fn print_type_param(&mut self, param: TypeParamId) { - let generic_params = self.db.generic_params(param.parent()); + let generic_params = GenericParams::of(self.db, param.parent()); match generic_params[param.local_id()].name() { Some(name) => w!(self, "{}", name.display(self.db, self.edition)), @@ -1221,7 +1221,7 @@ impl Printer<'_> { } pub(crate) fn print_lifetime_param(&mut self, param: LifetimeParamId) { - let generic_params = self.db.generic_params(param.parent); + let generic_params = GenericParams::of(self.db, param.parent); w!(self, "{}", generic_params[param.local_id].name.display(self.db, self.edition)) } diff --git a/crates/hir-def/src/expr_store/scope.rs b/crates/hir-def/src/expr_store/scope.rs index 1952dae9d7..40ae0b7de4 100644 --- a/crates/hir-def/src/expr_store/scope.rs +++ b/crates/hir-def/src/expr_store/scope.rs @@ -1,13 +1,16 @@ //! Name resolution for expressions. use hir_expand::{MacroDefId, name::Name}; use la_arena::{Arena, ArenaMap, Idx, IdxRange, RawIdx}; -use triomphe::Arc; use crate::{ - BlockId, DefWithBodyId, + BlockId, DefWithBodyId, ExpressionStoreOwnerId, GenericDefId, VariantId, db::DefDatabase, expr_store::{Body, ExpressionStore, HygieneId}, - hir::{Binding, BindingId, Expr, ExprId, Item, LabelId, Pat, PatId, Statement}, + hir::{ + Binding, BindingId, Expr, ExprId, Item, LabelId, Pat, PatId, Statement, + generics::GenericParams, + }, + signatures::VariantFields, }; pub type ScopeId = Idx<ScopeData>; @@ -50,12 +53,45 @@ pub struct ScopeData { entries: IdxRange<ScopeEntry>, } +#[salsa::tracked] impl ExprScopes { - pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> { - let body = db.body(def); - let mut scopes = ExprScopes::new_body(&body); + #[salsa::tracked(returns(ref))] + pub fn body_expr_scopes(db: &dyn DefDatabase, def: DefWithBodyId) -> ExprScopes { + let body = Body::of(db, def); + let mut scopes = ExprScopes::new_body(body); scopes.shrink_to_fit(); - Arc::new(scopes) + scopes + } + + #[salsa::tracked(returns(ref))] + pub fn sig_expr_scopes(db: &dyn DefDatabase, def: GenericDefId) -> ExprScopes { + let (_, store) = GenericParams::with_store(db, def); + let roots = store.expr_roots(); + let mut scopes = ExprScopes::new_store(store, roots); + scopes.shrink_to_fit(); + scopes + } + + #[salsa::tracked(returns(ref))] + pub fn variant_scopes(db: &dyn DefDatabase, def: VariantId) -> ExprScopes { + let fields = VariantFields::of(db, def); + let roots = fields.store.expr_roots(); + let mut scopes = ExprScopes::new_store(&fields.store, roots); + scopes.shrink_to_fit(); + scopes + } +} + +impl ExprScopes { + #[inline] + pub fn of(db: &dyn DefDatabase, def: impl Into<ExpressionStoreOwnerId>) -> &ExprScopes { + match def.into() { + ExpressionStoreOwnerId::Body(def) => Self::body_expr_scopes(db, def), + ExpressionStoreOwnerId::Signature(def) => Self::sig_expr_scopes(db, def), + ExpressionStoreOwnerId::VariantFields(variant_id) => { + Self::variant_scopes(db, variant_id) + } + } } pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] { @@ -115,7 +151,23 @@ impl ExprScopes { scopes.add_bindings(body, root, self_param, body.binding_hygiene(self_param)); } scopes.add_params_bindings(body, root, &body.params); - compute_expr_scopes(body.body_expr, body, &mut scopes, &mut root); + compute_expr_scopes(body.root_expr(), body, &mut scopes, &mut root); + scopes + } + + fn new_store(store: &ExpressionStore, roots: impl IntoIterator<Item = ExprId>) -> ExprScopes { + let mut scopes = ExprScopes { + scopes: Arena::default(), + scope_entries: Arena::default(), + scope_by_expr: ArenaMap::with_capacity( + store.expr_only.as_ref().map_or(0, |it| it.exprs.len()), + ), + }; + let root = scopes.root_scope(); + for root_expr in roots { + let mut scope = scopes.new_scope(root); + compute_expr_scopes(root_expr, store, &mut scopes, &mut scope); + } scopes } @@ -327,7 +379,10 @@ mod tests { use test_utils::{assert_eq_text, extract_offset}; use crate::{ - FunctionId, ModuleDefId, db::DefDatabase, nameres::crate_def_map, test_db::TestDB, + DefWithBodyId, FunctionId, ModuleDefId, + expr_store::{Body, scope::ExprScopes}, + nameres::crate_def_map, + test_db::TestDB, }; fn find_function(db: &TestDB, file_id: FileId) -> FunctionId { @@ -363,8 +418,8 @@ mod tests { let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap(); let function = find_function(&db, file_id); - let scopes = db.expr_scopes(function.into()); - let (_body, source_map) = db.body_with_source_map(function.into()); + let scopes = ExprScopes::of(&db, DefWithBodyId::from(function)); + let (_body, source_map) = Body::with_source_map(&db, function.into()); let expr_id = source_map .node_expr(InFile { file_id: editioned_file_id.into(), value: &marker.into() }) @@ -522,8 +577,8 @@ fn foo() { let function = find_function(&db, file_id); - let scopes = db.expr_scopes(function.into()); - let (_, source_map) = db.body_with_source_map(function.into()); + let scopes = ExprScopes::body_expr_scopes(&db, DefWithBodyId::from(function)); + let (_, source_map) = Body::with_source_map(&db, function.into()); let expr_scope = { let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap(); diff --git a/crates/hir-def/src/expr_store/tests/body.rs b/crates/hir-def/src/expr_store/tests/body.rs index 8f857aeeff..985cd96662 100644 --- a/crates/hir-def/src/expr_store/tests/body.rs +++ b/crates/hir-def/src/expr_store/tests/body.rs @@ -4,11 +4,10 @@ use crate::{DefWithBodyId, ModuleDefId, hir::MatchArm, nameres::crate_def_map, t use expect_test::{Expect, expect}; use la_arena::RawIdx; use test_fixture::WithFixture; -use triomphe::Arc; use super::super::*; -fn lower(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> (TestDB, Arc<Body>, DefWithBodyId) { +fn lower(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> (TestDB, DefWithBodyId) { let db = TestDB::with_files(ra_fixture); let krate = db.fetch_test_crate(); @@ -24,8 +23,27 @@ fn lower(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> (TestDB, Arc<Body>, } let fn_def = fn_def.unwrap().into(); - let body = db.body(fn_def); - (db, body, fn_def) + Body::of(&db, fn_def); + (db, fn_def) +} + +fn pretty_print(#[rust_analyzer::rust_fixture] ra_fixture: &str, expect: Expect) { + let db = TestDB::with_files(ra_fixture); + + let krate = db.fetch_test_crate(); + let def_map = crate_def_map(&db, krate); + let mut fn_def = None; + 'outer: for (_, module) in def_map.modules() { + for decl in module.scope.declarations() { + if let ModuleDefId::FunctionId(it) = decl { + fn_def = Some(it); + break 'outer; + } + } + } + let fn_def = fn_def.unwrap().into(); + + expect.assert_eq(&Body::of(&db, fn_def).pretty_print(&db, fn_def, Edition::CURRENT)); } fn def_map_at(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> String { @@ -144,7 +162,7 @@ mod m { #[test] fn desugar_for_loop() { - let (db, body, def) = lower( + pretty_print( r#" //- minicore: iterator fn main() { @@ -154,9 +172,7 @@ fn main() { } } "#, - ); - - expect![[r#" + expect![[r#" fn main() { match builtin#lang(into_iter)( 0..10, @@ -173,13 +189,13 @@ fn main() { } }, } - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ); } #[test] fn desugar_builtin_format_args_before_1_89_0() { - let (db, body, def) = lower( + pretty_print( r#" //- minicore: fmt_before_1_89_0 fn main() { @@ -188,9 +204,7 @@ fn main() { builtin#format_args("\u{1b}hello {count:02} {} friends, we {are:?} {0}{last}", "fancy", orphan = (), last = "!"); } "#, - ); - - expect![[r#" + expect![[r#" fn main() { let are = "are"; let count = 10; @@ -256,13 +270,13 @@ fn main() { } }, ); - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } #[test] fn desugar_builtin_format_args_before_1_93_0() { - let (db, body, def) = lower( + pretty_print( r#" //- minicore: fmt_before_1_93_0 fn main() { @@ -271,9 +285,7 @@ fn main() { builtin#format_args("\u{1b}hello {count:02} {} friends, we {are:?} {0}{last}", "fancy", orphan = (), last = "!"); } "#, - ); - - expect![[r#" + expect![[r#" fn main() { let are = "are"; let count = 10; @@ -339,13 +351,13 @@ fn main() { ) } }; - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } #[test] fn desugar_builtin_format_args() { - let (db, body, def) = lower( + pretty_print( r#" //- minicore: fmt fn main() { @@ -356,9 +368,7 @@ fn main() { builtin#format_args("hello world", orphan = ()); } "#, - ); - - expect![[r#" + expect![[r#" fn main() { let are = "are"; let count = 10; @@ -392,13 +402,13 @@ fn main() { "hello world", ) }; - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } #[test] fn test_macro_hygiene() { - let (db, body, def) = lower( + pretty_print( r##" //- minicore: fmt, from //- /main.rs @@ -428,10 +438,7 @@ impl SsrError { } } "##, - ); - - assert_eq!(db.body_with_source_map(def).1.diagnostics(), &[]); - expect![[r#" + expect![[r#" fn main() { _ = ra_test_fixture::error::SsrError::new( { @@ -449,13 +456,13 @@ impl SsrError { } }, ); - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } #[test] fn regression_10300() { - let (db, body, def) = lower( + pretty_print( r#" //- minicore: concat, panic, fmt_before_1_89_0 mod private { @@ -472,16 +479,7 @@ fn f(a: i32, b: u32) -> String { m!(); } "#, - ); - - let (_, source_map) = db.body_with_source_map(def); - assert_eq!(source_map.diagnostics(), &[]); - - for (_, def_map) in body.blocks(&db) { - assert_eq!(def_map.diagnostics(), &[]); - } - - expect![[r#" + expect![[r#" fn f(a, b) { { core::panicking::panic_fmt( @@ -497,8 +495,8 @@ fn f(a: i32, b: u32) -> String { ), ); }; - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } #[test] @@ -507,7 +505,7 @@ fn destructuring_assignment_tuple_macro() { // but in destructuring assignment it is valid, because `m!()()` is a valid expression, and destructuring // assignments start their lives as expressions. So we have to do the same. - let (db, body, def) = lower( + pretty_print( r#" struct Bar(); @@ -519,25 +517,16 @@ fn foo() { m!()() = Bar(); } "#, - ); - - let (_, source_map) = db.body_with_source_map(def); - assert_eq!(source_map.diagnostics(), &[]); - - for (_, def_map) in body.blocks(&db) { - assert_eq!(def_map.diagnostics(), &[]); - } - - expect![[r#" + expect![[r#" fn foo() { Bar() = Bar(); - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } #[test] fn shadowing_record_variant() { - let (_, body, _) = lower( + let (db, def) = lower( r#" enum A { B { field: i32 }, @@ -550,6 +539,7 @@ fn f() { } "#, ); + let body = Body::of(&db, def); assert_eq!(body.assert_expr_only().bindings.len(), 1, "should have a binding for `B`"); assert_eq!( body[BindingId::from_raw(RawIdx::from_u32(0))].name.as_str(), @@ -560,39 +550,35 @@ fn f() { #[test] fn regression_pretty_print_bind_pat() { - let (db, body, owner) = lower( + pretty_print( r#" fn foo() { let v @ u = 123; } "#, - ); - let printed = body.pretty_print(&db, owner, Edition::CURRENT); - - expect![[r#" + expect![[r#" fn foo() { let v @ u = 123; - }"#]] - .assert_eq(&printed); + }"#]], + ); } #[test] fn skip_skips_body() { - let (db, body, owner) = lower( + pretty_print( r#" #[rust_analyzer::skip] async fn foo(a: (), b: i32) -> u32 { 0 + 1 + b() } "#, + expect!["fn foo(�, �) �"], ); - let printed = body.pretty_print(&db, owner, Edition::CURRENT); - expect!["fn foo(�, �) �"].assert_eq(&printed); } #[test] fn range_bounds_are_hir_exprs() { - let (_, body, _) = lower( + let (db, body) = lower( r#" pub const L: i32 = 6; mod x { @@ -607,6 +593,7 @@ const fn f(x: i32) -> i32 { }"#, ); + let body = Body::of(&db, body); let mtch_arms = body .assert_expr_only() .exprs @@ -635,7 +622,7 @@ const fn f(x: i32) -> i32 { #[test] fn print_hir_precedences() { - let (db, body, def) = lower( + pretty_print( r#" fn main() { _ = &(1 - (2 - 3) + 4 * 5 * (6 + 7)); @@ -646,9 +633,7 @@ fn main() { let _ = &mut (*r as i32) } "#, - ); - - expect![[r#" + expect![[r#" fn main() { _ = &((1 - (2 - 3)) + (4 * 5) * (6 + 7)); _ = 1 + 2 < 3 && true && 4 < 5 && (a || b || c) || d && e; @@ -656,24 +641,22 @@ fn main() { break a && b || (return) || (return 2); let r = &2; let _ = &mut (*r as i32); - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } #[test] fn async_fn_weird_param_patterns() { - let (db, body, def) = lower( + pretty_print( r#" async fn main(&self, param1: i32, ref mut param2: i32, _: i32, param4 @ _: i32, 123: i32) {} "#, - ); - - expect![[r#" + expect![[r#" fn main(self, param1, mut param2, mut <ra@gennew>0, param4 @ _, mut <ra@gennew>1) async { let ref mut param2 = param2; let _ = <ra@gennew>0; let 123 = <ra@gennew>1; {} - }"#]] - .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) + }"#]], + ) } diff --git a/crates/hir-def/src/expr_store/tests/body/block.rs b/crates/hir-def/src/expr_store/tests/body/block.rs index d457a4ca7a..83594ee021 100644 --- a/crates/hir-def/src/expr_store/tests/body/block.rs +++ b/crates/hir-def/src/expr_store/tests/body/block.rs @@ -196,7 +196,7 @@ fn f() { ), block: Some( BlockId( - 4401, + 4801, ), ), }"#]], diff --git a/crates/hir-def/src/expr_store/tests/signatures.rs b/crates/hir-def/src/expr_store/tests/signatures.rs index f1db00cf6a..5e0184dfad 100644 --- a/crates/hir-def/src/expr_store/tests/signatures.rs +++ b/crates/hir-def/src/expr_store/tests/signatures.rs @@ -2,6 +2,7 @@ use crate::{ GenericDefId, ModuleDefId, expr_store::pretty::{print_function, print_struct}, nameres::crate_def_map, + signatures::{FunctionSignature, StructSignature}, test_db::TestDB, }; use expect_test::{Expect, expect}; @@ -41,7 +42,7 @@ fn lower_and_print(#[rust_analyzer::rust_fixture] ra_fixture: &str, expect: Expe out += &print_struct( &db, struct_id, - &db.struct_signature(struct_id), + StructSignature::of(&db, struct_id), Edition::CURRENT, ); } @@ -53,7 +54,7 @@ fn lower_and_print(#[rust_analyzer::rust_fixture] ra_fixture: &str, expect: Expe out += &print_function( &db, function_id, - &db.function_signature(function_id), + FunctionSignature::of(&db, function_id), Edition::CURRENT, ) } diff --git a/crates/hir-def/src/find_path.rs b/crates/hir-def/src/find_path.rs index 5d1cac8e93..8308203693 100644 --- a/crates/hir-def/src/find_path.rs +++ b/crates/hir-def/src/find_path.rs @@ -14,6 +14,7 @@ use rustc_hash::FxHashSet; use crate::{ FindPathConfig, ModuleDefId, ModuleId, db::DefDatabase, + import_map::ImportMap, item_scope::ItemInNs, nameres::DefMap, visibility::{Visibility, VisibilityExplicitness}, @@ -426,7 +427,7 @@ fn find_in_dep( best_choice: &mut Option<Choice>, dep: Crate, ) { - let import_map = ctx.db.import_map(dep); + let import_map = ImportMap::of(ctx.db, dep); let Some(import_info_for) = import_map.import_info_for(item) else { return; }; diff --git a/crates/hir-def/src/hir/generics.rs b/crates/hir-def/src/hir/generics.rs index 482cf36f95..43dd7d1c54 100644 --- a/crates/hir-def/src/hir/generics.rs +++ b/crates/hir-def/src/hir/generics.rs @@ -5,12 +5,15 @@ use hir_expand::name::Name; use la_arena::{Arena, Idx, RawIdx}; use stdx::impl_from; use thin_vec::ThinVec; -use triomphe::Arc; use crate::{ AdtId, ConstParamId, GenericDefId, LifetimeParamId, TypeOrConstParamId, TypeParamId, db::DefDatabase, expr_store::{ExpressionStore, ExpressionStoreSourceMap}, + signatures::{ + ConstSignature, EnumSignature, FunctionSignature, ImplSignature, StaticSignature, + StructSignature, TraitSignature, TypeAliasSignature, UnionSignature, + }, type_ref::{ConstRef, LifetimeRefId, TypeBound, TypeRefId}, }; @@ -142,7 +145,7 @@ pub enum GenericParamDataRef<'a> { } /// Data about the generic parameters of a function, struct, impl, etc. -#[derive(Clone, PartialEq, Eq, Debug, Hash)] +#[derive(PartialEq, Eq, Debug, Hash, Default)] pub struct GenericParams { pub(crate) type_or_consts: Arena<TypeOrConstParamData>, pub(crate) lifetimes: Arena<LifetimeParamData>, @@ -174,125 +177,105 @@ pub enum WherePredicate { ForLifetime { lifetimes: ThinVec<Name>, target: TypeRefId, bound: TypeBound }, } -static EMPTY: LazyLock<Arc<GenericParams>> = LazyLock::new(|| { - Arc::new(GenericParams { - type_or_consts: Arena::default(), - lifetimes: Arena::default(), - where_predicates: Box::default(), - }) +static EMPTY: LazyLock<GenericParams> = LazyLock::new(|| GenericParams { + type_or_consts: Arena::default(), + lifetimes: Arena::default(), + where_predicates: Box::default(), }); impl GenericParams { /// The index of the self param in the generic of the non-parent definition. - pub(crate) const SELF_PARAM_ID_IN_SELF: la_arena::Idx<TypeOrConstParamData> = + pub const SELF_PARAM_ID_IN_SELF: la_arena::Idx<TypeOrConstParamData> = LocalTypeOrConstParamId::from_raw(RawIdx::from_u32(0)); - pub fn new(db: &dyn DefDatabase, def: GenericDefId) -> Arc<GenericParams> { - match def { - GenericDefId::AdtId(AdtId::EnumId(it)) => db.enum_signature(it).generic_params.clone(), - GenericDefId::AdtId(AdtId::StructId(it)) => { - db.struct_signature(it).generic_params.clone() - } - GenericDefId::AdtId(AdtId::UnionId(it)) => { - db.union_signature(it).generic_params.clone() - } - GenericDefId::ConstId(_) => EMPTY.clone(), - GenericDefId::FunctionId(function_id) => { - db.function_signature(function_id).generic_params.clone() - } - GenericDefId::ImplId(impl_id) => db.impl_signature(impl_id).generic_params.clone(), - GenericDefId::StaticId(_) => EMPTY.clone(), - GenericDefId::TraitId(trait_id) => db.trait_signature(trait_id).generic_params.clone(), - GenericDefId::TypeAliasId(type_alias_id) => { - db.type_alias_signature(type_alias_id).generic_params.clone() - } - } + pub fn of(db: &dyn DefDatabase, def: GenericDefId) -> &GenericParams { + Self::with_store(db, def).0 } - pub fn generic_params_and_store( + pub fn with_store( db: &dyn DefDatabase, def: GenericDefId, - ) -> (Arc<GenericParams>, Arc<ExpressionStore>) { + ) -> (&GenericParams, &ExpressionStore) { match def { GenericDefId::AdtId(AdtId::EnumId(id)) => { - let sig = db.enum_signature(id); - (sig.generic_params.clone(), sig.store.clone()) + let sig = EnumSignature::of(db, id); + (&sig.generic_params, &sig.store) } GenericDefId::AdtId(AdtId::StructId(id)) => { - let sig = db.struct_signature(id); - (sig.generic_params.clone(), sig.store.clone()) + let sig = StructSignature::of(db, id); + (&sig.generic_params, &sig.store) } GenericDefId::AdtId(AdtId::UnionId(id)) => { - let sig = db.union_signature(id); - (sig.generic_params.clone(), sig.store.clone()) + let sig = UnionSignature::of(db, id); + (&sig.generic_params, &sig.store) } GenericDefId::ConstId(id) => { - let sig = db.const_signature(id); - (EMPTY.clone(), sig.store.clone()) + let sig = ConstSignature::of(db, id); + (&EMPTY, &sig.store) } GenericDefId::FunctionId(id) => { - let sig = db.function_signature(id); - (sig.generic_params.clone(), sig.store.clone()) + let sig = FunctionSignature::of(db, id); + (&sig.generic_params, &sig.store) } GenericDefId::ImplId(id) => { - let sig = db.impl_signature(id); - (sig.generic_params.clone(), sig.store.clone()) + let sig = ImplSignature::of(db, id); + (&sig.generic_params, &sig.store) } GenericDefId::StaticId(id) => { - let sig = db.static_signature(id); - (EMPTY.clone(), sig.store.clone()) + let sig = StaticSignature::of(db, id); + (&EMPTY, &sig.store) } GenericDefId::TraitId(id) => { - let sig = db.trait_signature(id); - (sig.generic_params.clone(), sig.store.clone()) + let sig = TraitSignature::of(db, id); + (&sig.generic_params, &sig.store) } GenericDefId::TypeAliasId(id) => { - let sig = db.type_alias_signature(id); - (sig.generic_params.clone(), sig.store.clone()) + let sig = TypeAliasSignature::of(db, id); + (&sig.generic_params, &sig.store) } } } - pub fn generic_params_and_store_and_source_map( + pub fn with_source_map( db: &dyn DefDatabase, def: GenericDefId, - ) -> (Arc<GenericParams>, Arc<ExpressionStore>, Arc<ExpressionStoreSourceMap>) { + ) -> (&GenericParams, &ExpressionStore, &ExpressionStoreSourceMap) { match def { GenericDefId::AdtId(AdtId::EnumId(id)) => { - let (sig, sm) = db.enum_signature_with_source_map(id); - (sig.generic_params.clone(), sig.store.clone(), sm) + let (sig, sm) = EnumSignature::with_source_map(db, id); + (&sig.generic_params, &sig.store, sm) } GenericDefId::AdtId(AdtId::StructId(id)) => { - let (sig, sm) = db.struct_signature_with_source_map(id); - (sig.generic_params.clone(), sig.store.clone(), sm) + let (sig, sm) = StructSignature::with_source_map(db, id); + (&sig.generic_params, &sig.store, sm) } GenericDefId::AdtId(AdtId::UnionId(id)) => { - let (sig, sm) = db.union_signature_with_source_map(id); - (sig.generic_params.clone(), sig.store.clone(), sm) + let (sig, sm) = UnionSignature::with_source_map(db, id); + (&sig.generic_params, &sig.store, sm) } GenericDefId::ConstId(id) => { - let (sig, sm) = db.const_signature_with_source_map(id); - (EMPTY.clone(), sig.store.clone(), sm) + let (sig, sm) = ConstSignature::with_source_map(db, id); + (&EMPTY, &sig.store, sm) } GenericDefId::FunctionId(id) => { - let (sig, sm) = db.function_signature_with_source_map(id); - (sig.generic_params.clone(), sig.store.clone(), sm) + let (sig, sm) = FunctionSignature::with_source_map(db, id); + (&sig.generic_params, &sig.store, sm) } GenericDefId::ImplId(id) => { - let (sig, sm) = db.impl_signature_with_source_map(id); - (sig.generic_params.clone(), sig.store.clone(), sm) + let (sig, sm) = ImplSignature::with_source_map(db, id); + (&sig.generic_params, &sig.store, sm) } GenericDefId::StaticId(id) => { - let (sig, sm) = db.static_signature_with_source_map(id); - (EMPTY.clone(), sig.store.clone(), sm) + let (sig, sm) = StaticSignature::with_source_map(db, id); + (&EMPTY, &sig.store, sm) } GenericDefId::TraitId(id) => { - let (sig, sm) = db.trait_signature_with_source_map(id); - (sig.generic_params.clone(), sig.store.clone(), sm) + let (sig, sm) = TraitSignature::with_source_map(db, id); + (&sig.generic_params, &sig.store, sm) } GenericDefId::TypeAliasId(id) => { - let (sig, sm) = db.type_alias_signature_with_source_map(id); - (sig.generic_params.clone(), sig.store.clone(), sm) + let (sig, sm) = TypeAliasSignature::with_source_map(db, id); + (&sig.generic_params, &sig.store, sm) } } } diff --git a/crates/hir-def/src/import_map.rs b/crates/hir-def/src/import_map.rs index 6c5d226cac..0014e1af5c 100644 --- a/crates/hir-def/src/import_map.rs +++ b/crates/hir-def/src/import_map.rs @@ -10,7 +10,6 @@ use rustc_hash::FxHashSet; use smallvec::SmallVec; use span::Edition; use stdx::format_to; -use triomphe::Arc; use crate::{ AssocItemId, AttrDefId, Complete, FxIndexMap, ModuleDefId, ModuleId, TraitId, @@ -63,6 +62,14 @@ enum IsTraitAssocItem { type ImportMapIndex = FxIndexMap<ItemInNs, (SmallVec<[ImportInfo; 1]>, IsTraitAssocItem)>; +#[salsa::tracked] +impl ImportMap { + #[salsa::tracked(returns(ref))] + pub fn of(db: &dyn DefDatabase, krate: Crate) -> Self { + Self::import_map_query_impl(db, krate) + } +} + impl ImportMap { pub fn dump(&self, db: &dyn DefDatabase) -> String { let mut out = String::new(); @@ -76,7 +83,7 @@ impl ImportMap { out } - pub(crate) fn import_map_query(db: &dyn DefDatabase, krate: Crate) -> Arc<Self> { + fn import_map_query_impl(db: &dyn DefDatabase, krate: Crate) -> Self { let _p = tracing::info_span!("import_map_query").entered(); let map = Self::collect_import_map(db, krate); @@ -120,7 +127,7 @@ impl ImportMap { } let importables = importables.into_iter().map(|(item, _, idx)| (item, idx)).collect(); - Arc::new(ImportMap { item_to_info_map: map, fst: builder.into_map(), importables }) + ImportMap { item_to_info_map: map, fst: builder.into_map(), importables } } pub fn import_info_for(&self, item: ItemInNs) -> Option<&[ImportInfo]> { @@ -424,7 +431,7 @@ pub fn search_dependencies( let _p = tracing::info_span!("search_dependencies", ?query).entered(); let import_maps: Vec<_> = - krate.data(db).dependencies.iter().map(|dep| db.import_map(dep.crate_id)).collect(); + krate.data(db).dependencies.iter().map(|dep| ImportMap::of(db, dep.crate_id)).collect(); let mut op = fst::map::OpBuilder::new(); @@ -458,7 +465,7 @@ pub fn search_dependencies( fn search_maps( _db: &dyn DefDatabase, - import_maps: &[Arc<ImportMap>], + import_maps: &[&ImportMap], mut stream: fst::map::Union<'_>, query: &Query, ) -> FxHashSet<(ItemInNs, Complete)> { @@ -467,7 +474,7 @@ fn search_maps( for &IndexedValue { index: import_map_idx, value } in indexed_values { let end = (value & 0xFFFF_FFFF) as usize; let start = (value >> 32) as usize; - let ImportMap { item_to_info_map, importables, .. } = &*import_maps[import_map_idx]; + let ImportMap { item_to_info_map, importables, .. } = import_maps[import_map_idx]; let importables = &importables[start..end]; let iter = importables @@ -546,9 +553,9 @@ mod tests { .into_iter() .filter_map(|(dependency, _)| { let dependency_krate = dependency.krate(&db)?; - let dependency_imports = db.import_map(dependency_krate); + let dependency_imports = ImportMap::of(&db, dependency_krate); - let (path, mark) = match assoc_item_path(&db, &dependency_imports, dependency) { + let (path, mark) = match assoc_item_path(&db, dependency_imports, dependency) { Some(assoc_item_path) => (assoc_item_path, "a"), None => ( render_path(&db, &dependency_imports.import_info_for(dependency)?[0]), @@ -618,7 +625,7 @@ mod tests { let cdata = &krate.extra_data(&db); let name = cdata.display_name.as_ref()?; - let map = db.import_map(krate); + let map = ImportMap::of(&db, krate); Some(format!("{name}:\n{}\n", map.fmt_for_test(&db))) }) diff --git a/crates/hir-def/src/item_scope.rs b/crates/hir-def/src/item_scope.rs index 9e1efb9777..b11a8bcd90 100644 --- a/crates/hir-def/src/item_scope.rs +++ b/crates/hir-def/src/item_scope.rs @@ -483,6 +483,11 @@ impl ItemScope { self.declarations.push(def) } + pub(crate) fn remove_from_value_ns(&mut self, name: &Name, def: ModuleDefId) { + let entry = self.values.shift_remove(name); + assert!(entry.is_some_and(|entry| entry.def == def)) + } + pub(crate) fn get_legacy_macro(&self, name: &Name) -> Option<&[MacroId]> { self.legacy_macros.get(name).map(|it| &**it) } @@ -893,6 +898,24 @@ impl ItemScope { self.macros.get_mut(name).expect("tried to update visibility of non-existent macro"); res.vis = vis; } + + pub(crate) fn update_def_types(&mut self, name: &Name, def: ModuleDefId, vis: Visibility) { + let res = self.types.get_mut(name).expect("tried to update def of non-existent type"); + res.def = def; + res.vis = vis; + } + + pub(crate) fn update_def_values(&mut self, name: &Name, def: ModuleDefId, vis: Visibility) { + let res = self.values.get_mut(name).expect("tried to update def of non-existent value"); + res.def = def; + res.vis = vis; + } + + pub(crate) fn update_def_macros(&mut self, name: &Name, def: MacroId, vis: Visibility) { + let res = self.macros.get_mut(name).expect("tried to update def of non-existent macro"); + res.def = def; + res.vis = vis; + } } impl PerNs { diff --git a/crates/hir-def/src/item_tree.rs b/crates/hir-def/src/item_tree.rs index a1707f17be..e7ab2b390f 100644 --- a/crates/hir-def/src/item_tree.rs +++ b/crates/hir-def/src/item_tree.rs @@ -44,6 +44,7 @@ use std::{ }; use ast::{AstNode, StructKind}; +use base_db::Crate; use cfg::CfgOptions; use hir_expand::{ ExpandTo, HirFileId, @@ -121,21 +122,23 @@ fn lower_extra_crate_attrs<'a>( } #[salsa_macros::tracked(returns(deref))] -pub(crate) fn file_item_tree_query(db: &dyn DefDatabase, file_id: HirFileId) -> Arc<ItemTree> { +pub(crate) fn file_item_tree_query( + db: &dyn DefDatabase, + file_id: HirFileId, + krate: Crate, +) -> Arc<ItemTree> { let _p = tracing::info_span!("file_item_tree_query", ?file_id).entered(); static EMPTY: OnceLock<Arc<ItemTree>> = OnceLock::new(); - let ctx = lower::Ctx::new(db, file_id); + let ctx = lower::Ctx::new(db, file_id, krate); let syntax = db.parse_or_expand(file_id); let mut item_tree = match_ast! { match syntax { ast::SourceFile(file) => { - let krate = file_id.krate(db); let root_file_id = krate.root_file_id(db); let extra_top_attrs = (file_id == root_file_id).then(|| { parse_extra_crate_attrs(db, krate).map(|crate_attrs| { - let file_id = root_file_id.editioned_file_id(db); - lower_extra_crate_attrs(db, crate_attrs, file_id, &|| ctx.cfg_options()) + lower_extra_crate_attrs(db, crate_attrs, root_file_id.span_file_id(db), &|| ctx.cfg_options()) }) }).flatten(); let top_attrs = match extra_top_attrs { @@ -189,41 +192,22 @@ pub(crate) fn file_item_tree_query(db: &dyn DefDatabase, file_id: HirFileId) -> } } -#[salsa_macros::tracked(returns(deref))] -pub(crate) fn block_item_tree_query(db: &dyn DefDatabase, block: BlockId) -> Arc<ItemTree> { +#[salsa_macros::tracked(returns(ref))] +pub(crate) fn block_item_tree_query( + db: &dyn DefDatabase, + block: BlockId, + krate: Crate, +) -> ItemTree { let _p = tracing::info_span!("block_item_tree_query", ?block).entered(); - static EMPTY: OnceLock<Arc<ItemTree>> = OnceLock::new(); - let loc = block.lookup(db); let block = loc.ast_id.to_node(db); - let ctx = lower::Ctx::new(db, loc.ast_id.file_id); + let ctx = lower::Ctx::new(db, loc.ast_id.file_id, krate); let mut item_tree = ctx.lower_block(&block); - let ItemTree { top_level, top_attrs, attrs, vis, big_data, small_data } = &item_tree; - if small_data.is_empty() - && big_data.is_empty() - && top_level.is_empty() - && attrs.is_empty() - && top_attrs.is_empty() - && vis.arena.is_empty() - { - EMPTY - .get_or_init(|| { - Arc::new(ItemTree { - top_level: Box::new([]), - attrs: FxHashMap::default(), - small_data: FxHashMap::default(), - big_data: FxHashMap::default(), - top_attrs: AttrsOrCfg::empty(), - vis: ItemVisibilities { arena: ThinVec::new() }, - }) - }) - .clone() - } else { - item_tree.shrink_to_fit(); - Arc::new(item_tree) - } + item_tree.shrink_to_fit(); + item_tree } + /// The item tree of a source file. #[derive(Debug, Default, Eq, PartialEq)] pub struct ItemTree { @@ -356,10 +340,10 @@ impl TreeId { Self { file, block } } - pub(crate) fn item_tree<'db>(&self, db: &'db dyn DefDatabase) -> &'db ItemTree { + pub(crate) fn item_tree<'db>(&self, db: &'db dyn DefDatabase, krate: Crate) -> &'db ItemTree { match self.block { - Some(block) => block_item_tree_query(db, block), - None => file_item_tree_query(db, self.file), + Some(block) => block_item_tree_query(db, block, krate), + None => file_item_tree_query(db, self.file, krate), } } diff --git a/crates/hir-def/src/item_tree/lower.rs b/crates/hir-def/src/item_tree/lower.rs index 3f19e00154..31e409d86e 100644 --- a/crates/hir-def/src/item_tree/lower.rs +++ b/crates/hir-def/src/item_tree/lower.rs @@ -2,7 +2,7 @@ use std::cell::OnceCell; -use base_db::FxIndexSet; +use base_db::{Crate, FxIndexSet}; use cfg::CfgOptions; use hir_expand::{ HirFileId, @@ -16,7 +16,6 @@ use syntax::{ AstNode, ast::{self, HasModuleItem, HasName}, }; -use triomphe::Arc; use crate::{ db::DefDatabase, @@ -29,19 +28,20 @@ use crate::{ }, }; -pub(super) struct Ctx<'a> { - pub(super) db: &'a dyn DefDatabase, +pub(super) struct Ctx<'db> { + pub(super) db: &'db dyn DefDatabase, tree: ItemTree, - source_ast_id_map: Arc<AstIdMap>, + source_ast_id_map: &'db AstIdMap, span_map: OnceCell<SpanMap>, file: HirFileId, - cfg_options: OnceCell<&'a CfgOptions>, + cfg_options: OnceCell<&'db CfgOptions>, + krate: Crate, top_level: Vec<ModItemId>, visibilities: FxIndexSet<RawVisibility>, } -impl<'a> Ctx<'a> { - pub(super) fn new(db: &'a dyn DefDatabase, file: HirFileId) -> Self { +impl<'db> Ctx<'db> { + pub(super) fn new(db: &'db dyn DefDatabase, file: HirFileId, krate: Crate) -> Self { Self { db, tree: ItemTree::default(), @@ -51,12 +51,13 @@ impl<'a> Ctx<'a> { span_map: OnceCell::new(), visibilities: FxIndexSet::default(), top_level: Vec::new(), + krate, } } #[inline] - pub(super) fn cfg_options(&self) -> &'a CfgOptions { - self.cfg_options.get_or_init(|| self.file.krate(self.db).cfg_options(self.db)) + pub(super) fn cfg_options(&self) -> &'db CfgOptions { + self.cfg_options.get_or_init(|| self.krate.cfg_options(self.db)) } pub(super) fn span_map(&self) -> SpanMapRef<'_> { diff --git a/crates/hir-def/src/item_tree/tests.rs b/crates/hir-def/src/item_tree/tests.rs index 1926ed74e8..b71b25a1a5 100644 --- a/crates/hir-def/src/item_tree/tests.rs +++ b/crates/hir-def/src/item_tree/tests.rs @@ -6,7 +6,7 @@ use crate::{db::DefDatabase, test_db::TestDB}; fn check(#[rust_analyzer::rust_fixture] ra_fixture: &str, expect: Expect) { let (db, file_id) = TestDB::with_single_file(ra_fixture); - let item_tree = db.file_item_tree(file_id.into()); + let item_tree = db.file_item_tree(file_id.into(), db.test_crate()); let pretty = item_tree.pretty_print(&db, Edition::CURRENT); expect.assert_eq(&pretty); } diff --git a/crates/hir-def/src/lang_item.rs b/crates/hir-def/src/lang_item.rs index 51dd55301f..fef92c89b1 100644 --- a/crates/hir-def/src/lang_item.rs +++ b/crates/hir-def/src/lang_item.rs @@ -456,6 +456,7 @@ language_item_table! { LangItems => TryTraitFromOutput, sym::from_output, FunctionId; TryTraitBranch, sym::branch, FunctionId; TryTraitFromYeet, sym::from_yeet, FunctionId; + ResidualIntoTryType, sym::into_try_type, FunctionId; PointerLike, sym::pointer_like, TraitId; diff --git a/crates/hir-def/src/lib.rs b/crates/hir-def/src/lib.rs index 8d6c418d75..9a7fbc812f 100644 --- a/crates/hir-def/src/lib.rs +++ b/crates/hir-def/src/lib.rs @@ -49,7 +49,6 @@ pub mod visibility; use intern::{Interned, Symbol}; pub use rustc_abi as layout; use thin_vec::ThinVec; -use triomphe::Arc; pub use crate::signatures::LocalFieldId; @@ -86,14 +85,19 @@ use crate::{ builtin_type::BuiltinType, db::DefDatabase, expr_store::ExpressionStoreSourceMap, - hir::generics::{LocalLifetimeParamId, LocalTypeOrConstParamId}, + hir::{ + ExprId, + generics::{GenericParams, LocalLifetimeParamId, LocalTypeOrConstParamId}, + }, nameres::{ LocalDefMap, assoc::{ImplItems, TraitItems}, block_def_map, crate_def_map, crate_local_def_map, diagnostics::DefDiagnostics, }, - signatures::{EnumVariants, InactiveEnumVariantCode, VariantFields}, + signatures::{ + ConstSignature, EnumVariants, InactiveEnumVariantCode, StaticSignature, VariantFields, + }, }; type FxIndexMap<K, V> = indexmap::IndexMap<K, V, rustc_hash::FxBuildHasher>; @@ -255,14 +259,15 @@ impl_intern!(StructId, StructLoc, intern_struct, lookup_intern_struct); impl StructId { pub fn fields(self, db: &dyn DefDatabase) -> &VariantFields { - VariantFields::firewall(db, self.into()) + VariantFields::of(db, self.into()) } pub fn fields_with_source_map( self, db: &dyn DefDatabase, - ) -> (Arc<VariantFields>, Arc<ExpressionStoreSourceMap>) { - VariantFields::query(db, self.into()) + ) -> (&VariantFields, &ExpressionStoreSourceMap) { + let r = VariantFields::with_source_map(db, self.into()); + (&r.0, &r.1) } } @@ -271,14 +276,15 @@ impl_intern!(UnionId, UnionLoc, intern_union, lookup_intern_union); impl UnionId { pub fn fields(self, db: &dyn DefDatabase) -> &VariantFields { - VariantFields::firewall(db, self.into()) + VariantFields::of(db, self.into()) } pub fn fields_with_source_map( self, db: &dyn DefDatabase, - ) -> (Arc<VariantFields>, Arc<ExpressionStoreSourceMap>) { - VariantFields::query(db, self.into()) + ) -> (&VariantFields, &ExpressionStoreSourceMap) { + let r = VariantFields::with_source_map(db, self.into()); + (&r.0, &r.1) } } @@ -306,6 +312,19 @@ impl_intern!(ConstId, ConstLoc, intern_const, lookup_intern_const); pub type StaticLoc = AssocItemLoc<ast::Static>; impl_intern!(StaticId, StaticLoc, intern_static, lookup_intern_static); +/// An anonymous const expression that appears in a type position (e.g., array lengths, +/// const generic arguments like `{ N + 1 }`). Unlike named constants, these don't have +/// their own `Body` — their expressions live in the parent's signature `ExpressionStore`. +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct AnonConstLoc { + /// The owner store containing this expression. + pub owner: ExpressionStoreOwnerId, + /// The ExprId within the owner's ExpressionStore that is the root + /// of this anonymous const expression. + pub expr: ExprId, +} +impl_intern!(AnonConstId, AnonConstLoc, intern_anon_const, lookup_intern_anon_const); + pub type TraitLoc = ItemLoc<ast::Trait>; impl_intern!(TraitId, TraitLoc, intern_trait, lookup_intern_trait); @@ -377,14 +396,15 @@ impl_loc!(EnumVariantLoc, id: Variant, parent: EnumId); impl EnumVariantId { pub fn fields(self, db: &dyn DefDatabase) -> &VariantFields { - VariantFields::firewall(db, self.into()) + VariantFields::of(db, self.into()) } pub fn fields_with_source_map( self, db: &dyn DefDatabase, - ) -> (Arc<VariantFields>, Arc<ExpressionStoreSourceMap>) { - VariantFields::query(db, self.into()) + ) -> (&VariantFields, &ExpressionStoreSourceMap) { + let r = VariantFields::with_source_map(db, self.into()); + (&r.0, &r.1) } } @@ -553,15 +573,25 @@ pub struct TypeOrConstParamId { pub struct TypeParamId(TypeOrConstParamId); impl TypeParamId { + #[inline] pub fn parent(&self) -> GenericDefId { self.0.parent } + + #[inline] pub fn local_id(&self) -> LocalTypeOrConstParamId { self.0.local_id } -} -impl TypeParamId { + #[inline] + pub fn trait_self(trait_: TraitId) -> TypeParamId { + TypeParamId::from_unchecked(TypeOrConstParamId { + parent: trait_.into(), + local_id: GenericParams::SELF_PARAM_ID_IN_SELF, + }) + } + + #[inline] /// Caller should check if this toc id really belongs to a type pub fn from_unchecked(it: TypeOrConstParamId) -> Self { Self(it) @@ -696,46 +726,47 @@ impl From<DefWithBodyId> for ModuleDefId { pub enum GeneralConstId { ConstId(ConstId), StaticId(StaticId), + AnonConstId(AnonConstId), } -impl_from!(ConstId, StaticId for GeneralConstId); +impl_from!(ConstId, StaticId, AnonConstId for GeneralConstId); impl GeneralConstId { - pub fn generic_def(self, _db: &dyn DefDatabase) -> Option<GenericDefId> { + pub fn generic_def(self, db: &dyn DefDatabase) -> Option<GenericDefId> { match self { GeneralConstId::ConstId(it) => Some(it.into()), GeneralConstId::StaticId(it) => Some(it.into()), + GeneralConstId::AnonConstId(it) => Some(it.lookup(db).owner.generic_def(db)), } } pub fn name(self, db: &dyn DefDatabase) -> String { match self { GeneralConstId::StaticId(it) => { - db.static_signature(it).name.display(db, Edition::CURRENT).to_string() + StaticSignature::of(db, it).name.display(db, Edition::CURRENT).to_string() } GeneralConstId::ConstId(const_id) => { - db.const_signature(const_id).name.as_ref().map_or_else( + ConstSignature::of(db, const_id).name.as_ref().map_or_else( || "_".to_owned(), |name| name.display(db, Edition::CURRENT).to_string(), ) } + GeneralConstId::AnonConstId(_) => "{anon const}".to_owned(), } } } -/// The defs which have a body (have root expressions for type inference). +/// The defs which have a body. #[derive(Debug, PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash, salsa_macros::Supertype)] pub enum DefWithBodyId { + /// A function body. FunctionId(FunctionId), + /// A static item initializer. StaticId(StaticId), + /// A const item initializer ConstId(ConstId), + /// An enum variant discrimiant VariantId(EnumVariantId), - // /// All fields of a variant are inference roots - // VariantId(VariantId), - // /// The signature can contain inference roots in a bunch of places - // /// like const parameters or const arguments in paths - // This should likely be kept on its own with a separate query - // GenericDefId(GenericDefId), } impl_from!(FunctionId, ConstId, StaticId for DefWithBodyId); @@ -804,6 +835,62 @@ impl_from!( for GenericDefId ); +/// Owner of an expression store - either a body or a signature. +/// This is used for queries that operate on expression stores generically, +/// such as `expr_scopes`. +// NOTE: This type cannot be `salsa::Supertype` as its variants are overlapping. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord /* !salsa::Supertype */)] +pub enum ExpressionStoreOwnerId { + Signature(GenericDefId), + /// A body, something with a root expression. + /// + /// An enum variant's body is considered its discriminant initializer. + Body(DefWithBodyId), + VariantFields(VariantId), +} + +impl ExpressionStoreOwnerId { + // FIXME: Check callers of this, this method likely can be removed + pub fn as_def_with_body(self) -> Option<DefWithBodyId> { + if let Self::Body(v) = self { Some(v) } else { None } + } + + pub fn generic_def(self, db: &dyn DefDatabase) -> GenericDefId { + match self { + ExpressionStoreOwnerId::Signature(generic_def_id) => generic_def_id, + ExpressionStoreOwnerId::Body(def_with_body_id) => match def_with_body_id { + DefWithBodyId::FunctionId(id) => GenericDefId::FunctionId(id), + DefWithBodyId::StaticId(id) => GenericDefId::StaticId(id), + DefWithBodyId::ConstId(id) => GenericDefId::ConstId(id), + DefWithBodyId::VariantId(it) => it.lookup(db).parent.into(), + }, + ExpressionStoreOwnerId::VariantFields(variant_id) => match variant_id { + VariantId::EnumVariantId(it) => it.lookup(db).parent.into(), + VariantId::StructId(it) => it.into(), + VariantId::UnionId(it) => it.into(), + }, + } + } +} + +impl From<GenericDefId> for ExpressionStoreOwnerId { + fn from(id: GenericDefId) -> Self { + ExpressionStoreOwnerId::Signature(id) + } +} + +impl From<DefWithBodyId> for ExpressionStoreOwnerId { + fn from(id: DefWithBodyId) -> Self { + ExpressionStoreOwnerId::Body(id) + } +} + +impl From<VariantId> for ExpressionStoreOwnerId { + fn from(id: VariantId) -> Self { + ExpressionStoreOwnerId::VariantFields(id) + } +} + impl GenericDefId { pub fn file_id_and_params_of( self, @@ -944,7 +1031,9 @@ impl From<VariantId> for AttrDefId { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa_macros::Supertype, salsa::Update)] +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, salsa_macros::Supertype, salsa::Update, +)] pub enum VariantId { EnumVariantId(EnumVariantId), StructId(StructId), @@ -954,14 +1043,15 @@ impl_from!(EnumVariantId, StructId, UnionId for VariantId); impl VariantId { pub fn fields(self, db: &dyn DefDatabase) -> &VariantFields { - VariantFields::firewall(db, self) + VariantFields::of(db, self) } pub fn fields_with_source_map( self, db: &dyn DefDatabase, - ) -> (Arc<VariantFields>, Arc<ExpressionStoreSourceMap>) { - VariantFields::query(db, self) + ) -> (&VariantFields, &ExpressionStoreSourceMap) { + let r = VariantFields::with_source_map(db, self); + (&r.0, &r.1) } pub fn file_id(self, db: &dyn DefDatabase) -> HirFileId { @@ -1162,6 +1252,16 @@ impl HasModule for DefWithBodyId { } } +impl HasModule for ExpressionStoreOwnerId { + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + match self { + ExpressionStoreOwnerId::Signature(def) => def.module(db), + ExpressionStoreOwnerId::Body(def) => def.module(db), + ExpressionStoreOwnerId::VariantFields(variant_id) => variant_id.module(db), + } + } +} + impl HasModule for GenericDefId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { match self { diff --git a/crates/hir-def/src/macro_expansion_tests/builtin_fn_macro.rs b/crates/hir-def/src/macro_expansion_tests/builtin_fn_macro.rs index eeaf865338..46cdb39c5b 100644 --- a/crates/hir-def/src/macro_expansion_tests/builtin_fn_macro.rs +++ b/crates/hir-def/src/macro_expansion_tests/builtin_fn_macro.rs @@ -568,6 +568,12 @@ cfg_select! { _ => { fn true_2() {} } } +const _: ((),) = cfg_select! { _ => ((), ) }; +const _: i32 = cfg_select! { true => 2 + 3, _ => 3 + 4 }; +const _: i32 = cfg_select! { false => 2 + 3, _ => 3 + 4 }; +const _: bool = cfg_select! { _ => 2 < 3 }; +const _: bool = cfg_select! { true => foo::<(), fn() -> Foo<i32, i64>>(1,), _ => false }; + cfg_select! { false => { fn false_3() {} } } @@ -589,6 +595,12 @@ fn true_1() {} fn true_2() {} +const _: ((),) = ((), ); +const _: i32 = 2+3; +const _: i32 = 3+4; +const _: bool = 2<3; +const _: bool = foo::<(), fn() -> Foo<i32, i64>>(1, ); + /* error: none of the predicates in this `cfg_select` evaluated to true */ /* error: expected `=>` after cfg expression */ diff --git a/crates/hir-def/src/macro_expansion_tests/mbe.rs b/crates/hir-def/src/macro_expansion_tests/mbe.rs index 7b5d0103e6..d93df7af6a 100644 --- a/crates/hir-def/src/macro_expansion_tests/mbe.rs +++ b/crates/hir-def/src/macro_expansion_tests/mbe.rs @@ -35,9 +35,9 @@ macro_rules! f { }; } -struct#0:MacroRules[BE8F, 0]@58..64#17408# MyTraitMap2#0:MacroCall[BE8F, 0]@31..42#ROOT2024# {#0:MacroRules[BE8F, 0]@72..73#17408# - map#0:MacroRules[BE8F, 0]@86..89#17408#:#0:MacroRules[BE8F, 0]@89..90#17408# #0:MacroRules[BE8F, 0]@89..90#17408#::#0:MacroRules[BE8F, 0]@91..93#17408#std#0:MacroRules[BE8F, 0]@93..96#17408#::#0:MacroRules[BE8F, 0]@96..98#17408#collections#0:MacroRules[BE8F, 0]@98..109#17408#::#0:MacroRules[BE8F, 0]@109..111#17408#HashSet#0:MacroRules[BE8F, 0]@111..118#17408#<#0:MacroRules[BE8F, 0]@118..119#17408#(#0:MacroRules[BE8F, 0]@119..120#17408#)#0:MacroRules[BE8F, 0]@120..121#17408#>#0:MacroRules[BE8F, 0]@121..122#17408#,#0:MacroRules[BE8F, 0]@122..123#17408# -}#0:MacroRules[BE8F, 0]@132..133#17408# +struct#0:MacroRules[BE8F, 0]@58..64#18432# MyTraitMap2#0:MacroCall[BE8F, 0]@31..42#ROOT2024# {#0:MacroRules[BE8F, 0]@72..73#18432# + map#0:MacroRules[BE8F, 0]@86..89#18432#:#0:MacroRules[BE8F, 0]@89..90#18432# #0:MacroRules[BE8F, 0]@89..90#18432#::#0:MacroRules[BE8F, 0]@91..93#18432#std#0:MacroRules[BE8F, 0]@93..96#18432#::#0:MacroRules[BE8F, 0]@96..98#18432#collections#0:MacroRules[BE8F, 0]@98..109#18432#::#0:MacroRules[BE8F, 0]@109..111#18432#HashSet#0:MacroRules[BE8F, 0]@111..118#18432#<#0:MacroRules[BE8F, 0]@118..119#18432#(#0:MacroRules[BE8F, 0]@119..120#18432#)#0:MacroRules[BE8F, 0]@120..121#18432#>#0:MacroRules[BE8F, 0]@121..122#18432#,#0:MacroRules[BE8F, 0]@122..123#18432# +}#0:MacroRules[BE8F, 0]@132..133#18432# "#]], ); } @@ -197,7 +197,7 @@ macro_rules! mk_struct { #[macro_use] mod foo; -struct#1:MacroRules[DB0C, 0]@59..65#17408# Foo#0:MacroCall[DB0C, 0]@32..35#ROOT2024#(#1:MacroRules[DB0C, 0]@70..71#17408#u32#0:MacroCall[DB0C, 0]@41..44#ROOT2024#)#1:MacroRules[DB0C, 0]@74..75#17408#;#1:MacroRules[DB0C, 0]@75..76#17408# +struct#1:MacroRules[DB0C, 0]@59..65#18432# Foo#0:MacroCall[DB0C, 0]@32..35#ROOT2024#(#1:MacroRules[DB0C, 0]@70..71#18432#u32#0:MacroCall[DB0C, 0]@41..44#ROOT2024#)#1:MacroRules[DB0C, 0]@74..75#18432#;#1:MacroRules[DB0C, 0]@75..76#18432# "#]], ); } @@ -423,10 +423,10 @@ m! { foo, bar } macro_rules! m { ($($i:ident),*) => ( impl Bar { $(fn $i() {})* } ); } -impl#\17408# Bar#\17408# {#\17408# - fn#\17408# foo#\ROOT2024#(#\17408#)#\17408# {#\17408#}#\17408# - fn#\17408# bar#\ROOT2024#(#\17408#)#\17408# {#\17408#}#\17408# -}#\17408# +impl#\18432# Bar#\18432# {#\18432# + fn#\18432# foo#\ROOT2024#(#\18432#)#\18432# {#\18432#}#\18432# + fn#\18432# bar#\ROOT2024#(#\18432#)#\18432# {#\18432#}#\18432# +}#\18432# "#]], ); } diff --git a/crates/hir-def/src/macro_expansion_tests/mod.rs b/crates/hir-def/src/macro_expansion_tests/mod.rs index c63f2c1d78..8317c56caf 100644 --- a/crates/hir-def/src/macro_expansion_tests/mod.rs +++ b/crates/hir-def/src/macro_expansion_tests/mod.rs @@ -45,6 +45,7 @@ use tt::{TextRange, TextSize}; use crate::{ AdtId, Lookup, ModuleDefId, db::DefDatabase, + expr_store::Body, nameres::{DefMap, ModuleSource, crate_def_map}, src::HasSource, test_db::TestDB, @@ -276,7 +277,7 @@ fn resolve_macro_call_id( _ => continue, }; - let (body, sm) = db.body_with_source_map(body); + let (body, sm) = Body::with_source_map(db, body); if let Some(it) = body .blocks(db) .find_map(|block| resolve_macro_call_id(db, block.1, ast_id, ast_ptr)) @@ -458,7 +459,7 @@ m!(g); "#; let (db, file_id) = TestDB::with_single_file(fixture); - let krate = file_id.krate(&db); + let krate = db.test_crate(); let def_map = crate_def_map(&db, krate); let source = def_map[def_map.root].definition_source(&db); let source_file = match source.value { diff --git a/crates/hir-def/src/nameres.rs b/crates/hir-def/src/nameres.rs index 1e3ea50c5a..5fda1beab4 100644 --- a/crates/hir-def/src/nameres.rs +++ b/crates/hir-def/src/nameres.rs @@ -211,6 +211,7 @@ struct DefMapCrateData { /// Side table for resolving derive helpers. exported_derives: FxHashMap<MacroId, Box<[Name]>>, fn_proc_macro_mapping: FxHashMap<FunctionId, ProcMacroId>, + fn_proc_macro_mapping_back: FxHashMap<ProcMacroId, FunctionId>, /// Custom tool modules registered with `#![register_tool]`. registered_tools: Vec<Symbol>, @@ -230,6 +231,7 @@ impl DefMapCrateData { Self { exported_derives: FxHashMap::default(), fn_proc_macro_mapping: FxHashMap::default(), + fn_proc_macro_mapping_back: FxHashMap::default(), registered_tools: PREDEFINED_TOOLS.iter().map(|it| Symbol::intern(it)).collect(), unstable_features: FxHashSet::default(), rustc_coherence_is_core: false, @@ -244,6 +246,7 @@ impl DefMapCrateData { let Self { exported_derives, fn_proc_macro_mapping, + fn_proc_macro_mapping_back, registered_tools, unstable_features, rustc_coherence_is_core: _, @@ -254,6 +257,7 @@ impl DefMapCrateData { } = self; exported_derives.shrink_to_fit(); fn_proc_macro_mapping.shrink_to_fit(); + fn_proc_macro_mapping_back.shrink_to_fit(); registered_tools.shrink_to_fit(); unstable_features.shrink_to_fit(); } @@ -570,6 +574,10 @@ impl DefMap { self.data.fn_proc_macro_mapping.get(&id).copied() } + pub fn proc_macro_as_fn(&self, id: ProcMacroId) -> Option<FunctionId> { + self.data.fn_proc_macro_mapping_back.get(&id).copied() + } + pub fn krate(&self) -> Crate { self.krate } diff --git a/crates/hir-def/src/nameres/assoc.rs b/crates/hir-def/src/nameres/assoc.rs index 9d2b2109fb..f5a852b39c 100644 --- a/crates/hir-def/src/nameres/assoc.rs +++ b/crates/hir-def/src/nameres/assoc.rs @@ -17,7 +17,6 @@ use syntax::{ ast::{self, HasModuleItem, HasName}, }; use thin_vec::ThinVec; -use triomphe::Arc; use crate::{ AssocItemId, AstIdWithPath, ConstLoc, FunctionId, FunctionLoc, ImplId, ItemContainerId, @@ -133,14 +132,14 @@ impl ImplItems { } } -struct AssocItemCollector<'a> { - db: &'a dyn DefDatabase, +struct AssocItemCollector<'db> { + db: &'db dyn DefDatabase, module_id: ModuleId, - def_map: &'a DefMap, - local_def_map: &'a LocalDefMap, - ast_id_map: Arc<AstIdMap>, + def_map: &'db DefMap, + local_def_map: &'db LocalDefMap, + ast_id_map: &'db AstIdMap, span_map: SpanMap, - cfg_options: &'a CfgOptions, + cfg_options: &'db CfgOptions, file_id: HirFileId, diagnostics: Vec<DefDiagnostic>, container: ItemContainerId, @@ -150,9 +149,9 @@ struct AssocItemCollector<'a> { macro_calls: ThinVec<(AstId<ast::Item>, MacroCallId)>, } -impl<'a> AssocItemCollector<'a> { +impl<'db> AssocItemCollector<'db> { fn new( - db: &'a dyn DefDatabase, + db: &'db dyn DefDatabase, module_id: ModuleId, container: ItemContainerId, file_id: HirFileId, diff --git a/crates/hir-def/src/nameres/collector.rs b/crates/hir-def/src/nameres/collector.rs index 323060f61d..9c101c127b 100644 --- a/crates/hir-def/src/nameres/collector.rs +++ b/crates/hir-def/src/nameres/collector.rs @@ -279,7 +279,7 @@ impl<'db> DefCollector<'db> { let _p = tracing::info_span!("seed_with_top_level").entered(); let file_id = self.def_map.krate.root_file_id(self.db); - let item_tree = self.db.file_item_tree(file_id.into()); + let item_tree = self.db.file_item_tree(file_id.into(), self.def_map.krate); let attrs = match item_tree.top_level_attrs() { AttrsOrCfg::Enabled { attrs } => attrs.as_ref(), AttrsOrCfg::CfgDisabled(it) => it.1.as_ref(), @@ -387,7 +387,7 @@ impl<'db> DefCollector<'db> { } fn seed_with_inner(&mut self, tree_id: TreeId) { - let item_tree = tree_id.item_tree(self.db); + let item_tree = tree_id.item_tree(self.db, self.def_map.krate); let is_cfg_enabled = matches!(item_tree.top_level_attrs(), AttrsOrCfg::Enabled { .. }); if is_cfg_enabled { self.inject_prelude(); @@ -634,6 +634,7 @@ impl<'db> DefCollector<'db> { crate_data.exported_derives.insert(proc_macro_id.into(), helpers); } crate_data.fn_proc_macro_mapping.insert(fn_id, proc_macro_id); + crate_data.fn_proc_macro_mapping_back.insert(proc_macro_id, fn_id); } /// Define a macro with `macro_rules`. @@ -1209,42 +1210,69 @@ impl<'db> DefCollector<'db> { // `ItemScope::push_res_with_import()`. if let Some(def) = defs.types && let Some(prev_def) = prev_defs.types - && def.def == prev_def.def - && self.from_glob_import.contains_type(module_id, name.clone()) - && def.vis != prev_def.vis - && def.vis.max(self.db, prev_def.vis, &self.def_map) == Some(def.vis) { - changed = true; - // This import is being handled here, don't pass it down to - // `ItemScope::push_res_with_import()`. - defs.types = None; - self.def_map.modules[module_id].scope.update_visibility_types(name, def.vis); + if def.def == prev_def.def + && self.from_glob_import.contains_type(module_id, name.clone()) + && def.vis != prev_def.vis + && def.vis.max(self.db, prev_def.vis, &self.def_map) == Some(def.vis) + { + changed = true; + // This import is being handled here, don't pass it down to + // `ItemScope::push_res_with_import()`. + defs.types = None; + self.def_map.modules[module_id].scope.update_visibility_types(name, def.vis); + } + // When the source module's definition changed (e.g., due to an explicit import + // shadowing a glob), propagate the new definition to modules that glob-import from it. + // We check that the previous definition came from the same glob import to avoid + // incorrectly overwriting definitions from different glob sources. + // + // Note this is not a perfect fix, but it makes + // https://github.com/rust-lang/rust-analyzer/issues/19224 work for now until we + // implement a proper glob graph + else if def.def != prev_def.def && prev_def.import == def_import_type { + changed = true; + defs.types = None; + self.def_map.modules[module_id].scope.update_def_types(name, def.def, def.vis); + } } if let Some(def) = defs.values && let Some(prev_def) = prev_defs.values - && def.def == prev_def.def - && self.from_glob_import.contains_value(module_id, name.clone()) - && def.vis != prev_def.vis - && def.vis.max(self.db, prev_def.vis, &self.def_map) == Some(def.vis) { - changed = true; - // See comment above. - defs.values = None; - self.def_map.modules[module_id].scope.update_visibility_values(name, def.vis); + if def.def == prev_def.def + && self.from_glob_import.contains_value(module_id, name.clone()) + && def.vis != prev_def.vis + && def.vis.max(self.db, prev_def.vis, &self.def_map) == Some(def.vis) + { + changed = true; + defs.values = None; + self.def_map.modules[module_id].scope.update_visibility_values(name, def.vis); + } else if def.def != prev_def.def + && prev_def.import.map(ImportOrExternCrate::from) == def_import_type + { + changed = true; + defs.values = None; + self.def_map.modules[module_id].scope.update_def_values(name, def.def, def.vis); + } } if let Some(def) = defs.macros && let Some(prev_def) = prev_defs.macros - && def.def == prev_def.def - && self.from_glob_import.contains_macro(module_id, name.clone()) - && def.vis != prev_def.vis - && def.vis.max(self.db, prev_def.vis, &self.def_map) == Some(def.vis) { - changed = true; - // See comment above. - defs.macros = None; - self.def_map.modules[module_id].scope.update_visibility_macros(name, def.vis); + if def.def == prev_def.def + && self.from_glob_import.contains_macro(module_id, name.clone()) + && def.vis != prev_def.vis + && def.vis.max(self.db, prev_def.vis, &self.def_map) == Some(def.vis) + { + changed = true; + defs.macros = None; + self.def_map.modules[module_id].scope.update_visibility_macros(name, def.vis); + } else if def.def != prev_def.def && prev_def.import == def_import_type { + changed = true; + defs.macros = None; + self.def_map.modules[module_id].scope.update_def_macros(name, def.def, def.vis); + } } } @@ -1680,7 +1708,7 @@ impl<'db> DefCollector<'db> { } let file_id = macro_call_id.into(); - let item_tree = self.db.file_item_tree(file_id); + let item_tree = self.db.file_item_tree(file_id, self.def_map.krate); // Derive helpers that are in scope for an item are also in scope for attribute macro expansions // of that item (but not derive or fn like macros). @@ -2068,6 +2096,8 @@ impl ModCollector<'_, '_> { let vis = resolve_vis(def_map, local_def_map, &self.item_tree[it.visibility]); + update_def(self.def_collector, fn_id.into(), &it.name, vis, false); + if self.def_collector.def_map.block.is_none() && self.def_collector.is_proc_macro && self.module_id == self.def_collector.def_map.root @@ -2078,9 +2108,14 @@ impl ModCollector<'_, '_> { InFile::new(self.file_id(), id), fn_id, ); - } - update_def(self.def_collector, fn_id.into(), &it.name, vis, false); + // A proc macro is implemented as a function, but it's treated as a macro, not a function. + // You cannot call it like a function, for example, except in its defining crate. + // So we keep the function definition, but remove it from the scope, leaving only the macro. + self.def_collector.def_map[module_id] + .scope + .remove_from_value_ns(&it.name, fn_id.into()); + } } ModItemId::Struct(id) => { let it = &self.item_tree[id]; @@ -2300,10 +2335,10 @@ impl ModCollector<'_, '_> { self.file_id(), &module.name, path_attr.as_deref(), - self.def_collector.def_map.krate, ) { Ok((file_id, is_mod_rs, mod_dir)) => { - let item_tree = db.file_item_tree(file_id.into()); + let item_tree = + db.file_item_tree(file_id.into(), self.def_collector.def_map.krate); match item_tree.top_level_attrs() { AttrsOrCfg::CfgDisabled(cfg) => { self.emit_unconfigured_diagnostic( @@ -2793,8 +2828,8 @@ foo!(KABOOM); let fixture = r#" //- /lib.rs crate:foo crate-attr:recursion_limit="4" crate-attr:no_core crate-attr:no_std crate-attr:feature(register_tool) "#; - let (db, file_id) = TestDB::with_single_file(fixture); - let def_map = crate_def_map(&db, file_id.krate(&db)); + let (db, _) = TestDB::with_single_file(fixture); + let def_map = crate_def_map(&db, db.test_crate()); assert_eq!(def_map.recursion_limit(), 4); assert!(def_map.is_no_core()); assert!(def_map.is_no_std()); diff --git a/crates/hir-def/src/nameres/mod_resolution.rs b/crates/hir-def/src/nameres/mod_resolution.rs index 140b77ac00..0c50f13edf 100644 --- a/crates/hir-def/src/nameres/mod_resolution.rs +++ b/crates/hir-def/src/nameres/mod_resolution.rs @@ -1,6 +1,6 @@ //! This module resolves `mod foo;` declaration to file. use arrayvec::ArrayVec; -use base_db::{AnchoredPath, Crate}; +use base_db::AnchoredPath; use hir_expand::{EditionedFileId, name::Name}; use crate::{HirFileId, db::DefDatabase}; @@ -62,7 +62,6 @@ impl ModDir { file_id: HirFileId, name: &Name, attr_path: Option<&str>, - krate: Crate, ) -> Result<(EditionedFileId, bool, ModDir), Box<[String]>> { let name = name.as_str(); @@ -92,7 +91,7 @@ impl ModDir { if let Some(mod_dir) = self.child(dir_path, !root_dir_owner) { return Ok(( // FIXME: Edition, is this rightr? - EditionedFileId::new(db, file_id, orig_file_id.edition(db), krate), + EditionedFileId::new(db, file_id, orig_file_id.edition(db)), is_mod_rs, mod_dir, )); diff --git a/crates/hir-def/src/nameres/tests/incremental.rs b/crates/hir-def/src/nameres/tests/incremental.rs index 225ba95863..5b75c078ec 100644 --- a/crates/hir-def/src/nameres/tests/incremental.rs +++ b/crates/hir-def/src/nameres/tests/incremental.rs @@ -166,15 +166,15 @@ fn no() {} [ "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "EnumVariants::of_", @@ -183,7 +183,7 @@ fn no() {} expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "EnumVariants::of_", @@ -224,21 +224,21 @@ pub struct S {} [ "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "decl_macro_expander_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "macro_def_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "macro_arg_shim", ] @@ -246,12 +246,12 @@ pub struct S {} expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "macro_arg_shim", "parse_macro_expansion_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", ] "#]], @@ -282,26 +282,26 @@ fn f() { foo } [ "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "crate_local_def_map", "proc_macros_for_crate_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "macro_def_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "expand_proc_macro_shim", "macro_arg_shim", @@ -311,13 +311,13 @@ fn f() { foo } expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "macro_arg_shim", "expand_proc_macro_shim", "parse_macro_expansion_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", ] "#]], @@ -406,38 +406,38 @@ pub struct S {} [ "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "crate_local_def_map", "proc_macros_for_crate_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "decl_macro_expander_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "macro_def_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "macro_arg_shim", "decl_macro_expander_shim", "macro_def_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "macro_arg_shim", "macro_def_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "expand_proc_macro_shim", "macro_arg_shim", @@ -447,7 +447,7 @@ pub struct S {} expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "macro_arg_shim", @@ -523,29 +523,29 @@ m!(Z); [ "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "decl_macro_expander_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "macro_def_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "macro_arg_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "macro_arg_shim", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_macro_expansion_shim", "macro_arg_shim", ] @@ -572,7 +572,7 @@ m!(Z); expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "macro_arg_shim", @@ -604,13 +604,13 @@ pub type Ty = (); execute_assert_events( &db, || { - db.file_item_tree(pos.file_id.into()); + db.file_item_tree(pos.file_id.into(), db.test_crate()); }, &[("file_item_tree_query", 1), ("parse", 1)], expect![[r#" [ "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", ] @@ -624,13 +624,13 @@ pub type Ty = (); execute_assert_events( &db, || { - db.file_item_tree(pos.file_id.into()); + db.file_item_tree(pos.file_id.into(), db.test_crate()); }, &[("file_item_tree_query", 1), ("parse", 1)], expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", ] diff --git a/crates/hir-def/src/nameres/tests/macros.rs b/crates/hir-def/src/nameres/tests/macros.rs index a943f6f0ac..a013f8b2bc 100644 --- a/crates/hir-def/src/nameres/tests/macros.rs +++ b/crates/hir-def/src/nameres/tests/macros.rs @@ -1068,10 +1068,8 @@ pub fn derive_macro_2(_item: TokenStream) -> TokenStream { - AnotherTrait : macro# - DummyTrait : macro# - TokenStream : type value - - attribute_macro : value macro# - - derive_macro : value - - derive_macro_2 : value - - function_like_macro : value macro! + - attribute_macro : macro# + - function_like_macro : macro! "#]], ); } diff --git a/crates/hir-def/src/resolver.rs b/crates/hir-def/src/resolver.rs index 2ac0f90fb2..bb292ac1a6 100644 --- a/crates/hir-def/src/resolver.rs +++ b/crates/hir-def/src/resolver.rs @@ -13,14 +13,13 @@ use rustc_hash::FxHashSet; use smallvec::{SmallVec, smallvec}; use span::SyntaxContext; use syntax::ast::HasName; -use triomphe::Arc; use crate::{ - AdtId, AstIdLoc, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, ExternBlockId, - ExternCrateId, FunctionId, FxIndexMap, GenericDefId, GenericParamId, HasModule, ImplId, - ItemContainerId, LifetimeParamId, Lookup, Macro2Id, MacroId, MacroRulesId, ModuleDefId, - ModuleId, ProcMacroId, StaticId, StructId, TraitId, TypeAliasId, TypeOrConstParamId, - TypeParamId, UseId, VariantId, + AdtId, AstIdLoc, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, + ExpressionStoreOwnerId, ExternBlockId, ExternCrateId, FunctionId, FxIndexMap, GenericDefId, + GenericParamId, HasModule, ImplId, ItemContainerId, LifetimeParamId, Lookup, Macro2Id, MacroId, + MacroRulesId, ModuleDefId, ModuleId, ProcMacroId, StaticId, StructId, TraitId, TypeAliasId, + TypeOrConstParamId, TypeParamId, UseId, VariantId, builtin_type::BuiltinType, db::DefDatabase, expr_store::{ @@ -32,10 +31,11 @@ use crate::{ BindingId, ExprId, LabelId, generics::{GenericParams, TypeOrConstParamData}, }, - item_scope::{BUILTIN_SCOPE, BuiltinShadowMode, ImportOrExternCrate, ImportOrGlob, ItemScope}, + item_scope::{BUILTIN_SCOPE, BuiltinShadowMode, ImportOrExternCrate, ItemScope}, lang_item::LangItemTarget, nameres::{DefMap, LocalDefMap, MacroSubNs, ResolvePathResultPrefixInfo, block_def_map}, per_ns::PerNs, + signatures::ImplSignature, src::HasSource, type_ref::LifetimeRef, visibility::{RawVisibility, Visibility}, @@ -65,13 +65,13 @@ impl fmt::Debug for ModuleItemMap<'_> { } #[derive(Clone)] -struct ExprScope { - owner: DefWithBodyId, - expr_scopes: Arc<ExprScopes>, +struct ExprScope<'db> { + owner: ExpressionStoreOwnerId, + expr_scopes: &'db ExprScopes, scope_id: ScopeId, } -impl fmt::Debug for ExprScope { +impl fmt::Debug for ExprScope<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ExprScope") .field("owner", &self.owner) @@ -86,9 +86,9 @@ enum Scope<'db> { BlockScope(ModuleItemMap<'db>), /// Brings the generic parameters of an item into scope as well as the `Self` type alias / /// generic for ADTs and impls. - GenericParams { def: GenericDefId, params: Arc<GenericParams> }, + GenericParams { def: GenericDefId, params: &'db GenericParams }, /// Local bindings - ExprScope(ExprScope), + ExprScope(ExprScope<'db>), /// Macro definition inside bodies that affects all paths after it in the same block. MacroDefScope(MacroDefId), } @@ -111,8 +111,8 @@ pub enum TypeNs { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ResolveValueResult { - ValueNs(ValueNs, Option<ImportOrGlob>), - Partial(TypeNs, usize, Option<ImportOrExternCrate>), + ValueNs(ValueNs), + Partial(TypeNs, usize), } #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -332,20 +332,17 @@ impl<'db> Resolver<'db> { Path::Normal(it) => &it.mod_path, Path::LangItem(l, None) => { return Some(( - ResolveValueResult::ValueNs( - match *l { - LangItemTarget::FunctionId(it) => ValueNs::FunctionId(it), - LangItemTarget::StaticId(it) => ValueNs::StaticId(it), - LangItemTarget::StructId(it) => ValueNs::StructId(it), - LangItemTarget::EnumVariantId(it) => ValueNs::EnumVariantId(it), - LangItemTarget::UnionId(_) - | LangItemTarget::ImplId(_) - | LangItemTarget::TypeAliasId(_) - | LangItemTarget::TraitId(_) - | LangItemTarget::EnumId(_) => return None, - }, - None, - ), + ResolveValueResult::ValueNs(match *l { + LangItemTarget::FunctionId(it) => ValueNs::FunctionId(it), + LangItemTarget::StaticId(it) => ValueNs::StaticId(it), + LangItemTarget::StructId(it) => ValueNs::StructId(it), + LangItemTarget::EnumVariantId(it) => ValueNs::EnumVariantId(it), + LangItemTarget::UnionId(_) + | LangItemTarget::ImplId(_) + | LangItemTarget::TypeAliasId(_) + | LangItemTarget::TraitId(_) + | LangItemTarget::EnumId(_) => return None, + }), ResolvePathResultPrefixInfo::default(), )); } @@ -363,7 +360,7 @@ impl<'db> Resolver<'db> { }; // Remaining segments start from 0 because lang paths have no segments other than the remaining. return Some(( - ResolveValueResult::Partial(type_ns, 0, None), + ResolveValueResult::Partial(type_ns, 0), ResolvePathResultPrefixInfo::default(), )); } @@ -388,10 +385,7 @@ impl<'db> Resolver<'db> { if let Some(e) = entry { return Some(( - ResolveValueResult::ValueNs( - ValueNs::LocalBinding(e.binding()), - None, - ), + ResolveValueResult::ValueNs(ValueNs::LocalBinding(e.binding())), ResolvePathResultPrefixInfo::default(), )); } @@ -404,14 +398,14 @@ impl<'db> Resolver<'db> { && *first_name == sym::Self_ { return Some(( - ResolveValueResult::ValueNs(ValueNs::ImplSelf(impl_), None), + ResolveValueResult::ValueNs(ValueNs::ImplSelf(impl_)), ResolvePathResultPrefixInfo::default(), )); } if let Some(id) = params.find_const_by_name(first_name, *def) { let val = ValueNs::GenericParam(id); return Some(( - ResolveValueResult::ValueNs(val, None), + ResolveValueResult::ValueNs(val), ResolvePathResultPrefixInfo::default(), )); } @@ -431,7 +425,7 @@ impl<'db> Resolver<'db> { if let &GenericDefId::ImplId(impl_) = def { if *first_name == sym::Self_ { return Some(( - ResolveValueResult::Partial(TypeNs::SelfType(impl_), 1, None), + ResolveValueResult::Partial(TypeNs::SelfType(impl_), 1), ResolvePathResultPrefixInfo::default(), )); } @@ -440,14 +434,14 @@ impl<'db> Resolver<'db> { { let ty = TypeNs::AdtSelfType(adt); return Some(( - ResolveValueResult::Partial(ty, 1, None), + ResolveValueResult::Partial(ty, 1), ResolvePathResultPrefixInfo::default(), )); } if let Some(id) = params.find_type_by_name(first_name, *def) { let ty = TypeNs::GenericParam(id); return Some(( - ResolveValueResult::Partial(ty, 1, None), + ResolveValueResult::Partial(ty, 1), ResolvePathResultPrefixInfo::default(), )); } @@ -473,7 +467,7 @@ impl<'db> Resolver<'db> { && let Some(builtin) = BuiltinType::by_name(first_name) { return Some(( - ResolveValueResult::Partial(TypeNs::BuiltinType(builtin), 1, None), + ResolveValueResult::Partial(TypeNs::BuiltinType(builtin), 1), ResolvePathResultPrefixInfo::default(), )); } @@ -488,7 +482,7 @@ impl<'db> Resolver<'db> { hygiene: HygieneId, ) -> Option<ValueNs> { match self.resolve_path_in_value_ns(db, path, hygiene)? { - ResolveValueResult::ValueNs(it, _) => Some(it), + ResolveValueResult::ValueNs(it) => Some(it), ResolveValueResult::Partial(..) => None, } } @@ -659,7 +653,7 @@ impl<'db> Resolver<'db> { match scope { Scope::BlockScope(m) => traits.extend(m.def_map[m.module_id].scope.traits()), &Scope::GenericParams { def: GenericDefId::ImplId(impl_), .. } => { - let impl_data = db.impl_signature(impl_); + let impl_data = ImplSignature::of(db, impl_); if let Some(target_trait) = impl_data.target_trait && let Some(TypeNs::TraitId(trait_)) = self .resolve_path_in_type_ns_fully(db, &impl_data.store[target_trait.path]) @@ -730,19 +724,19 @@ impl<'db> Resolver<'db> { pub fn generic_params(&self) -> Option<&GenericParams> { self.scopes().find_map(|scope| match scope { - Scope::GenericParams { params, .. } => Some(&**params), + &Scope::GenericParams { params, .. } => Some(params), _ => None, }) } - pub fn all_generic_params(&self) -> impl Iterator<Item = (&GenericParams, &GenericDefId)> { + pub fn all_generic_params(&self) -> impl Iterator<Item = (&GenericParams, GenericDefId)> { self.scopes().filter_map(|scope| match scope { - Scope::GenericParams { params, def } => Some((&**params, def)), + &Scope::GenericParams { params, def } => Some((params, def)), _ => None, }) } - pub fn body_owner(&self) -> Option<DefWithBodyId> { + pub fn expression_store_owner(&self) -> Option<ExpressionStoreOwnerId> { self.scopes().find_map(|scope| match scope { Scope::ExprScope(it) => Some(it.owner), _ => None, @@ -860,25 +854,30 @@ impl<'db> Resolver<'db> { pub fn update_to_inner_scope( &mut self, db: &'db dyn DefDatabase, - owner: DefWithBodyId, + owner: impl Into<ExpressionStoreOwnerId>, + expr_id: ExprId, + ) -> UpdateGuard { + self.update_to_inner_scope_(db, owner.into(), expr_id) + } + + fn update_to_inner_scope_( + &mut self, + db: &'db dyn DefDatabase, + owner: ExpressionStoreOwnerId, expr_id: ExprId, ) -> UpdateGuard { #[inline(always)] fn append_expr_scope<'db>( db: &'db dyn DefDatabase, resolver: &mut Resolver<'db>, - owner: DefWithBodyId, - expr_scopes: &Arc<ExprScopes>, + owner: ExpressionStoreOwnerId, + expr_scopes: &'db ExprScopes, scope_id: ScopeId, ) { if let Some(macro_id) = expr_scopes.macro_def(scope_id) { resolver.scopes.push(Scope::MacroDefScope(**macro_id)); } - resolver.scopes.push(Scope::ExprScope(ExprScope { - owner, - expr_scopes: expr_scopes.clone(), - scope_id, - })); + resolver.scopes.push(Scope::ExprScope(ExprScope { owner, expr_scopes, scope_id })); if let Some(block) = expr_scopes.block(scope_id) { let def_map = block_def_map(db, block); let local_def_map = block.lookup(db).module.only_local_def_map(db); @@ -896,21 +895,20 @@ impl<'db> Resolver<'db> { let start = self.scopes.len(); let innermost_scope = self.scopes().find(|scope| !matches!(scope, Scope::MacroDefScope(_))); match innermost_scope { - Some(&Scope::ExprScope(ExprScope { scope_id, ref expr_scopes, owner })) => { - let expr_scopes = expr_scopes.clone(); + Some(&Scope::ExprScope(ExprScope { scope_id, expr_scopes, owner })) => { let scope_chain = expr_scopes .scope_chain(expr_scopes.scope_for(expr_id)) .take_while(|&it| it != scope_id); for scope_id in scope_chain { - append_expr_scope(db, self, owner, &expr_scopes, scope_id); + append_expr_scope(db, self, owner, expr_scopes, scope_id); } } _ => { - let expr_scopes = db.expr_scopes(owner); + let expr_scopes = ExprScopes::of(db, owner); let scope_chain = expr_scopes.scope_chain(expr_scopes.scope_for(expr_id)); for scope_id in scope_chain { - append_expr_scope(db, self, owner, &expr_scopes, scope_id); + append_expr_scope(db, self, owner, expr_scopes, scope_id); } } } @@ -1022,7 +1020,7 @@ impl<'db> Scope<'db> { }) }); } - &Scope::GenericParams { ref params, def: parent } => { + &Scope::GenericParams { params, def: parent } => { if let GenericDefId::ImplId(impl_) = parent { acc.add(&Name::new_symbol_root(sym::Self_), ScopeDef::ImplSelfType(impl_)); } else if let GenericDefId::AdtId(adt) = parent { @@ -1032,7 +1030,7 @@ impl<'db> Scope<'db> { for (local_id, param) in params.iter_type_or_consts() { if let Some(name) = ¶m.name() { let id = TypeOrConstParamId { parent, local_id }; - let data = &db.generic_params(parent)[local_id]; + let data = &GenericParams::of(db, parent)[local_id]; acc.add( name, ScopeDef::GenericParam(match data { @@ -1066,20 +1064,21 @@ impl<'db> Scope<'db> { pub fn resolver_for_scope( db: &dyn DefDatabase, - owner: DefWithBodyId, + owner: impl Into<ExpressionStoreOwnerId> + HasResolver, scope_id: Option<ScopeId>, ) -> Resolver<'_> { - let r = owner.resolver(db); - let scopes = db.expr_scopes(owner); - resolver_for_scope_(db, scopes, scope_id, r, owner) + let store_owner = owner.into(); + let r = store_owner.resolver(db); + let scopes = ExprScopes::of(db, store_owner); + resolver_for_scope_(db, scopes, scope_id, r, store_owner) } fn resolver_for_scope_<'db>( db: &'db dyn DefDatabase, - scopes: Arc<ExprScopes>, + scopes: &'db ExprScopes, scope_id: Option<ScopeId>, mut r: Resolver<'db>, - owner: DefWithBodyId, + owner: ExpressionStoreOwnerId, ) -> Resolver<'db> { let scope_chain = scopes.scope_chain(scope_id).collect::<Vec<_>>(); r.scopes.reserve(scope_chain.len()); @@ -1099,7 +1098,7 @@ fn resolver_for_scope_<'db>( r = r.push_scope(Scope::MacroDefScope(**macro_id)); } - r = r.push_expr_scope(owner, Arc::clone(&scopes), scope); + r = r.push_expr_scope(owner, scopes, scope); } r } @@ -1115,7 +1114,7 @@ impl<'db> Resolver<'db> { db: &'db dyn DefDatabase, def: GenericDefId, ) -> Resolver<'db> { - let params = db.generic_params(def); + let params = GenericParams::of(db, def); self.push_scope(Scope::GenericParams { def, params }) } @@ -1130,8 +1129,8 @@ impl<'db> Resolver<'db> { fn push_expr_scope( self, - owner: DefWithBodyId, - expr_scopes: Arc<ExprScopes>, + owner: ExpressionStoreOwnerId, + expr_scopes: &'db ExprScopes, scope_id: ScopeId, ) -> Resolver<'db> { self.push_scope(Scope::ExprScope(ExprScope { owner, expr_scopes, scope_id })) @@ -1153,12 +1152,12 @@ impl<'db> ModuleItemMap<'db> { ); match unresolved_idx { None => { - let (value, import) = to_value_ns(module_def)?; - Some((ResolveValueResult::ValueNs(value, import), prefix_info)) + let value = to_value_ns(module_def, self.def_map)?; + Some((ResolveValueResult::ValueNs(value), prefix_info)) } Some(unresolved_idx) => { - let def = module_def.take_types_full()?; - let ty = match def.def { + let def = module_def.take_types()?; + let ty = match def { ModuleDefId::AdtId(it) => TypeNs::AdtId(it), ModuleDefId::TraitId(it) => TypeNs::TraitId(it), ModuleDefId::TypeAliasId(it) => TypeNs::TypeAliasId(it), @@ -1171,7 +1170,7 @@ impl<'db> ModuleItemMap<'db> { | ModuleDefId::MacroId(_) | ModuleDefId::StaticId(_) => return None, }; - Some((ResolveValueResult::Partial(ty, unresolved_idx, def.import), prefix_info)) + Some((ResolveValueResult::Partial(ty, unresolved_idx), prefix_info)) } } } @@ -1194,8 +1193,13 @@ impl<'db> ModuleItemMap<'db> { } } -fn to_value_ns(per_ns: PerNs) -> Option<(ValueNs, Option<ImportOrGlob>)> { - let (def, import) = per_ns.take_values_import()?; +fn to_value_ns(per_ns: PerNs, def_map: &DefMap) -> Option<ValueNs> { + let def = per_ns.take_values().or_else(|| { + let Some(MacroId::ProcMacroId(proc_macro)) = per_ns.take_macros() else { return None }; + // If we cannot resolve to value ns, but we can resolve to a proc macro, and this is the crate + // defining this proc macro - inside this crate, we should treat the macro as a function. + def_map.proc_macro_as_fn(proc_macro).map(ModuleDefId::FunctionId) + })?; let res = match def { ModuleDefId::FunctionId(it) => ValueNs::FunctionId(it), ModuleDefId::AdtId(AdtId::StructId(it)) => ValueNs::StructId(it), @@ -1210,7 +1214,7 @@ fn to_value_ns(per_ns: PerNs) -> Option<(ValueNs, Option<ImportOrGlob>)> { | ModuleDefId::MacroId(_) | ModuleDefId::ModuleId(_) => return None, }; - Some((res, import)) + Some(res) } fn to_type_ns(per_ns: PerNs) -> Option<(TypeNs, Option<ImportOrExternCrate>)> { @@ -1410,6 +1414,16 @@ impl HasResolver for GenericDefId { } } +impl HasResolver for ExpressionStoreOwnerId { + fn resolver(self, db: &dyn DefDatabase) -> Resolver<'_> { + match self { + ExpressionStoreOwnerId::Signature(def) => def.resolver(db), + ExpressionStoreOwnerId::Body(def) => def.resolver(db), + ExpressionStoreOwnerId::VariantFields(variant_id) => variant_id.resolver(db), + } + } +} + impl HasResolver for EnumVariantId { fn resolver(self, db: &dyn DefDatabase) -> Resolver<'_> { self.lookup(db).parent.resolver(db) diff --git a/crates/hir-def/src/signatures.rs b/crates/hir-def/src/signatures.rs index 37c8f762fe..6d704274f4 100644 --- a/crates/hir-def/src/signatures.rs +++ b/crates/hir-def/src/signatures.rs @@ -24,7 +24,7 @@ use crate::{ attrs::AttrFlags, db::DefDatabase, expr_store::{ - ExpressionStore, ExpressionStoreSourceMap, + Body, ExpressionStore, ExpressionStoreBuilder, ExpressionStoreSourceMap, lower::{ ExprCollector, lower_function, lower_generic_params, lower_trait, lower_type_alias, }, @@ -32,7 +32,7 @@ use crate::{ hir::{ExprId, PatId, generics::GenericParams}, item_tree::{FieldsShape, RawVisibility, visibility_from_ast}, src::HasSource, - type_ref::{TraitRef, TypeBound, TypeRefId}, + type_ref::{ConstRef, TraitRef, TypeBound, TypeRefId}, }; #[inline] @@ -43,8 +43,8 @@ fn as_name_opt(name: Option<ast::Name>) -> Name { #[derive(Debug, PartialEq, Eq)] pub struct StructSignature { pub name: Name, - pub generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + pub generic_params: GenericParams, + pub store: ExpressionStore, pub flags: StructFlags, pub shape: FieldsShape, } @@ -71,8 +71,18 @@ bitflags! { } } +#[salsa::tracked] impl StructSignature { - pub fn query(db: &dyn DefDatabase, id: StructId) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: StructId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( + db: &dyn DefDatabase, + id: StructId, + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let InFile { file_id, value: source } = loc.source(db); let attrs = AttrFlags::query(db, id.into()); @@ -115,10 +125,12 @@ impl StructSignature { shape, name: as_name_opt(source.name()), }), - Arc::new(source_map), + source_map, ) } +} +impl StructSignature { #[inline] pub fn repr(&self, db: &dyn DefDatabase, id: StructId) -> Option<ReprOptions> { if self.flags.contains(StructFlags::HAS_REPR) { @@ -141,13 +153,23 @@ fn adt_shape(adt_kind: ast::StructKind) -> FieldsShape { #[derive(Debug, PartialEq, Eq)] pub struct UnionSignature { pub name: Name, - pub generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + pub generic_params: GenericParams, + pub store: ExpressionStore, pub flags: StructFlags, } +#[salsa::tracked] impl UnionSignature { - pub fn query(db: &dyn DefDatabase, id: UnionId) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: UnionId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( + db: &dyn DefDatabase, + id: UnionId, + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let attrs = AttrFlags::query(db, id.into()); let mut flags = StructFlags::empty(); @@ -177,7 +199,7 @@ impl UnionSignature { flags, name: as_name_opt(source.name()), }), - Arc::new(source_map), + source_map, ) } } @@ -195,13 +217,23 @@ bitflags! { #[derive(Debug, PartialEq, Eq)] pub struct EnumSignature { pub name: Name, - pub generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + pub generic_params: GenericParams, + pub store: ExpressionStore, pub flags: EnumFlags, } +#[salsa::tracked] impl EnumSignature { - pub fn query(db: &dyn DefDatabase, id: EnumId) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: EnumId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( + db: &dyn DefDatabase, + id: EnumId, + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let attrs = AttrFlags::query(db, id.into()); let mut flags = EnumFlags::empty(); @@ -229,10 +261,12 @@ impl EnumSignature { flags, name: as_name_opt(source.name()), }), - Arc::new(source_map), + source_map, ) } +} +impl EnumSignature { pub fn variant_body_type(db: &dyn DefDatabase, id: EnumId) -> IntegerType { match AttrFlags::repr(db, id.into()) { Some(ReprOptions { int: Some(builtin), .. }) => builtin, @@ -256,14 +290,24 @@ bitflags::bitflags! { #[derive(Debug, PartialEq, Eq)] pub struct ConstSignature { pub name: Option<Name>, - // generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + // generic_params: GenericParams, + pub store: ExpressionStore, pub type_ref: TypeRefId, pub flags: ConstFlags, } +#[salsa::tracked] impl ConstSignature { - pub fn query(db: &dyn DefDatabase, id: ConstId) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: ConstId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( + db: &dyn DefDatabase, + id: ConstId, + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let module = loc.container.module(db); @@ -282,15 +326,17 @@ impl ConstSignature { ( Arc::new(ConstSignature { - store: Arc::new(store), + store, type_ref, flags, name: source.value.name().map(|it| it.as_name()), }), - Arc::new(source_map), + source_map, ) } +} +impl ConstSignature { pub fn has_body(&self) -> bool { self.flags.contains(ConstFlags::HAS_BODY) } @@ -312,13 +358,24 @@ bitflags::bitflags! { pub struct StaticSignature { pub name: Name, - // generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + // generic_params: GenericParams, + pub store: ExpressionStore, pub type_ref: TypeRefId, pub flags: StaticFlags, } + +#[salsa::tracked] impl StaticSignature { - pub fn query(db: &dyn DefDatabase, id: StaticId) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: StaticId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( + db: &dyn DefDatabase, + id: StaticId, + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let module = loc.container.module(db); @@ -351,12 +408,12 @@ impl StaticSignature { ( Arc::new(StaticSignature { - store: Arc::new(store), + store, type_ref, flags, name: as_name_opt(source.value.name()), }), - Arc::new(source_map), + source_map, ) } } @@ -372,15 +429,25 @@ bitflags::bitflags! { #[derive(Debug, PartialEq, Eq)] pub struct ImplSignature { - pub generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + pub generic_params: GenericParams, + pub store: ExpressionStore, pub self_ty: TypeRefId, pub target_trait: Option<TraitRef>, pub flags: ImplFlags, } +#[salsa::tracked] impl ImplSignature { - pub fn query(db: &dyn DefDatabase, id: ImplId) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: ImplId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( + db: &dyn DefDatabase, + id: ImplId, + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let mut flags = ImplFlags::empty(); @@ -399,17 +466,13 @@ impl ImplSignature { crate::expr_store::lower::lower_impl(db, loc.container, src, id); ( - Arc::new(ImplSignature { - store: Arc::new(store), - generic_params, - self_ty, - target_trait, - flags, - }), - Arc::new(source_map), + Arc::new(ImplSignature { store, generic_params, self_ty, target_trait, flags }), + source_map, ) } +} +impl ImplSignature { #[inline] pub fn is_negative(&self) -> bool { self.flags.contains(ImplFlags::NEGATIVE) @@ -439,13 +502,23 @@ bitflags::bitflags! { #[derive(Debug, PartialEq, Eq)] pub struct TraitSignature { pub name: Name, - pub generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + pub generic_params: GenericParams, + pub store: ExpressionStore, pub flags: TraitFlags, } +#[salsa::tracked] impl TraitSignature { - pub fn query(db: &dyn DefDatabase, id: TraitId) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: TraitId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( + db: &dyn DefDatabase, + id: TraitId, + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let mut flags = TraitFlags::empty(); @@ -483,10 +556,7 @@ impl TraitSignature { let name = as_name_opt(source.value.name()); let (store, source_map, generic_params) = lower_trait(db, loc.container, source, id); - ( - Arc::new(TraitSignature { store: Arc::new(store), generic_params, flags, name }), - Arc::new(source_map), - ) + (Arc::new(TraitSignature { store, generic_params, flags, name }), source_map) } } @@ -516,19 +586,26 @@ bitflags! { #[derive(Debug, PartialEq, Eq)] pub struct FunctionSignature { pub name: Name, - pub generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + pub generic_params: GenericParams, + pub store: ExpressionStore, pub params: Box<[TypeRefId]>, pub ret_type: Option<TypeRefId>, pub abi: Option<Symbol>, pub flags: FnFlags, } +#[salsa::tracked] impl FunctionSignature { - pub fn query( + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: FunctionId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( db: &dyn DefDatabase, id: FunctionId, - ) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let module = loc.container.module(db); @@ -589,17 +666,19 @@ impl FunctionSignature { ( Arc::new(FunctionSignature { generic_params, - store: Arc::new(store), + store, params, ret_type, abi, flags, name, }), - Arc::new(source_map), + source_map, ) } +} +impl FunctionSignature { pub fn has_body(&self) -> bool { self.flags.contains(FnFlags::HAS_BODY) } @@ -656,7 +735,7 @@ impl FunctionSignature { } pub fn is_intrinsic(db: &dyn DefDatabase, id: FunctionId) -> bool { - let data = db.function_signature(id); + let data = FunctionSignature::of(db, id); data.flags.contains(FnFlags::RUSTC_INTRINSIC) // Keep this around for a bit until extern "rustc-intrinsic" abis are no longer used || match &data.abi { @@ -683,18 +762,25 @@ bitflags! { #[derive(Debug, PartialEq, Eq)] pub struct TypeAliasSignature { pub name: Name, - pub generic_params: Arc<GenericParams>, - pub store: Arc<ExpressionStore>, + pub generic_params: GenericParams, + pub store: ExpressionStore, pub bounds: Box<[TypeBound]>, pub ty: Option<TypeRefId>, pub flags: TypeAliasFlags, } +#[salsa::tracked] impl TypeAliasSignature { - pub fn query( + #[salsa::tracked(returns(deref))] + pub fn of(db: &dyn DefDatabase, id: TypeAliasId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() + } + + #[salsa::tracked(returns(ref))] + pub fn with_source_map( db: &dyn DefDatabase, id: TypeAliasId, - ) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let loc = id.lookup(db); let mut flags = TypeAliasFlags::empty(); @@ -714,28 +800,21 @@ impl TypeAliasSignature { lower_type_alias(db, loc.container.module(db), source, id); ( - Arc::new(TypeAliasSignature { - store: Arc::new(store), - generic_params, - flags, - bounds, - name, - ty, - }), - Arc::new(source_map), + Arc::new(TypeAliasSignature { store, generic_params, flags, bounds, name, ty }), + source_map, ) } } #[derive(Debug, PartialEq, Eq)] pub struct FunctionBody { - pub store: Arc<ExpressionStore>, + pub store: ExpressionStore, pub parameters: Box<[PatId]>, } #[derive(Debug, PartialEq, Eq)] pub struct SimpleBody { - pub store: Arc<ExpressionStore>, + pub store: ExpressionStore, } pub type StaticBody = SimpleBody; pub type ConstBody = SimpleBody; @@ -743,7 +822,7 @@ pub type EnumVariantBody = SimpleBody; #[derive(Debug, PartialEq, Eq)] pub struct VariantFieldsBody { - pub store: Arc<ExpressionStore>, + pub store: ExpressionStore, pub fields: Box<[Option<ExprId>]>, } @@ -754,7 +833,7 @@ pub struct FieldData { pub type_ref: TypeRefId, pub visibility: RawVisibility, pub is_unsafe: bool, - pub default_value: Option<ExprId>, + pub default_value: Option<ConstRef>, } pub type LocalFieldId = Idx<FieldData>; @@ -762,17 +841,17 @@ pub type LocalFieldId = Idx<FieldData>; #[derive(Debug, Clone, PartialEq, Eq)] pub struct VariantFields { fields: Arena<FieldData>, - pub store: Arc<ExpressionStore>, + pub store: ExpressionStore, pub shape: FieldsShape, } #[salsa::tracked] impl VariantFields { - #[salsa::tracked(returns(clone))] - pub(crate) fn query( + #[salsa::tracked(returns(ref))] + pub fn with_source_map( db: &dyn DefDatabase, id: VariantId, - ) -> (Arc<Self>, Arc<ExpressionStoreSourceMap>) { + ) -> (Arc<Self>, ExpressionStoreSourceMap) { let (shape, result) = match id { VariantId::EnumVariantId(id) => { let loc = id.lookup(db); @@ -809,20 +888,26 @@ impl VariantFields { } }; match result { - Some((fields, store, source_map)) => ( - Arc::new(VariantFields { fields, store: Arc::new(store), shape }), - Arc::new(source_map), - ), + Some((fields, store, source_map)) => { + (Arc::new(VariantFields { fields, store, shape }), source_map) + } None => { - let (store, source_map) = ExpressionStore::empty_singleton(); - (Arc::new(VariantFields { fields: Arena::default(), store, shape }), source_map) + let source_map = ExpressionStoreSourceMap::default(); + ( + Arc::new(VariantFields { + fields: Arena::default(), + store: ExpressionStoreBuilder::default().finish().0, + shape, + }), + source_map, + ) } } } #[salsa::tracked(returns(deref))] - pub(crate) fn firewall(db: &dyn DefDatabase, id: VariantId) -> Arc<Self> { - Self::query(db, id).0 + pub fn of(db: &dyn DefDatabase, id: VariantId) -> Arc<Self> { + Self::with_source_map(db, id).0.clone() } } @@ -873,7 +958,7 @@ fn lower_fields<Field: ast::HasAttrs + ast::HasVisibility>( override_visibility: Option<Option<ast::Visibility>>, ) -> Option<(Arena<FieldData>, ExpressionStore, ExpressionStoreSourceMap)> { let cfg_options = module.krate(db).cfg_options(db); - let mut col = ExprCollector::new(db, module, fields.file_id); + let mut col = ExprCollector::signature(db, module, fields.file_id); let override_visibility = override_visibility.map(|vis| { LazyCell::new(|| { let span_map = db.span_map(fields.file_id); @@ -907,9 +992,9 @@ fn lower_fields<Field: ast::HasAttrs + ast::HasVisibility>( // Check if field has default value (only for record fields) let default_value = ast::RecordField::cast(field.syntax().clone()) - .and_then(|rf| rf.eq_token().is_some().then_some(rf.expr())) + .and_then(|rf| rf.eq_token().is_some().then_some(rf.default_val())) .flatten() - .map(|expr| col.collect_expr_opt(Some(expr))); + .map(|expr| col.lower_const_arg(expr)); arena.alloc(FieldData { name, type_ref, visibility, is_unsafe, default_value }); idx += 1; @@ -1014,9 +1099,9 @@ impl EnumVariants { } // The outer if condition is whether this variant has const ctor or not if !matches!(variant.shape, FieldsShape::Unit) { - let body = db.body(v.into()); + let body = Body::of(db, v.into()); // A variant with explicit discriminant - if !matches!(body[body.body_expr], crate::hir::Expr::Missing) { + if !matches!(body[body.root_expr()], crate::hir::Expr::Missing) { return false; } } diff --git a/crates/hir-def/src/src.rs b/crates/hir-def/src/src.rs index 6fe016f1e6..e33fd95908 100644 --- a/crates/hir-def/src/src.rs +++ b/crates/hir-def/src/src.rs @@ -7,7 +7,7 @@ use syntax::{AstNode, AstPtr, ast}; use crate::{ AstIdLoc, GenericDefId, LocalFieldId, LocalLifetimeParamId, LocalTypeOrConstParamId, Lookup, - UseId, VariantId, attrs::AttrFlags, db::DefDatabase, + UseId, VariantId, attrs::AttrFlags, db::DefDatabase, hir::generics::GenericParams, }; pub trait HasSource { @@ -76,7 +76,7 @@ impl HasChildSource<LocalTypeOrConstParamId> for GenericDefId { &self, db: &dyn DefDatabase, ) -> InFile<ArenaMap<LocalTypeOrConstParamId, Self::Value>> { - let generic_params = db.generic_params(*self); + let generic_params = GenericParams::of(db, *self); let mut idx_iter = generic_params.iter_type_or_consts().map(|(idx, _)| idx); let (file_id, generic_params_list) = self.file_id_and_params_of(db); @@ -110,7 +110,7 @@ impl HasChildSource<LocalLifetimeParamId> for GenericDefId { &self, db: &dyn DefDatabase, ) -> InFile<ArenaMap<LocalLifetimeParamId, Self::Value>> { - let generic_params = db.generic_params(*self); + let generic_params = GenericParams::of(db, *self); let idx_iter = generic_params.iter_lt().map(|(idx, _)| idx); let (file_id, generic_params_list) = self.file_id_and_params_of(db); diff --git a/crates/hir-def/src/test_db.rs b/crates/hir-def/src/test_db.rs index e8377fde49..0d260279f9 100644 --- a/crates/hir-def/src/test_db.rs +++ b/crates/hir-def/src/test_db.rs @@ -15,6 +15,7 @@ use triomphe::Arc; use crate::{ Lookup, ModuleDefId, ModuleId, db::DefDatabase, + expr_store::{Body, scope::ExprScopes}, nameres::{DefMap, ModuleSource, block_def_map, crate_def_map}, src::HasSource, }; @@ -284,8 +285,8 @@ impl TestDB { // Find the innermost block expression that has a `DefMap`. let (def_with_body, file_id) = fn_def?; let def_with_body = def_with_body.into(); - let source_map = self.body_with_source_map(def_with_body).1; - let scopes = self.expr_scopes(def_with_body); + let source_map = &Body::with_source_map(self, def_with_body).1; + let scopes = ExprScopes::body_expr_scopes(self, def_with_body); let root_syntax_node = self.parse(file_id).syntax_node(); let scope_iter = diff --git a/crates/hir-def/src/visibility.rs b/crates/hir-def/src/visibility.rs index a1645de6ec..81a61ec20f 100644 --- a/crates/hir-def/src/visibility.rs +++ b/crates/hir-def/src/visibility.rs @@ -6,11 +6,11 @@ use base_db::Crate; use hir_expand::{InFile, Lookup}; use la_arena::ArenaMap; use syntax::ast::{self, HasVisibility}; -use triomphe::Arc; use crate::{ AssocItemId, HasModule, ItemContainerId, LocalFieldId, ModuleId, TraitId, VariantId, - db::DefDatabase, nameres::DefMap, resolver::HasResolver, src::HasSource, + db::DefDatabase, nameres::DefMap, resolver::HasResolver, signatures::VariantFields, + src::HasSource, }; pub use crate::item_tree::{RawVisibility, VisibilityExplicitness}; @@ -146,7 +146,7 @@ impl Visibility { /// Returns the most permissive visibility of `self` and `other`. /// - /// If there is no subset relation between `self` and `other`, returns `None` (ie. they're only + /// If there is no subset relation between `self` and `other`, returns `None` (i.e. they're only /// visible in unrelated modules). pub(crate) fn max( self, @@ -212,7 +212,7 @@ impl Visibility { /// Returns the least permissive visibility of `self` and `other`. /// - /// If there is no subset relation between `self` and `other`, returns `None` (ie. they're only + /// If there is no subset relation between `self` and `other`, returns `None` (i.e. they're only /// visible in unrelated modules). pub(crate) fn min( self, @@ -234,7 +234,7 @@ impl Visibility { if mod_.krate(db) == krate { Some(Visibility::Module(mod_, exp)) } else { None } } (Visibility::Module(mod_a, expl_a), Visibility::Module(mod_b, expl_b)) => { - if mod_a.krate(db) != mod_b.krate(db) { + if mod_a == mod_b { // Most module visibilities are `pub(self)`, and assuming no errors // this will be the common and thus fast path. return Some(Visibility::Module( @@ -277,23 +277,26 @@ impl Visibility { } } -/// Resolve visibility of all specific fields of a struct or union variant. -pub(crate) fn field_visibilities_query( - db: &dyn DefDatabase, - variant_id: VariantId, -) -> Arc<ArenaMap<LocalFieldId, Visibility>> { - let variant_fields = variant_id.fields(db); - let fields = variant_fields.fields(); - if fields.is_empty() { - return Arc::default(); - } - let resolver = variant_id.module(db).resolver(db); - let mut res = ArenaMap::default(); - for (field_id, field_data) in fields.iter() { - res.insert(field_id, Visibility::resolve(db, &resolver, &field_data.visibility)); +#[salsa::tracked] +impl VariantFields { + /// Resolve visibility of all specific fields of a struct or union variant. + #[salsa::tracked(returns(ref))] + pub fn field_visibilities( + db: &dyn DefDatabase, + variant_id: VariantId, + ) -> ArenaMap<LocalFieldId, Visibility> { + let variant_fields = variant_id.fields(db); + let fields = variant_fields.fields(); + if fields.is_empty() { + return ArenaMap::default(); + } + let resolver = variant_id.module(db).resolver(db); + let mut res = ArenaMap::with_capacity(fields.len()); + for (field_id, field_data) in fields.iter() { + res.insert(field_id, Visibility::resolve(db, &resolver, &field_data.visibility)); + } + res } - res.shrink_to_fit(); - Arc::new(res) } pub fn visibility_from_ast( diff --git a/crates/hir-expand/src/builtin/fn_macro.rs b/crates/hir-expand/src/builtin/fn_macro.rs index 6e4b96b050..b3572a1cef 100644 --- a/crates/hir-expand/src/builtin/fn_macro.rs +++ b/crates/hir-expand/src/builtin/fn_macro.rs @@ -5,10 +5,7 @@ use std::borrow::Cow; use base_db::AnchoredPath; use cfg::CfgExpr; use either::Either; -use intern::{ - Symbol, - sym::{self}, -}; +use intern::{Symbol, sym}; use itertools::Itertools; use mbe::{DelimiterKind, expect_fragment}; use span::{Edition, FileId, Span}; @@ -384,16 +381,40 @@ fn cfg_select_expand( ); } } - let expand_to_if_active = match iter.next() { - Some(tt::TtElement::Subtree(_, tt)) => tt.remaining(), - _ => { + let expand_to_if_active = match iter.peek() { + Some(tt::TtElement::Subtree(sub, tt)) if sub.delimiter.kind == DelimiterKind::Brace => { + iter.next(); + tt.remaining() + } + None | Some(TtElement::Leaf(tt::Leaf::Punct(tt::Punct { char: ',', .. }))) => { let err_span = iter.peek().map(|it| it.first_span()).unwrap_or(span); + iter.next(); return ExpandResult::new( tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), ExpandError::other(err_span, "expected a token tree after `=>`"), ); } + Some(_) => { + let expr = expect_fragment( + db, + &mut iter, + parser::PrefixEntryPoint::Expr, + tt.top_subtree().delimiter.delim_span(), + ); + if let Some(err) = expr.err { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + err.into(), + ); + } + expr.value + } }; + if let Some(TtElement::Leaf(tt::Leaf::Punct(p))) = iter.peek() + && p.char == ',' + { + iter.next(); + } if expand_to.is_none() && active { expand_to = Some(expand_to_if_active); @@ -753,7 +774,7 @@ fn relative_file( if res == call_site && !allow_recursion { Err(ExpandError::other(err_span, format!("recursive inclusion of `{path_str}`"))) } else { - Ok(EditionedFileId::new(db, res, lookup.krate.data(db).edition, lookup.krate)) + Ok(EditionedFileId::new(db, res, lookup.krate.data(db).edition)) } } @@ -830,15 +851,18 @@ fn include_bytes_expand( span: Span, ) -> ExpandResult<tt::TopSubtree> { // FIXME: actually read the file here if the user asked for macro expansion - let res = tt::TopSubtree::invisible_from_leaves( + let underscore = sym::underscore; + let zero = tt::Literal { + text_and_suffix: sym::_0_u8, span, - [tt::Leaf::Literal(tt::Literal { - text_and_suffix: Symbol::empty(), - span, - kind: tt::LitKind::ByteStrRaw(1), - suffix_len: 0, - })], - ); + kind: tt::LitKind::Integer, + suffix_len: 3, + }; + // We don't use a real length since we can't know the file length, so we use an underscore + // to infer it. + let res = quote! {span => + &[#zero; #underscore] + }; ExpandResult::ok(res) } diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs index 51767f87ff..020731cf9a 100644 --- a/crates/hir-expand/src/db.rs +++ b/crates/hir-expand/src/db.rs @@ -58,8 +58,8 @@ pub trait ExpandDatabase: RootQueryDb { fn proc_macros_for_crate(&self, krate: Crate) -> Option<Arc<CrateProcMacros>>; #[salsa::invoke(ast_id_map)] - #[salsa::lru(1024)] - fn ast_id_map(&self, file_id: HirFileId) -> Arc<AstIdMap>; + #[salsa::transparent] + fn ast_id_map(&self, file_id: HirFileId) -> &AstIdMap; #[salsa::transparent] fn resolve_span(&self, span: Span) -> FileRange; @@ -162,7 +162,7 @@ fn syntax_context(db: &dyn ExpandDatabase, file: HirFileId, edition: Edition) -> } fn resolve_span(db: &dyn ExpandDatabase, Span { range, anchor, ctx: _ }: Span) -> FileRange { - let file_id = EditionedFileId::from_span_guess_origin(db, anchor.file_id); + let file_id = EditionedFileId::from_span_file_id(db, anchor.file_id); let anchor_offset = db.ast_id_map(file_id.into()).get_erased(anchor.ast_id).text_range().start(); FileRange { file_id, range: range + anchor_offset } @@ -334,8 +334,9 @@ pub fn expand_speculative( Some((node.syntax_node(), token)) } -fn ast_id_map(db: &dyn ExpandDatabase, file_id: HirFileId) -> triomphe::Arc<AstIdMap> { - triomphe::Arc::new(AstIdMap::from_source(&db.parse_or_expand(file_id))) +#[salsa::tracked(lru = 1024, returns(ref))] +fn ast_id_map(db: &dyn ExpandDatabase, file_id: HirFileId) -> AstIdMap { + AstIdMap::from_source(&db.parse_or_expand(file_id)) } /// Main public API -- parses a hir file, not caring whether it's a real diff --git a/crates/hir-expand/src/fixup.rs b/crates/hir-expand/src/fixup.rs index 92ddd7fa8b..424655ed65 100644 --- a/crates/hir-expand/src/fixup.rs +++ b/crates/hir-expand/src/fixup.rs @@ -208,7 +208,6 @@ pub(crate) fn fixup_syntax( ]); } }, - // FIXME: foo:: ast::MatchExpr(it) => { if it.expr().is_none() { let match_token = match it.match_token() { diff --git a/crates/hir-expand/src/inert_attr_macro.rs b/crates/hir-expand/src/inert_attr_macro.rs index 6dec2c5b32..53b624d9a6 100644 --- a/crates/hir-expand/src/inert_attr_macro.rs +++ b/crates/hir-expand/src/inert_attr_macro.rs @@ -124,8 +124,6 @@ pub const INERT_ATTRIBUTES: &[BuiltinAttribute] = &[ should_panic, Normal, template!(Word, List: r#"expected = "reason""#, NameValueStr: "reason"), FutureWarnFollowing, ), - // FIXME(Centril): This can be used on stable but shouldn't. - ungated!(reexport_test_harness_main, CrateLevel, template!(NameValueStr: "name"), ErrorFollowing), // Macros: ungated!(automatically_derived, Normal, template!(Word), WarnFollowing), @@ -264,6 +262,13 @@ pub const INERT_ATTRIBUTES: &[BuiltinAttribute] = &[ test_runner, CrateLevel, template!(List: "path"), ErrorFollowing, custom_test_frameworks, "custom test frameworks are an unstable feature", ), + + gated!( + reexport_test_harness_main, CrateLevel, template!(NameValueStr: "name"), + ErrorFollowing, custom_test_frameworks, + "custom test frameworks are an unstable feature", + ), + // RFC #1268 gated!( marker, Normal, template!(Word), WarnFollowing, @only_local: true, diff --git a/crates/hir-expand/src/lib.rs b/crates/hir-expand/src/lib.rs index 05541e782e..4b2c75ed38 100644 --- a/crates/hir-expand/src/lib.rs +++ b/crates/hir-expand/src/lib.rs @@ -386,7 +386,7 @@ impl MacroCallKind { impl HirFileId { pub fn edition(self, db: &dyn ExpandDatabase) -> Edition { match self { - HirFileId::FileId(file_id) => file_id.editioned_file_id(db).edition(), + HirFileId::FileId(file_id) => file_id.edition(db), HirFileId::MacroFile(m) => db.lookup_intern_macro_call(m).def.edition, } } @@ -1118,14 +1118,6 @@ impl HirFileId { HirFileId::MacroFile(_) => None, } } - - #[inline] - pub fn krate(self, db: &dyn ExpandDatabase) -> Crate { - match self { - HirFileId::FileId(it) => it.krate(db), - HirFileId::MacroFile(it) => it.loc(db).krate, - } - } } impl PartialEq<EditionedFileId> for HirFileId { diff --git a/crates/hir-expand/src/span_map.rs b/crates/hir-expand/src/span_map.rs index 586b815294..71d0b880ca 100644 --- a/crates/hir-expand/src/span_map.rs +++ b/crates/hir-expand/src/span_map.rs @@ -135,7 +135,7 @@ pub(crate) fn real_span_map( }); Arc::new(RealSpanMap::from_file( - editioned_file_id.editioned_file_id(db), + editioned_file_id.span_file_id(db), pairs.into_boxed_slice(), tree.syntax().text_range().end(), )) diff --git a/crates/hir-ty/src/builtin_derive.rs b/crates/hir-ty/src/builtin_derive.rs index f3e67d01e5..92629b7a05 100644 --- a/crates/hir-ty/src/builtin_derive.rs +++ b/crates/hir-ty/src/builtin_derive.rs @@ -35,6 +35,24 @@ fn coerce_pointee_new_type_param(trait_id: TraitId) -> TypeParamId { }) } +fn trait_args(trait_: BuiltinDeriveImplTrait, self_ty: Ty<'_>) -> GenericArgs<'_> { + match trait_ { + BuiltinDeriveImplTrait::Copy + | BuiltinDeriveImplTrait::Clone + | BuiltinDeriveImplTrait::Default + | BuiltinDeriveImplTrait::Debug + | BuiltinDeriveImplTrait::Hash + | BuiltinDeriveImplTrait::Eq + | BuiltinDeriveImplTrait::Ord => GenericArgs::new_from_slice(&[self_ty.into()]), + BuiltinDeriveImplTrait::PartialOrd | BuiltinDeriveImplTrait::PartialEq => { + GenericArgs::new_from_slice(&[self_ty.into(), self_ty.into()]) + } + BuiltinDeriveImplTrait::CoerceUnsized | BuiltinDeriveImplTrait::DispatchFromDyn => { + panic!("`CoerceUnsized` and `DispatchFromDyn` have special generics") + } + } +} + pub(crate) fn generics_of<'db>(interner: DbInterner<'db>, id: BuiltinDeriveImplId) -> Generics { let db = interner.db; let loc = id.loc(db); @@ -62,7 +80,7 @@ pub(crate) fn generics_of<'db>(interner: DbInterner<'db>, id: BuiltinDeriveImplI pub fn generic_params_count(db: &dyn HirDatabase, id: BuiltinDeriveImplId) -> usize { let loc = id.loc(db); - let adt_params = GenericParams::new(db, loc.adt.into()); + let adt_params = GenericParams::of(db, loc.adt.into()); let extra_params_count = match loc.trait_ { BuiltinDeriveImplTrait::Copy | BuiltinDeriveImplTrait::Clone @@ -95,29 +113,27 @@ pub fn impl_trait<'db>( | BuiltinDeriveImplTrait::Debug | BuiltinDeriveImplTrait::Hash | BuiltinDeriveImplTrait::Ord - | BuiltinDeriveImplTrait::Eq => { + | BuiltinDeriveImplTrait::Eq + | BuiltinDeriveImplTrait::PartialOrd + | BuiltinDeriveImplTrait::PartialEq => { let self_ty = Ty::new_adt( interner, loc.adt, GenericArgs::identity_for_item(interner, loc.adt.into()), ); - EarlyBinder::bind(TraitRef::new(interner, trait_id.into(), [self_ty])) - } - BuiltinDeriveImplTrait::PartialOrd | BuiltinDeriveImplTrait::PartialEq => { - let self_ty = Ty::new_adt( + EarlyBinder::bind(TraitRef::new_from_args( interner, - loc.adt, - GenericArgs::identity_for_item(interner, loc.adt.into()), - ); - EarlyBinder::bind(TraitRef::new(interner, trait_id.into(), [self_ty, self_ty])) + trait_id.into(), + trait_args(loc.trait_, self_ty), + )) } BuiltinDeriveImplTrait::CoerceUnsized | BuiltinDeriveImplTrait::DispatchFromDyn => { - let generic_params = GenericParams::new(db, loc.adt.into()); + let generic_params = GenericParams::of(db, loc.adt.into()); let interner = DbInterner::new_no_crate(db); let args = GenericArgs::identity_for_item(interner, loc.adt.into()); let self_ty = Ty::new_adt(interner, loc.adt, args); let Some((pointee_param_idx, _, new_param_ty)) = - coerce_pointee_params(interner, loc, &generic_params, trait_id) + coerce_pointee_params(interner, loc, generic_params, trait_id) else { // Malformed derive. return EarlyBinder::bind(TraitRef::new( @@ -133,10 +149,10 @@ pub fn impl_trait<'db>( } } -#[salsa::tracked(returns(ref), unsafe(non_update_types))] +#[salsa::tracked(returns(ref))] pub fn predicates<'db>(db: &'db dyn HirDatabase, impl_: BuiltinDeriveImplId) -> GenericPredicates { let loc = impl_.loc(db); - let generic_params = GenericParams::new(db, loc.adt.into()); + let generic_params = GenericParams::of(db, loc.adt.into()); let interner = DbInterner::new_with(db, loc.module(db).krate(db)); let adt_predicates = GenericPredicates::query(db, loc.adt.into()); let trait_id = loc @@ -152,7 +168,7 @@ pub fn predicates<'db>(db: &'db dyn HirDatabase, impl_: BuiltinDeriveImplId) -> | BuiltinDeriveImplTrait::PartialOrd | BuiltinDeriveImplTrait::Eq | BuiltinDeriveImplTrait::PartialEq => { - simple_trait_predicates(interner, loc, &generic_params, adt_predicates, trait_id) + simple_trait_predicates(interner, loc, generic_params, adt_predicates, trait_id) } BuiltinDeriveImplTrait::Default => { if matches!(loc.adt, AdtId::EnumId(_)) { @@ -162,12 +178,12 @@ pub fn predicates<'db>(db: &'db dyn HirDatabase, impl_: BuiltinDeriveImplId) -> .store(), )) } else { - simple_trait_predicates(interner, loc, &generic_params, adt_predicates, trait_id) + simple_trait_predicates(interner, loc, generic_params, adt_predicates, trait_id) } } BuiltinDeriveImplTrait::CoerceUnsized | BuiltinDeriveImplTrait::DispatchFromDyn => { let Some((pointee_param_idx, pointee_param_id, new_param_ty)) = - coerce_pointee_params(interner, loc, &generic_params, trait_id) + coerce_pointee_params(interner, loc, generic_params, trait_id) else { // Malformed derive. return GenericPredicates::from_explicit_own_predicates(StoredEarlyBinder::bind( @@ -260,7 +276,8 @@ fn simple_trait_predicates<'db>( let param_idx = param_idx.into_raw().into_u32() + (generic_params.len_lifetimes() as u32); let param_ty = Ty::new_param(interner, param_id, param_idx); - let trait_ref = TraitRef::new(interner, trait_id.into(), [param_ty]); + let trait_args = trait_args(loc.trait_, param_ty); + let trait_ref = TraitRef::new_from_args(interner, trait_id.into(), trait_args); trait_ref.upcast(interner) }); let mut assoc_type_bounds = Vec::new(); @@ -270,12 +287,14 @@ fn simple_trait_predicates<'db>( &mut assoc_type_bounds, interner.db.field_types(id.into()), trait_id, + loc.trait_, ), AdtId::UnionId(id) => extend_assoc_type_bounds( interner, &mut assoc_type_bounds, interner.db.field_types(id.into()), trait_id, + loc.trait_, ), AdtId::EnumId(id) => { for &(variant_id, _, _) in &id.enum_variants(interner.db).variants { @@ -284,6 +303,7 @@ fn simple_trait_predicates<'db>( &mut assoc_type_bounds, interner.db.field_types(variant_id.into()), trait_id, + loc.trait_, ) } } @@ -305,12 +325,14 @@ fn extend_assoc_type_bounds<'db>( interner: DbInterner<'db>, assoc_type_bounds: &mut Vec<Clause<'db>>, fields: &ArenaMap<LocalFieldId, StoredEarlyBinder<StoredTy>>, - trait_: TraitId, + trait_id: TraitId, + trait_: BuiltinDeriveImplTrait, ) { struct ProjectionFinder<'a, 'db> { interner: DbInterner<'db>, assoc_type_bounds: &'a mut Vec<Clause<'db>>, - trait_: TraitId, + trait_id: TraitId, + trait_: BuiltinDeriveImplTrait, } impl<'db> TypeVisitor<DbInterner<'db>> for ProjectionFinder<'_, 'db> { @@ -319,7 +341,12 @@ fn extend_assoc_type_bounds<'db>( fn visit_ty(&mut self, t: Ty<'db>) -> Self::Result { if let TyKind::Alias(AliasTyKind::Projection, _) = t.kind() { self.assoc_type_bounds.push( - TraitRef::new(self.interner, self.trait_.into(), [t]).upcast(self.interner), + TraitRef::new_from_args( + self.interner, + self.trait_id.into(), + trait_args(self.trait_, t), + ) + .upcast(self.interner), ); } @@ -327,7 +354,7 @@ fn extend_assoc_type_bounds<'db>( } } - let mut visitor = ProjectionFinder { interner, assoc_type_bounds, trait_ }; + let mut visitor = ProjectionFinder { interner, assoc_type_bounds, trait_id, trait_ }; for (_, field) in fields.iter() { field.get().instantiate_identity().visit_with(&mut visitor); } @@ -488,10 +515,12 @@ struct MultiGenericParams<'a, T, #[pointee] U: ?Sized, const N: usize>(*const U) #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] struct Simple; -trait Trait {} +trait Trait { + type Assoc; +} #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -struct WithGenerics<'a, T: Trait, const N: usize>(&'a [T; N]); +struct WithGenerics<'a, T: Trait, const N: usize>(&'a [T; N], T::Assoc); "#, expect![[r#" @@ -514,41 +543,49 @@ struct WithGenerics<'a, T: Trait, const N: usize>(&'a [T; N]); Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Debug, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Debug, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Clone, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Clone, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Copy, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Copy, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(#1: PartialEq, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(#1: PartialEq<[#1]>, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): PartialEq<[Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. })]>, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Eq, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Eq, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(#1: PartialOrd, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(#1: PartialOrd<[#1]>, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): PartialOrd<[Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. })]>, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Ord, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Ord, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Hash, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Hash, polarity:Positive), bound_vars: [] }) "#]], ); diff --git a/crates/hir-ty/src/consteval.rs b/crates/hir-ty/src/consteval.rs index 5bc2446fdd..928396c63a 100644 --- a/crates/hir-ty/src/consteval.rs +++ b/crates/hir-ty/src/consteval.rs @@ -5,10 +5,11 @@ mod tests; use base_db::Crate; use hir_def::{ - ConstId, EnumVariantId, GeneralConstId, HasModule, StaticId, + ConstId, EnumVariantId, ExpressionStoreOwnerId, GeneralConstId, GenericDefId, HasModule, + StaticId, attrs::AttrFlags, builtin_type::{BuiltinInt, BuiltinType, BuiltinUint}, - expr_store::Body, + expr_store::{Body, ExpressionStore}, hir::{Expr, ExprId, Literal}, }; use hir_expand::Lookup; @@ -28,7 +29,7 @@ use crate::{ traits::StoredParamEnvAndCrate, }; -use super::mir::{interpret_mir, lower_to_mir, pad16}; +use super::mir::{interpret_mir, lower_body_to_mir, pad16}; pub fn unknown_const<'db>(_ty: Ty<'db>) -> Const<'db> { Const::new(DbInterner::conjure(), rustc_type_ir::ConstKind::Error(ErrorGuaranteed)) @@ -235,6 +236,7 @@ pub fn try_const_usize<'db>(db: &'db dyn HirDatabase, c: Const<'db>) -> Option<u let ec = db.const_eval_static(id).ok()?; try_const_usize(db, ec) } + GeneralConstId::AnonConstId(_) => None, }, ConstKind::Value(val) => Some(u128::from_le_bytes(pad16(&val.value.inner().memory, false))), ConstKind::Error(_) => None, @@ -258,6 +260,7 @@ pub fn try_const_isize<'db>(db: &'db dyn HirDatabase, c: &Const<'db>) -> Option< let ec = db.const_eval_static(id).ok()?; try_const_isize(db, &ec) } + GeneralConstId::AnonConstId(_) => None, }, ConstKind::Value(val) => Some(i128::from_le_bytes(pad16(&val.value.inner().memory, true))), ConstKind::Error(_) => None, @@ -271,9 +274,9 @@ pub(crate) fn const_eval_discriminant_variant( ) -> Result<i128, ConstEvalError> { let interner = DbInterner::new_no_crate(db); let def = variant_id.into(); - let body = db.body(def); + let body = Body::of(db, def); let loc = variant_id.lookup(db); - if matches!(body[body.body_expr], Expr::Missing) { + if matches!(body[body.root_expr()], Expr::Missing) { let prev_idx = loc.index.checked_sub(1); let value = match prev_idx { Some(prev_idx) => { @@ -292,7 +295,7 @@ pub(crate) fn const_eval_discriminant_variant( let mir_body = db.monomorphized_mir_body( def, GenericArgs::empty(interner).store(), - ParamEnvAndCrate { param_env: db.trait_environment_for_body(def), krate: def.krate(db) } + ParamEnvAndCrate { param_env: db.trait_environment(def.into()), krate: def.krate(db) } .store(), )?; let c = interpret_mir(db, mir_body, false, None)?.0?; @@ -309,23 +312,23 @@ pub(crate) fn const_eval_discriminant_variant( // and make this function private. See the fixme comment on `InferenceContext::resolve_all`. pub(crate) fn eval_to_const<'db>(expr: ExprId, ctx: &mut InferenceContext<'_, 'db>) -> Const<'db> { let infer = ctx.fixme_resolve_all_clone(); - fn has_closure(body: &Body, expr: ExprId) -> bool { - if matches!(body[expr], Expr::Closure { .. }) { + fn has_closure(store: &ExpressionStore, expr: ExprId) -> bool { + if matches!(store[expr], Expr::Closure { .. }) { return true; } let mut r = false; - body.walk_child_exprs(expr, |idx| r |= has_closure(body, idx)); + store.walk_child_exprs(expr, |idx| r |= has_closure(store, idx)); r } - if has_closure(ctx.body, expr) { + if has_closure(ctx.store, expr) { // Type checking clousres need an isolated body (See the above FIXME). Bail out early to prevent panic. return Const::error(ctx.interner()); } - if let Expr::Path(p) = &ctx.body[expr] { + if let Expr::Path(p) = &ctx.store[expr] { let mut ctx = TyLoweringContext::new( ctx.db, &ctx.resolver, - ctx.body, + ctx.store, ctx.generic_def, LifetimeElisionKind::Infer, ); @@ -333,7 +336,9 @@ pub(crate) fn eval_to_const<'db>(expr: ExprId, ctx: &mut InferenceContext<'_, 'd return c; } } - if let Ok(mir_body) = lower_to_mir(ctx.db, ctx.owner, ctx.body, &infer, expr) + if let Some(body_owner) = ctx.owner.as_def_with_body() + && let Ok(mir_body) = + lower_body_to_mir(ctx.db, body_owner, Body::of(ctx.db, body_owner), &infer, expr) && let Ok((Ok(result), _)) = interpret_mir(ctx.db, Arc::new(mir_body), true, None) { return result; @@ -370,8 +375,12 @@ pub(crate) fn const_eval<'db>( let body = db.monomorphized_mir_body( def.into(), subst, - ParamEnvAndCrate { param_env: db.trait_environment(def.into()), krate: def.krate(db) } - .store(), + ParamEnvAndCrate { + param_env: db + .trait_environment(ExpressionStoreOwnerId::from(GenericDefId::from(def))), + krate: def.krate(db), + } + .store(), )?; let c = interpret_mir(db, body, false, trait_env.as_ref().map(|env| env.as_ref()))?.0?; Ok(c.store()) @@ -407,7 +416,8 @@ pub(crate) fn const_eval_static<'db>( def.into(), GenericArgs::empty(interner).store(), ParamEnvAndCrate { - param_env: db.trait_environment_for_body(def.into()), + param_env: db + .trait_environment(ExpressionStoreOwnerId::from(GenericDefId::from(def))), krate: def.krate(db), } .store(), diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index 5f6bcb4a60..31cf86476f 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -1,5 +1,5 @@ use base_db::RootQueryDb; -use hir_def::db::DefDatabase; +use hir_def::signatures::ConstSignature; use hir_expand::EditionedFileId; use rustc_apfloat::{ Float, @@ -131,7 +131,11 @@ fn eval_goal(db: &TestDB, file_id: EditionedFileId) -> Result<Const<'_>, ConstEv .declarations() .find_map(|x| match x { hir_def::ModuleDefId::ConstId(x) => { - if db.const_signature(x).name.as_ref()?.display(db, file_id.edition(db)).to_string() + if ConstSignature::of(db, x) + .name + .as_ref()? + .display(db, file_id.edition(db)) + .to_string() == "GOAL" { Some(x) diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs index 70474fc469..a0fb75397a 100644 --- a/crates/hir-ty/src/db.rs +++ b/crates/hir-ty/src/db.rs @@ -5,9 +5,9 @@ use base_db::{Crate, target::TargetLoadError}; use either::Either; use hir_def::{ AdtId, BuiltinDeriveImplId, CallableDefId, ConstId, ConstParamId, DefWithBodyId, EnumVariantId, - FunctionId, GenericDefId, ImplId, LifetimeParamId, LocalFieldId, StaticId, TraitId, - TypeAliasId, VariantId, builtin_derive::BuiltinDeriveImplMethod, db::DefDatabase, hir::ExprId, - layout::TargetDataLayout, + ExpressionStoreOwnerId, FunctionId, GenericDefId, ImplId, LifetimeParamId, LocalFieldId, + StaticId, TraitId, TypeAliasId, VariantId, builtin_derive::BuiltinDeriveImplMethod, + db::DefDatabase, hir::ExprId, layout::TargetDataLayout, }; use la_arena::ArenaMap; use salsa::plumbing::AsId; @@ -178,13 +178,9 @@ pub trait HirDatabase: DefDatabase + std::fmt::Debug { def: CallableDefId, ) -> EarlyBinder<'db, PolyFnSig<'db>>; - #[salsa::invoke(crate::lower::trait_environment_for_body_query)] - #[salsa::transparent] - fn trait_environment_for_body<'db>(&'db self, def: DefWithBodyId) -> ParamEnv<'db>; - #[salsa::invoke(crate::lower::trait_environment)] #[salsa::transparent] - fn trait_environment<'db>(&'db self, def: GenericDefId) -> ParamEnv<'db>; + fn trait_environment<'db>(&'db self, def: ExpressionStoreOwnerId) -> ParamEnv<'db>; #[salsa::invoke(crate::lower::generic_defaults_with_diagnostics_query)] #[salsa::cycle(cycle_result = crate::lower::generic_defaults_with_diagnostics_cycle_result)] @@ -240,7 +236,7 @@ pub struct InternedOpaqueTyId { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct InternedClosure(pub DefWithBodyId, pub ExprId); +pub struct InternedClosure(pub ExpressionStoreOwnerId, pub ExprId); #[salsa_macros::interned(no_lifetime, debug, revisions = usize::MAX)] #[derive(PartialOrd, Ord)] @@ -249,7 +245,7 @@ pub struct InternedClosureId { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct InternedCoroutine(pub DefWithBodyId, pub ExprId); +pub struct InternedCoroutine(pub ExpressionStoreOwnerId, pub ExprId); #[salsa_macros::interned(no_lifetime, debug, revisions = usize::MAX)] #[derive(PartialOrd, Ord)] diff --git a/crates/hir-ty/src/diagnostics/decl_check.rs b/crates/hir-ty/src/diagnostics/decl_check.rs index 29da1b0c51..89d8c0e91d 100644 --- a/crates/hir-ty/src/diagnostics/decl_check.rs +++ b/crates/hir-ty/src/diagnostics/decl_check.rs @@ -17,8 +17,17 @@ use std::fmt; use hir_def::{ AdtId, ConstId, EnumId, EnumVariantId, FunctionId, HasModule, ItemContainerId, Lookup, - ModuleDefId, ModuleId, StaticId, StructId, TraitId, TypeAliasId, attrs::AttrFlags, - db::DefDatabase, hir::Pat, item_tree::FieldsShape, signatures::StaticFlags, src::HasSource, + ModuleDefId, ModuleId, StaticId, StructId, TraitId, TypeAliasId, UnionId, + attrs::AttrFlags, + db::DefDatabase, + expr_store::Body, + hir::Pat, + item_tree::FieldsShape, + signatures::{ + ConstSignature, EnumSignature, FunctionSignature, StaticFlags, StaticSignature, + StructSignature, TraitSignature, TypeAliasSignature, UnionSignature, + }, + src::HasSource, }; use hir_expand::{ HirFileId, @@ -77,6 +86,7 @@ pub enum IdentType { Structure, Trait, TypeAlias, + Union, Variable, Variant, } @@ -94,6 +104,7 @@ impl fmt::Display for IdentType { IdentType::Structure => "Structure", IdentType::Trait => "Trait", IdentType::TypeAlias => "Type alias", + IdentType::Union => "Union", IdentType::Variable => "Variable", IdentType::Variant => "Variant", }; @@ -146,9 +157,7 @@ impl<'a> DeclValidator<'a> { match adt { AdtId::StructId(struct_id) => self.validate_struct(struct_id), AdtId::EnumId(enum_id) => self.validate_enum(enum_id), - AdtId::UnionId(_) => { - // FIXME: Unions aren't yet supported by this validator. - } + AdtId::UnionId(union_id) => self.validate_union(union_id), } } @@ -178,7 +187,7 @@ impl<'a> DeclValidator<'a> { fn validate_trait(&mut self, trait_id: TraitId) { // Check the trait name. - let data = self.db.trait_signature(trait_id); + let data = TraitSignature::of(self.db, trait_id); self.create_incorrect_case_diagnostic_for_item_name( trait_id, &data.name, @@ -197,7 +206,7 @@ impl<'a> DeclValidator<'a> { // Check the function name. // Skipped if function is an associated item of a trait implementation. if !self.is_trait_impl_container(container) { - let data = self.db.function_signature(func); + let data = FunctionSignature::of(self.db, func); // Don't run the lint on extern "[not Rust]" fn items with the // #[no_mangle] attribute. @@ -223,7 +232,7 @@ impl<'a> DeclValidator<'a> { /// Check incorrect names for patterns inside the function body. /// This includes function parameters except for trait implementation associated functions. fn validate_func_body(&mut self, func: FunctionId) { - let body = self.db.body(func.into()); + let body = Body::of(self.db, func.into()); let edition = self.edition(func); let mut pats_replacements = body .pats() @@ -250,7 +259,7 @@ impl<'a> DeclValidator<'a> { return; } - let source_map = self.db.body_with_source_map(func.into()).1; + let source_map = &Body::with_source_map(self.db, func.into()).1; for (id, replacement) in pats_replacements { let Ok(source_ptr) = source_map.pat_syntax(id) else { continue; @@ -292,7 +301,7 @@ impl<'a> DeclValidator<'a> { fn validate_struct(&mut self, struct_id: StructId) { // Check the structure name. - let data = self.db.struct_signature(struct_id); + let data = StructSignature::of(self.db, struct_id); // rustc implementation excuses repr(C) since C structs predominantly don't // use camel case. @@ -383,9 +392,97 @@ impl<'a> DeclValidator<'a> { } } + fn validate_union(&mut self, union_id: UnionId) { + // Check the union name. + let data = UnionSignature::of(self.db, union_id); + + // rustc implementation excuses repr(C) since C unions predominantly don't + // use camel case. + let has_repr_c = AttrFlags::repr(self.db, union_id.into()).is_some_and(|repr| repr.c()); + if !has_repr_c { + self.create_incorrect_case_diagnostic_for_item_name( + union_id, + &data.name, + CaseType::UpperCamelCase, + IdentType::Union, + ); + } + + // Check the field names. + self.validate_union_fields(union_id); + } + + /// Check incorrect names for union fields. + fn validate_union_fields(&mut self, union_id: UnionId) { + let data = union_id.fields(self.db); + let edition = self.edition(union_id); + let mut union_fields_replacements = data + .fields() + .iter() + .filter_map(|(_, field)| { + to_lower_snake_case(&field.name.display_no_db(edition).to_smolstr()).map( + |new_name| Replacement { + current_name: field.name.clone(), + suggested_text: new_name, + expected_case: CaseType::LowerSnakeCase, + }, + ) + }) + .peekable(); + + // XXX: Only look at sources if we do have incorrect names. + if union_fields_replacements.peek().is_none() { + return; + } + + let union_loc = union_id.lookup(self.db); + let union_src = union_loc.source(self.db); + + let Some(union_fields_list) = union_src.value.record_field_list() else { + always!( + union_fields_replacements.peek().is_none(), + "Replacements ({:?}) were generated for a union fields \ + which had no fields list: {:?}", + union_fields_replacements.collect::<Vec<_>>(), + union_src + ); + return; + }; + let mut union_fields_iter = union_fields_list.fields(); + for field_replacement in union_fields_replacements { + // We assume that parameters in replacement are in the same order as in the + // actual params list, but just some of them (ones that named correctly) are skipped. + let field = loop { + if let Some(field) = union_fields_iter.next() { + let Some(field_name) = field.name() else { + continue; + }; + if field_name.as_name() == field_replacement.current_name { + break field; + } + } else { + never!( + "Replacement ({:?}) was generated for a union field \ + which was not found: {:?}", + field_replacement, + union_src + ); + return; + } + }; + + self.create_incorrect_case_diagnostic_for_ast_node( + field_replacement, + union_src.file_id, + &field, + IdentType::Field, + ); + } + } + fn validate_enum(&mut self, enum_id: EnumId) { // Check the enum name. - let data = self.db.enum_signature(enum_id); + let data = EnumSignature::of(self.db, enum_id); // rustc implementation excuses repr(C) since C structs predominantly don't // use camel case. @@ -556,7 +653,7 @@ impl<'a> DeclValidator<'a> { return; } - let data = self.db.const_signature(const_id); + let data = ConstSignature::of(self.db, const_id); let Some(name) = &data.name else { return; }; @@ -569,7 +666,7 @@ impl<'a> DeclValidator<'a> { } fn validate_static(&mut self, static_id: StaticId) { - let data = self.db.static_signature(static_id); + let data = StaticSignature::of(self.db, static_id); if data.flags.contains(StaticFlags::EXTERN) { cov_mark::hit!(extern_static_incorrect_case_ignored); return; @@ -595,7 +692,7 @@ impl<'a> DeclValidator<'a> { } // Check the type alias name. - let data = self.db.type_alias_signature(type_alias_id); + let data = TypeAliasSignature::of(self.db, type_alias_id); self.create_incorrect_case_diagnostic_for_item_name( type_alias_id, &data.name, diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs index 4e1bb6f4c5..33d9dd538d 100644 --- a/crates/hir-ty/src/diagnostics/expr.rs +++ b/crates/hir-ty/src/diagnostics/expr.rs @@ -21,7 +21,7 @@ use syntax::{ ast::{self, UnaryOp}, }; use tracing::debug; -use triomphe::Arc; + use typed_arena::Arena; use crate::{ @@ -76,9 +76,9 @@ impl BodyValidationDiagnostic { validate_lints: bool, ) -> Vec<BodyValidationDiagnostic> { let _p = tracing::info_span!("BodyValidationDiagnostic::collect").entered(); - let infer = InferenceResult::for_body(db, owner); - let body = db.body(owner); - let env = db.trait_environment_for_body(owner); + let infer = InferenceResult::of(db, owner); + let body = Body::of(db, owner); + let env = db.trait_environment(owner.into()); let interner = DbInterner::new_with(db, owner.krate(db)); let infcx = interner.infer_ctxt().build(TypingMode::typeck_for_body(interner, owner.into())); @@ -98,7 +98,7 @@ impl BodyValidationDiagnostic { struct ExprValidator<'db> { owner: DefWithBodyId, - body: Arc<Body>, + body: &'db Body, infer: &'db InferenceResult, env: ParamEnv<'db>, diagnostics: Vec<BodyValidationDiagnostic>, @@ -116,10 +116,10 @@ impl<'db> ExprValidator<'db> { let db = self.db(); let mut filter_map_next_checker = None; // we'll pass &mut self while iterating over body.exprs, so they need to be disjoint - let body = Arc::clone(&self.body); + let body = self.body; if matches!(self.owner, DefWithBodyId::FunctionId(_)) { - self.check_for_trailing_return(body.body_expr, &body); + self.check_for_trailing_return(body.root_expr(), body); } for (id, expr) in body.exprs() { @@ -141,7 +141,7 @@ impl<'db> ExprValidator<'db> { self.validate_call(id, expr, &mut filter_map_next_checker); } Expr::Closure { body: body_expr, .. } => { - self.check_for_trailing_return(*body_expr, &body); + self.check_for_trailing_return(*body_expr, body); } Expr::If { .. } => { self.check_for_unnecessary_else(id, expr); @@ -240,7 +240,7 @@ impl<'db> ExprValidator<'db> { .as_reference() .map(|(match_expr_ty, ..)| match_expr_ty == pat_ty) .unwrap_or(false)) - && types_of_subpatterns_do_match(arm.pat, &self.body, self.infer) + && types_of_subpatterns_do_match(arm.pat, self.body, self.infer) { // If we had a NotUsefulMatchArm diagnostic, we could // check the usefulness of each pattern as we added it @@ -388,7 +388,7 @@ impl<'db> ExprValidator<'db> { pat: PatId, have_errors: &mut bool, ) -> DeconstructedPat<'a, 'db> { - let mut patcx = match_check::PatCtxt::new(self.db(), self.infer, &self.body); + let mut patcx = match_check::PatCtxt::new(self.db(), self.infer, self.body); let pattern = patcx.lower_pattern(pat); let pattern = cx.lower_pat(&pattern); if !patcx.errors.is_empty() { @@ -451,7 +451,7 @@ impl<'db> ExprValidator<'db> { && last_then_expr_ty.is_never() { // Only look at sources if the then branch diverges and we have an else branch. - let source_map = self.db().body_with_source_map(self.owner).1; + let source_map = &Body::with_source_map(self.db(), self.owner).1; let Ok(source_ptr) = source_map.expr_syntax(id) else { return; }; diff --git a/crates/hir-ty/src/diagnostics/match_check.rs b/crates/hir-ty/src/diagnostics/match_check.rs index 8e6101e6a0..f559c26bf5 100644 --- a/crates/hir-ty/src/diagnostics/match_check.rs +++ b/crates/hir-ty/src/diagnostics/match_check.rs @@ -14,6 +14,7 @@ use hir_def::{ expr_store::{Body, path::Path}, hir::PatId, item_tree::FieldsShape, + signatures::{StructSignature, UnionSignature}, }; use hir_expand::name::Name; use rustc_type_ir::inherent::IntoKind; @@ -340,12 +341,12 @@ impl<'db> HirDisplay<'db> for Pat<'db> { VariantId::StructId(s) => write!( f, "{}", - f.db.struct_signature(s).name.display(f.db, f.edition()) + StructSignature::of(f.db, s).name.display(f.db, f.edition()) )?, VariantId::UnionId(u) => write!( f, "{}", - f.db.union_signature(u).name.display(f.db, f.edition()) + UnionSignature::of(f.db, u).name.display(f.db, f.edition()) )?, }; diff --git a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs index eda7e7e249..bc3d9bbec6 100644 --- a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs +++ b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs @@ -4,6 +4,7 @@ use std::{cell::LazyCell, fmt}; use hir_def::{ EnumId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId, attrs::AttrFlags, + signatures::VariantFields, }; use intern::sym; use rustc_pattern_analysis::{ @@ -363,7 +364,8 @@ impl<'a, 'db> PatCx for MatchCheckCtx<'a, 'db> { let adt = adt_def.def_id().0; let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap(); - let visibilities = LazyCell::new(|| self.db.field_visibilities(variant)); + let visibilities = + LazyCell::new(|| VariantFields::field_visibilities(self.db, variant)); self.list_variant_fields(*ty, variant) .map(move |(fid, ty)| { diff --git a/crates/hir-ty/src/diagnostics/unsafe_check.rs b/crates/hir-ty/src/diagnostics/unsafe_check.rs index 50d4517d01..09c648139c 100644 --- a/crates/hir-ty/src/diagnostics/unsafe_check.rs +++ b/crates/hir-ty/src/diagnostics/unsafe_check.rs @@ -5,11 +5,12 @@ use std::mem; use either::Either; use hir_def::{ - AdtId, CallableDefId, DefWithBodyId, FieldId, FunctionId, VariantId, - expr_store::{Body, path::Path}, + AdtId, CallableDefId, DefWithBodyId, ExpressionStoreOwnerId, FieldId, FunctionId, GenericDefId, + VariantId, + expr_store::{Body, ExpressionStore, path::Path}, hir::{AsmOperand, Expr, ExprId, ExprOrPatId, InlineAsmKind, Pat, PatId, Statement, UnaryOp}, resolver::{HasResolver, ResolveValueResult, Resolver, ValueNs}, - signatures::StaticFlags, + signatures::{FunctionSignature, StaticFlags, StaticSignature}, type_ref::Rawness, }; use rustc_type_ir::inherent::IntoKind; @@ -34,15 +35,15 @@ pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> MissingUnsafe let _p = tracing::info_span!("missing_unsafe").entered(); let is_unsafe = match def { - DefWithBodyId::FunctionId(it) => db.function_signature(it).is_unsafe(), + DefWithBodyId::FunctionId(it) => FunctionSignature::of(db, it).is_unsafe(), DefWithBodyId::StaticId(_) | DefWithBodyId::ConstId(_) | DefWithBodyId::VariantId(_) => { false } }; let mut res = MissingUnsafeResult { fn_is_unsafe: is_unsafe, ..MissingUnsafeResult::default() }; - let body = db.body(def); - let infer = InferenceResult::for_body(db, def); + let body = Body::of(db, def); + let infer = InferenceResult::of(db, def); let mut callback = |diag| match diag { UnsafeDiagnostic::UnsafeOperation { node, inside_unsafe_block, reason } => { if inside_unsafe_block == InsideUnsafeBlock::No { @@ -55,8 +56,8 @@ pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> MissingUnsafe } } }; - let mut visitor = UnsafeVisitor::new(db, infer, &body, def, &mut callback); - visitor.walk_expr(body.body_expr); + let mut visitor = UnsafeVisitor::new(db, infer, body, def.into(), &mut callback); + visitor.walk_expr(body.root_expr()); if !is_unsafe { // Unsafety in function parameter patterns (that can only be union destructuring) @@ -109,8 +110,8 @@ pub fn unsafe_operations_for_body( callback(node); } }; - let mut visitor = UnsafeVisitor::new(db, infer, body, def, &mut visitor_callback); - visitor.walk_expr(body.body_expr); + let mut visitor = UnsafeVisitor::new(db, infer, body, def.into(), &mut visitor_callback); + visitor.walk_expr(body.root_expr()); for ¶m in &body.params { visitor.walk_pat(param); } @@ -119,8 +120,8 @@ pub fn unsafe_operations_for_body( pub fn unsafe_operations( db: &dyn HirDatabase, infer: &InferenceResult, - def: DefWithBodyId, - body: &Body, + def: ExpressionStoreOwnerId, + body: &ExpressionStore, current: ExprId, callback: &mut dyn FnMut(ExprOrPatId, InsideUnsafeBlock), ) { @@ -137,9 +138,9 @@ pub fn unsafe_operations( struct UnsafeVisitor<'db> { db: &'db dyn HirDatabase, infer: &'db InferenceResult, - body: &'db Body, + body: &'db ExpressionStore, resolver: Resolver<'db>, - def: DefWithBodyId, + def: ExpressionStoreOwnerId, inside_unsafe_block: InsideUnsafeBlock, inside_assignment: bool, inside_union_destructure: bool, @@ -156,13 +157,16 @@ impl<'db> UnsafeVisitor<'db> { fn new( db: &'db dyn HirDatabase, infer: &'db InferenceResult, - body: &'db Body, - def: DefWithBodyId, + body: &'db ExpressionStore, + def: ExpressionStoreOwnerId, unsafe_expr_cb: &'db mut dyn FnMut(UnsafeDiagnostic), ) -> Self { let resolver = def.resolver(db); let def_target_features = match def { - DefWithBodyId::FunctionId(func) => TargetFeatures::from_fn(db, func), + ExpressionStoreOwnerId::Body(DefWithBodyId::FunctionId(func)) + | ExpressionStoreOwnerId::Signature(GenericDefId::FunctionId(func)) => { + TargetFeatures::from_fn(db, func) + } _ => TargetFeatures::default(), }; let krate = resolver.krate(); @@ -430,8 +434,8 @@ impl<'db> UnsafeVisitor<'db> { fn mark_unsafe_path(&mut self, node: ExprOrPatId, path: &Path) { let hygiene = self.body.expr_or_pat_path_hygiene(node); let value_or_partial = self.resolver.resolve_path_in_value_ns(self.db, path, hygiene); - if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id), _)) = value_or_partial { - let static_data = self.db.static_signature(id); + if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id))) = value_or_partial { + let static_data = StaticSignature::of(self.db, id); if static_data.flags.contains(StaticFlags::MUTABLE) { self.on_unsafe_op(node, UnsafetyReason::MutableStatic); } else if static_data.flags.contains(StaticFlags::EXTERN) diff --git a/crates/hir-ty/src/display.rs b/crates/hir-ty/src/display.rs index 43b428c3fa..d680588645 100644 --- a/crates/hir-ty/src/display.rs +++ b/crates/hir-ty/src/display.rs @@ -10,16 +10,18 @@ use std::{ use base_db::{Crate, FxIndexMap}; use either::Either; use hir_def::{ - FindPathConfig, GenericDefId, GenericParamId, HasModule, LocalFieldId, Lookup, ModuleDefId, - ModuleId, TraitId, - db::DefDatabase, + ExpressionStoreOwnerId, FindPathConfig, GenericDefId, GenericParamId, HasModule, LocalFieldId, + Lookup, ModuleDefId, ModuleId, TraitId, expr_store::{ExpressionStore, path::Path}, find_path::{self, PrefixKind}, - hir::generics::{TypeOrConstParamData, TypeParamProvenance, WherePredicate}, + hir::generics::{GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate}, item_scope::ItemInNs, item_tree::FieldsShape, lang_item::LangItems, - signatures::VariantFields, + signatures::{ + EnumSignature, FunctionSignature, StructSignature, TraitSignature, TypeAliasSignature, + UnionSignature, VariantFields, + }, type_ref::{ ConstRef, LifetimeRef, LifetimeRefId, TraitBoundModifier, TypeBound, TypeRef, TypeRefId, UseArgRef, @@ -100,6 +102,9 @@ pub struct HirFormatter<'a, 'db> { display_kind: DisplayKind, display_target: DisplayTarget, bounds_formatting_ctx: BoundsFormattingCtx<'db>, + /// Whether formatting `impl Trait1 + Trait2` or `dyn Trait1 + Trait2` needs parentheses around it, + /// for example when formatting `&(impl Trait1 + Trait2)`. + trait_bounds_need_parens: bool, } // FIXME: To consider, ref and dyn trait lifetimes can be omitted if they are `'_`, path args should @@ -331,6 +336,7 @@ pub trait HirDisplay<'db> { show_container_bounds: false, display_lifetimes: DisplayLifetime::OnlyNamedOrStatic, bounds_formatting_ctx: Default::default(), + trait_bounds_need_parens: false, }) { Ok(()) => {} Err(HirDisplayError::FmtError) => panic!("Writing to String can't fail!"), @@ -566,6 +572,7 @@ impl<'db, T: HirDisplay<'db>> HirDisplayWrapper<'_, 'db, T> { show_container_bounds: self.show_container_bounds, display_lifetimes: self.display_lifetimes, bounds_formatting_ctx: Default::default(), + trait_bounds_need_parens: false, }) } @@ -612,7 +619,11 @@ impl<'db, T: HirDisplay<'db> + Internable> HirDisplay<'db> for Interned<T> { } } -fn write_projection<'db>(f: &mut HirFormatter<'_, 'db>, alias: &AliasTy<'db>) -> Result { +fn write_projection<'db>( + f: &mut HirFormatter<'_, 'db>, + alias: &AliasTy<'db>, + needs_parens_if_multi: bool, +) -> Result { if f.should_truncate() { return write!(f, "{TYPE_HINT_TRUNCATION}"); } @@ -650,6 +661,7 @@ fn write_projection<'db>(f: &mut HirFormatter<'_, 'db>, alias: &AliasTy<'db>) -> Either::Left(Ty::new_alias(f.interner, AliasTyKind::Projection, *alias)), &bounds, SizedByDefault::NotSized, + needs_parens_if_multi, ) }); } @@ -662,7 +674,9 @@ fn write_projection<'db>(f: &mut HirFormatter<'_, 'db>, alias: &AliasTy<'db>) -> write!( f, ">::{}", - f.db.type_alias_signature(alias.def_id.expect_type_alias()).name.display(f.db, f.edition()) + TypeAliasSignature::of(f.db, alias.def_id.expect_type_alias()) + .name + .display(f.db, f.edition()) )?; let proj_params = &alias.args.as_slice()[trait_ref.args.len()..]; hir_fmt_generics(f, proj_params, None, None) @@ -844,7 +858,7 @@ fn render_const_scalar_inner<'db>( } TyKind::Adt(adt, _) if b.len() == 2 * size_of::<usize>() => match adt.def_id().0 { hir_def::AdtId::StructId(s) => { - let data = f.db.struct_signature(s); + let data = StructSignature::of(f.db, s); write!(f, "&{}", data.name.display(f.db, f.edition()))?; Ok(()) } @@ -902,14 +916,16 @@ fn render_const_scalar_inner<'db>( }; match def { hir_def::AdtId::StructId(s) => { - let data = f.db.struct_signature(s); + let data = StructSignature::of(f.db, s); write!(f, "{}", data.name.display(f.db, f.edition()))?; let field_types = f.db.field_types(s.into()); render_variant_after_name( s.fields(f.db), f, field_types, - f.db.trait_environment(def.into()), + f.db.trait_environment(ExpressionStoreOwnerId::from(GenericDefId::from( + def, + ))), &layout, args, b, @@ -917,7 +933,7 @@ fn render_const_scalar_inner<'db>( ) } hir_def::AdtId::UnionId(u) => { - write!(f, "{}", f.db.union_signature(u).name.display(f.db, f.edition())) + write!(f, "{}", UnionSignature::of(f.db, u).name.display(f.db, f.edition())) } hir_def::AdtId::EnumId(e) => { let Ok(target_data_layout) = f.db.target_data_layout(f.krate()) else { @@ -941,7 +957,9 @@ fn render_const_scalar_inner<'db>( var_id.fields(f.db), f, field_types, - f.db.trait_environment(def.into()), + f.db.trait_environment(ExpressionStoreOwnerId::from(GenericDefId::from( + def, + ))), var_layout, args, b, @@ -1056,7 +1074,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { return write!(f, "{TYPE_HINT_TRUNCATION}"); } - use TyKind; + let trait_bounds_need_parens = mem::replace(&mut f.trait_bounds_need_parens, false); match self.kind() { TyKind::Never => write!(f, "!")?, TyKind::Str => write!(f, "str")?, @@ -1077,103 +1095,34 @@ impl<'db> HirDisplay<'db> for Ty<'db> { c.hir_fmt(f)?; write!(f, "]")?; } - kind @ (TyKind::RawPtr(t, m) | TyKind::Ref(_, t, m)) => { - if let TyKind::Ref(l, _, _) = kind { - f.write_char('&')?; - if f.render_region(l) { - l.hir_fmt(f)?; - f.write_char(' ')?; - } - match m { - rustc_ast_ir::Mutability::Not => (), - rustc_ast_ir::Mutability::Mut => f.write_str("mut ")?, - } - } else { - write!( - f, - "*{}", - match m { - rustc_ast_ir::Mutability::Not => "const ", - rustc_ast_ir::Mutability::Mut => "mut ", - } - )?; + TyKind::Ref(l, t, m) => { + f.write_char('&')?; + if f.render_region(l) { + l.hir_fmt(f)?; + f.write_char(' ')?; + } + match m { + rustc_ast_ir::Mutability::Not => (), + rustc_ast_ir::Mutability::Mut => f.write_str("mut ")?, } - // FIXME: all this just to decide whether to use parentheses... - let (preds_to_print, has_impl_fn_pred) = match t.kind() { - TyKind::Dynamic(bounds, region) => { - let contains_impl_fn = - bounds.iter().any(|bound| match bound.skip_binder() { - ExistentialPredicate::Trait(trait_ref) => { - let trait_ = trait_ref.def_id.0; - fn_traits(f.lang_items()).any(|it| it == trait_) - } - _ => false, - }); - let render_lifetime = f.render_region(region); - (bounds.len() + render_lifetime as usize, contains_impl_fn) - } - TyKind::Alias(AliasTyKind::Opaque, ty) => { - let opaque_ty_id = match ty.def_id { - SolverDefId::InternedOpaqueTyId(id) => id, - _ => unreachable!(), - }; - let impl_trait_id = db.lookup_intern_impl_trait_id(opaque_ty_id); - if let ImplTraitId::ReturnTypeImplTrait(func, _) = impl_trait_id { - let data = impl_trait_id.predicates(db); - let bounds = - || data.iter_instantiated_copied(f.interner, ty.args.as_slice()); - let mut len = bounds().count(); - - // Don't count Sized but count when it absent - // (i.e. when explicit ?Sized bound is set). - let default_sized = SizedByDefault::Sized { anchor: func.krate(db) }; - let sized_bounds = bounds() - .filter(|b| { - matches!( - b.kind().skip_binder(), - ClauseKind::Trait(trait_ref) - if default_sized.is_sized_trait( - trait_ref.def_id().0, - db, - ), - ) - }) - .count(); - match sized_bounds { - 0 => len += 1, - _ => { - len = len.saturating_sub(sized_bounds); - } - } - - let contains_impl_fn = bounds().any(|bound| { - if let ClauseKind::Trait(trait_ref) = bound.kind().skip_binder() { - let trait_ = trait_ref.def_id().0; - fn_traits(f.lang_items()).any(|it| it == trait_) - } else { - false - } - }); - (len, contains_impl_fn) - } else { - (0, false) - } + f.trait_bounds_need_parens = true; + t.hir_fmt(f)?; + f.trait_bounds_need_parens = false; + } + TyKind::RawPtr(t, m) => { + write!( + f, + "*{}", + match m { + rustc_ast_ir::Mutability::Not => "const ", + rustc_ast_ir::Mutability::Mut => "mut ", } - _ => (0, false), - }; - - if has_impl_fn_pred && preds_to_print <= 2 { - return t.hir_fmt(f); - } + )?; - if preds_to_print > 1 { - write!(f, "(")?; - t.hir_fmt(f)?; - write!(f, ")")?; - } else { - t.hir_fmt(f)?; - } + f.trait_bounds_need_parens = true; + t.hir_fmt(f)?; + f.trait_bounds_need_parens = false; } TyKind::Tuple(tys) => { if tys.len() == 1 { @@ -1212,11 +1161,13 @@ impl<'db> HirDisplay<'db> for Ty<'db> { write!(f, "fn ")?; f.start_location_link(def.into()); match def { - CallableDefId::FunctionId(ff) => { - write!(f, "{}", db.function_signature(ff).name.display(f.db, f.edition()))? - } + CallableDefId::FunctionId(ff) => write!( + f, + "{}", + FunctionSignature::of(db, ff).name.display(f.db, f.edition()) + )?, CallableDefId::StructId(s) => { - write!(f, "{}", db.struct_signature(s).name.display(f.db, f.edition()))? + write!(f, "{}", StructSignature::of(db, s).name.display(f.db, f.edition()))? } CallableDefId::EnumVariantId(e) => { let loc = e.lookup(db); @@ -1295,9 +1246,11 @@ impl<'db> HirDisplay<'db> for Ty<'db> { match f.display_kind { DisplayKind::Diagnostics | DisplayKind::Test => { let name = match def_id { - hir_def::AdtId::StructId(it) => db.struct_signature(it).name.clone(), - hir_def::AdtId::UnionId(it) => db.union_signature(it).name.clone(), - hir_def::AdtId::EnumId(it) => db.enum_signature(it).name.clone(), + hir_def::AdtId::StructId(it) => { + StructSignature::of(db, it).name.clone() + } + hir_def::AdtId::UnionId(it) => UnionSignature::of(db, it).name.clone(), + hir_def::AdtId::EnumId(it) => EnumSignature::of(db, it).name.clone(), }; write!(f, "{}", name.display(f.db, f.edition()))?; } @@ -1328,9 +1281,11 @@ impl<'db> HirDisplay<'db> for Ty<'db> { hir_fmt_generics(f, parameters.as_slice(), Some(def.def_id().0.into()), None)?; } - TyKind::Alias(AliasTyKind::Projection, alias_ty) => write_projection(f, &alias_ty)?, + TyKind::Alias(AliasTyKind::Projection, alias_ty) => { + write_projection(f, &alias_ty, trait_bounds_need_parens)? + } TyKind::Foreign(alias) => { - let type_alias = db.type_alias_signature(alias.0); + let type_alias = TypeAliasSignature::of(db, alias.0); f.start_location_link(alias.0.into()); write!(f, "{}", type_alias.name.display(f.db, f.edition()))?; f.end_location_link(); @@ -1363,6 +1318,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { Either::Left(*self), &bounds, SizedByDefault::Sized { anchor: krate }, + trait_bounds_need_parens, )?; } TyKind::Closure(id, substs) => { @@ -1393,8 +1349,8 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } let sig = interner.signature_unclosure(substs.as_closure().sig(), Safety::Safe); let sig = sig.skip_binder(); - let InternedClosure(def, _) = db.lookup_intern_closure(id); - let infer = InferenceResult::for_body(db, def); + let InternedClosure(owner, _) = db.lookup_intern_closure(id); + let infer = InferenceResult::of(db, owner); let (_, kind) = infer.closure_info(id); match f.closure_style { ClosureStyle::ImplFn => write!(f, "impl {kind:?}(")?, @@ -1525,6 +1481,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { Either::Left(*self), &bounds, SizedByDefault::Sized { anchor: krate }, + trait_bounds_need_parens, )?; } }, @@ -1567,6 +1524,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { Either::Left(*self), &bounds_to_display, SizedByDefault::NotSized, + trait_bounds_need_parens, )?; } TyKind::Error(_) => { @@ -1581,7 +1539,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { let InternedCoroutine(owner, expr_id) = coroutine_id.0.loc(db); let CoroutineArgsParts { resume_ty, yield_ty, return_ty, .. } = subst.split_coroutine_args(); - let body = db.body(owner); + let body = ExpressionStore::of(db, owner); let expr = &body[expr_id]; match expr { hir_def::hir::Expr::Closure { @@ -1806,11 +1764,11 @@ pub enum SizedByDefault { } impl SizedByDefault { - fn is_sized_trait(self, trait_: TraitId, db: &dyn DefDatabase) -> bool { + fn is_sized_trait(self, trait_: TraitId, interner: DbInterner<'_>) -> bool { match self { Self::NotSized => false, - Self::Sized { anchor } => { - let sized_trait = hir_def::lang_item::lang_items(db, anchor).Sized; + Self::Sized { .. } => { + let sized_trait = interner.lang_items().Sized; Some(trait_) == sized_trait } } @@ -1823,16 +1781,62 @@ pub fn write_bounds_like_dyn_trait_with_prefix<'db>( this: Either<Ty<'db>, Region<'db>>, predicates: &[Clause<'db>], default_sized: SizedByDefault, + needs_parens_if_multi: bool, ) -> Result { + let needs_parens = + needs_parens_if_multi && trait_bounds_need_parens(f, this, predicates, default_sized); + if needs_parens { + write!(f, "(")?; + } write!(f, "{prefix}")?; if !predicates.is_empty() || predicates.is_empty() && matches!(default_sized, SizedByDefault::Sized { .. }) { write!(f, " ")?; - write_bounds_like_dyn_trait(f, this, predicates, default_sized) - } else { - Ok(()) + write_bounds_like_dyn_trait(f, this, predicates, default_sized)?; + } + if needs_parens { + write!(f, ")")?; + } + Ok(()) +} + +fn trait_bounds_need_parens<'db>( + f: &mut HirFormatter<'_, 'db>, + this: Either<Ty<'db>, Region<'db>>, + predicates: &[Clause<'db>], + default_sized: SizedByDefault, +) -> bool { + // This needs to be kept in sync with `write_bounds_like_dyn_trait()`. + let mut distinct_bounds = 0usize; + let mut is_sized = false; + for p in predicates { + match p.kind().skip_binder() { + ClauseKind::Trait(trait_ref) => { + let trait_ = trait_ref.def_id().0; + if default_sized.is_sized_trait(trait_, f.interner) { + is_sized = true; + if matches!(default_sized, SizedByDefault::Sized { .. }) { + // Don't print +Sized, but rather +?Sized if absent. + continue; + } + } + + distinct_bounds += 1; + } + ClauseKind::TypeOutlives(to) if Either::Left(to.0) == this => distinct_bounds += 1, + ClauseKind::RegionOutlives(lo) if Either::Right(lo.0) == this => distinct_bounds += 1, + _ => {} + } } + + if let SizedByDefault::Sized { .. } = default_sized + && !is_sized + { + distinct_bounds += 1; + } + + distinct_bounds > 1 } fn write_bounds_like_dyn_trait<'db>( @@ -1855,7 +1859,7 @@ fn write_bounds_like_dyn_trait<'db>( match p.kind().skip_binder() { ClauseKind::Trait(trait_ref) => { let trait_ = trait_ref.def_id().0; - if default_sized.is_sized_trait(trait_, f.db) { + if default_sized.is_sized_trait(trait_, f.interner) { is_sized = true; if matches!(default_sized, SizedByDefault::Sized { .. }) { // Don't print +Sized, but rather +?Sized if absent. @@ -1876,7 +1880,7 @@ fn write_bounds_like_dyn_trait<'db>( // existential) here, which is the only thing that's // possible in actual Rust, and hence don't print it f.start_location_link(trait_.into()); - write!(f, "{}", f.db.trait_signature(trait_).name.display(f.db, f.edition()))?; + write!(f, "{}", TraitSignature::of(f.db, trait_).name.display(f.db, f.edition()))?; f.end_location_link(); if is_fn_trait { if let [_self, params @ ..] = trait_ref.trait_ref.args.as_slice() @@ -1939,7 +1943,7 @@ fn write_bounds_like_dyn_trait<'db>( angle_open = true; } let assoc_ty_id = projection.def_id().expect_type_alias(); - let type_alias = f.db.type_alias_signature(assoc_ty_id); + let type_alias = TypeAliasSignature::of(f.db, assoc_ty_id); f.start_location_link(assoc_ty_id.into()); write!(f, "{}", type_alias.name.display(f.db, f.edition()))?; f.end_location_link(); @@ -2030,7 +2034,7 @@ impl<'db> HirDisplay<'db> for TraitRef<'db> { fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result { let trait_ = self.def_id.0; f.start_location_link(trait_.into()); - write!(f, "{}", f.db.trait_signature(trait_).name.display(f.db, f.edition()))?; + write!(f, "{}", TraitSignature::of(f.db, trait_).name.display(f.db, f.edition()))?; f.end_location_link(); let substs = self.args.as_slice(); hir_fmt_generic_args(f, &substs[1..], None, Some(self.self_ty())) @@ -2137,7 +2141,7 @@ impl<'db> HirDisplayWithExpressionStore<'db> for LifetimeRefId { LifetimeRef::Placeholder => write!(f, "'_"), LifetimeRef::Error => write!(f, "'{{error}}"), &LifetimeRef::Param(lifetime_param_id) => { - let generic_params = f.db.generic_params(lifetime_param_id.parent); + let generic_params = GenericParams::of(f.db, lifetime_param_id.parent); write!( f, "{}", @@ -2153,7 +2157,7 @@ impl<'db> HirDisplayWithExpressionStore<'db> for TypeRefId { match &store[*self] { TypeRef::Never => write!(f, "!")?, TypeRef::TypeParam(param) => { - let generic_params = f.db.generic_params(param.parent()); + let generic_params = GenericParams::of(f.db, param.parent()); match generic_params[param.local_id()].name() { Some(name) => write!(f, "{}", name.display(f.db, f.edition()))?, None => { diff --git a/crates/hir-ty/src/drop.rs b/crates/hir-ty/src/drop.rs index 9d6869eee9..ddc4e4ce85 100644 --- a/crates/hir-ty/src/drop.rs +++ b/crates/hir-ty/src/drop.rs @@ -1,6 +1,9 @@ //! Utilities for computing drop info about types. -use hir_def::{AdtId, signatures::StructFlags}; +use hir_def::{ + AdtId, + signatures::{StructFlags, StructSignature}, +}; use rustc_hash::FxHashSet; use rustc_type_ir::inherent::{AdtDef, IntoKind}; use stdx::never; @@ -73,8 +76,7 @@ fn has_drop_glue_impl<'db>( } match adt_id { AdtId::StructId(id) => { - if db - .struct_signature(id) + if StructSignature::of(db, id) .flags .intersects(StructFlags::IS_MANUALLY_DROP | StructFlags::IS_PHANTOM_DATA) { @@ -132,9 +134,9 @@ fn has_drop_glue_impl<'db>( TyKind::Slice(ty) => has_drop_glue_impl(infcx, ty, env, visited), TyKind::Closure(closure_id, subst) => { let owner = db.lookup_intern_closure(closure_id.0).0; - let infer = InferenceResult::for_body(db, owner); + let infer = InferenceResult::of(db, owner); let (captures, _) = infer.closure_info(closure_id.0); - let env = db.trait_environment_for_body(owner); + let env = db.trait_environment(owner); captures .iter() .map(|capture| has_drop_glue_impl(infcx, capture.ty(db, subst), env, visited)) diff --git a/crates/hir-ty/src/dyn_compatibility.rs b/crates/hir-ty/src/dyn_compatibility.rs index 59cfd3fdc9..4c300affd8 100644 --- a/crates/hir-ty/src/dyn_compatibility.rs +++ b/crates/hir-ty/src/dyn_compatibility.rs @@ -4,8 +4,10 @@ use std::ops::ControlFlow; use hir_def::{ AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId, TypeAliasId, - TypeOrConstParamId, TypeParamId, hir::generics::LocalTypeOrConstParamId, - nameres::crate_def_map, signatures::TraitFlags, + TypeOrConstParamId, TypeParamId, + hir::generics::{GenericParams, LocalTypeOrConstParamId}, + nameres::crate_def_map, + signatures::{FunctionSignature, TraitFlags, TraitSignature}, }; use rustc_hash::FxHashSet; use rustc_type_ir::{ @@ -298,7 +300,7 @@ where if def_map.is_unstable_feature_enabled(&intern::sym::generic_associated_type_extended) { ControlFlow::Continue(()) } else { - let generic_params = db.generic_params(item.into()); + let generic_params = GenericParams::of(db, item.into()); if !generic_params.is_empty() { cb(DynCompatibilityViolation::GAT(it)) } else { @@ -318,7 +320,7 @@ fn virtual_call_violations_for_method<F>( where F: FnMut(MethodViolationCode) -> ControlFlow<()>, { - let func_data = db.function_signature(func); + let func_data = FunctionSignature::of(db, func); if !func_data.has_self_param() { cb(MethodViolationCode::StaticMethod)?; } @@ -349,7 +351,7 @@ where cb(mvc)?; } - let generic_params = db.generic_params(func.into()); + let generic_params = GenericParams::of(db, func.into()); if generic_params.len_type_or_consts() > 0 { cb(MethodViolationCode::Generic)?; } @@ -371,7 +373,7 @@ where trait_ref: pred_trait_ref, polarity: PredicatePolarity::Positive, }) = pred - && let trait_data = db.trait_signature(pred_trait_ref.def_id.0) + && let trait_data = TraitSignature::of(db, pred_trait_ref.def_id.0) && trait_data.flags.contains(TraitFlags::AUTO) && let rustc_type_ir::TyKind::Param(ParamTy { index: 0, .. }) = pred_trait_ref.self_ty().kind() diff --git a/crates/hir-ty/src/dyn_compatibility/tests.rs b/crates/hir-ty/src/dyn_compatibility/tests.rs index 5c9b06e39a..a70f98a0fe 100644 --- a/crates/hir-ty/src/dyn_compatibility/tests.rs +++ b/crates/hir-ty/src/dyn_compatibility/tests.rs @@ -1,6 +1,6 @@ use std::ops::ControlFlow; -use hir_def::db::DefDatabase; +use hir_def::signatures::TraitSignature; use rustc_hash::{FxHashMap, FxHashSet}; use syntax::ToSmolStr; use test_fixture::WithFixture; @@ -40,8 +40,7 @@ fn check_dyn_compatibility<'a>( .declarations() .filter_map(|def| { if let hir_def::ModuleDefId::TraitId(trait_id) = def { - let name = db - .trait_signature(trait_id) + let name = TraitSignature::of(&db, trait_id) .name .display_no_db(file_id.edition(&db)) .to_smolstr(); diff --git a/crates/hir-ty/src/generics.rs b/crates/hir-ty/src/generics.rs index 5f0261437b..822942eec3 100644 --- a/crates/hir-ty/src/generics.rs +++ b/crates/hir-ty/src/generics.rs @@ -20,24 +20,23 @@ use hir_def::{ }, }; use itertools::chain; -use triomphe::Arc; -pub fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics { +pub fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics<'_> { let parent_generics = parent_generic_def(db, def).map(|def| Box::new(generics(db, def))); - let (params, store) = db.generic_params_and_store(def); + let (params, store) = GenericParams::with_store(db, def); let has_trait_self_param = params.trait_self_param().is_some(); Generics { def, params, parent_generics, has_trait_self_param, store } } #[derive(Clone, Debug)] -pub struct Generics { +pub struct Generics<'db> { def: GenericDefId, - params: Arc<GenericParams>, - store: Arc<ExpressionStore>, - parent_generics: Option<Box<Generics>>, + params: &'db GenericParams, + store: &'db ExpressionStore, + parent_generics: Option<Box<Generics<'db>>>, has_trait_self_param: bool, } -impl<T> ops::Index<T> for Generics +impl<T> ops::Index<T> for Generics<'_> where GenericParams: ops::Index<T>, { @@ -47,13 +46,13 @@ where } } -impl Generics { +impl<'db> Generics<'db> { pub(crate) fn def(&self) -> GenericDefId { self.def } pub(crate) fn store(&self) -> &ExpressionStore { - &self.store + self.store } pub(crate) fn where_predicates(&self) -> impl Iterator<Item = &WherePredicate> { @@ -97,7 +96,7 @@ impl Generics { ) -> impl Iterator<Item = ((GenericParamId, GenericParamDataRef<'_>), &ExpressionStore)> + '_ { self.iter_parent() - .zip(self.parent_generics().into_iter().flat_map(|it| std::iter::repeat(&*it.store))) + .zip(self.parent_generics().into_iter().flat_map(|it| std::iter::repeat(it.store))) } /// Iterate over the params without parent params. @@ -185,7 +184,7 @@ impl Generics { if param.parent == self.def { let idx = param.local_id.into_raw().into_u32() as usize; debug_assert!( - idx <= self.params.len_type_or_consts(), + idx < self.params.len_type_or_consts(), "idx: {} len: {}", idx, self.params.len_type_or_consts() @@ -219,27 +218,11 @@ impl Generics { } } - pub(crate) fn parent_generics(&self) -> Option<&Generics> { + pub(crate) fn parent_generics(&self) -> Option<&Generics<'db>> { self.parent_generics.as_deref() } } -pub(crate) fn trait_self_param_idx(db: &dyn DefDatabase, def: GenericDefId) -> Option<usize> { - match def { - GenericDefId::TraitId(_) => { - let params = db.generic_params(def); - params.trait_self_param().map(|idx| idx.into_raw().into_u32() as usize) - } - GenericDefId::ImplId(_) => None, - _ => { - let parent_def = parent_generic_def(db, def)?; - let parent_params = db.generic_params(parent_def); - let parent_self_idx = parent_params.trait_self_param()?.into_raw().into_u32() as usize; - Some(parent_self_idx) - } - } -} - pub(crate) fn parent_generic_def(db: &dyn DefDatabase, def: GenericDefId) -> Option<GenericDefId> { let container = match def { GenericDefId::FunctionId(it) => it.lookup(db).container, @@ -259,7 +242,7 @@ pub(crate) fn parent_generic_def(db: &dyn DefDatabase, def: GenericDefId) -> Opt } fn from_toc_id<'a>( - it: &'a Generics, + it: &'a Generics<'a>, ) -> impl Fn( (LocalTypeOrConstParamId, &'a TypeOrConstParamData), ) -> (GenericParamId, GenericParamDataRef<'a>) { @@ -279,7 +262,7 @@ fn from_toc_id<'a>( } fn from_lt_id<'a>( - it: &'a Generics, + it: &'a Generics<'a>, ) -> impl Fn((LocalLifetimeParamId, &'a LifetimeParamData)) -> (GenericParamId, GenericParamDataRef<'a>) { move |(local_id, p): (_, _)| { diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 35d744e7d1..d14e9d6526 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -33,14 +33,15 @@ use std::{cell::OnceCell, convert::identity, iter}; use base_db::Crate; use either::Either; use hir_def::{ - AdtId, AssocItemId, ConstId, DefWithBodyId, FieldId, FunctionId, GenericDefId, GenericParamId, - ItemContainerId, LocalFieldId, Lookup, TraitId, TupleFieldId, TupleId, TypeAliasId, VariantId, - expr_store::{Body, ExpressionStore, HygieneId, path::Path}, + AdtId, AssocItemId, ConstId, ConstParamId, DefWithBodyId, ExpressionStoreOwnerId, FieldId, + FunctionId, GenericDefId, GenericParamId, ItemContainerId, LocalFieldId, Lookup, TraitId, + TupleFieldId, TupleId, TypeAliasId, TypeOrConstParamId, VariantId, + expr_store::{Body, ExpressionStore, HygieneId, RootExprOrigin, path::Path}, hir::{BindingAnnotation, BindingId, ExprId, ExprOrPatId, LabelId, PatId}, lang_item::LangItems, layout::Integer, resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs}, - signatures::{ConstSignature, EnumSignature, StaticSignature}, + signatures::{ConstSignature, EnumSignature, FunctionSignature, StaticSignature}, type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId}, }; use hir_expand::{mod_path::ModPath, name::Name}; @@ -104,19 +105,18 @@ pub fn infer_query_with_inspect<'db>( ) -> InferenceResult { let _p = tracing::info_span!("infer_query").entered(); let resolver = def.resolver(db); - let body = db.body(def); - let mut ctx = InferenceContext::new(db, def, &body, resolver); + let body = Body::of(db, def); + let mut ctx = + InferenceContext::new(db, ExpressionStoreOwnerId::Body(def), &body.store, resolver); if let Some(inspect) = inspect { ctx.table.infer_ctxt.attach_obligation_inspector(inspect); } match def { - DefWithBodyId::FunctionId(f) => { - ctx.collect_fn(f); - } - DefWithBodyId::ConstId(c) => ctx.collect_const(c, &db.const_signature(c)), - DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_signature(s)), + DefWithBodyId::FunctionId(f) => ctx.collect_fn(f, body.self_param, &body.params), + DefWithBodyId::ConstId(c) => ctx.collect_const(c, ConstSignature::of(db, c)), + DefWithBodyId::StaticId(s) => ctx.collect_static(StaticSignature::of(db, s)), DefWithBodyId::VariantId(v) => { ctx.return_ty = match EnumSignature::variant_body_type(db, v.lookup(db).parent) { hir_def::layout::IntegerType::Pointer(signed) => match signed { @@ -143,10 +143,113 @@ pub fn infer_query_with_inspect<'db>( } } - ctx.infer_body(); + ctx.infer_body(body.root_expr()); + + ctx.infer_mut_body(body.root_expr()); + + infer_finalize(ctx) +} + +fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> InferenceResult { + InferenceResult { + has_errors: true, + ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) + } +} + +/// Infer types for all const expressions in an item's signature. +/// +/// This handles const expressions that appear in type positions within a generic +/// item's signature, such as array lengths (`[T; N]`) and const generic arguments +/// (`Foo<{ expr }>`). Each root expression is inferred independently within +/// a shared `InferenceContext`, accumulating results into a single `InferenceResult`. +fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult { + let _p = tracing::info_span!("infer_signature_query").entered(); + let store = ExpressionStore::of(db, def.into()); + let mut roots = store.expr_roots_with_origins().peekable(); + let Some(_) = roots.peek() else { + return InferenceResult::new(crate::next_solver::default_types(db).types.error); + }; + + let resolver = def.resolver(db); + let owner = ExpressionStoreOwnerId::Signature(def); + + let mut ctx = InferenceContext::new(db, owner, store, resolver); + + for (root_expr, origin) in roots { + let expected = match origin { + // Array lengths are always `usize`. + RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize), + // Const parameter default: look up the param's declared type. + RootExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty_ns( + ConstParamId::from_unchecked(TypeOrConstParamId { parent: def, local_id }), + )), + // Path const generic args: determining the expected type requires + // path resolution. + // FIXME + RootExprOrigin::GenericArgsPath => Expectation::None, + RootExprOrigin::BodyRoot => Expectation::None, + }; + ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes); + } + + infer_finalize(ctx) +} + +fn infer_variant_fields_query(db: &dyn HirDatabase, def: VariantId) -> InferenceResult { + let _p = tracing::info_span!("infer_variant_fields_query").entered(); + let store = ExpressionStore::of(db, def.into()); + let mut roots = store.expr_roots_with_origins().peekable(); + let Some(_) = roots.peek() else { + return InferenceResult::new(crate::next_solver::default_types(db).types.error); + }; + + let resolver = def.resolver(db); + let owner = ExpressionStoreOwnerId::VariantFields(def); + + let mut ctx = InferenceContext::new(db, owner, store, resolver); + + for (root_expr, origin) in roots { + let expected = match origin { + // Array lengths are always `usize`. + RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize), + // unreachable + RootExprOrigin::ConstParam(_) => Expectation::None, + // Path const generic args: determining the expected type requires + // path resolution. + // FIXME + RootExprOrigin::GenericArgsPath => Expectation::None, + RootExprOrigin::BodyRoot => Expectation::None, + }; + ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes); + } + + infer_finalize(ctx) +} + +fn infer_signature_cycle_result( + db: &dyn HirDatabase, + _: salsa::Id, + _: GenericDefId, +) -> InferenceResult { + InferenceResult { + has_errors: true, + ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) + } +} - ctx.infer_mut_body(); +fn infer_variant_fields_cycle_result( + db: &dyn HirDatabase, + _: salsa::Id, + _: VariantId, +) -> InferenceResult { + InferenceResult { + has_errors: true, + ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) + } +} +fn infer_finalize(mut ctx: InferenceContext<'_, '_>) -> InferenceResult { ctx.handle_opaque_type_uses(); ctx.type_inference_fallback(); @@ -171,14 +274,6 @@ pub fn infer_query_with_inspect<'db>( ctx.resolve_all() } - -fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> InferenceResult { - InferenceResult { - has_errors: true, - ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) - } -} - /// Binding modes inferred for patterns. /// <https://doc.rust-lang.org/reference/patterns.html#binding-modes> #[derive(Copy, Clone, Debug, Eq, PartialEq, Default)] @@ -552,12 +647,39 @@ pub struct InferenceResult { #[salsa::tracked] impl InferenceResult { #[salsa::tracked(returns(ref), cycle_result = infer_cycle_result)] - pub fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { + fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { infer_query(db, def) } + + /// Infer types for all const expressions in an item's signature. + /// + /// Returns an `InferenceResult` containing type information for array lengths, + /// const generic arguments, and other const expressions appearing in type + /// positions within the item's signature. + #[salsa::tracked(returns(ref), cycle_result = infer_signature_cycle_result)] + fn for_signature(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult { + infer_signature_query(db, def) + } + + #[salsa::tracked(returns(ref), cycle_result = infer_variant_fields_cycle_result)] + fn for_variant_fields(db: &dyn HirDatabase, def: VariantId) -> InferenceResult { + infer_variant_fields_query(db, def) + } } impl InferenceResult { + pub fn of(db: &dyn HirDatabase, def: impl Into<ExpressionStoreOwnerId>) -> &InferenceResult { + match def.into() { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + Self::for_signature(db, generic_def_id) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => Self::for_body(db, def_with_body_id), + ExpressionStoreOwnerId::VariantFields(variant_id) => { + Self::for_variant_fields(db, variant_id) + } + } + } + fn new(error_ty: Ty<'_>) -> Self { Self { method_resolutions: Default::default(), @@ -754,8 +876,8 @@ impl InferenceResult { #[derive(Clone, Debug)] pub(crate) struct InferenceContext<'body, 'db> { pub(crate) db: &'db dyn HirDatabase, - pub(crate) owner: DefWithBodyId, - pub(crate) body: &'body Body, + pub(crate) owner: ExpressionStoreOwnerId, + pub(crate) store: &'body ExpressionStore, /// Generally you should not resolve things via this resolver. Instead create a TyLoweringContext /// and resolve the path via its methods. This will ensure proper error reporting. pub(crate) resolver: Resolver<'db>, @@ -855,11 +977,21 @@ fn find_continuable<'a, 'db>( impl<'body, 'db> InferenceContext<'body, 'db> { fn new( db: &'db dyn HirDatabase, - owner: DefWithBodyId, - body: &'body Body, + owner: ExpressionStoreOwnerId, + store: &'body ExpressionStore, resolver: Resolver<'db>, ) -> Self { - let trait_env = db.trait_environment_for_body(owner); + let trait_env = match owner { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + db.trait_environment(ExpressionStoreOwnerId::from(generic_def_id)) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => { + db.trait_environment(ExpressionStoreOwnerId::Body(def_with_body_id)) + } + ExpressionStoreOwnerId::VariantFields(variant_id) => { + db.trait_environment(ExpressionStoreOwnerId::VariantFields(variant_id)) + } + }; let table = unify::InferenceTable::new(db, trait_env, resolver.krate(), Some(owner)); let types = crate::next_solver::default_types(db); InferenceContext { @@ -878,13 +1010,8 @@ impl<'body, 'db> InferenceContext<'body, 'db> { return_coercion: None, db, owner, - generic_def: match owner { - DefWithBodyId::FunctionId(it) => it.into(), - DefWithBodyId::StaticId(it) => it.into(), - DefWithBodyId::ConstId(it) => it.into(), - DefWithBodyId::VariantId(it) => it.lookup(db).parent.into(), - }, - body, + generic_def: owner.generic_def(db), + store, traits_in_scope: resolver.traits_in_scope(db), resolver, diverges: Diverges::Maybe, @@ -908,7 +1035,9 @@ impl<'body, 'db> InferenceContext<'body, 'db> { fn target_features(&self) -> (&TargetFeatures<'db>, TargetFeatureIsSafeInTarget) { let (target_features, target_feature_is_safe) = self.target_features.get_or_init(|| { let target_features = match self.owner { - DefWithBodyId::FunctionId(id) => TargetFeatures::from_fn(self.db, id), + ExpressionStoreOwnerId::Body(DefWithBodyId::FunctionId(id)) => { + TargetFeatures::from_fn(self.db, id) + } _ => TargetFeatures::default(), }; let target_feature_is_safe = match &self.krate().workspace_data(self.db).target { @@ -1102,12 +1231,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> { self.return_ty = return_ty; } - fn collect_fn(&mut self, func: FunctionId) { - let data = self.db.function_signature(func); + fn collect_fn(&mut self, func: FunctionId, self_param: Option<BindingId>, params: &[PatId]) { + let data = FunctionSignature::of(self.db, func); let mut param_tys = self.with_ty_lowering( &data.store, InferenceTyDiagnosticSource::Signature, - LifetimeElisionKind::for_fn_params(&data), + LifetimeElisionKind::for_fn_params(data), |ctx| data.params.iter().map(|&type_ref| ctx.lower_ty(type_ref)).collect::<Vec<_>>(), ); @@ -1130,13 +1259,13 @@ impl<'body, 'db> InferenceContext<'body, 'db> { param_tys.push(va_list_ty); } let mut param_tys = param_tys.into_iter().chain(iter::repeat(self.table.next_ty_var())); - if let Some(self_param) = self.body.self_param + if let Some(self_param) = self_param && let Some(ty) = param_tys.next() { let ty = self.process_user_written_ty(ty); self.write_binding_ty(self_param, ty); } - for (ty, pat) in param_tys.zip(&*self.body.params) { + for (ty, pat) in param_tys.zip(params) { let ty = self.process_user_written_ty(ty); self.infer_top_pat(*pat, ty, None); @@ -1170,12 +1299,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> { &self.table.infer_ctxt } - fn infer_body(&mut self) { + fn infer_body(&mut self, body_expr: ExprId) { match self.return_coercion { - Some(_) => self.infer_return(self.body.body_expr), + Some(_) => self.infer_return(body_expr), None => { _ = self.infer_expr_coerce( - self.body.body_expr, + body_expr, &Expectation::has_type(self.return_ty), ExprIsRead::Yes, ) @@ -1282,7 +1411,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { f: impl FnOnce(&mut TyLoweringContext<'db, '_>) -> R, ) -> R { self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, f, @@ -1324,7 +1453,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> { self.make_ty( type_ref, - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, ) @@ -1332,7 +1461,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_body_const(&mut self, const_ref: ConstRef, ty: Ty<'db>) -> Const<'db> { let const_ = self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, |ctx| ctx.lower_const(const_ref, ty), @@ -1342,7 +1471,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_path_as_body_const(&mut self, path: &Path, ty: Ty<'db>) -> Const<'db> { let const_ = self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, |ctx| ctx.lower_path_as_const(path, ty), @@ -1356,7 +1485,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_body_lifetime(&mut self, lifetime_ref: LifetimeRefId) -> Region<'db> { let lt = self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, |ctx| ctx.lower_lifetime(lifetime_ref), @@ -1571,7 +1700,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { let mut ctx = TyLoweringContext::new( self.db, &self.resolver, - &self.body.store, + self.store, &self.diagnostics, InferenceTyDiagnosticSource::Body, self.generic_def, @@ -1584,7 +1713,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { return (self.err_ty(), None); }; match res { - ResolveValueResult::ValueNs(value, _) => match value { + ResolveValueResult::ValueNs(value) => match value { ValueNs::EnumVariantId(var) => { let args = path_ctx.substs_from_path(var.into(), true, false); drop(ctx); @@ -1608,7 +1737,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { return (self.err_ty(), None); } }, - ResolveValueResult::Partial(typens, unresolved, _) => (typens, Some(unresolved)), + ResolveValueResult::Partial(typens, unresolved) => (typens, Some(unresolved)), } } else { match path_ctx.resolve_path_in_type_ns() { diff --git a/crates/hir-ty/src/infer/cast.rs b/crates/hir-ty/src/infer/cast.rs index d69b00adb7..e5ee734474 100644 --- a/crates/hir-ty/src/infer/cast.rs +++ b/crates/hir-ty/src/infer/cast.rs @@ -1,6 +1,10 @@ //! Type cast logic. Basically coercion + additional casts. -use hir_def::{AdtId, hir::ExprId, signatures::TraitFlags}; +use hir_def::{ + AdtId, + hir::ExprId, + signatures::{TraitFlags, TraitSignature}, +}; use rustc_ast_ir::Mutability; use rustc_hash::FxHashSet; use rustc_type_ir::{ @@ -328,11 +332,7 @@ impl<'db> CastCheck<'db> { // // Note that trait upcasting goes through a different mechanism (`coerce_unsized`) // and is unaffected by this check. - (Some(src_principal), Some(dst_principal)) => { - if src_principal == dst_principal { - return Ok(()); - } - + (Some(src_principal), Some(_)) => { // We need to reconstruct trait object types. // `m_src` and `m_dst` won't work for us here because they will potentially // contain wrappers, which we do not care about. @@ -387,8 +387,7 @@ impl<'db> CastCheck<'db> { .chain( elaborate::supertrait_def_ids(ctx.interner(), src_principal) .filter(|trait_| { - ctx.db - .trait_signature(trait_.0) + TraitSignature::of(ctx.db, trait_.0) .flags .contains(TraitFlags::AUTO) }), diff --git a/crates/hir-ty/src/infer/closure/analysis.rs b/crates/hir-ty/src/infer/closure/analysis.rs index 5a3eba1a71..ce0ccfe82f 100644 --- a/crates/hir-ty/src/infer/closure/analysis.rs +++ b/crates/hir-ty/src/infer/closure/analysis.rs @@ -4,14 +4,15 @@ use std::{cmp, mem}; use base_db::Crate; use hir_def::{ - DefWithBodyId, FieldId, HasModule, VariantId, - expr_store::path::Path, + ExpressionStoreOwnerId, FieldId, HasModule, VariantId, + expr_store::{Body, ExpressionStore, path::Path}, hir::{ Array, AsmOperand, BinaryOp, BindingId, CaptureBy, Expr, ExprId, ExprOrPatId, Pat, PatId, RecordSpread, Statement, UnaryOp, }, item_tree::FieldsShape, resolver::ValueNs, + signatures::VariantFields, }; use rustc_ast_ir::Mutability; use rustc_hash::{FxHashMap, FxHashSet}; @@ -179,9 +180,26 @@ impl CapturedItem { } /// Converts the place to a name that can be inserted into source code. - pub fn place_to_name(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { - let body = db.body(owner); - let mut result = body[self.place.local].name.as_str().to_owned(); + pub fn place_to_name(&self, owner: ExpressionStoreOwnerId, db: &dyn HirDatabase) -> String { + let krate = owner.krate(db); + let edition = krate.data(db).edition; + let mut result = match owner { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + ExpressionStore::of(db, generic_def_id.into())[self.place.local] + .name + .display(db, edition) + .to_string() + } + ExpressionStoreOwnerId::Body(def_with_body_id) => Body::of(db, def_with_body_id) + [self.place.local] + .name + .display(db, edition) + .to_string(), + ExpressionStoreOwnerId::VariantFields(variant_id) => { + let fields = VariantFields::of(db, variant_id); + fields.store[self.place.local].name.display(db, edition).to_string() + } + }; for proj in &self.place.projections { match proj { HirPlaceProjection::Deref => {} @@ -213,11 +231,30 @@ impl CapturedItem { result } - pub fn display_place_source_code(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { - let body = db.body(owner); + pub fn display_place_source_code( + &self, + owner: ExpressionStoreOwnerId, + db: &dyn HirDatabase, + ) -> String { let krate = owner.krate(db); let edition = krate.data(db).edition; - let mut result = body[self.place.local].name.display(db, edition).to_string(); + let mut result = match owner { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + ExpressionStore::of(db, generic_def_id.into())[self.place.local] + .name + .display(db, edition) + .to_string() + } + ExpressionStoreOwnerId::Body(def_with_body_id) => Body::of(db, def_with_body_id) + [self.place.local] + .name + .display(db, edition) + .to_string(), + ExpressionStoreOwnerId::VariantFields(variant_id) => { + let fields = VariantFields::of(db, variant_id); + fields.store[self.place.local].name.display(db, edition).to_string() + } + }; for proj in &self.place.projections { match proj { // In source code autoderef kicks in. @@ -258,11 +295,26 @@ impl CapturedItem { result } - pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { - let body = db.body(owner); + pub fn display_place(&self, owner: ExpressionStoreOwnerId, db: &dyn HirDatabase) -> String { let krate = owner.krate(db); let edition = krate.data(db).edition; - let mut result = body[self.place.local].name.display(db, edition).to_string(); + let mut result = match owner { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + ExpressionStore::of(db, generic_def_id.into())[self.place.local] + .name + .display(db, edition) + .to_string() + } + ExpressionStoreOwnerId::Body(def_with_body_id) => Body::of(db, def_with_body_id) + [self.place.local] + .name + .display(db, edition) + .to_string(), + ExpressionStoreOwnerId::VariantFields(variant_id) => { + let fields = VariantFields::of(db, variant_id); + fields.store[self.place.local].name.display(db, edition).to_string() + } + }; let mut field_need_paren = false; for proj in &self.place.projections { match proj { @@ -346,7 +398,7 @@ impl<'db> InferenceContext<'_, 'db> { if path.type_anchor().is_some() { return None; } - let hygiene = self.body.expr_or_pat_path_hygiene(id); + let hygiene = self.store.expr_or_pat_path_hygiene(id); self.resolver.resolve_path_in_value_ns_fully(self.db, path, hygiene).and_then(|result| { match result { ValueNs::LocalBinding(binding) => { @@ -365,7 +417,7 @@ impl<'db> InferenceContext<'_, 'db> { /// Changes `current_capture_span_stack` to contain the stack of spans for this expr. fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option<HirPlace> { self.current_capture_span_stack.clear(); - match &self.body[tgt_expr] { + match &self.store[tgt_expr] { Expr::Path(p) => { let resolver_guard = self.resolver.update_to_inner_scope(self.db, self.owner, tgt_expr); @@ -416,7 +468,7 @@ impl<'db> InferenceContext<'_, 'db> { let mut actual_truncate_to = 0; for &span in &*span_stack { actual_truncate_to += 1; - if !span.is_ref_span(self.body) { + if !span.is_ref_span(self.store) { remained -= 1; if remained == 0 { break; @@ -424,7 +476,7 @@ impl<'db> InferenceContext<'_, 'db> { } } if actual_truncate_to < span_stack.len() - && span_stack[actual_truncate_to].is_ref_span(self.body) + && span_stack[actual_truncate_to].is_ref_span(self.store) { // Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect. actual_truncate_to += 1; @@ -533,7 +585,7 @@ impl<'db> InferenceContext<'_, 'db> { } fn walk_expr_without_adjust(&mut self, tgt_expr: ExprId) { - match &self.body[tgt_expr] { + match &self.store[tgt_expr] { Expr::OffsetOf(_) => (), Expr::InlineAsm(e) => e.operands.iter().for_each(|(_, op)| match op { AsmOperand::In { expr, .. } @@ -733,7 +785,7 @@ impl<'db> InferenceContext<'_, 'db> { self.consume_with_pat(rhs_place, target); self.inside_assignment = false; } - None => self.body.walk_pats(target, &mut |pat| match &self.body[pat] { + None => self.store.walk_pats(target, &mut |pat| match &self.store[pat] { Pat::Path(path) => self.mutate_path_pat(path, pat), &Pat::Expr(expr) => { let place = self.place_of_expr(expr); @@ -775,7 +827,7 @@ impl<'db> InferenceContext<'_, 'db> { update_result: &mut impl FnMut(CaptureKind), mut for_mut: BorrowKind, ) { - match &self.body[p] { + match &self.store[p] { Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing @@ -819,13 +871,13 @@ impl<'db> InferenceContext<'_, 'db> { if self.result.pat_adjustments.get(&p).is_some_and(|it| !it.is_empty()) { for_mut = BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture }; } - self.body.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut)); + self.store.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut)); } fn is_upvar(&self, place: &HirPlace) -> bool { if let Some(c) = self.current_closure { let InternedClosure(_, root) = self.db.lookup_intern_closure(c); - return self.body.is_binding_upvar(place.local, root); + return self.store.is_binding_upvar(place.local, root); } false } @@ -858,7 +910,7 @@ impl<'db> InferenceContext<'_, 'db> { if ty.is_raw_ptr() || ty.is_union() { capture.kind = CaptureKind::ByRef(BorrowKind::Shared); self.truncate_capture_spans(capture, 0); - capture.place.projections.truncate(0); + capture.place.projections.clear(); continue; } for (i, p) in capture.place.projections.iter().enumerate() { @@ -866,7 +918,7 @@ impl<'db> InferenceContext<'_, 'db> { &self.table.infer_ctxt, self.table.param_env, ty, - self.owner.module(self.db).krate(self.db), + self.owner.krate(self.db), ); if ty.is_raw_ptr() || ty.is_union() { capture.kind = CaptureKind::ByRef(BorrowKind::Shared); @@ -938,7 +990,7 @@ impl<'db> InferenceContext<'_, 'db> { self.current_capture_span_stack .extend((0..adjustments_count).map(|_| MirSpan::PatId(tgt_pat))); 'reset_span_stack: { - match &self.body[tgt_pat] { + match &self.store[tgt_pat] { Pat::Missing | Pat::Wild => (), Pat::Tuple { args, ellipsis } => { let (al, ar) = args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); @@ -1089,7 +1141,7 @@ impl<'db> InferenceContext<'_, 'db> { fn analyze_closure(&mut self, closure: InternedClosureId) -> FnTrait { let InternedClosure(_, root) = self.db.lookup_intern_closure(closure); self.current_closure = Some(closure); - let Expr::Closure { body, capture_by, .. } = &self.body[root] else { + let Expr::Closure { body, capture_by, .. } = &self.store[root] else { unreachable!("Closure expression id is always closure"); }; self.consume_expr(*body); @@ -1133,7 +1185,7 @@ impl<'db> InferenceContext<'_, 'db> { for capture in &mut captures { if matches!(capture.kind, CaptureKind::ByValue) { for span_stack in &mut capture.span_stacks { - if span_stack[span_stack.len() - 1].is_ref_span(self.body) { + if span_stack[span_stack.len() - 1].is_ref_span(self.store) { span_stack.truncate(span_stack.len() - 1); } } @@ -1149,7 +1201,7 @@ impl<'db> InferenceContext<'_, 'db> { let kind = self.analyze_closure(closure); for (derefed_callee, callee_ty, params, expr) in exprs { - if let &Expr::Call { callee, .. } = &self.body[expr] { + if let &Expr::Call { callee, .. } = &self.store[expr] { let mut adjustments = self.result.expr_adjustments.remove(&callee).unwrap_or_default().into_vec(); self.write_fn_trait_method_resolution( diff --git a/crates/hir-ty/src/infer/coerce.rs b/crates/hir-ty/src/infer/coerce.rs index e79868f4ae..47a7049248 100644 --- a/crates/hir-ty/src/infer/coerce.rs +++ b/crates/hir-ty/src/infer/coerce.rs @@ -1718,6 +1718,9 @@ fn coerce<'db>( fn is_capturing_closure(db: &dyn HirDatabase, closure: InternedClosureId) -> bool { let InternedClosure(owner, expr) = closure.loc(db); - upvars_mentioned(db, owner) + let Some(body_owner) = owner.as_def_with_body() else { + return false; + }; + upvars_mentioned(db, body_owner) .is_some_and(|upvars| upvars.get(&expr).is_some_and(|upvars| !upvars.is_empty())) } diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 9f2d9d25b9..dc57b1d1c2 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -11,6 +11,7 @@ use hir_def::{ InlineAsmKind, LabelId, Literal, Pat, PatId, RecordSpread, Statement, UnaryOp, }, resolver::ValueNs, + signatures::{FunctionSignature, VariantFields}, }; use hir_def::{FunctionId, hir::ClosureKind}; use hir_expand::name::Name; @@ -155,7 +156,7 @@ impl<'db> InferenceContext<'_, 'db> { /// it is matching against. This is used to determine whether we should /// perform `NeverToAny` coercions. fn pat_guaranteed_to_constitute_read_for_never(&self, pat: PatId) -> bool { - match &self.body[pat] { + match &self.store[pat] { // Does not constitute a read. Pat::Wild => false, @@ -197,25 +198,25 @@ impl<'db> InferenceContext<'_, 'db> { // FIXME(tschottdorf): this is problematic as the HIR is being scraped, but // ref bindings are be implicit after #42640 (default match binding modes). See issue #44848. fn contains_explicit_ref_binding(&self, pat: PatId) -> bool { - if let Pat::Bind { id, .. } = self.body[pat] - && matches!(self.body[id].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) + if let Pat::Bind { id, .. } = self.store[pat] + && matches!(self.store[id].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) { return true; } let mut result = false; - self.body.walk_pats_shallow(pat, |pat| result |= self.contains_explicit_ref_binding(pat)); + self.store.walk_pats_shallow(pat, |pat| result |= self.contains_explicit_ref_binding(pat)); result } fn is_syntactic_place_expr(&self, expr: ExprId) -> bool { - match &self.body[expr] { + match &self.store[expr] { // Lang item paths cannot currently be local variables or statics. Expr::Path(Path::LangItem(_, _)) => false, Expr::Path(Path::Normal(path)) => path.type_anchor.is_none(), Expr::Path(path) => self .resolver - .resolve_path_in_value_ns_fully(self.db, path, self.body.expr_path_hygiene(expr)) + .resolve_path_in_value_ns_fully(self.db, path, self.store.expr_path_hygiene(expr)) .is_none_or(|res| matches!(res, ValueNs::LocalBinding(_) | ValueNs::StaticId(_))), Expr::Underscore => true, Expr::UnaryOp { op: UnaryOp::Deref, .. } => true, @@ -311,7 +312,7 @@ impl<'db> InferenceContext<'_, 'db> { ) -> Ty<'db> { self.db.unwind_if_revision_cancelled(); - let expr = &self.body[tgt_expr]; + let expr = &self.store[tgt_expr]; tracing::trace!(?expr); let ty = match expr { Expr::Missing => self.err_ty(), @@ -608,7 +609,7 @@ impl<'db> InferenceContext<'_, 'db> { Some(def) => { let field_types = self.db.field_types(def); let variant_data = def.fields(self.db); - let visibilities = self.db.field_visibilities(def); + let visibilities = VariantFields::field_visibilities(self.db, def); for field in fields.iter() { let field_def = { match variant_data.field(&field.name) { @@ -658,7 +659,7 @@ impl<'db> InferenceContext<'_, 'db> { } } if let RecordSpread::Expr(expr) = *spread { - self.infer_expr(expr, &Expectation::has_type(ty), ExprIsRead::Yes); + self.infer_expr_coerce_never(expr, &Expectation::has_type(ty), ExprIsRead::Yes); } ty } @@ -717,7 +718,7 @@ impl<'db> InferenceContext<'_, 'db> { // instantiations in RHS can be coerced to it. Note that this // cannot happen in destructuring assignments because of how // they are desugared. - let lhs_ty = match &self.body[target] { + let lhs_ty = match &self.store[target] { // LHS of assignment doesn't constitute reads. &Pat::Expr(expr) => { Some(self.infer_expr(expr, &Expectation::none(), ExprIsRead::No)) @@ -728,7 +729,7 @@ impl<'db> InferenceContext<'_, 'db> { let resolution = self.resolver.resolve_path_in_value_ns_fully( self.db, path, - self.body.pat_path_hygiene(target), + self.store.pat_path_hygiene(target), ); self.resolver.reset_to_guard(resolver_guard); @@ -751,7 +752,7 @@ impl<'db> InferenceContext<'_, 'db> { if let Some(lhs_ty) = lhs_ty { self.write_pat_ty(target, lhs_ty); - self.infer_expr_coerce(value, &Expectation::has_type(lhs_ty), ExprIsRead::No); + self.infer_expr_coerce(value, &Expectation::has_type(lhs_ty), ExprIsRead::Yes); } else { let rhs_ty = self.infer_expr(value, &Expectation::none(), ExprIsRead::Yes); let resolver_guard = @@ -1351,7 +1352,7 @@ impl<'db> InferenceContext<'_, 'db> { ExprIsRead::Yes, ); let usize = self.types.types.usize; - let len = match self.body[repeat] { + let len = match self.store[repeat] { Expr::Underscore => { self.write_expr_ty(repeat, usize); self.table.next_const_var() @@ -1491,7 +1492,7 @@ impl<'db> InferenceContext<'_, 'db> { } else { ExprIsRead::No }; - let ty = if contains_explicit_ref_binding(this.body, *pat) { + let ty = if contains_explicit_ref_binding(this.store, *pat) { this.infer_expr( *expr, &Expectation::has_type(decl_ty), @@ -1624,7 +1625,8 @@ impl<'db> InferenceContext<'_, 'db> { }, _ => return None, }; - let is_visible = self.db.field_visibilities(field_id.parent)[field_id.local_id] + let is_visible = VariantFields::field_visibilities(self.db, field_id.parent) + [field_id.local_id] .is_visible_from(self.db, self.resolver.module()); if !is_visible { if private_field.is_none() { @@ -2117,7 +2119,7 @@ impl<'db> InferenceContext<'_, 'db> { // the return value of an argument-position async block to an argument-position // closure wrapped in a block. // See <https://github.com/rust-lang/rust/issues/112225>. - let is_closure = if let Expr::Closure { closure_kind, .. } = self.body[*arg] { + let is_closure = if let Expr::Closure { closure_kind, .. } = self.store[*arg] { !matches!(closure_kind, ClosureKind::Coroutine(_)) } else { false @@ -2194,7 +2196,7 @@ impl<'db> InferenceContext<'_, 'db> { _ => return Default::default(), }; - let data = self.db.function_signature(func); + let data = FunctionSignature::of(self.db, func); let Some(legacy_const_generics_indices) = data.legacy_const_generics_indices(self.db, func) else { return Default::default(); diff --git a/crates/hir-ty/src/infer/mutability.rs b/crates/hir-ty/src/infer/mutability.rs index 45fa141b6d..bfe43fc928 100644 --- a/crates/hir-ty/src/infer/mutability.rs +++ b/crates/hir-ty/src/infer/mutability.rs @@ -14,8 +14,8 @@ use crate::{ }; impl<'db> InferenceContext<'_, 'db> { - pub(crate) fn infer_mut_body(&mut self) { - self.infer_mut_expr(self.body.body_expr, Mutability::Not); + pub(crate) fn infer_mut_body(&mut self, body_expr: ExprId) { + self.infer_mut_expr(body_expr, Mutability::Not); } fn infer_mut_expr(&mut self, tgt_expr: ExprId, mut mutability: Mutability) { @@ -52,7 +52,7 @@ impl<'db> InferenceContext<'_, 'db> { } fn infer_mut_expr_without_adjust(&mut self, tgt_expr: ExprId, mutability: Mutability) { - match &self.body[tgt_expr] { + match &self.store[tgt_expr] { Expr::Missing => (), Expr::InlineAsm(e) => { e.operands.iter().for_each(|(_, op)| match op { @@ -173,7 +173,7 @@ impl<'db> InferenceContext<'_, 'db> { self.infer_mut_expr(*rhs, Mutability::Not); } &Expr::Assignment { target, value } => { - self.body.walk_pats(target, &mut |pat| match self.body[pat] { + self.store.walk_pats(target, &mut |pat| match self.store[pat] { Pat::Expr(expr) => self.infer_mut_expr(expr, Mutability::Mut), Pat::ConstBlock(block) => self.infer_mut_expr(block, Mutability::Not), _ => {} @@ -220,8 +220,8 @@ impl<'db> InferenceContext<'_, 'db> { /// `let (ref x0, ref x1) = *it;` we should use `Deref`. fn pat_bound_mutability(&self, pat: PatId) -> Mutability { let mut r = Mutability::Not; - self.body.walk_bindings_in_pat(pat, |b| { - if self.body[b].mode == BindingAnnotation::RefMut { + self.store.walk_bindings_in_pat(pat, |b| { + if self.store[b].mode == BindingAnnotation::RefMut { r = Mutability::Mut; } }); diff --git a/crates/hir-ty/src/infer/op.rs b/crates/hir-ty/src/infer/op.rs index c79c828cd4..95d63ffb50 100644 --- a/crates/hir-ty/src/infer/op.rs +++ b/crates/hir-ty/src/infer/op.rs @@ -178,9 +178,9 @@ impl<'a, 'db> InferenceContext<'a, 'db> { // trait matching creating lifetime constraints that are too strict. // e.g., adding `&'a T` and `&'b T`, given `&'x T: Add<&'x T>`, will result // in `&'a T <: &'x T` and `&'b T <: &'x T`, instead of `'a = 'b = 'x`. - let lhs_ty = self.infer_expr_no_expect(lhs_expr, ExprIsRead::No); + let lhs_ty = self.infer_expr_no_expect(lhs_expr, ExprIsRead::Yes); let fresh_var = self.table.next_ty_var(); - self.demand_coerce(lhs_expr, lhs_ty, fresh_var, AllowTwoPhase::No, ExprIsRead::No) + self.demand_coerce(lhs_expr, lhs_ty, fresh_var, AllowTwoPhase::No, ExprIsRead::Yes) } }; let lhs_ty = self.table.resolve_vars_with_obligations(lhs_ty); @@ -200,7 +200,7 @@ impl<'a, 'db> InferenceContext<'a, 'db> { // see `NB` above let rhs_ty = - self.infer_expr_coerce(rhs_expr, &Expectation::HasType(rhs_ty_var), ExprIsRead::No); + self.infer_expr_coerce(rhs_expr, &Expectation::HasType(rhs_ty_var), ExprIsRead::Yes); let rhs_ty = self.table.resolve_vars_with_obligations(rhs_ty); let return_ty = match result { @@ -320,7 +320,11 @@ impl<'a, 'db> InferenceContext<'a, 'db> { if let Some((rhs_expr, rhs_ty)) = opt_rhs && rhs_ty.is_ty_var() { - self.infer_expr_coerce(rhs_expr, &Expectation::HasType(rhs_ty), ExprIsRead::No); + self.infer_expr_coerce( + rhs_expr, + &Expectation::HasType(rhs_ty), + ExprIsRead::Yes, + ); } // Construct an obligation `self_ty : Trait<input_tys>` diff --git a/crates/hir-ty/src/infer/pat.rs b/crates/hir-ty/src/infer/pat.rs index 1b8ce5ceaf..8033680dcc 100644 --- a/crates/hir-ty/src/infer/pat.rs +++ b/crates/hir-ty/src/infer/pat.rs @@ -3,9 +3,10 @@ use std::{cmp, iter}; use hir_def::{ - HasModule, - expr_store::{Body, path::Path}, + HasModule as _, + expr_store::{ExpressionStore, path::Path}, hir::{Binding, BindingAnnotation, BindingId, Expr, ExprId, Literal, Pat, PatId}, + signatures::VariantFields, }; use hir_expand::name::Name; use rustc_ast_ir::Mutability; @@ -60,7 +61,7 @@ impl<'db> InferenceContext<'_, 'db> { Some(def) => { let field_types = self.db.field_types(def); let variant_data = def.fields(self.db); - let visibilities = self.db.field_visibilities(def); + let visibilities = VariantFields::field_visibilities(self.db, def); let (pre, post) = match ellipsis { Some(idx) => subs.split_at(idx as usize), @@ -129,7 +130,7 @@ impl<'db> InferenceContext<'_, 'db> { Some(def) => { let field_types = self.db.field_types(def); let variant_data = def.fields(self.db); - let visibilities = self.db.field_visibilities(def); + let visibilities = VariantFields::field_visibilities(self.db, def); let substs = ty.as_adt().map(TupleExt::tail); @@ -260,14 +261,14 @@ impl<'db> InferenceContext<'_, 'db> { ) -> Ty<'db> { let mut expected = self.table.structurally_resolve_type(expected); - if matches!(&self.body[pat], Pat::Ref { .. }) || self.inside_assignment { + if matches!(&self.store[pat], Pat::Ref { .. }) || self.inside_assignment { cov_mark::hit!(match_ergonomics_ref); // When you encounter a `&pat` pattern, reset to Move. // This is so that `w` is by value: `let (_, &w) = &(1, &2);` // Destructuring assignments also reset the binding mode and // don't do match ergonomics. default_bm = BindingMode::Move; - } else if self.is_non_ref_pat(self.body, pat) { + } else if self.is_non_ref_pat(self.store, pat) { let mut pat_adjustments = Vec::new(); while let TyKind::Ref(_lifetime, inner, mutability) = expected.kind() { pat_adjustments.push(expected.store()); @@ -289,7 +290,7 @@ impl<'db> InferenceContext<'_, 'db> { let default_bm = default_bm; let expected = expected; - let ty = match &self.body[pat] { + let ty = match &self.store[pat] { Pat::Tuple { args, ellipsis } => { self.infer_tuple_pat_like(pat, expected, default_bm, *ellipsis, args, decl) } @@ -485,7 +486,7 @@ impl<'db> InferenceContext<'_, 'db> { expected: Ty<'db>, decl: Option<DeclContext>, ) -> Ty<'db> { - let Binding { mode, .. } = self.body[binding]; + let Binding { mode, .. } = self.store[binding]; let mode = if mode == BindingAnnotation::Unannotated { default_bm } else { @@ -569,7 +570,7 @@ impl<'db> InferenceContext<'_, 'db> { fn infer_lit_pat(&mut self, expr: ExprId, expected: Ty<'db>) -> Ty<'db> { // Like slice patterns, byte string patterns can denote both `&[u8; N]` and `&[u8]`. - if let Expr::Literal(Literal::ByteString(_)) = self.body[expr] + if let Expr::Literal(Literal::ByteString(_)) = self.store[expr] && let TyKind::Ref(_, inner, _) = expected.kind() { let inner = self.table.try_structurally_resolve_type(inner); @@ -590,14 +591,14 @@ impl<'db> InferenceContext<'_, 'db> { self.infer_expr(expr, &Expectation::has_type(expected), ExprIsRead::Yes) } - fn is_non_ref_pat(&mut self, body: &hir_def::expr_store::Body, pat: PatId) -> bool { - match &body[pat] { + fn is_non_ref_pat(&mut self, store: &hir_def::expr_store::ExpressionStore, pat: PatId) -> bool { + match &store[pat] { Pat::Tuple { .. } | Pat::TupleStruct { .. } | Pat::Record { .. } | Pat::Range { .. } | Pat::Slice { .. } => true, - Pat::Or(pats) => pats.iter().all(|p| self.is_non_ref_pat(body, *p)), + Pat::Or(pats) => pats.iter().all(|p| self.is_non_ref_pat(store, *p)), Pat::Path(path) => { // A const is a reference pattern, but other value ns things aren't (see #16131). let resolved = self.resolve_value_path_inner(path, pat.into(), true); @@ -605,7 +606,7 @@ impl<'db> InferenceContext<'_, 'db> { } Pat::ConstBlock(..) => false, Pat::Lit(expr) => !matches!( - body[*expr], + store[*expr], Expr::Literal(Literal::String(..) | Literal::CString(..) | Literal::ByteString(..)) ), Pat::Wild @@ -670,10 +671,10 @@ impl<'db> InferenceContext<'_, 'db> { } } -pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool { +pub(super) fn contains_explicit_ref_binding(store: &ExpressionStore, pat_id: PatId) -> bool { let mut res = false; - body.walk_pats(pat_id, &mut |pat| { - res |= matches!(body[pat], Pat::Bind { id, .. } if body[id].mode == BindingAnnotation::Ref); + store.walk_pats(pat_id, &mut |pat| { + res |= matches!(store[pat], Pat::Bind { id, .. } if matches!(store[id].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut)); }); res } diff --git a/crates/hir-ty/src/infer/path.rs b/crates/hir-ty/src/infer/path.rs index ef1a610a32..71d68ccd47 100644 --- a/crates/hir-ty/src/infer/path.rs +++ b/crates/hir-ty/src/infer/path.rs @@ -4,6 +4,7 @@ use hir_def::{ AdtId, AssocItemId, GenericDefId, ItemContainerId, Lookup, expr_store::path::{Path, PathSegment}, resolver::{ResolveValueResult, TypeNs, ValueNs}, + signatures::{ConstSignature, FunctionSignature}, }; use hir_expand::name::Name; use rustc_type_ir::inherent::{SliceLike, Ty as _}; @@ -136,7 +137,7 @@ impl<'db> InferenceContext<'_, 'db> { let mut ctx = TyLoweringContext::new( self.db, &self.resolver, - self.body, + self.store, &self.diagnostics, InferenceTyDiagnosticSource::Body, self.generic_def, @@ -159,16 +160,16 @@ impl<'db> InferenceContext<'_, 'db> { let ty = self.table.process_user_written_ty(ty); self.resolve_ty_assoc_item(ty, last.name, id).map(|(it, substs)| (it, Some(substs)))? } else { - let hygiene = self.body.expr_or_pat_path_hygiene(id); + let hygiene = self.store.expr_or_pat_path_hygiene(id); // FIXME: report error, unresolved first path segment let value_or_partial = path_ctx.resolve_path_in_value_ns(hygiene)?; match value_or_partial { - ResolveValueResult::ValueNs(it, _) => { + ResolveValueResult::ValueNs(it) => { drop_ctx(ctx, no_diagnostics); (it, None) } - ResolveValueResult::Partial(def, remaining_index, _) => { + ResolveValueResult::Partial(def, remaining_index) => { // there may be more intermediate segments between the resolved one and // the end. Only the last segment needs to be resolved to a value; from // the segments before that, we need to get either a type or a trait ref. @@ -263,7 +264,7 @@ impl<'db> InferenceContext<'_, 'db> { trait_.trait_items(self.db).items.iter().map(|(_name, id)| *id).find_map(|item| { match item { AssocItemId::FunctionId(func) => { - if segment.name == &self.db.function_signature(func).name { + if segment.name == &FunctionSignature::of(self.db, func).name { Some(CandidateId::FunctionId(func)) } else { None @@ -271,7 +272,7 @@ impl<'db> InferenceContext<'_, 'db> { } AssocItemId::ConstId(konst) => { - if self.db.const_signature(konst).name.as_ref() == Some(segment.name) { + if ConstSignature::of(self.db, konst).name.as_ref() == Some(segment.name) { Some(CandidateId::ConstId(konst)) } else { None diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index 2057159c46..d093412b42 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -3,7 +3,7 @@ use std::fmt; use base_db::Crate; -use hir_def::{AdtId, DefWithBodyId, GenericParamId}; +use hir_def::{AdtId, ExpressionStoreOwnerId, GenericParamId}; use hir_expand::name::Name; use intern::sym; use rustc_hash::FxHashSet; @@ -147,7 +147,7 @@ impl<'db> InferenceTable<'db> { db: &'db dyn HirDatabase, trait_env: ParamEnv<'db>, krate: Crate, - owner: Option<DefWithBodyId>, + owner: Option<ExpressionStoreOwnerId>, ) -> Self { let interner = DbInterner::new_with(db, krate); let typing_mode = match owner { diff --git a/crates/hir-ty/src/inhabitedness.rs b/crates/hir-ty/src/inhabitedness.rs index 402e9ce969..74d66123ea 100644 --- a/crates/hir-ty/src/inhabitedness.rs +++ b/crates/hir-ty/src/inhabitedness.rs @@ -1,7 +1,9 @@ //! Type inhabitedness logic. use std::ops::ControlFlow::{self, Break, Continue}; -use hir_def::{AdtId, EnumVariantId, ModuleId, VariantId, visibility::Visibility}; +use hir_def::{ + AdtId, EnumVariantId, ModuleId, VariantId, signatures::VariantFields, visibility::Visibility, +}; use rustc_hash::FxHashSet; use rustc_type_ir::{ TypeSuperVisitable, TypeVisitable, TypeVisitor, @@ -151,7 +153,11 @@ impl<'a, 'db> UninhabitedFrom<'a, 'db> { let is_enum = matches!(variant, VariantId::EnumVariantId(..)); let field_tys = self.db().field_types(variant); - let field_vis = if is_enum { None } else { Some(self.db().field_visibilities(variant)) }; + let field_vis = if is_enum { + None + } else { + Some(VariantFields::field_visibilities(self.db(), variant)) + }; for (fid, _) in fields.iter() { self.visit_field(field_vis.as_ref().map(|it| it[fid]), &field_tys[fid].get(), subst)?; diff --git a/crates/hir-ty/src/lang_items.rs b/crates/hir-ty/src/lang_items.rs index 18feb0f46a..ae53276f56 100644 --- a/crates/hir-ty/src/lang_items.rs +++ b/crates/hir-ty/src/lang_items.rs @@ -1,13 +1,17 @@ //! Functions to detect special lang items -use hir_def::{AdtId, TraitId, lang_item::LangItems, signatures::StructFlags}; +use hir_def::{ + AdtId, TraitId, + lang_item::LangItems, + signatures::{StructFlags, StructSignature}, +}; use intern::{Symbol, sym}; use crate::db::HirDatabase; pub fn is_box(db: &dyn HirDatabase, adt: AdtId) -> bool { let AdtId::StructId(id) = adt else { return false }; - db.struct_signature(id).flags.contains(StructFlags::IS_BOX) + StructSignature::of(db, id).flags.contains(StructFlags::IS_BOX) } pub fn lang_items_for_bin_op( diff --git a/crates/hir-ty/src/layout.rs b/crates/hir-ty/src/layout.rs index 525100439f..54332122d0 100644 --- a/crates/hir-ty/src/layout.rs +++ b/crates/hir-ty/src/layout.rs @@ -333,7 +333,7 @@ pub fn layout_of_ty_query( } TyKind::Closure(id, args) => { let def = db.lookup_intern_closure(id.0); - let infer = InferenceResult::for_body(db, def.0); + let infer = InferenceResult::of(db, def.0); let (captures, _) = infer.closure_info(id.0); let fields = captures .iter() diff --git a/crates/hir-ty/src/layout/adt.rs b/crates/hir-ty/src/layout/adt.rs index d249591718..6090ddfd45 100644 --- a/crates/hir-ty/src/layout/adt.rs +++ b/crates/hir-ty/src/layout/adt.rs @@ -5,7 +5,7 @@ use std::{cmp, ops::Bound}; use hir_def::{ AdtId, VariantId, attrs::AttrFlags, - signatures::{StructFlags, VariantFields}, + signatures::{StructFlags, StructSignature, VariantFields}, }; use rustc_abi::{Integer, ReprOptions, TargetDataLayout}; use rustc_index::IndexVec; @@ -41,7 +41,7 @@ pub fn layout_of_adt_query( }; let (variants, repr, is_special_no_niche) = match def { AdtId::StructId(s) => { - let sig = db.struct_signature(s); + let sig = StructSignature::of(db, s); let mut r = SmallVec::<[_; 1]>::new(); r.push(handle_variant(s.into(), s.fields(db))?); ( diff --git a/crates/hir-ty/src/layout/tests.rs b/crates/hir-ty/src/layout/tests.rs index 8c91be1d78..484ecebba5 100644 --- a/crates/hir-ty/src/layout/tests.rs +++ b/crates/hir-ty/src/layout/tests.rs @@ -1,6 +1,12 @@ use base_db::target::TargetData; use either::Either; -use hir_def::{HasModule, db::DefDatabase}; +use hir_def::{ + DefWithBodyId, ExpressionStoreOwnerId, GenericDefId, HasModule, + expr_store::Body, + signatures::{ + EnumSignature, FunctionSignature, StructSignature, TypeAliasSignature, UnionSignature, + }, +}; use project_model::{Sysroot, toolchain_info::QueryConfig}; use rustc_hash::FxHashMap; use rustc_type_ir::inherent::GenericArgs as _; @@ -49,18 +55,15 @@ fn eval_goal( let adt_or_type_alias_id = scope.declarations().find_map(|x| match x { hir_def::ModuleDefId::AdtId(x) => { let name = match x { - hir_def::AdtId::StructId(x) => db - .struct_signature(x) + hir_def::AdtId::StructId(x) => StructSignature::of(&db, x) .name .display_no_db(file_id.edition(&db)) .to_smolstr(), - hir_def::AdtId::UnionId(x) => db - .union_signature(x) + hir_def::AdtId::UnionId(x) => UnionSignature::of(&db, x) .name .display_no_db(file_id.edition(&db)) .to_smolstr(), - hir_def::AdtId::EnumId(x) => db - .enum_signature(x) + hir_def::AdtId::EnumId(x) => EnumSignature::of(&db, x) .name .display_no_db(file_id.edition(&db)) .to_smolstr(), @@ -68,8 +71,7 @@ fn eval_goal( (name == "Goal").then_some(Either::Left(x)) } hir_def::ModuleDefId::TypeAliasId(x) => { - let name = db - .type_alias_signature(x) + let name = TypeAliasSignature::of(&db, x) .name .display_no_db(file_id.edition(&db)) .to_smolstr(); @@ -90,10 +92,13 @@ fn eval_goal( ), Either::Right(ty_id) => db.ty(ty_id.into()).instantiate_identity(), }; - let param_env = db.trait_environment(match adt_or_type_alias_id { - Either::Left(adt) => hir_def::GenericDefId::AdtId(adt), - Either::Right(ty) => hir_def::GenericDefId::TypeAliasId(ty), - }); + let param_env = db.trait_environment( + match adt_or_type_alias_id { + Either::Left(adt) => hir_def::GenericDefId::AdtId(adt), + Either::Right(ty) => hir_def::GenericDefId::TypeAliasId(ty), + } + .into(), + ); let krate = match adt_or_type_alias_id { Either::Left(it) => it.krate(&db), Either::Right(it) => it.krate(&db), @@ -123,8 +128,7 @@ fn eval_expr( .declarations() .find_map(|x| match x { hir_def::ModuleDefId::FunctionId(x) => { - let name = db - .function_signature(x) + let name = FunctionSignature::of(&db, x) .name .display_no_db(file_id.edition(&db)) .to_smolstr(); @@ -133,15 +137,16 @@ fn eval_expr( _ => None, }) .unwrap(); - let hir_body = db.body(function_id.into()); + let hir_body = Body::of(&db, function_id.into()); let b = hir_body .bindings() .find(|x| x.1.name.display_no_db(file_id.edition(&db)).to_smolstr() == "goal") .unwrap() .0; - let infer = InferenceResult::for_body(&db, function_id.into()); + let infer = InferenceResult::of(&db, DefWithBodyId::from(function_id)); let goal_ty = infer.type_of_binding[b].clone(); - let param_env = db.trait_environment(function_id.into()); + let param_env = + db.trait_environment(ExpressionStoreOwnerId::from(GenericDefId::from(function_id))); let krate = function_id.krate(&db); db.layout_of_ty(goal_ty, ParamEnvAndCrate { param_env, krate }.store()) }) @@ -379,6 +384,11 @@ struct Goal(Foo<S>); #[test] fn simd_types() { + let size = 16; + #[cfg(not(target_arch = "s390x"))] + let align = 16; + #[cfg(target_arch = "s390x")] + let align = 8; check_size_and_align( r#" #[repr(simd)] @@ -386,8 +396,8 @@ fn simd_types() { struct Goal(SimdType); "#, "", - 16, - 16, + size, + align, ); } diff --git a/crates/hir-ty/src/lib.rs b/crates/hir-ty/src/lib.rs index f8920904f0..e6b8329ca8 100644 --- a/crates/hir-ty/src/lib.rs +++ b/crates/hir-ty/src/lib.rs @@ -31,6 +31,7 @@ mod inhabitedness; mod lower; pub mod next_solver; mod opaques; +mod representability; mod specialization; mod target_feature; mod utils; @@ -57,9 +58,12 @@ mod test_db; #[cfg(test)] mod tests; -use std::hash::Hash; +use std::{hash::Hash, ops::ControlFlow}; -use hir_def::{CallableDefId, TypeOrConstParamId, type_ref::Rawness}; +use hir_def::{ + CallableDefId, ExpressionStoreOwnerId, GenericDefId, TypeAliasId, TypeOrConstParamId, + TypeParamId, hir::generics::GenericParams, resolver::TypeNs, type_ref::Rawness, +}; use hir_expand::name::Name; use indexmap::{IndexMap, map::Entry}; use intern::{Symbol, sym}; @@ -77,10 +81,11 @@ use crate::{ db::HirDatabase, display::{DisplayTarget, HirDisplay}, infer::unify::InferenceTable, + lower::SupertraitsInfo, next_solver::{ AliasTy, Binder, BoundConst, BoundRegion, BoundRegionKind, BoundTy, BoundTyKind, Canonical, - CanonicalVarKind, CanonicalVars, Const, ConstKind, DbInterner, FnSig, GenericArgs, - PolyFnSig, Predicate, Region, RegionKind, TraitRef, Ty, TyKind, Tys, abi, + CanonicalVarKind, CanonicalVars, ClauseKind, Const, ConstKind, DbInterner, FnSig, + GenericArgs, PolyFnSig, Predicate, Region, RegionKind, TraitRef, Ty, TyKind, Tys, abi, }, }; @@ -94,7 +99,7 @@ pub use infer::{ }; pub use lower::{ GenericPredicates, ImplTraits, LifetimeElisionKind, TyDefId, TyLoweringContext, ValueTyDefId, - associated_type_shorthand_candidates, diagnostics::*, + diagnostics::*, }; pub use next_solver::interner::{attach_db, attach_db_allow_change, with_attached_db}; pub use target_feature::TargetFeatures; @@ -479,6 +484,55 @@ where } /// To be used from `hir` only. +pub fn associated_type_shorthand_candidates( + db: &dyn HirDatabase, + def: GenericDefId, + res: TypeNs, + mut cb: impl FnMut(&Name, TypeAliasId) -> bool, +) -> Option<TypeAliasId> { + let interner = DbInterner::new_no_crate(db); + let (def, param) = match res { + TypeNs::GenericParam(param) => (def, param), + TypeNs::SelfType(impl_) => { + let impl_trait = db.impl_trait(impl_)?.skip_binder().def_id.0; + let param = TypeParamId::from_unchecked(TypeOrConstParamId { + parent: impl_trait.into(), + local_id: GenericParams::SELF_PARAM_ID_IN_SELF, + }); + (impl_trait.into(), param) + } + _ => return None, + }; + + let mut dedup_map = FxHashSet::default(); + let param_ty = Ty::new_param(interner, param, param_idx(db, param.into()).unwrap() as u32); + // We use the ParamEnv and not the predicates because the ParamEnv elaborates bounds. + let param_env = db.trait_environment(ExpressionStoreOwnerId::from(def)); + for clause in param_env.clauses { + let ClauseKind::Trait(trait_clause) = clause.kind().skip_binder() else { continue }; + if trait_clause.self_ty() != param_ty { + continue; + } + let trait_id = trait_clause.def_id().0; + dedup_map.extend( + SupertraitsInfo::query(db, trait_id) + .defined_assoc_types + .iter() + .map(|(name, id)| (name, *id)), + ); + } + + dedup_map + .into_iter() + .try_for_each( + |(name, id)| { + if cb(name, id) { ControlFlow::Break(id) } else { ControlFlow::Continue(()) } + }, + ) + .break_value() +} + +/// To be used from `hir` only. pub fn callable_sig_from_fn_trait<'db>( self_ty: Ty<'db>, trait_env: ParamEnvAndCrate<'db>, diff --git a/crates/hir-ty/src/lower.rs b/crates/hir-ty/src/lower.rs index 5789bf02a4..7259099107 100644 --- a/crates/hir-ty/src/lower.rs +++ b/crates/hir-ty/src/lower.rs @@ -13,19 +13,23 @@ use std::{cell::OnceCell, iter, mem}; use arrayvec::ArrayVec; use either::Either; use hir_def::{ - AdtId, AssocItemId, CallableDefId, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, - FunctionId, GeneralConstId, GenericDefId, GenericParamId, HasModule, ImplId, ItemContainerId, - LifetimeParamId, LocalFieldId, Lookup, StaticId, StructId, TraitId, TypeAliasId, - TypeOrConstParamId, TypeParamId, UnionId, VariantId, + AdtId, AssocItemId, CallableDefId, ConstId, ConstParamId, EnumId, EnumVariantId, + ExpressionStoreOwnerId, FunctionId, GeneralConstId, GenericDefId, GenericParamId, HasModule, + ImplId, ItemContainerId, LifetimeParamId, LocalFieldId, Lookup, StaticId, StructId, TraitId, + TypeAliasId, TypeOrConstParamId, TypeParamId, UnionId, VariantId, builtin_type::BuiltinType, expr_store::{ExpressionStore, HygieneId, path::Path}, hir::generics::{ - GenericParamDataRef, TypeOrConstParamData, TypeParamProvenance, WherePredicate, + GenericParamDataRef, GenericParams, TypeOrConstParamData, TypeParamProvenance, + WherePredicate, }, item_tree::FieldsShape, lang_item::LangItems, resolver::{HasResolver, LifetimeNs, Resolver, TypeNs, ValueNs}, - signatures::{FunctionSignature, TraitFlags, TypeAliasFlags}, + signatures::{ + ConstSignature, FunctionSignature, ImplSignature, StaticSignature, StructSignature, + TraitFlags, TraitSignature, TypeAliasFlags, TypeAliasSignature, + }, type_ref::{ ConstRef, FnType, LifetimeRefId, PathId, TraitBoundModifier, TraitRef as HirTraitRef, TypeBound, TypeRef, TypeRefId, @@ -36,27 +40,22 @@ use la_arena::{Arena, ArenaMap, Idx}; use path::{PathDiagnosticCallback, PathLoweringContext}; use rustc_ast_ir::Mutability; use rustc_hash::FxHashSet; -use rustc_pattern_analysis::Captures; use rustc_type_ir::{ AliasTyKind, BoundVarIndexKind, ConstKind, DebruijnIndex, ExistentialPredicate, ExistentialProjection, ExistentialTraitRef, FnSig, Interner, OutlivesPredicate, TermKind, - TyKind::{self}, - TypeFoldable, TypeVisitableExt, Upcast, UpcastFrom, elaborate, - inherent::{ - Clause as _, GenericArg as _, GenericArgs as _, IntoKind as _, Region as _, SliceLike, - Ty as _, - }, + TyKind, TypeFoldable, TypeVisitableExt, Upcast, UpcastFrom, elaborate, + inherent::{Clause as _, GenericArgs as _, IntoKind as _, Region as _, Ty as _}, }; -use smallvec::{SmallVec, smallvec}; +use smallvec::SmallVec; use stdx::{impl_from, never}; use tracing::debug; use triomphe::{Arc, ThinArc}; use crate::{ - FnAbi, ImplTraitId, TyLoweringDiagnostic, TyLoweringDiagnosticKind, all_super_traits, + FnAbi, ImplTraitId, TyLoweringDiagnostic, TyLoweringDiagnosticKind, consteval::intern_const_ref, db::{HirDatabase, InternedOpaqueTyId}, - generics::{Generics, generics, trait_self_param_idx}, + generics::{Generics, generics}, next_solver::{ AliasTy, Binder, BoundExistentialPredicates, Clause, ClauseKind, Clauses, Const, DbInterner, EarlyBinder, EarlyParamRegion, ErrorGuaranteed, FxIndexMap, GenericArg, @@ -182,7 +181,7 @@ pub struct TyLoweringContext<'db, 'a> { resolver: &'a Resolver<'db>, store: &'a ExpressionStore, def: GenericDefId, - generics: OnceCell<Generics>, + generics: OnceCell<Generics<'db>>, in_binders: DebruijnIndex, impl_trait_mode: ImplTraitLoweringState, /// Tracks types with explicit `?Sized` bounds. @@ -283,11 +282,12 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { } pub(crate) fn lower_const(&mut self, const_ref: ConstRef, const_type: Ty<'db>) -> Const<'db> { - let const_ref = &self.store[const_ref.expr]; - match const_ref { - hir_def::hir::Expr::Path(path) => { - self.path_to_const(path).unwrap_or_else(|| unknown_const(const_type)) - } + let expr_id = const_ref.expr; + let expr = &self.store[expr_id]; + match expr { + hir_def::hir::Expr::Path(path) => self + .path_to_const(path) + .unwrap_or_else(|| Const::new(self.interner, ConstKind::Error(ErrorGuaranteed))), hir_def::hir::Expr::Literal(literal) => { intern_const_ref(self.db, literal, const_type, self.resolver.krate()) } @@ -304,20 +304,74 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { self.resolver.krate(), ) } else { - unknown_const(const_type) + Const::new(self.interner, ConstKind::Error(ErrorGuaranteed)) } } // For unsigned integers, chars, bools, etc., negation is not meaningful - _ => unknown_const(const_type), + _ => Const::new(self.interner, ConstKind::Error(ErrorGuaranteed)), } } else { - unknown_const(const_type) + // Complex negation expression (e.g. `-N` where N is a const param) + self.lower_const_as_unevaluated(expr_id, const_type) } } - _ => unknown_const(const_type), + hir_def::hir::Expr::Underscore => { + Const::new(self.interner, ConstKind::Error(ErrorGuaranteed)) + } + // Any other complex expression becomes an unevaluated anonymous const. + _ => self.lower_const_as_unevaluated(expr_id, const_type), } } + /// Lower a complex const expression to an `UnevaluatedConst` backed by an `AnonConstId`. + /// + /// The `expected_ty_ref` is `None` for array lengths (implicitly `usize`) or + /// `Some(type_ref_id)` for const generic arguments where the expected type comes + /// from the const parameter declaration. + fn lower_const_as_unevaluated( + &mut self, + _expr: hir_def::hir::ExprId, + _expected_ty: Ty<'db>, + ) -> Const<'db> { + // /// Build the identity generic args for the current generic context. + // /// + // /// This maps each generic parameter to itself (as a `ParamTy`, `ParamConst`, + // /// or `EarlyParamRegion`), which is the correct substitution when creating + // /// an `UnevaluatedConst` during type lowering — the anon const inherits the + // /// parent's generics and they haven't been substituted yet. + // fn current_generic_args(&self) -> GenericArgs<'db> { + // let generics = self.generics(); + // let interner = self.interner; + // GenericArgs::new_from_iter( + // interner, + // generics.iter_id().enumerate().map(|(index, id)| match id { + // GenericParamId::TypeParamId(id) => { + // GenericArg::from(Ty::new_param(interner, id, index as u32)) + // } + // GenericParamId::ConstParamId(id) => GenericArg::from(Const::new_param( + // interner, + // ParamConst { id, index: index as u32 }, + // )), + // GenericParamId::LifetimeParamId(id) => GenericArg::from(Region::new_early_param( + // interner, + // EarlyParamRegion { id, index: index as u32 }, + // )), + // }), + // ) + // } + // let loc = AnonConstLoc { owner: self.def, expr }; + // let id = loc.intern(self.db); + // let args = self.current_generic_args(); + // Const::new( + // self.interner, + // ConstKind::Unevaluated(UnevaluatedConst::new( + // GeneralConstId::AnonConstId(id).into(), + // args, + // )), + // ) + Const::new(self.interner, ConstKind::Error(ErrorGuaranteed)) + } + pub(crate) fn path_to_const(&mut self, path: &Path) -> Option<Const<'db>> { match self.resolver.resolve_path_in_value_ns_fully(self.db, path, HygieneId::ROOT) { Some(ValueNs::GenericParam(p)) => { @@ -353,7 +407,7 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { self.path_to_const(path).unwrap_or_else(|| unknown_const(const_type)) } - fn generics(&self) -> &Generics { + fn generics(&self) -> &Generics<'db> { self.generics.get_or_init(|| generics(self.db, self.def)) } @@ -618,33 +672,12 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { &'b mut self, where_predicate: &'b WherePredicate, ignore_bindings: bool, - generics: &Generics, - predicate_filter: PredicateFilter, ) -> impl Iterator<Item = (Clause<'db>, GenericPredicateSource)> + use<'a, 'b, 'db> { match where_predicate { WherePredicate::ForLifetime { target, bound, .. } | WherePredicate::TypeBound { target, bound } => { - if let PredicateFilter::SelfTrait = predicate_filter { - let target_type = &self.store[*target]; - let self_type = 'is_self: { - if let TypeRef::Path(path) = target_type - && path.is_self_type() - { - break 'is_self true; - } - if let TypeRef::TypeParam(param) = target_type - && generics[param.local_id()].is_trait_self() - { - break 'is_self true; - } - false - }; - if !self_type { - return Either::Left(Either::Left(iter::empty())); - } - } let self_ty = self.lower_ty(*target); - Either::Left(Either::Right(self.lower_type_bound(bound, self_ty, ignore_bindings))) + Either::Left(self.lower_type_bound(bound, self_ty, ignore_bindings)) } &WherePredicate::Lifetime { bound, target } => Either::Right(iter::once(( Clause(Predicate::new( @@ -752,7 +785,8 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { match b.kind().skip_binder() { rustc_type_ir::ClauseKind::Trait(t) => { let id = t.def_id(); - let is_auto = db.trait_signature(id.0).flags.contains(TraitFlags::AUTO); + let is_auto = + TraitSignature::of(db, id.0).flags.contains(TraitFlags::AUTO); if is_auto { auto_traits.push(t.def_id().0); } else { @@ -831,6 +865,8 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { let mut ordered_associated_types = vec![]; if let Some(principal_trait) = principal { + // Generally we should not elaborate in lowering as this can lead to cycles, but + // here rustc cycles as well. for clause in elaborate::elaborate( interner, [Clause::upcast_from( @@ -1134,7 +1170,7 @@ pub(crate) fn impl_trait_with_diagnostics<'db>( db: &'db dyn HirDatabase, impl_id: ImplId, ) -> Option<(StoredEarlyBinder<(TraitId, StoredGenericArgs)>, Diagnostics)> { - let impl_data = db.impl_signature(impl_id); + let impl_data = ImplSignature::of(db, impl_id); let resolver = impl_id.resolver(db); let mut ctx = TyLoweringContext::new( db, @@ -1219,7 +1255,7 @@ impl ImplTraits { def: hir_def::FunctionId, ) -> Option<Box<StoredEarlyBinder<ImplTraits>>> { // FIXME unify with fn_sig_for_fn instead of doing lowering twice, maybe - let data = db.function_signature(def); + let data = FunctionSignature::of(db, def); let resolver = def.resolver(db); let mut ctx_ret = TyLoweringContext::new( db, @@ -1247,7 +1283,7 @@ impl ImplTraits { db: &dyn HirDatabase, def: hir_def::TypeAliasId, ) -> Option<Box<StoredEarlyBinder<ImplTraits>>> { - let data = db.type_alias_signature(def); + let data = TypeAliasSignature::of(db, def); let resolver = def.resolver(db); let mut ctx = TyLoweringContext::new( db, @@ -1337,7 +1373,7 @@ fn type_for_fn(db: &dyn HirDatabase, def: FunctionId) -> StoredEarlyBinder<Store /// Build the declared type of a const. fn type_for_const(db: &dyn HirDatabase, def: ConstId) -> StoredEarlyBinder<StoredTy> { let resolver = def.resolver(db); - let data = db.const_signature(def); + let data = ConstSignature::of(db, def); let parent = def.loc(db).container; let mut ctx = TyLoweringContext::new( db, @@ -1353,7 +1389,7 @@ fn type_for_const(db: &dyn HirDatabase, def: ConstId) -> StoredEarlyBinder<Store /// Build the declared type of a static. fn type_for_static(db: &dyn HirDatabase, def: StaticId) -> StoredEarlyBinder<StoredTy> { let resolver = def.resolver(db); - let data = db.static_signature(def); + let data = StaticSignature::of(db, def); let mut ctx = TyLoweringContext::new( db, &resolver, @@ -1370,7 +1406,7 @@ fn type_for_struct_constructor( db: &dyn HirDatabase, def: StructId, ) -> Option<StoredEarlyBinder<StoredTy>> { - let struct_data = db.struct_signature(def); + let struct_data = StructSignature::of(db, def); match struct_data.shape { FieldsShape::Record => None, FieldsShape::Unit => Some(type_for_adt(db, def.into())), @@ -1445,7 +1481,7 @@ pub(crate) fn type_for_type_alias_with_diagnostics<'db>( db: &'db dyn HirDatabase, t: TypeAliasId, ) -> (StoredEarlyBinder<StoredTy>, Diagnostics) { - let type_alias_data = db.type_alias_signature(t); + let type_alias_data = TypeAliasSignature::of(db, t); let mut diags = None; let resolver = t.resolver(db); let interner = DbInterner::new_no_crate(db); @@ -1508,7 +1544,7 @@ pub(crate) fn impl_self_ty_with_diagnostics<'db>( ) -> (StoredEarlyBinder<StoredTy>, Diagnostics) { let resolver = impl_id.resolver(db); - let impl_data = db.impl_signature(impl_id); + let impl_data = ImplSignature::of(db, impl_id); let mut ctx = TyLoweringContext::new( db, &resolver, @@ -1554,14 +1590,14 @@ pub(crate) fn const_param_ty_with_diagnostics<'db>( _: (), def: ConstParamId, ) -> (StoredTy, Diagnostics) { - let (parent_data, store) = db.generic_params_and_store(def.parent()); + let (parent_data, store) = GenericParams::with_store(db, def.parent()); let data = &parent_data[def.local_id()]; let resolver = def.parent().resolver(db); let interner = DbInterner::new_no_crate(db); let mut ctx = TyLoweringContext::new( db, &resolver, - &store, + store, def.parent(), LifetimeElisionKind::AnonymousReportError, ); @@ -1624,6 +1660,98 @@ pub(crate) fn field_types_with_diagnostics_query<'db>( (res, create_diagnostics(ctx.diagnostics)) } +#[derive(Debug, PartialEq, Eq, Default)] +pub(crate) struct SupertraitsInfo { + /// This includes the trait itself. + pub(crate) all_supertraits: Box<[TraitId]>, + pub(crate) direct_supertraits: Box<[TraitId]>, + pub(crate) defined_assoc_types: Box<[(Name, TypeAliasId)]>, +} + +impl SupertraitsInfo { + #[inline] + pub(crate) fn query(db: &dyn HirDatabase, trait_: TraitId) -> &Self { + return supertraits_info(db, trait_); + + #[salsa::tracked(returns(ref), cycle_result = supertraits_info_cycle)] + fn supertraits_info(db: &dyn HirDatabase, trait_: TraitId) -> SupertraitsInfo { + let mut all_supertraits = FxHashSet::default(); + let mut direct_supertraits = FxHashSet::default(); + let mut defined_assoc_types = FxHashSet::default(); + + all_supertraits.insert(trait_); + defined_assoc_types.extend(trait_.trait_items(db).items.iter().filter_map( + |(name, id)| match *id { + AssocItemId::TypeAliasId(id) => Some((name.clone(), id)), + _ => None, + }, + )); + + let resolver = trait_.resolver(db); + let signature = TraitSignature::of(db, trait_); + for pred in signature.generic_params.where_predicates() { + let (WherePredicate::TypeBound { target, bound } + | WherePredicate::ForLifetime { lifetimes: _, target, bound }) = pred + else { + continue; + }; + let (TypeBound::Path(bounded_trait, TraitBoundModifier::None) + | TypeBound::ForLifetime(_, bounded_trait)) = *bound + else { + continue; + }; + let target = &signature.store[*target]; + match target { + TypeRef::TypeParam(param) + if param.local_id() == GenericParams::SELF_PARAM_ID_IN_SELF => {} + TypeRef::Path(path) if path.is_self_type() => {} + _ => continue, + } + let Some(TypeNs::TraitId(bounded_trait)) = + resolver.resolve_path_in_type_ns_fully(db, &signature.store[bounded_trait]) + else { + continue; + }; + let SupertraitsInfo { + all_supertraits: bounded_trait_all_supertraits, + direct_supertraits: _, + defined_assoc_types: bounded_traits_defined_assoc_types, + } = SupertraitsInfo::query(db, bounded_trait); + all_supertraits.extend(bounded_trait_all_supertraits); + direct_supertraits.insert(bounded_trait); + defined_assoc_types.extend(bounded_traits_defined_assoc_types.iter().cloned()); + } + + SupertraitsInfo { + all_supertraits: Box::from_iter(all_supertraits), + direct_supertraits: Box::from_iter(direct_supertraits), + defined_assoc_types: Box::from_iter(defined_assoc_types), + } + } + + fn supertraits_info_cycle( + _db: &dyn HirDatabase, + _: salsa::Id, + _trait_: TraitId, + ) -> SupertraitsInfo { + SupertraitsInfo::default() + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum AssocTypeShorthandResolution { + Resolved(StoredEarlyBinder<(TypeAliasId, StoredGenericArgs)>), + Ambiguous { + /// If one resolution belongs to a sub-trait and one to a supertrait, this contains + /// the sub-trait's resolution. This can be `None` if there is no trait inheritance + /// relationship between the resolutions. + sub_trait_resolution: Option<StoredEarlyBinder<(TypeAliasId, StoredGenericArgs)>>, + }, + NotFound, + Cycle, +} + /// Predicates for `param_id` of the form `P: SomeTrait`. If /// `assoc_name` is provided, only return predicates referencing traits /// that have an associated type of that name. @@ -1638,15 +1766,14 @@ pub(crate) fn field_types_with_diagnostics_query<'db>( /// following bounds are disallowed: `T: Foo<U::Item>, U: Foo<T::Item>`, but /// these are fine: `T: Foo<U::Item>, U: Foo<()>`. #[tracing::instrument(skip(db), ret)] -#[salsa::tracked(returns(ref), cycle_result = generic_predicates_for_param_cycle_result)] -pub(crate) fn generic_predicates_for_param<'db>( - db: &'db dyn HirDatabase, +#[salsa::tracked(returns(ref), cycle_result = resolve_type_param_assoc_type_shorthand_cycle_result)] +fn resolve_type_param_assoc_type_shorthand( + db: &dyn HirDatabase, def: GenericDefId, - param_id: TypeOrConstParamId, - assoc_name: Option<Name>, -) -> StoredEarlyBinder<StoredClauses> { + param: TypeParamId, + assoc_name: Name, +) -> AssocTypeShorthandResolution { let generics = generics(db, def); - let interner = DbInterner::new_no_crate(db); let resolver = def.resolver(db); let mut ctx = TyLoweringContext::new( db, @@ -1655,128 +1782,138 @@ pub(crate) fn generic_predicates_for_param<'db>( def, LifetimeElisionKind::AnonymousReportError, ); + let interner = ctx.interner; + let param_ty = Ty::new_param( + interner, + param, + generics.type_or_const_param_idx(param.into()).unwrap() as u32, + ); - // we have to filter out all other predicates *first*, before attempting to lower them - let has_relevant_bound = |pred: &_, ctx: &mut TyLoweringContext<'_, '_>| match pred { - WherePredicate::ForLifetime { target, bound, .. } - | WherePredicate::TypeBound { target, bound, .. } => { - let invalid_target = { ctx.lower_ty_only_param(*target) != Some(param_id) }; - if invalid_target { - // FIXME(sized-hierarchy): Revisit and adjust this properly once we have implemented - // sized-hierarchy correctly. - // If this is filtered out without lowering, `?Sized` or `PointeeSized` is not gathered into - // `ctx.unsized_types` - let lower = || -> bool { - match bound { - TypeBound::Path(_, TraitBoundModifier::Maybe) => true, - TypeBound::Path(path, _) | TypeBound::ForLifetime(_, path) => { - let TypeRef::Path(path) = &ctx.store[path.type_ref()] else { - return false; - }; - let Some(pointee_sized) = ctx.lang_items.PointeeSized else { - return false; - }; - // Lower the path directly with `Resolver` instead of PathLoweringContext` - // to prevent diagnostics duplications. - ctx.resolver.resolve_path_in_type_ns_fully(ctx.db, path).is_some_and( - |it| matches!(it, TypeNs::TraitId(tr) if tr == pointee_sized), - ) - } - _ => false, - } - }(); - if lower { - ctx.lower_where_predicate(pred, true, &generics, PredicateFilter::All) - .for_each(drop); - } - return false; - } - - match bound { - &TypeBound::ForLifetime(_, path) | &TypeBound::Path(path, _) => { - // Only lower the bound if the trait could possibly define the associated - // type we're looking for. - let path = &ctx.store[path]; - - let Some(assoc_name) = &assoc_name else { return true }; - let Some(TypeNs::TraitId(tr)) = - resolver.resolve_path_in_type_ns_fully(db, path) - else { - return false; - }; - - trait_or_supertrait_has_assoc_type(db, tr, assoc_name) - } - TypeBound::Use(_) | TypeBound::Lifetime(_) | TypeBound::Error => false, - } + let mut this_trait_resolution = None; + if let GenericDefId::TraitId(containing_trait) = param.parent() + && param.local_id() == GenericParams::SELF_PARAM_ID_IN_SELF + { + // Add the trait's own associated types. + if let Some(assoc_type) = + containing_trait.trait_items(db).associated_type_by_name(&assoc_name) + { + let args = GenericArgs::identity_for_item(interner, containing_trait.into()); + this_trait_resolution = Some(StoredEarlyBinder::bind((assoc_type, args.store()))); } - WherePredicate::Lifetime { .. } => false, - }; - let mut predicates = Vec::new(); + } + + let mut supertraits_resolution = None; for maybe_parent_generics in std::iter::successors(Some(&generics), |generics| generics.parent_generics()) { ctx.store = maybe_parent_generics.store(); for pred in maybe_parent_generics.where_predicates() { - if has_relevant_bound(pred, &mut ctx) { - predicates.extend( - ctx.lower_where_predicate( - pred, - true, - maybe_parent_generics, - PredicateFilter::All, - ) - .map(|(pred, _)| pred), - ); + let (WherePredicate::TypeBound { target, bound } + | WherePredicate::ForLifetime { lifetimes: _, target, bound }) = pred + else { + continue; + }; + let (TypeBound::Path(bounded_trait_path, TraitBoundModifier::None) + | TypeBound::ForLifetime(_, bounded_trait_path)) = *bound + else { + continue; + }; + let Some(target) = ctx.lower_ty_only_param(*target) else { continue }; + if target != param.into() { + continue; + } + let Some(TypeNs::TraitId(bounded_trait)) = + resolver.resolve_path_in_type_ns_fully(db, &ctx.store[bounded_trait_path]) + else { + continue; + }; + if !SupertraitsInfo::query(db, bounded_trait) + .defined_assoc_types + .iter() + .any(|(name, _)| *name == assoc_name) + { + continue; + } + + let Some((bounded_trait_ref, _)) = + ctx.lower_trait_ref_from_path(bounded_trait_path, param_ty) + else { + continue; + }; + // Now, search from the start on the *bounded* trait like if we wrote `Self::Assoc`. Eventually, we'll get + // the correct trait ref (or a cycle). + let lookup_on_bounded_trait = resolve_type_param_assoc_type_shorthand( + db, + bounded_trait.into(), + TypeParamId::trait_self(bounded_trait), + assoc_name.clone(), + ); + let assoc_type_and_args = match &lookup_on_bounded_trait { + AssocTypeShorthandResolution::Resolved(trait_ref) => trait_ref, + AssocTypeShorthandResolution::Ambiguous { + sub_trait_resolution: Some(trait_ref), + } => trait_ref, + AssocTypeShorthandResolution::Ambiguous { sub_trait_resolution: None } => { + return AssocTypeShorthandResolution::Ambiguous { + sub_trait_resolution: this_trait_resolution, + }; + } + AssocTypeShorthandResolution::NotFound => { + never!("we checked that the trait defines this assoc type"); + continue; + } + AssocTypeShorthandResolution::Cycle => return AssocTypeShorthandResolution::Cycle, + }; + let (assoc_type, args) = assoc_type_and_args + .get_with(|(assoc_type, args)| (*assoc_type, args.as_ref())) + .skip_binder(); + let args = EarlyBinder::bind(args).instantiate(interner, bounded_trait_ref.args); + let current_result = StoredEarlyBinder::bind((assoc_type, args.store())); + if let Some(this_trait_resolution) = this_trait_resolution { + return AssocTypeShorthandResolution::Ambiguous { + sub_trait_resolution: Some(this_trait_resolution), + }; + } else if let Some(prev_resolution) = &supertraits_resolution { + if let AssocTypeShorthandResolution::Ambiguous { + sub_trait_resolution: Some(prev_resolution), + } + | AssocTypeShorthandResolution::Resolved(prev_resolution) = prev_resolution + && *prev_resolution == current_result + { + continue; + } else { + return AssocTypeShorthandResolution::Ambiguous { sub_trait_resolution: None }; + } + } else { + supertraits_resolution = Some(match lookup_on_bounded_trait { + AssocTypeShorthandResolution::Resolved(_) => { + AssocTypeShorthandResolution::Resolved(current_result) + } + AssocTypeShorthandResolution::Ambiguous { .. } => { + AssocTypeShorthandResolution::Ambiguous { + sub_trait_resolution: Some(current_result), + } + } + AssocTypeShorthandResolution::NotFound + | AssocTypeShorthandResolution::Cycle => unreachable!(), + }); } } } - let args = GenericArgs::identity_for_item(interner, def.into()); - if !args.is_empty() { - let explicitly_unsized_tys = ctx.unsized_types; - if let Some(implicitly_sized_predicates) = implicitly_sized_clauses( - db, - ctx.lang_items, - param_id.parent, - &explicitly_unsized_tys, - &args, - ) { - predicates.extend(implicitly_sized_predicates); - }; - } - StoredEarlyBinder::bind(Clauses::new_from_slice(&predicates).store()) + supertraits_resolution + .or_else(|| this_trait_resolution.map(AssocTypeShorthandResolution::Resolved)) + .unwrap_or(AssocTypeShorthandResolution::NotFound) } -pub(crate) fn generic_predicates_for_param_cycle_result( - db: &dyn HirDatabase, +fn resolve_type_param_assoc_type_shorthand_cycle_result( + _db: &dyn HirDatabase, _: salsa::Id, _def: GenericDefId, - _param_id: TypeOrConstParamId, - _assoc_name: Option<Name>, -) -> StoredEarlyBinder<StoredClauses> { - StoredEarlyBinder::bind(Clauses::empty(DbInterner::new_no_crate(db)).store()) -} - -/// Check if this trait or any of its supertraits define an associated -/// type with the given name. -fn trait_or_supertrait_has_assoc_type( - db: &dyn HirDatabase, - tr: TraitId, - assoc_name: &Name, -) -> bool { - for trait_id in all_super_traits(db, tr) { - if trait_id - .trait_items(db) - .items - .iter() - .any(|(name, item)| matches!(item, AssocItemId::TypeAliasId(_)) && name == assoc_name) - { - return true; - } - } - - false + _param: TypeParamId, + _assoc_name: Name, +) -> AssocTypeShorthandResolution { + AssocTypeShorthandResolution::Cycle } #[inline] @@ -1822,7 +1959,7 @@ fn type_alias_bounds_with_diagnostics<'db>( db: &'db dyn HirDatabase, type_alias: TypeAliasId, ) -> (TypeAliasBounds<StoredEarlyBinder<StoredClauses>>, Diagnostics) { - let type_alias_data = db.type_alias_signature(type_alias); + let type_alias_data = TypeAliasSignature::of(db, type_alias); let resolver = hir_def::resolver::HasResolver::resolver(type_alias, db); let mut ctx = TyLoweringContext::new( db, @@ -1897,15 +2034,29 @@ impl<'db> GenericPredicates { /// Resolve the where clause(s) of an item with generics. /// /// Diagnostics are computed only for this item's predicates, not for parents. - #[salsa::tracked(returns(ref))] + #[salsa::tracked(returns(ref), cycle_result=generic_predicates_cycle_result)] pub fn query_with_diagnostics( db: &'db dyn HirDatabase, def: GenericDefId, ) -> (GenericPredicates, Diagnostics) { - generic_predicates_filtered_by(db, def, PredicateFilter::All, |_| true) + generic_predicates(db, def) } } +/// A cycle can occur from malformed code. +fn generic_predicates_cycle_result( + _db: &dyn HirDatabase, + _: salsa::Id, + _def: GenericDefId, +) -> (GenericPredicates, Diagnostics) { + ( + GenericPredicates::from_explicit_own_predicates(StoredEarlyBinder::bind( + Clauses::default().store(), + )), + None, + ) +} + impl GenericPredicates { #[inline] pub(crate) fn from_explicit_own_predicates( @@ -1987,16 +2138,6 @@ impl GenericPredicates { } } -pub(crate) fn trait_environment_for_body_query( - db: &dyn HirDatabase, - def: DefWithBodyId, -) -> ParamEnv<'_> { - let Some(def) = def.as_generic_def_id(db) else { - return ParamEnv::empty(); - }; - db.trait_environment(def) -} - pub(crate) fn param_env_from_predicates<'db>( interner: DbInterner<'db>, predicates: &'db GenericPredicates, @@ -2011,7 +2152,12 @@ pub(crate) fn param_env_from_predicates<'db>( ParamEnv { clauses } } -pub(crate) fn trait_environment<'db>(db: &'db dyn HirDatabase, def: GenericDefId) -> ParamEnv<'db> { +pub(crate) fn trait_environment<'db>( + db: &'db dyn HirDatabase, + def: ExpressionStoreOwnerId, +) -> ParamEnv<'db> { + let def = def.generic_def(db); + return ParamEnv { clauses: trait_environment_query(db, def).as_ref() }; #[salsa::tracked(returns(ref))] @@ -2026,24 +2172,10 @@ pub(crate) fn trait_environment<'db>(db: &'db dyn HirDatabase, def: GenericDefId } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub(crate) enum PredicateFilter { - SelfTrait, - All, -} - /// Resolve the where clause(s) of an item with generics, /// with a given filter -#[tracing::instrument(skip(db, filter), ret)] -pub(crate) fn generic_predicates_filtered_by<F>( - db: &dyn HirDatabase, - def: GenericDefId, - predicate_filter: PredicateFilter, - filter: F, -) -> (GenericPredicates, Diagnostics) -where - F: Fn(GenericDefId) -> bool, -{ +#[tracing::instrument(skip(db), ret)] +fn generic_predicates(db: &dyn HirDatabase, def: GenericDefId) -> (GenericPredicates, Diagnostics) { let generics = generics(db, def); let resolver = def.resolver(db); let interner = DbInterner::new_no_crate(db); @@ -2065,9 +2197,9 @@ where let all_generics = std::iter::successors(Some(&generics), |generics| generics.parent_generics()) .collect::<ArrayVec<_, 2>>(); - let own_implicit_trait_predicate = implicit_trait_predicate(interner, def, predicate_filter); + let own_implicit_trait_predicate = implicit_trait_predicate(interner, def); let parent_implicit_trait_predicate = if all_generics.len() > 1 { - implicit_trait_predicate(interner, all_generics.last().unwrap().def(), predicate_filter) + implicit_trait_predicate(interner, all_generics.last().unwrap().def()) } else { None }; @@ -2075,97 +2207,85 @@ where // Collect only diagnostics from the child, not including parents. ctx.diagnostics.clear(); - if filter(maybe_parent_generics.def()) { - ctx.store = maybe_parent_generics.store(); - for pred in maybe_parent_generics.where_predicates() { - tracing::debug!(?pred); - for (pred, source) in - ctx.lower_where_predicate(pred, false, maybe_parent_generics, predicate_filter) - { - match source { - GenericPredicateSource::SelfOnly => { - if maybe_parent_generics.def() == def { - own_predicates.push(pred); - } else { - parent_predicates.push(pred); - } + ctx.store = maybe_parent_generics.store(); + for pred in maybe_parent_generics.where_predicates() { + tracing::debug!(?pred); + for (pred, source) in ctx.lower_where_predicate(pred, false) { + match source { + GenericPredicateSource::SelfOnly => { + if maybe_parent_generics.def() == def { + own_predicates.push(pred); + } else { + parent_predicates.push(pred); } - GenericPredicateSource::AssocTyBound => { - if maybe_parent_generics.def() == def { - own_assoc_ty_bounds.push(pred); - } else { - parent_assoc_ty_bounds.push(pred); - } + } + GenericPredicateSource::AssocTyBound => { + if maybe_parent_generics.def() == def { + own_assoc_ty_bounds.push(pred); + } else { + parent_assoc_ty_bounds.push(pred); } } } } + } - if maybe_parent_generics.def() == def { - push_const_arg_has_type_predicates(db, &mut own_predicates, maybe_parent_generics); - } else { - push_const_arg_has_type_predicates( - db, - &mut parent_predicates, - maybe_parent_generics, - ); - } + if maybe_parent_generics.def() == def { + push_const_arg_has_type_predicates(db, &mut own_predicates, maybe_parent_generics); + } else { + push_const_arg_has_type_predicates(db, &mut parent_predicates, maybe_parent_generics); + } - if let Some(sized_trait) = sized_trait { - let mut add_sized_clause = |param_idx, param_id, param_data| { - let ( - GenericParamId::TypeParamId(param_id), - GenericParamDataRef::TypeParamData(param_data), - ) = (param_id, param_data) - else { - return; - }; + if let Some(sized_trait) = sized_trait { + let mut add_sized_clause = |param_idx, param_id, param_data| { + let ( + GenericParamId::TypeParamId(param_id), + GenericParamDataRef::TypeParamData(param_data), + ) = (param_id, param_data) + else { + return; + }; - if param_data.provenance == TypeParamProvenance::TraitSelf { - return; - } + if param_data.provenance == TypeParamProvenance::TraitSelf { + return; + } - let param_ty = Ty::new_param(interner, param_id, param_idx); - if ctx.unsized_types.contains(¶m_ty) { - return; - } - let trait_ref = TraitRef::new_from_args( - interner, - sized_trait.into(), - GenericArgs::new_from_slice(&[param_ty.into()]), - ); - let clause = Clause(Predicate::new( - interner, - Binder::dummy(rustc_type_ir::PredicateKind::Clause( - rustc_type_ir::ClauseKind::Trait(TraitPredicate { - trait_ref, - polarity: rustc_type_ir::PredicatePolarity::Positive, - }), - )), - )); - if maybe_parent_generics.def() == def { - own_predicates.push(clause); - } else { - parent_predicates.push(clause); - } - }; - let parent_params_len = maybe_parent_generics.len_parent(); - maybe_parent_generics.iter_self().enumerate().for_each( - |(param_idx, (param_id, param_data))| { - add_sized_clause( - (param_idx + parent_params_len) as u32, - param_id, - param_data, - ); - }, + let param_ty = Ty::new_param(interner, param_id, param_idx); + if ctx.unsized_types.contains(¶m_ty) { + return; + } + let trait_ref = TraitRef::new_from_args( + interner, + sized_trait.into(), + GenericArgs::new_from_slice(&[param_ty.into()]), ); - } - - // We do not clear `ctx.unsized_types`, as the `?Sized` clause of a child (e.g. an associated type) can - // be declared on the parent (e.g. the trait). It is nevertheless fine to register the implicit `Sized` - // predicates before lowering the child, as a child cannot define a `?Sized` predicate for its parent. - // But we do have to lower the parent first. + let clause = Clause(Predicate::new( + interner, + Binder::dummy(rustc_type_ir::PredicateKind::Clause( + rustc_type_ir::ClauseKind::Trait(TraitPredicate { + trait_ref, + polarity: rustc_type_ir::PredicatePolarity::Positive, + }), + )), + )); + if maybe_parent_generics.def() == def { + own_predicates.push(clause); + } else { + parent_predicates.push(clause); + } + }; + let parent_params_len = maybe_parent_generics.len_parent(); + maybe_parent_generics.iter_self().enumerate().for_each( + |(param_idx, (param_id, param_data))| { + add_sized_clause((param_idx + parent_params_len) as u32, param_id, param_data); + }, + ); } + + // We do not clear `ctx.unsized_types`, as the `?Sized` clause of a child (e.g. an associated type) can + // be declared on the parent (e.g. the trait). It is nevertheless fine to register the implicit `Sized` + // predicates before lowering the child, as a child cannot define a `?Sized` predicate for its parent. + // But we do have to lower the parent first. } let diagnostics = create_diagnostics(ctx.diagnostics); @@ -2213,7 +2333,6 @@ where fn implicit_trait_predicate<'db>( interner: DbInterner<'db>, def: GenericDefId, - predicate_filter: PredicateFilter, ) -> Option<Clause<'db>> { // For traits, add `Self: Trait` predicate. This is // not part of the predicates that a user writes, but it @@ -2227,9 +2346,7 @@ where // prove that the trait applies to the types that were // used, and adding the predicate into this list ensures // that this is done. - if let GenericDefId::TraitId(def_id) = def - && predicate_filter == PredicateFilter::All - { + if let GenericDefId::TraitId(def_id) = def { Some(TraitRef::identity(interner, def_id.into()).upcast(interner)) } else { None @@ -2240,7 +2357,7 @@ where fn push_const_arg_has_type_predicates<'db>( db: &'db dyn HirDatabase, predicates: &mut Vec<Clause<'db>>, - generics: &Generics, + generics: &Generics<'db>, ) { let interner = DbInterner::new_no_crate(db); let const_params_offset = generics.len_parent() + generics.len_lifetimes_self(); @@ -2266,49 +2383,6 @@ fn push_const_arg_has_type_predicates<'db>( } } -/// Generate implicit `: Sized` predicates for all generics that has no `?Sized` bound. -/// Exception is Self of a trait def. -fn implicitly_sized_clauses<'a, 'subst, 'db>( - db: &'db dyn HirDatabase, - lang_items: &LangItems, - def: GenericDefId, - explicitly_unsized_tys: &'a FxHashSet<Ty<'db>>, - args: &'subst GenericArgs<'db>, -) -> Option<impl Iterator<Item = Clause<'db>> + Captures<'a> + Captures<'subst>> { - let interner = DbInterner::new_no_crate(db); - let sized_trait = lang_items.Sized?; - - let trait_self_idx = trait_self_param_idx(db, def); - - Some( - args.iter() - .enumerate() - .filter_map( - move |(idx, generic_arg)| { - if Some(idx) == trait_self_idx { None } else { Some(generic_arg) } - }, - ) - .filter_map(|generic_arg| generic_arg.as_type()) - .filter(move |self_ty| !explicitly_unsized_tys.contains(self_ty)) - .map(move |self_ty| { - let trait_ref = TraitRef::new_from_args( - interner, - sized_trait.into(), - GenericArgs::new_from_slice(&[self_ty.into()]), - ); - Clause(Predicate::new( - interner, - Binder::dummy(rustc_type_ir::PredicateKind::Clause( - rustc_type_ir::ClauseKind::Trait(TraitPredicate { - trait_ref, - polarity: rustc_type_ir::PredicatePolarity::Positive, - }), - )), - )) - }), - ) -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GenericDefaults(Option<Arc<[Option<StoredEarlyBinder<StoredGenericArg>>]>>); @@ -2336,10 +2410,11 @@ pub(crate) fn generic_defaults_with_diagnostics_query( } let resolver = def.resolver(db); + let store_for_self = generic_params.store(); let mut ctx = TyLoweringContext::new( db, &resolver, - generic_params.store(), + store_for_self, def, LifetimeElisionKind::AnonymousReportError, ) @@ -2357,6 +2432,7 @@ pub(crate) fn generic_defaults_with_diagnostics_query( }) .collect::<Vec<_>>(); ctx.diagnostics.clear(); // Don't include diagnostics from the parent. + ctx.store = store_for_self; defaults.extend(generic_params.iter_self().map(|(_id, p)| { let (result, has_default) = handle_generic_param(&mut ctx, idx, p); has_any_default |= has_default; @@ -2427,7 +2503,7 @@ pub(crate) fn callable_item_signature<'db>( } fn fn_sig_for_fn(db: &dyn HirDatabase, def: FunctionId) -> StoredEarlyBinder<StoredPolyFnSig> { - let data = db.function_signature(def); + let data = FunctionSignature::of(db, def); let resolver = def.resolver(db); let interner = DbInterner::new_no_crate(db); let mut ctx_params = TyLoweringContext::new( @@ -2435,7 +2511,7 @@ fn fn_sig_for_fn(db: &dyn HirDatabase, def: FunctionId) -> StoredEarlyBinder<Sto &resolver, &data.store, def.into(), - LifetimeElisionKind::for_fn_params(&data), + LifetimeElisionKind::for_fn_params(data), ); let params = data.params.iter().map(|&tr| ctx_params.lower_ty(tr)); @@ -2482,7 +2558,7 @@ fn fn_sig_for_struct_constructor( let inputs_and_output = Tys::new_from_iter(DbInterner::new_no_crate(db), params.chain(Some(ret.as_ref()))); StoredEarlyBinder::bind(StoredPolyFnSig::new(Binder::dummy(FnSig { - abi: FnAbi::RustCall, + abi: FnAbi::Rust, c_variadic: false, safety: Safety::Safe, inputs_and_output, @@ -2501,7 +2577,7 @@ fn fn_sig_for_enum_variant_constructor( let inputs_and_output = Tys::new_from_iter(DbInterner::new_no_crate(db), params.chain(Some(ret.as_ref()))); StoredEarlyBinder::bind(StoredPolyFnSig::new(Binder::dummy(FnSig { - abi: FnAbi::RustCall, + abi: FnAbi::Rust, c_variadic: false, safety: Safety::Safe, inputs_and_output, @@ -2513,7 +2589,7 @@ pub(crate) fn associated_ty_item_bounds<'db>( db: &'db dyn HirDatabase, type_alias: TypeAliasId, ) -> EarlyBinder<'db, BoundExistentialPredicates<'db>> { - let type_alias_data = db.type_alias_signature(type_alias); + let type_alias_data = TypeAliasSignature::of(db, type_alias); let resolver = hir_def::resolver::HasResolver::resolver(type_alias, db); let interner = DbInterner::new_no_crate(db); let mut ctx = TyLoweringContext::new( @@ -2535,7 +2611,7 @@ pub(crate) fn associated_ty_item_bounds<'db>( .map_bound(|c| match c { rustc_type_ir::ClauseKind::Trait(t) => { let id = t.def_id(); - let is_auto = db.trait_signature(id.0).flags.contains(TraitFlags::AUTO); + let is_auto = TraitSignature::of(db, id.0).flags.contains(TraitFlags::AUTO); if is_auto { Some(ExistentialPredicate::AutoTrait(t.def_id())) } else { @@ -2583,143 +2659,25 @@ pub(crate) fn associated_ty_item_bounds<'db>( EarlyBinder::bind(BoundExistentialPredicates::new_from_slice(&bounds)) } -pub(crate) fn associated_type_by_name_including_super_traits<'db>( +pub(crate) fn associated_type_by_name_including_super_traits_allow_ambiguity<'db>( db: &'db dyn HirDatabase, trait_ref: TraitRef<'db>, - name: &Name, -) -> Option<(TraitRef<'db>, TypeAliasId)> { - let module = trait_ref.def_id.0.module(db); - let interner = DbInterner::new_with(db, module.krate(db)); - rustc_type_ir::elaborate::supertraits(interner, Binder::dummy(trait_ref)).find_map(|t| { - let trait_id = t.as_ref().skip_binder().def_id.0; - let assoc_type = trait_id.trait_items(db).associated_type_by_name(name)?; - Some((t.skip_binder(), assoc_type)) - }) -} - -pub fn associated_type_shorthand_candidates( - db: &dyn HirDatabase, - def: GenericDefId, - res: TypeNs, - mut cb: impl FnMut(&Name, TypeAliasId) -> bool, -) -> Option<TypeAliasId> { - let interner = DbInterner::new_no_crate(db); - named_associated_type_shorthand_candidates(interner, def, res, None, |name, _, id| { - cb(name, id).then_some(id) - }) -} - -#[tracing::instrument(skip(interner, check_alias))] -fn named_associated_type_shorthand_candidates<'db, R>( - interner: DbInterner<'db>, - // If the type parameter is defined in an impl and we're in a method, there - // might be additional where clauses to consider - def: GenericDefId, - res: TypeNs, - assoc_name: Option<Name>, - mut check_alias: impl FnMut(&Name, TraitRef<'db>, TypeAliasId) -> Option<R>, -) -> Option<R> { - let db = interner.db; - let mut search = |t: TraitRef<'db>| -> Option<R> { - let mut checked_traits = FxHashSet::default(); - let mut check_trait = |trait_ref: TraitRef<'db>| { - let trait_id = trait_ref.def_id.0; - let name = &db.trait_signature(trait_id).name; - tracing::debug!(?trait_id, ?name); - if !checked_traits.insert(trait_id) { - return None; - } - let data = trait_id.trait_items(db); - - tracing::debug!(?data.items); - for (name, assoc_id) in &data.items { - if let &AssocItemId::TypeAliasId(alias) = assoc_id - && let Some(ty) = check_alias(name, trait_ref, alias) - { - return Some(ty); - } - } - None - }; - let mut stack: SmallVec<[_; 4]> = smallvec![t]; - while let Some(trait_ref) = stack.pop() { - if let Some(alias) = check_trait(trait_ref) { - return Some(alias); - } - let predicates = generic_predicates_filtered_by( - db, - GenericDefId::TraitId(trait_ref.def_id.0), - PredicateFilter::SelfTrait, - // We are likely in the midst of lowering generic predicates of `def`. - // So, if we allow `pred == def` we might fall into an infinite recursion. - // Actually, we have already checked for the case `pred == def` above as we started - // with a stack including `trait_id` - |pred| pred != def && pred == GenericDefId::TraitId(trait_ref.def_id.0), - ) - .0 - .predicates; - for pred in predicates.get().instantiate_identity() { - tracing::debug!(?pred); - let sup_trait_ref = match pred.kind().skip_binder() { - rustc_type_ir::ClauseKind::Trait(pred) => pred.trait_ref, - _ => continue, - }; - let sup_trait_ref = - EarlyBinder::bind(sup_trait_ref).instantiate(interner, trait_ref.args); - stack.push(sup_trait_ref); - } - tracing::debug!(?stack); - } - - None + name: Name, +) -> Option<(TypeAliasId, GenericArgs<'db>)> { + let (AssocTypeShorthandResolution::Resolved(assoc_type) + | AssocTypeShorthandResolution::Ambiguous { sub_trait_resolution: Some(assoc_type) }) = + resolve_type_param_assoc_type_shorthand( + db, + trait_ref.def_id.0.into(), + TypeParamId::trait_self(trait_ref.def_id.0), + name.clone(), + ) + else { + return None; }; - - match res { - TypeNs::SelfType(impl_id) => { - let trait_ref = db.impl_trait(impl_id)?; - - // FIXME(next-solver): same method in `lower` checks for impl or not - // Is that needed here? - - // we're _in_ the impl -- the binders get added back later. Correct, - // but it would be nice to make this more explicit - search(trait_ref.skip_binder()) - } - TypeNs::GenericParam(param_id) => { - // Handle `Self::Type` referring to own associated type in trait definitions - // This *must* be done first to avoid cycles with - // `generic_predicates_for_param`, but not sure that it's sufficient, - if let GenericDefId::TraitId(trait_id) = param_id.parent() { - let trait_name = &db.trait_signature(trait_id).name; - tracing::debug!(?trait_name); - let trait_generics = generics(db, trait_id.into()); - tracing::debug!(?trait_generics); - if trait_generics[param_id.local_id()].is_trait_self() { - let args = GenericArgs::identity_for_item(interner, trait_id.into()); - let trait_ref = TraitRef::new_from_args(interner, trait_id.into(), args); - tracing::debug!(?args, ?trait_ref); - return search(trait_ref); - } - } - - let predicates = - generic_predicates_for_param(db, def, param_id.into(), assoc_name.clone()); - predicates - .get() - .iter_identity() - .find_map(|pred| match pred.kind().skip_binder() { - rustc_type_ir::ClauseKind::Trait(trait_predicate) => Some(trait_predicate), - _ => None, - }) - .and_then(|trait_predicate| { - let trait_ref = trait_predicate.trait_ref; - assert!( - !trait_ref.has_escaping_bound_vars(), - "FIXME unexpected higher-ranked trait bound" - ); - search(trait_ref) - }) - } - _ => None, - } + let (assoc_type, trait_args) = assoc_type + .get_with(|(assoc_type, trait_args)| (*assoc_type, trait_args.as_ref())) + .skip_binder(); + let interner = DbInterner::new_no_crate(db); + Some((assoc_type, EarlyBinder::bind(trait_args).instantiate(interner, trait_ref.args))) } diff --git a/crates/hir-ty/src/lower/path.rs b/crates/hir-ty/src/lower/path.rs index f3d0de1227..889f0792d3 100644 --- a/crates/hir-ty/src/lower/path.rs +++ b/crates/hir-ty/src/lower/path.rs @@ -2,7 +2,7 @@ use either::Either; use hir_def::{ - GenericDefId, GenericParamId, Lookup, TraitId, TypeAliasId, + GenericDefId, GenericParamId, Lookup, TraitId, TypeParamId, expr_store::{ ExpressionStore, HygieneId, path::{ @@ -14,10 +14,9 @@ use hir_def::{ GenericParamDataRef, TypeOrConstParamData, TypeParamData, TypeParamProvenance, }, resolver::{ResolveValueResult, TypeNs, ValueNs}, - signatures::TraitFlags, + signatures::{TraitFlags, TraitSignature}, type_ref::{TypeRef, TypeRefId}, }; -use hir_expand::name::Name; use rustc_type_ir::{ AliasTerm, AliasTy, AliasTyKind, inherent::{GenericArgs as _, Region as _, Ty as _}, @@ -32,18 +31,18 @@ use crate::{ db::HirDatabase, generics::{Generics, generics}, lower::{ - GenericPredicateSource, LifetimeElisionKind, PathDiagnosticCallbackData, - named_associated_type_shorthand_candidates, + AssocTypeShorthandResolution, GenericPredicateSource, LifetimeElisionKind, + PathDiagnosticCallbackData, }, next_solver::{ - Binder, Clause, Const, DbInterner, ErrorGuaranteed, GenericArg, GenericArgs, Predicate, - ProjectionPredicate, Region, TraitRef, Ty, + Binder, Clause, Const, DbInterner, EarlyBinder, ErrorGuaranteed, GenericArg, GenericArgs, + Predicate, ProjectionPredicate, Region, TraitRef, Ty, }, }; use super::{ - ImplTraitLoweringMode, TyLoweringContext, associated_type_by_name_including_super_traits, - const_param_ty_query, ty_query, + ImplTraitLoweringMode, TyLoweringContext, + associated_type_by_name_including_super_traits_allow_ambiguity, const_param_ty_query, ty_query, }; type CallbackData<'a> = @@ -184,7 +183,7 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { let trait_ref = self.lower_trait_ref_from_resolved_path( trait_, Ty::new_error(self.ctx.interner, ErrorGuaranteed), - false, + infer_args, ); tracing::debug!(?trait_ref); self.skip_resolved_segment(); @@ -202,7 +201,7 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { // this point (`trait_ref.substitution`). let substitution = self.substs_from_path_segment( associated_ty.into(), - false, + infer_args, None, true, ); @@ -396,12 +395,10 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { } let (mod_segments, enum_segment, resolved_segment_idx) = match res { - ResolveValueResult::Partial(_, unresolved_segment, _) => { + ResolveValueResult::Partial(_, unresolved_segment) => { (segments.take(unresolved_segment - 1), None, unresolved_segment - 1) } - ResolveValueResult::ValueNs(ValueNs::EnumVariantId(_), _) - if prefix_info.enum_variant => - { + ResolveValueResult::ValueNs(ValueNs::EnumVariantId(_)) if prefix_info.enum_variant => { (segments.strip_last_two(), segments.len().checked_sub(2), segments.len() - 1) } ResolveValueResult::ValueNs(..) => (segments.strip_last(), None, segments.len() - 1), @@ -431,7 +428,7 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { } match &res { - ResolveValueResult::ValueNs(resolution, _) => { + ResolveValueResult::ValueNs(resolution) => { let resolved_segment_idx = self.current_segment_u32(); let resolved_segment = self.current_or_prev_segment; @@ -469,7 +466,7 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { | ValueNs::ConstId(_) => {} } } - ResolveValueResult::Partial(resolution, _, _) => { + ResolveValueResult::Partial(resolution, _) => { if !self.handle_type_ns_resolution(resolution) { return None; } @@ -481,43 +478,63 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { #[tracing::instrument(skip(self), ret)] fn select_associated_type(&mut self, res: Option<TypeNs>, infer_args: bool) -> Ty<'db> { let interner = self.ctx.interner; - let Some(res) = res else { - return Ty::new_error(self.ctx.interner, ErrorGuaranteed); - }; + let db = self.ctx.db; let def = self.ctx.def; let segment = self.current_or_prev_segment; let assoc_name = segment.name; - let check_alias = |name: &Name, t: TraitRef<'db>, associated_ty: TypeAliasId| { - if name != assoc_name { - return None; + let error_ty = || Ty::new_error(self.ctx.interner, ErrorGuaranteed); + let (assoc_type, trait_args) = match res { + Some(TypeNs::GenericParam(param)) => { + let AssocTypeShorthandResolution::Resolved(assoc_type) = + super::resolve_type_param_assoc_type_shorthand( + db, + def, + param, + assoc_name.clone(), + ) + else { + return error_ty(); + }; + assoc_type + .get_with(|(assoc_type, trait_args)| (*assoc_type, trait_args.as_ref())) + .skip_binder() } - - // FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent - // generic params. It's inefficient to splice the `Substitution`s, so we may want - // that method to optionally take parent `Substitution` as we already know them at - // this point (`t.substitution`). - let substs = - self.substs_from_path_segment(associated_ty.into(), infer_args, None, true); - - let substs = GenericArgs::new_from_iter( - interner, - t.args.iter().chain(substs.iter().skip(t.args.len())), - ); - - Some(Ty::new_alias( - interner, - AliasTyKind::Projection, - AliasTy::new_from_args(interner, associated_ty.into(), substs), - )) + Some(TypeNs::SelfType(impl_)) => { + let Some(impl_trait) = db.impl_trait(impl_) else { + return error_ty(); + }; + let impl_trait = impl_trait.instantiate_identity(); + // Searching for `Self::Assoc` in `impl Trait for Type` is like searching for `Self::Assoc` in `Trait`. + let AssocTypeShorthandResolution::Resolved(assoc_type) = + super::resolve_type_param_assoc_type_shorthand( + db, + impl_trait.def_id.0.into(), + TypeParamId::trait_self(impl_trait.def_id.0), + assoc_name.clone(), + ) + else { + return error_ty(); + }; + let (assoc_type, trait_args) = assoc_type + .get_with(|(assoc_type, trait_args)| (*assoc_type, trait_args.as_ref())) + .skip_binder(); + (assoc_type, EarlyBinder::bind(trait_args).instantiate(interner, impl_trait.args)) + } + _ => return error_ty(), }; - named_associated_type_shorthand_candidates( + + // FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent + // generic params. It's inefficient to splice the `Substitution`s, so we may want + // that method to optionally take parent `Substitution` as we already know them at + // this point (`t.substitution`). + let substs = self.substs_from_path_segment(assoc_type.into(), infer_args, None, true); + + let substs = GenericArgs::new_from_iter( interner, - def, - res, - Some(assoc_name.clone()), - check_alias, - ) - .unwrap_or_else(|| Ty::new_error(interner, ErrorGuaranteed)) + trait_args.iter().chain(substs.iter().skip(trait_args.len())), + ); + + Ty::new_projection_from_args(interner, assoc_type.into(), substs) } fn lower_path_inner(&mut self, typeable: TyDefId, infer_args: bool) -> Ty<'db> { @@ -608,10 +625,7 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { GenericDefId::TraitId(trait_) => { // RTN is prohibited anyways if we got here. let is_rtn = args.parenthesized == GenericArgsParentheses::ReturnTypeNotation; - let is_fn_trait = self - .ctx - .db - .trait_signature(trait_) + let is_fn_trait = TraitSignature::of(self.ctx.db, trait_) .flags .contains(TraitFlags::RUSTC_PAREN_SUGAR); is_rtn || !is_fn_trait @@ -859,12 +873,12 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { let interner = self.ctx.interner; self.current_or_prev_segment.args_and_bindings.map(|args_and_bindings| { args_and_bindings.bindings.iter().enumerate().flat_map(move |(binding_idx, binding)| { - let found = associated_type_by_name_including_super_traits( + let found = associated_type_by_name_including_super_traits_allow_ambiguity( self.ctx.db, trait_ref, - &binding.name, + binding.name.clone(), ); - let (super_trait_ref, associated_ty) = match found { + let (associated_ty, super_trait_args) = match found { None => return SmallVec::new(), Some(t) => t, }; @@ -878,7 +892,7 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { binding.args.as_ref(), associated_ty.into(), false, // this is not relevant - Some(super_trait_ref.self_ty()), + Some(super_trait_args.type_at(0)), PathGenericsSource::AssocType { segment: this.current_segment_u32(), assoc_type: binding_idx as u32, @@ -889,7 +903,7 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { }); let args = GenericArgs::new_from_iter( interner, - super_trait_ref.args.iter().chain(args.iter().skip(super_trait_ref.args.len())), + super_trait_args.iter().chain(args.iter().skip(super_trait_args.len())), ); let projection_term = AliasTerm::new_from_args(interner, associated_ty.into(), args); @@ -1007,7 +1021,7 @@ pub(crate) trait GenericArgsLowerer<'db> { fn check_generic_args_len<'db>( args_and_bindings: Option<&HirGenericArgs>, def: GenericDefId, - def_generics: &Generics, + def_generics: &Generics<'db>, infer_args: bool, lifetime_elision: &LifetimeElisionKind<'db>, lowering_assoc_type_generics: bool, diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs index ad4d79e68a..05b9ea5d74 100644 --- a/crates/hir-ty/src/method_resolution.rs +++ b/crates/hir-ty/src/method_resolution.rs @@ -17,11 +17,12 @@ use hir_def::{ ImplId, ItemContainerId, ModuleId, TraitId, attrs::AttrFlags, builtin_derive::BuiltinDeriveImplMethod, - expr_store::path::GenericArgs as HirGenericArgs, - hir::ExprId, + expr_store::{Body, path::GenericArgs as HirGenericArgs}, + hir::{ExprId, generics::GenericParams}, lang_item::LangItems, nameres::{DefMap, block_def_map, crate_def_map}, resolver::Resolver, + signatures::{ConstSignature, FunctionSignature}, }; use intern::{Symbol, sym}; use rustc_hash::{FxHashMap, FxHashSet}; @@ -366,7 +367,7 @@ pub fn lookup_impl_const<'db>( }; let trait_ref = TraitRef::new_from_args(interner, trait_id.into(), subs); - let const_signature = db.const_signature(const_id); + let const_signature = ConstSignature::of(db, const_id); let name = match const_signature.name.as_ref() { Some(name) => name, None => return (const_id, subs), @@ -396,7 +397,7 @@ pub fn is_dyn_method<'db>( let ItemContainerId::TraitId(trait_id) = func.loc(db).container else { return None; }; - let trait_params = db.generic_params(trait_id.into()).len(); + let trait_params = GenericParams::of(db, trait_id.into()).len(); let fn_params = fn_subst.len() - trait_params; let trait_ref = TraitRef::new_from_args( interner, @@ -432,14 +433,14 @@ pub(crate) fn lookup_impl_method_query<'db>( let ItemContainerId::TraitId(trait_id) = func.loc(db).container else { return (Either::Left(func), fn_subst); }; - let trait_params = db.generic_params(trait_id.into()).len(); + let trait_params = GenericParams::of(db, trait_id.into()).len(); let trait_ref = TraitRef::new_from_args( interner, trait_id.into(), GenericArgs::new_from_slice(&fn_subst[..trait_params]), ); - let name = &db.function_signature(func).name; + let name = &FunctionSignature::of(db, func).name; let Some((impl_fn, impl_subst)) = lookup_impl_assoc_item_for_trait_ref(&infcx, trait_ref, env.param_env, name).and_then( |(assoc, impl_args)| { @@ -623,7 +624,7 @@ impl InherentImpls { // To better support custom derives, collect impls in all unnamed const items. // const _: () = { ... }; for konst in module_data.scope.unnamed_consts() { - let body = db.body(konst.into()); + let body = Body::of(db, konst.into()); for (_, block_def_map) in body.blocks(db) { collect(db, block_def_map, map); } @@ -766,7 +767,7 @@ impl TraitImpls { // To better support custom derives, collect impls in all unnamed const items. // const _: () = { ... }; for konst in module_data.scope.unnamed_consts() { - let body = db.body(konst.into()); + let body = Body::of(db, konst.into()); for (_, block_def_map) in body.blocks(db) { collect(db, block_def_map, lang_items, map); } diff --git a/crates/hir-ty/src/method_resolution/confirm.rs b/crates/hir-ty/src/method_resolution/confirm.rs index 0024ca16a5..ec589085a8 100644 --- a/crates/hir-ty/src/method_resolution/confirm.rs +++ b/crates/hir-ty/src/method_resolution/confirm.rs @@ -456,7 +456,7 @@ impl<'a, 'b, 'db> ConfirmContext<'a, 'b, 'db> { substs_from_args_and_bindings( self.db(), - self.ctx.body, + self.ctx.store, generic_args, self.candidate.into(), true, diff --git a/crates/hir-ty/src/method_resolution/probe.rs b/crates/hir-ty/src/method_resolution/probe.rs index fc2bd87ee4..8c76bfbc07 100644 --- a/crates/hir-ty/src/method_resolution/probe.rs +++ b/crates/hir-ty/src/method_resolution/probe.rs @@ -5,7 +5,8 @@ use std::{cell::RefCell, convert::Infallible, ops::ControlFlow}; use hir_def::{ AssocItemId, FunctionId, GenericParamId, ImplId, ItemContainerId, TraitId, - signatures::TraitFlags, + hir::generics::GenericParams, + signatures::{FunctionSignature, TraitFlags, TraitSignature}, }; use hir_expand::name::Name; use rustc_ast_ir::Mutability; @@ -1605,7 +1606,8 @@ impl<'a, 'db, Choice: ProbeChoice<'db>> ProbeContext<'a, 'db, Choice> { // Some trait methods are excluded for arrays before 2021. // (`array.into_iter()` wants a slice iterator for compatibility.) if self_ty.is_array() && !self.ctx.edition.at_least_2021() { - let trait_signature = self.db().trait_signature(poly_trait_ref.def_id().0); + let trait_signature = + TraitSignature::of(self.db(), poly_trait_ref.def_id().0); if trait_signature .flags .contains(TraitFlags::SKIP_ARRAY_DURING_METHOD_DISPATCH) @@ -1619,7 +1621,8 @@ impl<'a, 'db, Choice: ProbeChoice<'db>> ProbeContext<'a, 'db, Choice> { if self_ty.boxed_ty().is_some_and(Ty::is_slice) && !self.ctx.edition.at_least_2024() { - let trait_signature = self.db().trait_signature(poly_trait_ref.def_id().0); + let trait_signature = + TraitSignature::of(self.db(), poly_trait_ref.def_id().0); if trait_signature .flags .contains(TraitFlags::SKIP_BOXED_SLICE_DURING_METHOD_DISPATCH) @@ -1963,7 +1966,7 @@ impl<'a, 'db, Choice: ProbeChoice<'db>> ProbeContext<'a, 'db, Choice> { // associated value (i.e., methods, constants). match item { CandidateId::FunctionId(id) if self.mode == Mode::MethodCall => { - self.db().function_signature(id).has_self_param() + FunctionSignature::of(self.db(), id).has_self_param() } _ => true, } @@ -2008,7 +2011,7 @@ impl<'a, 'db, Choice: ProbeChoice<'db>> ProbeContext<'a, 'db, Choice> { // we are given do not include type/lifetime parameters for the // method yet. So create fresh variables here for those too, // if there are any. - let generics = self.db().generic_params(method.into()); + let generics = GenericParams::of(self.db(), method.into()); let xform_fn_sig = if generics.is_empty() { fn_sig.instantiate(self.interner(), args) diff --git a/crates/hir-ty/src/mir.rs b/crates/hir-ty/src/mir.rs index 6642386011..a8865cd54e 100644 --- a/crates/hir-ty/src/mir.rs +++ b/crates/hir-ty/src/mir.rs @@ -6,7 +6,7 @@ use base_db::Crate; use either::Either; use hir_def::{ DefWithBodyId, FieldId, StaticId, TupleFieldId, UnionId, VariantId, - expr_store::Body, + expr_store::ExpressionStore, hir::{BindingAnnotation, BindingId, Expr, ExprId, Ordering, PatId}, }; use la_arena::{Arena, ArenaMap, Idx, RawIdx}; @@ -40,7 +40,10 @@ pub use borrowck::{BorrowckResult, MutabilityReason, borrowck_query}; pub use eval::{ Evaluator, MirEvalError, VTableMap, interpret_mir, pad16, render_const_using_debug_impl, }; -pub use lower::{MirLowerError, lower_to_mir, mir_body_for_closure_query, mir_body_query}; +pub use lower::{ + MirLowerError, lower_body_to_mir, lower_to_mir_with_store, mir_body_for_closure_query, + mir_body_query, +}; pub use monomorphization::{ monomorphized_mir_body_for_closure_query, monomorphized_mir_body_query, }; @@ -1207,12 +1210,12 @@ pub enum MirSpan { } impl MirSpan { - pub fn is_ref_span(&self, body: &Body) -> bool { + pub fn is_ref_span(&self, store: &ExpressionStore) -> bool { match *self { - MirSpan::ExprId(expr) => matches!(body[expr], Expr::Ref { .. }), + MirSpan::ExprId(expr) => matches!(store[expr], Expr::Ref { .. }), // FIXME: Figure out if this is correct wrt. match ergonomics. MirSpan::BindingId(binding) => { - matches!(body[binding].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) + matches!(store[binding].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) } MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false, } diff --git a/crates/hir-ty/src/mir/borrowck.rs b/crates/hir-ty/src/mir/borrowck.rs index dece61a57d..3ff2db15aa 100644 --- a/crates/hir-ty/src/mir/borrowck.rs +++ b/crates/hir-ty/src/mir/borrowck.rs @@ -5,7 +5,7 @@ use std::iter; -use hir_def::{DefWithBodyId, HasModule}; +use hir_def::{DefWithBodyId, ExpressionStoreOwnerId, HasModule}; use la_arena::ArenaMap; use rustc_hash::FxHashMap; use rustc_type_ir::inherent::GenericArgs as _; @@ -99,7 +99,7 @@ pub fn borrowck_query( let _p = tracing::info_span!("borrowck_query").entered(); let module = def.module(db); let interner = DbInterner::new_with(db, module.krate(db)); - let env = db.trait_environment_for_body(def); + let env = db.trait_environment(ExpressionStoreOwnerId::from(def)); let mut res = vec![]; // This calculates opaques defining scope which is a bit costly therefore is put outside `all_mir_bodies()`. let typing_mode = TypingMode::borrowck(interner, def.into()); @@ -121,11 +121,11 @@ fn make_fetch_closure_field<'db>( db: &'db dyn HirDatabase, ) -> impl FnOnce(InternedClosureId, GenericArgs<'db>, usize) -> Ty<'db> + use<'db> { |c: InternedClosureId, subst: GenericArgs<'db>, f: usize| { - let InternedClosure(def, _) = db.lookup_intern_closure(c); - let infer = InferenceResult::for_body(db, def); + let InternedClosure(owner, _) = db.lookup_intern_closure(c); + let interner = DbInterner::new_no_crate(db); + let infer = InferenceResult::of(db, owner); let (captures, _) = infer.closure_info(c); let parent_subst = subst.as_closure().parent_args(); - let interner = DbInterner::new_no_crate(db); captures.get(f).expect("broken closure field").ty.get().instantiate(interner, parent_subst) } } diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 5de08313f4..505db1776f 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -5,14 +5,17 @@ use std::{borrow::Cow, cell::RefCell, fmt::Write, iter, mem, ops::Range}; use base_db::{Crate, target::TargetLoadError}; use either::Either; use hir_def::{ - AdtId, DefWithBodyId, EnumVariantId, FunctionId, GeneralConstId, HasModule, ItemContainerId, - Lookup, StaticId, VariantId, - expr_store::HygieneId, + AdtId, DefWithBodyId, EnumVariantId, ExpressionStoreOwnerId, FunctionId, GeneralConstId, + HasModule, ItemContainerId, Lookup, StaticId, VariantId, + expr_store::{Body, HygieneId}, item_tree::FieldsShape, lang_item::LangItems, layout::{TagEncoding, Variants}, resolver::{HasResolver, TypeNs, ValueNs}, - signatures::{StaticFlags, StructFlags}, + signatures::{ + EnumSignature, FunctionSignature, StaticFlags, StaticSignature, StructFlags, + StructSignature, TraitSignature, + }, }; use hir_expand::{InFile, mod_path::path, name::Name}; use intern::sym; @@ -386,7 +389,7 @@ impl MirEvalError { for (func, span, def) in stack.iter().take(30).rev() { match func { Either::Left(func) => { - let function_name = db.function_signature(*func); + let function_name = FunctionSignature::of(db, *func); writeln!( f, "In function {} ({:?})", @@ -398,7 +401,7 @@ impl MirEvalError { writeln!(f, "In {closure:?}")?; } } - let source_map = db.body_with_source_map(*def).1; + let source_map = &Body::with_source_map(db, *def).1; let span: InFile<SyntaxNodePtr> = match span { MirSpan::ExprId(e) => match source_map.expr_syntax(*e) { Ok(s) => s.map(|it| it.into()), @@ -441,7 +444,7 @@ impl MirEvalError { )?; } MirEvalError::MirLowerError(func, err) => { - let function_name = db.function_signature(*func); + let function_name = FunctionSignature::of(db, *func); let self_ = match func.lookup(db).container { ItemContainerId::ImplId(impl_id) => Some({ db.impl_self_ty(impl_id) @@ -450,7 +453,10 @@ impl MirEvalError { .to_string() }), ItemContainerId::TraitId(it) => Some( - db.trait_signature(it).name.display(db, display_target.edition).to_string(), + TraitSignature::of(db, it) + .name + .display(db, display_target.edition) + .to_string(), ), _ => None, }; @@ -660,7 +666,7 @@ impl<'db> Evaluator<'db> { db, random_state: oorandom::Rand64::new(0), param_env: trait_env.unwrap_or_else(|| ParamEnvAndCrate { - param_env: db.trait_environment_for_body(owner), + param_env: db.trait_environment(ExpressionStoreOwnerId::from(owner)), krate: crate_id, }), crate_id, @@ -730,8 +736,8 @@ impl<'db> Evaluator<'db> { self.param_env.param_env, ty, |c, subst, f| { - let InternedClosure(def, _) = self.db.lookup_intern_closure(c); - let infer = InferenceResult::for_body(self.db, def); + let InternedClosure(owner, _) = self.db.lookup_intern_closure(c); + let infer = InferenceResult::of(self.db, owner); let (captures, _) = infer.closure_info(c); let parent_subst = subst.as_closure().parent_args(); captures @@ -893,8 +899,8 @@ impl<'db> Evaluator<'db> { OperandKind::Copy(p) | OperandKind::Move(p) => self.place_ty(p, locals)?, OperandKind::Constant { konst: _, ty } => ty.as_ref(), &OperandKind::Static(s) => { - let ty = InferenceResult::for_body(self.db, s.into()) - .expr_ty(self.db.body(s.into()).body_expr); + let ty = InferenceResult::of(self.db, DefWithBodyId::from(s)) + .expr_ty(Body::of(self.db, s.into()).root_expr()); Ty::new_ref( self.interner(), Region::new_static(self.interner()), @@ -1625,7 +1631,7 @@ impl<'db> Evaluator<'db> { }; match target_ty { rustc_type_ir::FloatTy::F32 => Owned((value as f32).to_le_bytes().to_vec()), - rustc_type_ir::FloatTy::F64 => Owned((value as f64).to_le_bytes().to_vec()), + rustc_type_ir::FloatTy::F64 => Owned(value.to_le_bytes().to_vec()), rustc_type_ir::FloatTy::F16 | rustc_type_ir::FloatTy::F128 => { not_supported!("unstable floating point type f16 and f128"); } @@ -1954,6 +1960,9 @@ impl<'db> Evaluator<'db> { MirEvalError::ConstEvalError(name, Box::new(e)) })? } + GeneralConstId::AnonConstId(_) => { + not_supported!("anonymous const evaluation") + } }; if let ConstKind::Value(value) = result_owner.kind() { break 'b value; @@ -2818,15 +2827,15 @@ impl<'db> Evaluator<'db> { if let Some(o) = self.static_locations.get(&st) { return Ok(*o); }; - let static_data = self.db.static_signature(st); + let static_data = StaticSignature::of(self.db, st); let result = if !static_data.flags.contains(StaticFlags::EXTERN) { let konst = self.db.const_eval_static(st).map_err(|e| { MirEvalError::ConstEvalError(static_data.name.as_str().to_owned(), Box::new(e)) })?; self.allocate_const_in_heap(locals, konst)? } else { - let ty = InferenceResult::for_body(self.db, st.into()) - .expr_ty(self.db.body(st.into()).body_expr); + let ty = InferenceResult::of(self.db, DefWithBodyId::from(st)) + .expr_ty(Body::of(self.db, st.into()).root_expr()); let Some((size, align)) = self.size_align_of(ty, locals)? else { not_supported!("unsized extern static"); }; @@ -2849,7 +2858,7 @@ impl<'db> Evaluator<'db> { let edition = self.crate_id.data(self.db).edition; let name = format!( "{}::{}", - self.db.enum_signature(loc.parent).name.display(db, edition), + EnumSignature::of(self.db, loc.parent).name.display(db, edition), loc.parent .enum_variants(self.db) .variant_name_by_id(variant) @@ -2909,7 +2918,7 @@ impl<'db> Evaluator<'db> { let id = adt_def.def_id().0; match id { AdtId::StructId(s) => { - let data = self.db.struct_signature(s); + let data = StructSignature::of(self.db, s); if data.flags.contains(StructFlags::IS_MANUALLY_DROP) { return Ok(()); } diff --git a/crates/hir-ty/src/mir/eval/shim.rs b/crates/hir-ty/src/mir/eval/shim.rs index 76c8701ea2..ff6c99ca53 100644 --- a/crates/hir-ty/src/mir/eval/shim.rs +++ b/crates/hir-ty/src/mir/eval/shim.rs @@ -45,7 +45,7 @@ impl<'db> Evaluator<'db> { return Ok(false); } - let function_data = self.db.function_signature(def); + let function_data = FunctionSignature::of(self.db, def); let attrs = AttrFlags::query(self.db, def.into()); let is_intrinsic = FunctionSignature::is_intrinsic(self.db, def); @@ -152,8 +152,8 @@ impl<'db> Evaluator<'db> { not_supported!("wrong arg count for clone"); }; let addr = Address::from_bytes(arg.get(self)?)?; - let InternedClosure(closure_owner, _) = self.db.lookup_intern_closure(id.0); - let infer = InferenceResult::for_body(self.db, closure_owner); + let InternedClosure(owner, _) = self.db.lookup_intern_closure(id.0); + let infer = InferenceResult::of(self.db, owner); let (captures, _) = infer.closure_info(id.0); let layout = self.layout(self_ty)?; let db = self.db; @@ -840,7 +840,7 @@ impl<'db> Evaluator<'db> { // cases. let [lhs, rhs] = args else { return Err(MirEvalError::InternalError( - "wrapping_add args are not provided".into(), + "ptr_guaranteed_cmp args are not provided".into(), )); }; let ans = lhs.get(self)? == rhs.get(self)?; diff --git a/crates/hir-ty/src/mir/eval/tests.rs b/crates/hir-ty/src/mir/eval/tests.rs index 61dd7757c9..6bf966c3ef 100644 --- a/crates/hir-ty/src/mir/eval/tests.rs +++ b/crates/hir-ty/src/mir/eval/tests.rs @@ -1,4 +1,4 @@ -use hir_def::{HasModule, db::DefDatabase}; +use hir_def::{GenericDefId, HasModule, signatures::FunctionSignature}; use hir_expand::EditionedFileId; use span::Edition; use syntax::{TextRange, TextSize}; @@ -25,7 +25,7 @@ fn eval_main(db: &TestDB, file_id: EditionedFileId) -> Result<(String, String), .declarations() .find_map(|x| match x { hir_def::ModuleDefId::FunctionId(x) => { - if db.function_signature(x).name.display(db, Edition::CURRENT).to_string() + if FunctionSignature::of(db, x).name.display(db, Edition::CURRENT).to_string() == "main" { Some(x) @@ -41,7 +41,7 @@ fn eval_main(db: &TestDB, file_id: EditionedFileId) -> Result<(String, String), func_id.into(), GenericArgs::empty(interner).store(), crate::ParamEnvAndCrate { - param_env: db.trait_environment(func_id.into()), + param_env: db.trait_environment(GenericDefId::from(func_id).into()), krate: func_id.krate(db), } .store(), diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 199db7a3e7..44785d948a 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -4,16 +4,17 @@ use std::{fmt::Write, iter, mem}; use base_db::Crate; use hir_def::{ - AdtId, DefWithBodyId, EnumVariantId, GeneralConstId, GenericParamId, HasModule, - ItemContainerId, LocalFieldId, Lookup, TraitId, TupleId, + AdtId, DefWithBodyId, EnumVariantId, ExpressionStoreOwnerId, GeneralConstId, GenericParamId, + HasModule, ItemContainerId, LocalFieldId, Lookup, TraitId, TupleId, expr_store::{Body, ExpressionStore, HygieneId, path::Path}, hir::{ ArithOp, Array, BinaryOp, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, - Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, + Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, generics::GenericParams, }, item_tree::FieldsShape, lang_item::LangItems, resolver::{HasResolver, ResolveValueResult, Resolver, ValueNs}, + signatures::{ConstSignature, EnumSignature, FunctionSignature, StaticSignature}, }; use hir_expand::name::Name; use la_arena::ArenaMap; @@ -82,7 +83,7 @@ struct MirLowerCtx<'a, 'db> { labeled_loop_blocks: FxHashMap<LabelId, LoopBlocks>, discr_temp: Option<Place>, db: &'db dyn HirDatabase, - body: &'a Body, + store: &'a ExpressionStore, infer: &'a InferenceResult, types: &'db crate::next_solver::DefaultAny<'db>, resolver: Resolver<'db>, @@ -185,7 +186,7 @@ impl MirLowerError { } } MirLowerError::MissingFunctionDefinition(owner, it) => { - let body = db.body(*owner); + let body = Body::of(db, *owner); writeln!( f, "Missing function definition for {}", @@ -202,13 +203,13 @@ impl MirLowerError { MirLowerError::GenericArgNotProvided(id, subst) => { let param_name = match *id { GenericParamId::TypeParamId(id) => { - db.generic_params(id.parent())[id.local_id()].name().cloned() + GenericParams::of(db, id.parent())[id.local_id()].name().cloned() } GenericParamId::ConstParamId(id) => { - db.generic_params(id.parent())[id.local_id()].name().cloned() + GenericParams::of(db, id.parent())[id.local_id()].name().cloned() } GenericParamId::LifetimeParamId(id) => { - Some(db.generic_params(id.parent)[id.local_id].name.clone()) + Some(GenericParams::of(db, id.parent)[id.local_id].name.clone()) } }; writeln!( @@ -285,7 +286,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { fn new( db: &'db dyn HirDatabase, owner: DefWithBodyId, - body: &'a Body, + store: &'a ExpressionStore, infer: &'a InferenceResult, ) -> Self { let mut basic_blocks = Arena::new(); @@ -307,7 +308,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { closures: vec![], }; let resolver = owner.resolver(db); - let env = db.trait_environment_for_body(owner); + let env = db.trait_environment(ExpressionStoreOwnerId::from(owner)); let interner = DbInterner::new_with(db, resolver.krate()); // FIXME(next-solver): Is `non_body_analysis()` correct here? Don't we want to reveal opaque types defined by this body? let infcx = interner.infer_ctxt().build(TypingMode::non_body_analysis()); @@ -316,7 +317,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { result: mir, db, infer, - body, + store, types: crate::next_solver::default_types(db), owner, resolver, @@ -354,7 +355,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { current: BasicBlockId, ) -> Result<'db, Option<(Operand, BasicBlockId)>> { if !self.has_adjustments(expr_id) - && let Expr::Literal(l) = &self.body[expr_id] + && let Expr::Literal(l) = &self.store[expr_id] { let ty = self.expr_ty_without_adjust(expr_id); return Ok(Some((self.lower_literal_to_operand(ty, l)?, current))); @@ -461,7 +462,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { place: Place, mut current: BasicBlockId, ) -> Result<'db, Option<BasicBlockId>> { - match &self.body[expr_id] { + match &self.store[expr_id] { Expr::OffsetOf(_) => { not_supported!("builtin#offset_of") } @@ -472,7 +473,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { if let DefWithBodyId::FunctionId(f) = self.owner { let assoc = f.lookup(self.db); if let ItemContainerId::TraitId(t) = assoc.container { - let name = &self.db.function_signature(f).name; + let name = &FunctionSignature::of(self.db, f).name; return Err(MirLowerError::TraitFunctionDefinition(t, name.clone())); } } @@ -500,7 +501,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { } else { let resolver_guard = self.resolver.update_to_inner_scope(self.db, self.owner, expr_id); - let hygiene = self.body.expr_path_hygiene(expr_id); + let hygiene = self.store.expr_path_hygiene(expr_id); let result = self .resolver .resolve_path_in_value_ns_fully(self.db, p, hygiene) @@ -509,7 +510,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { self.db, p, DisplayTarget::from_crate(self.db, self.krate()), - self.body, + self.store, ) })?; self.resolver.reset_to_guard(resolver_guard); @@ -882,7 +883,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { let variant_id = self.infer.variant_resolution_for_expr(expr_id).ok_or_else(|| match path { Some(p) => MirLowerError::UnresolvedName( - hir_display_with_store(&**p, self.body) + hir_display_with_store(&**p, self.store) .display(self.db, self.display_target()) .to_string(), ), @@ -1382,7 +1383,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { } fn push_field_projection(&mut self, place: &mut Place, expr_id: ExprId) -> Result<'db, ()> { - if let Expr::Field { expr, name } = &self.body[expr_id] { + if let Expr::Field { expr, name } = &self.store[expr_id] { if let TyKind::Tuple(..) = self.expr_ty_after_adjustments(*expr).kind() { let index = name.as_tuple_index().ok_or(MirLowerError::TypeError("named field on tuple"))? @@ -1411,7 +1412,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { ty: Ty<'db>, loc: &ExprId, ) -> Result<'db, Operand> { - match &self.body[*loc] { + match &self.store[*loc] { Expr::Literal(l) => self.lower_literal_to_operand(ty, l), Expr::Path(c) => { let owner = self.owner; @@ -1421,7 +1422,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { self.db, c, DisplayTarget::from_crate(db, owner.krate(db)), - self.body, + self.store, ) }; let pr = self @@ -1429,7 +1430,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { .resolve_path_in_value_ns(self.db, c, HygieneId::ROOT) .ok_or_else(unresolved_name)?; match pr { - ResolveValueResult::ValueNs(v, _) => { + ResolveValueResult::ValueNs(v) => { if let ValueNs::ConstId(c) = v { self.lower_const_to_operand( GenericArgs::empty(self.interner()), @@ -1439,7 +1440,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { not_supported!("bad path in range pattern"); } } - ResolveValueResult::Partial(_, _, _) => { + ResolveValueResult::Partial(_, _) => { not_supported!("associated constants in range pattern") } } @@ -1546,6 +1547,9 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { MirLowerError::ConstEvalError(name.into(), Box::new(e)) })? } + GeneralConstId::AnonConstId(_) => { + return Err(MirLowerError::IncompleteExpr); + } } }; let ty = self @@ -1553,6 +1557,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { .value_ty(match const_id { GeneralConstId::ConstId(id) => id.into(), GeneralConstId::StaticId(id) => id.into(), + GeneralConstId::AnonConstId(_) => unreachable!("handled above"), }) .unwrap() .instantiate(self.interner(), subst); @@ -1859,7 +1864,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { } } else { let mut err = None; - self.body.walk_bindings_in_pat(*pat, |b| { + self.store.walk_bindings_in_pat(*pat, |b| { if let Err(e) = self.push_storage_live(b, current) { err = Some(e); } @@ -1913,9 +1918,9 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { self.result.param_locals.extend(params.clone().map(|(it, ty)| { let local_id = self.result.locals.alloc(Local { ty: ty.store() }); self.drop_scopes.last_mut().unwrap().locals.push(local_id); - if let Pat::Bind { id, subpat: None } = self.body[it] + if let Pat::Bind { id, subpat: None } = self.store[it] && matches!( - self.body[id].mode, + self.store[id].mode, BindingAnnotation::Unannotated | BindingAnnotation::Mutable ) { @@ -1924,7 +1929,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { local_id })); // and then rest of bindings - for (id, _) in self.body.bindings() { + for (id, _) in self.store.bindings() { if !pick_binding(id) { continue; } @@ -1953,7 +1958,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { .into_iter() .skip(base_param_count + self_binding.is_some() as usize); for ((param, _), local) in params.zip(local_params) { - if let Pat::Bind { id, .. } = self.body[param] + if let Pat::Bind { id, .. } = self.store[param] && local == self.binding_local(id)? { continue; @@ -1989,7 +1994,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { let loc = variant.lookup(db); let name = format!( "{}::{}", - self.db.enum_signature(loc.parent).name.display(db, edition), + EnumSignature::of(db, loc.parent).name.display(db, edition), loc.parent .enum_variants(self.db) .variant_name_by_id(variant) @@ -2106,8 +2111,10 @@ pub fn mir_body_for_closure_query<'db>( closure: InternedClosureId, ) -> Result<'db, Arc<MirBody>> { let InternedClosure(owner, expr) = db.lookup_intern_closure(closure); - let body = db.body(owner); - let infer = InferenceResult::for_body(db, owner); + let body_owner = + owner.as_def_with_body().expect("MIR lowering should only happen for body-owned closures"); + let body = Body::of(db, body_owner); + let infer = InferenceResult::of(db, body_owner); let Expr::Closure { args, body: root, .. } = &body[expr] else { implementation_error!("closure expression is not closure"); }; @@ -2115,7 +2122,7 @@ pub fn mir_body_for_closure_query<'db>( implementation_error!("closure expression is not closure"); }; let (captures, kind) = infer.closure_info(closure); - let mut ctx = MirLowerCtx::new(db, owner, &body, infer); + let mut ctx = MirLowerCtx::new(db, body_owner, &body.store, infer); // 0 is return local ctx.result.locals.alloc(Local { ty: infer.expr_ty(*root).store() }); let closure_local = ctx.result.locals.alloc(Local { @@ -2138,7 +2145,7 @@ pub fn mir_body_for_closure_query<'db>( }); ctx.result.param_locals.push(closure_local); let sig = ctx.interner().signature_unclosure(substs.as_closure().sig(), Safety::Safe); - let resolver_guard = ctx.resolver.update_to_inner_scope(db, owner, expr); + let resolver_guard = ctx.resolver.update_to_inner_scope(db, body_owner, expr); let current = ctx.lower_params_and_bindings( args.iter().zip(sig.skip_binder().inputs().iter()).map(|(it, y)| (*it, *y)), None, @@ -2205,7 +2212,7 @@ pub fn mir_body_for_closure_query<'db>( .result .binding_locals .into_iter() - .filter(|it| ctx.body.binding_owner(it.0) == Some(expr)) + .filter(|it| ctx.store.binding_owner(it.0) == Some(expr)) .collect(); if let Some(err) = err { return Err(MirLowerError::UnresolvedUpvar(err)); @@ -2222,13 +2229,12 @@ pub fn mir_body_query<'db>( let edition = krate.data(db).edition; let detail = match def { DefWithBodyId::FunctionId(it) => { - db.function_signature(it).name.display(db, edition).to_string() + FunctionSignature::of(db, it).name.display(db, edition).to_string() } DefWithBodyId::StaticId(it) => { - db.static_signature(it).name.display(db, edition).to_string() + StaticSignature::of(db, it).name.display(db, edition).to_string() } - DefWithBodyId::ConstId(it) => db - .const_signature(it) + DefWithBodyId::ConstId(it) => ConstSignature::of(db, it) .name .clone() .unwrap_or_else(Name::missing) @@ -2243,9 +2249,9 @@ pub fn mir_body_query<'db>( } }; let _p = tracing::info_span!("mir_body_query", ?detail).entered(); - let body = db.body(def); - let infer = InferenceResult::for_body(db, def); - let mut result = lower_to_mir(db, def, &body, infer, body.body_expr)?; + let body = Body::of(db, def); + let infer = InferenceResult::of(db, def); + let mut result = lower_body_to_mir(db, def, body, infer, body.root_expr())?; result.shrink_to_fit(); Ok(Arc::new(result)) } @@ -2258,44 +2264,74 @@ pub(crate) fn mir_body_cycle_result<'db>( Err(MirLowerError::Loop) } -pub fn lower_to_mir<'db>( +/// Extracts params from `body.params`/`body.self_param` and the callable signature, +/// then delegates to [`lower_to_mir_with_store`]. +pub fn lower_body_to_mir<'db>( db: &'db dyn HirDatabase, owner: DefWithBodyId, body: &Body, infer: &InferenceResult, - // FIXME: root_expr should always be the body.body_expr, but since `X` in `[(); X]` doesn't have its own specific body yet, we - // need to take this input explicitly. + // FIXME: root_expr should always be the body.body_expr, + // but this is currently also used for `X` in `[(); X]` which live in the same expression store root_expr: ExprId, ) -> Result<'db, MirBody> { + let is_root = root_expr == body.root_expr(); + // Extract params and self_param only when lowering the body's root expression for a function. + if is_root && let DefWithBodyId::FunctionId(fid) = owner { + let callable_sig = + db.callable_item_signature(fid.into()).instantiate_identity().skip_binder(); + let mut param_tys = callable_sig.inputs().iter().copied(); + let self_param = body.self_param.and_then(|id| Some((id, param_tys.next()?))); + + lower_to_mir_with_store( + db, + owner, + &body.store, + infer, + root_expr, + body.params.iter().copied().zip(param_tys), + self_param, + is_root, + ) + } else { + lower_to_mir_with_store( + db, + owner, + &body.store, + infer, + root_expr, + iter::empty(), + None, + is_root, + ) + } +} + +/// # Parameters +/// - `is_root`: `true` when `root_expr` is the body's top-level expression (picks +/// bindings with no owner); `false` when lowering an inline const or anonymous +/// const (picks bindings owned by `root_expr`). +pub fn lower_to_mir_with_store<'db>( + db: &'db dyn HirDatabase, + owner: DefWithBodyId, + store: &ExpressionStore, + infer: &InferenceResult, + root_expr: ExprId, + params: impl Iterator<Item = (PatId, Ty<'db>)> + Clone, + self_param: Option<(BindingId, Ty<'db>)>, + is_root: bool, +) -> Result<'db, MirBody> { if infer.type_mismatches().next().is_some() || infer.is_erroneous() { return Err(MirLowerError::HasErrors); } - let mut ctx = MirLowerCtx::new(db, owner, body, infer); + let mut ctx = MirLowerCtx::new(db, owner, store, infer); // 0 is return local ctx.result.locals.alloc(Local { ty: ctx.expr_ty_after_adjustments(root_expr).store() }); let binding_picker = |b: BindingId| { - let owner = ctx.body.binding_owner(b); - if root_expr == body.body_expr { owner.is_none() } else { owner == Some(root_expr) } - }; - // 1 to param_len is for params - // FIXME: replace with let chain once it becomes stable - let current = 'b: { - if body.body_expr == root_expr { - // otherwise it's an inline const, and has no parameter - if let DefWithBodyId::FunctionId(fid) = owner { - let callable_sig = - db.callable_item_signature(fid.into()).instantiate_identity().skip_binder(); - let mut params = callable_sig.inputs().iter().copied(); - let self_param = body.self_param.and_then(|id| Some((id, params.next()?))); - break 'b ctx.lower_params_and_bindings( - body.params.iter().zip(params).map(|(it, y)| (*it, y)), - self_param, - binding_picker, - )?; - } - } - ctx.lower_params_and_bindings([].into_iter(), None, binding_picker)? + let owner = ctx.store.binding_owner(b); + if is_root { owner.is_none() } else { owner == Some(root_expr) } }; + let current = ctx.lower_params_and_bindings(params, self_param, binding_picker)?; if let Some(current) = ctx.lower_expr_to_place(root_expr, return_slot().into(), current)? { let current = ctx.pop_drop_scope_assert_finished(current, root_expr.into())?; ctx.set_terminator(current, TerminatorKind::Return, root_expr.into()); diff --git a/crates/hir-ty/src/mir/lower/as_place.rs b/crates/hir-ty/src/mir/lower/as_place.rs index cf05ec27ac..17dc95fb24 100644 --- a/crates/hir-ty/src/mir/lower/as_place.rs +++ b/crates/hir-ty/src/mir/lower/as_place.rs @@ -137,11 +137,11 @@ impl<'db> MirLowerCtx<'_, 'db> { } this.lower_expr_to_some_place_without_adjust(expr_id, current) }; - match &self.body[expr_id] { + match &self.store[expr_id] { Expr::Path(p) => { let resolver_guard = self.resolver.update_to_inner_scope(self.db, self.owner, expr_id); - let hygiene = self.body.expr_path_hygiene(expr_id); + let hygiene = self.store.expr_path_hygiene(expr_id); let resolved = self.resolver.resolve_path_in_value_ns_fully(self.db, p, hygiene); self.resolver.reset_to_guard(resolver_guard); let Some(pr) = resolved else { diff --git a/crates/hir-ty/src/mir/lower/pattern_matching.rs b/crates/hir-ty/src/mir/lower/pattern_matching.rs index a8aacbff16..99c5f0fc65 100644 --- a/crates/hir-ty/src/mir/lower/pattern_matching.rs +++ b/crates/hir-ty/src/mir/lower/pattern_matching.rs @@ -131,7 +131,7 @@ impl<'db> MirLowerCtx<'_, 'db> { .collect::<Vec<_>>() .into(), ); - Ok(match &self.body[pattern] { + Ok(match &self.store[pattern] { Pat::Missing => return Err(MirLowerError::IncompletePattern), Pat::Wild => (current, current_else), Pat::Tuple { args, ellipsis } => { @@ -322,7 +322,7 @@ impl<'db> MirLowerCtx<'_, 'db> { } if let &Some(slice) = slice && mode != MatchingMode::Check - && let Pat::Bind { id, subpat: _ } = self.body[slice] + && let Pat::Bind { id, subpat: _ } = self.store[slice] { let next_place = cond_place.project( ProjectionElem::Subslice { @@ -363,9 +363,14 @@ impl<'db> MirLowerCtx<'_, 'db> { )?, None => { let unresolved_name = || { - MirLowerError::unresolved_path(self.db, p, self.display_target(), self.body) + MirLowerError::unresolved_path( + self.db, + p, + self.display_target(), + self.store, + ) }; - let hygiene = self.body.pat_path_hygiene(pattern); + let hygiene = self.store.pat_path_hygiene(pattern); let pr = self .resolver .resolve_path_in_value_ns(self.db, p, hygiene) @@ -373,7 +378,7 @@ impl<'db> MirLowerCtx<'_, 'db> { if let ( MatchingMode::Assign, - ResolveValueResult::ValueNs(ValueNs::LocalBinding(binding), _), + ResolveValueResult::ValueNs(ValueNs::LocalBinding(binding)), ) = (mode, &pr) { let local = self.binding_local(*binding)?; @@ -398,7 +403,7 @@ impl<'db> MirLowerCtx<'_, 'db> { { break 'b (c, x.1); } - if let ResolveValueResult::ValueNs(ValueNs::ConstId(c), _) = pr { + if let ResolveValueResult::ValueNs(ValueNs::ConstId(c)) = pr { break 'b (c, GenericArgs::empty(self.interner())); } not_supported!("path in pattern position that is not const or variant") @@ -432,7 +437,7 @@ impl<'db> MirLowerCtx<'_, 'db> { (next, Some(else_target)) } }, - Pat::Lit(l) => match &self.body[*l] { + Pat::Lit(l) => match &self.store[*l] { Expr::Literal(l) => { if mode == MatchingMode::Check { let c = self.lower_literal_to_operand(self.infer.pat_ty(pattern), l)?; diff --git a/crates/hir-ty/src/mir/pretty.rs b/crates/hir-ty/src/mir/pretty.rs index 96b90a3f40..4b654a0fbe 100644 --- a/crates/hir-ty/src/mir/pretty.rs +++ b/crates/hir-ty/src/mir/pretty.rs @@ -6,7 +6,11 @@ use std::{ }; use either::Either; -use hir_def::{expr_store::Body, hir::BindingId}; +use hir_def::{ + expr_store::Body, + hir::BindingId, + signatures::{ConstSignature, EnumSignature, FunctionSignature, StaticSignature}, +}; use hir_expand::{Lookup, name::Name}; use la_arena::ArenaMap; @@ -38,19 +42,19 @@ macro_rules! wln { impl MirBody { pub fn pretty_print(&self, db: &dyn HirDatabase, display_target: DisplayTarget) -> String { - let hir_body = db.body(self.owner); - let mut ctx = MirPrettyCtx::new(self, &hir_body, db, display_target); + let hir_body = Body::of(db, self.owner); + let mut ctx = MirPrettyCtx::new(self, hir_body, db, display_target); ctx.for_body(|this| match ctx.body.owner { hir_def::DefWithBodyId::FunctionId(id) => { - let data = db.function_signature(id); + let data = FunctionSignature::of(db, id); w!(this, "fn {}() ", data.name.display(db, this.display_target.edition)); } hir_def::DefWithBodyId::StaticId(id) => { - let data = db.static_signature(id); + let data = StaticSignature::of(db, id); w!(this, "static {}: _ = ", data.name.display(db, this.display_target.edition)); } hir_def::DefWithBodyId::ConstId(id) => { - let data = db.const_signature(id); + let data = ConstSignature::of(db, id); w!( this, "const {}: _ = ", @@ -66,7 +70,7 @@ impl MirBody { w!( this, "enum {}::{} = ", - db.enum_signature(loc.parent).name.display(db, edition), + EnumSignature::of(db, loc.parent).name.display(db, edition), loc.parent .enum_variants(db) .variant_name_by_id(id) diff --git a/crates/hir-ty/src/next_solver/abi.rs b/crates/hir-ty/src/next_solver/abi.rs index 80d1ea4aa4..1813abab86 100644 --- a/crates/hir-ty/src/next_solver/abi.rs +++ b/crates/hir-ty/src/next_solver/abi.rs @@ -62,7 +62,6 @@ impl<'db> rustc_type_ir::inherent::Abi<DbInterner<'db>> for FnAbi { } fn is_rust(self) -> bool { - // TODO: rustc does not consider `RustCall` to be true here, but Chalk does - matches!(self, FnAbi::Rust | FnAbi::RustCall) + matches!(self, FnAbi::Rust) } } diff --git a/crates/hir-ty/src/next_solver/def_id.rs b/crates/hir-ty/src/next_solver/def_id.rs index aa6caefc4a..00161d6d08 100644 --- a/crates/hir-ty/src/next_solver/def_id.rs +++ b/crates/hir-ty/src/next_solver/def_id.rs @@ -1,9 +1,13 @@ //! Definition of `SolverDefId` use hir_def::{ - AdtId, AttrDefId, BuiltinDeriveImplId, CallableDefId, ConstId, DefWithBodyId, EnumId, - EnumVariantId, FunctionId, GeneralConstId, GenericDefId, ImplId, StaticId, StructId, TraitId, - TypeAliasId, UnionId, + AdtId, AnonConstId, AttrDefId, BuiltinDeriveImplId, CallableDefId, ConstId, DefWithBodyId, + EnumId, EnumVariantId, ExpressionStoreOwnerId, FunctionId, GeneralConstId, GenericDefId, + ImplId, StaticId, StructId, TraitId, TypeAliasId, UnionId, VariantId, + signatures::{ + ConstSignature, EnumSignature, FunctionSignature, StaticSignature, StructSignature, + TraitSignature, TypeAliasSignature, UnionSignature, + }, }; use rustc_type_ir::inherent; use stdx::impl_from; @@ -12,13 +16,13 @@ use crate::db::{InternedClosureId, InternedCoroutineId, InternedOpaqueTyId}; use super::DbInterner; -#[derive(Debug, PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash, salsa::Supertype)] +#[derive(Debug, PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash)] pub enum Ctor { Struct(StructId), Enum(EnumVariantId), } -#[derive(PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash, salsa::Supertype)] +#[derive(PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash)] pub enum SolverDefId { AdtId(AdtId), ConstId(ConstId), @@ -26,13 +30,13 @@ pub enum SolverDefId { ImplId(ImplId), BuiltinDeriveImplId(BuiltinDeriveImplId), StaticId(StaticId), + AnonConstId(AnonConstId), TraitId(TraitId), TypeAliasId(TypeAliasId), InternedClosureId(InternedClosureId), InternedCoroutineId(InternedCoroutineId), InternedOpaqueTyId(InternedOpaqueTyId), EnumVariantId(EnumVariantId), - // FIXME(next-solver): Do we need the separation of `Ctor`? It duplicates some variants. Ctor(Ctor), } @@ -42,32 +46,33 @@ impl std::fmt::Debug for SolverDefId { let db = interner.db; match *self { SolverDefId::AdtId(AdtId::StructId(id)) => { - f.debug_tuple("AdtId").field(&db.struct_signature(id).name.as_str()).finish() + f.debug_tuple("AdtId").field(&StructSignature::of(db, id).name.as_str()).finish() } SolverDefId::AdtId(AdtId::EnumId(id)) => { - f.debug_tuple("AdtId").field(&db.enum_signature(id).name.as_str()).finish() + f.debug_tuple("AdtId").field(&EnumSignature::of(db, id).name.as_str()).finish() } SolverDefId::AdtId(AdtId::UnionId(id)) => { - f.debug_tuple("AdtId").field(&db.union_signature(id).name.as_str()).finish() + f.debug_tuple("AdtId").field(&UnionSignature::of(db, id).name.as_str()).finish() } SolverDefId::ConstId(id) => f .debug_tuple("ConstId") - .field(&db.const_signature(id).name.as_ref().map_or("_", |name| name.as_str())) + .field(&ConstSignature::of(db, id).name.as_ref().map_or("_", |name| name.as_str())) + .finish(), + SolverDefId::FunctionId(id) => f + .debug_tuple("FunctionId") + .field(&FunctionSignature::of(db, id).name.as_str()) .finish(), - SolverDefId::FunctionId(id) => { - f.debug_tuple("FunctionId").field(&db.function_signature(id).name.as_str()).finish() - } SolverDefId::ImplId(id) => f.debug_tuple("ImplId").field(&id).finish(), SolverDefId::BuiltinDeriveImplId(id) => f.debug_tuple("ImplId").field(&id).finish(), SolverDefId::StaticId(id) => { - f.debug_tuple("StaticId").field(&db.static_signature(id).name.as_str()).finish() + f.debug_tuple("StaticId").field(&StaticSignature::of(db, id).name.as_str()).finish() } SolverDefId::TraitId(id) => { - f.debug_tuple("TraitId").field(&db.trait_signature(id).name.as_str()).finish() + f.debug_tuple("TraitId").field(&TraitSignature::of(db, id).name.as_str()).finish() } SolverDefId::TypeAliasId(id) => f .debug_tuple("TypeAliasId") - .field(&db.type_alias_signature(id).name.as_str()) + .field(&TypeAliasSignature::of(db, id).name.as_str()) .finish(), SolverDefId::InternedClosureId(id) => { f.debug_tuple("InternedClosureId").field(&id).finish() @@ -83,20 +88,21 @@ impl std::fmt::Debug for SolverDefId { f.debug_tuple("EnumVariantId") .field(&format_args!( "\"{}::{}\"", - db.enum_signature(parent_enum).name.as_str(), + EnumSignature::of(db, parent_enum).name.as_str(), parent_enum.enum_variants(db).variant_name_by_id(id).unwrap().as_str() )) .finish() } + SolverDefId::AnonConstId(id) => f.debug_tuple("AnonConstId").field(&id).finish(), SolverDefId::Ctor(Ctor::Struct(id)) => { - f.debug_tuple("Ctor").field(&db.struct_signature(id).name.as_str()).finish() + f.debug_tuple("Ctor").field(&StructSignature::of(db, id).name.as_str()).finish() } SolverDefId::Ctor(Ctor::Enum(id)) => { let parent_enum = id.loc(db).parent; f.debug_tuple("Ctor") .field(&format_args!( "\"{}::{}\"", - db.enum_signature(parent_enum).name.as_str(), + EnumSignature::of(db, parent_enum).name.as_str(), parent_enum.enum_variants(db).variant_name_by_id(id).unwrap().as_str() )) .finish() @@ -112,6 +118,7 @@ impl_from!( ImplId, BuiltinDeriveImplId, StaticId, + AnonConstId, TraitId, TypeAliasId, InternedClosureId, @@ -142,6 +149,7 @@ impl From<GeneralConstId> for SolverDefId { match value { GeneralConstId::ConstId(const_id) => SolverDefId::ConstId(const_id), GeneralConstId::StaticId(static_id) => SolverDefId::StaticId(static_id), + GeneralConstId::AnonConstId(anon_const_id) => SolverDefId::AnonConstId(anon_const_id), } } } @@ -158,6 +166,28 @@ impl From<DefWithBodyId> for SolverDefId { } } +impl From<VariantId> for SolverDefId { + #[inline] + fn from(value: VariantId) -> Self { + match value { + VariantId::EnumVariantId(id) => id.into(), + VariantId::StructId(id) => id.into(), + VariantId::UnionId(id) => id.into(), + } + } +} + +impl From<ExpressionStoreOwnerId> for SolverDefId { + #[inline] + fn from(value: ExpressionStoreOwnerId) -> Self { + match value { + ExpressionStoreOwnerId::Body(body_id) => body_id.into(), + ExpressionStoreOwnerId::Signature(sig_id) => sig_id.into(), + ExpressionStoreOwnerId::VariantFields(variant_id) => variant_id.into(), + } + } +} + impl TryFrom<SolverDefId> for AttrDefId { type Error = (); #[inline] @@ -176,7 +206,8 @@ impl TryFrom<SolverDefId> for AttrDefId { SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) - | SolverDefId::InternedOpaqueTyId(_) => Err(()), + | SolverDefId::InternedOpaqueTyId(_) + | SolverDefId::AnonConstId(_) => Err(()), } } } @@ -199,6 +230,7 @@ impl TryFrom<SolverDefId> for DefWithBodyId { | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) | SolverDefId::Ctor(Ctor::Struct(_)) + | SolverDefId::AnonConstId(_) | SolverDefId::AdtId(_) => return Err(()), }; Ok(id) @@ -222,6 +254,7 @@ impl TryFrom<SolverDefId> for GenericDefId { | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::EnumVariantId(_) | SolverDefId::BuiltinDeriveImplId(_) + | SolverDefId::AnonConstId(_) | SolverDefId::Ctor(_) => return Err(()), }) } @@ -343,6 +376,7 @@ impl From<GeneralConstIdWrapper> for SolverDefId { match value.0 { GeneralConstId::ConstId(id) => SolverDefId::ConstId(id), GeneralConstId::StaticId(id) => SolverDefId::StaticId(id), + GeneralConstId::AnonConstId(id) => SolverDefId::AnonConstId(id), } } } @@ -353,6 +387,7 @@ impl TryFrom<SolverDefId> for GeneralConstIdWrapper { match value { SolverDefId::ConstId(it) => Ok(Self(it.into())), SolverDefId::StaticId(it) => Ok(Self(it.into())), + SolverDefId::AnonConstId(it) => Ok(Self(it.into())), _ => Err(()), } } diff --git a/crates/hir-ty/src/next_solver/fulfill/errors.rs b/crates/hir-ty/src/next_solver/fulfill/errors.rs index 8f798b4ade..0e8218b33a 100644 --- a/crates/hir-ty/src/next_solver/fulfill/errors.rs +++ b/crates/hir-ty/src/next_solver/fulfill/errors.rs @@ -617,6 +617,7 @@ impl<'db> NextSolverError<'db> { } mod wf { + use hir_def::signatures::ImplSignature; use hir_def::{GeneralConstId, ItemContainerId}; use rustc_type_ir::inherent::{ AdtDef, BoundExistentialPredicates, GenericArgs as _, IntoKind, SliceLike, Term as _, @@ -1054,7 +1055,7 @@ mod wf { if let GeneralConstId::ConstId(uv_def) = uv.def.0 && let ItemContainerId::ImplId(impl_) = uv_def.loc(self.interner().db).container - && self.interner().db.impl_signature(impl_).target_trait.is_none() + && ImplSignature::of(self.interner().db, impl_).target_trait.is_none() { return; // Subtree is handled by above function } else { diff --git a/crates/hir-ty/src/next_solver/generics.rs b/crates/hir-ty/src/next_solver/generics.rs index a8288b4e82..f31de21796 100644 --- a/crates/hir-ty/src/next_solver/generics.rs +++ b/crates/hir-ty/src/next_solver/generics.rs @@ -55,7 +55,7 @@ pub(crate) fn generics(interner: DbInterner<'_>, def: SolverDefId) -> Generics { let (parent, own_params) = match (def.try_into(), def) { (Ok(def), _) => ( parent_generic_def(db, def), - own_params_for_generic_params(def, &db.generic_params(def)), + own_params_for_generic_params(def, GenericParams::of(db, def)), ), (_, SolverDefId::InternedOpaqueTyId(id)) => { match db.lookup_intern_impl_trait_id(id) { diff --git a/crates/hir-ty/src/next_solver/interner.rs b/crates/hir-ty/src/next_solver/interner.rs index e17bdac68c..5b81c7675d 100644 --- a/crates/hir-ty/src/next_solver/interner.rs +++ b/crates/hir-ty/src/next_solver/interner.rs @@ -10,11 +10,15 @@ pub use tls_db::{attach_db, attach_db_allow_change, with_attached_db}; use base_db::Crate; use hir_def::{ - AdtId, CallableDefId, DefWithBodyId, EnumVariantId, HasModule, ItemContainerId, StructId, - UnionId, VariantId, + AdtId, CallableDefId, DefWithBodyId, EnumVariantId, ExpressionStoreOwnerId, HasModule, + ItemContainerId, StructId, UnionId, VariantId, attrs::AttrFlags, + expr_store::{Body, ExpressionStore}, lang_item::LangItems, - signatures::{FieldData, FnFlags, ImplFlags, StructFlags, TraitFlags}, + signatures::{ + EnumSignature, FieldData, FnFlags, FunctionSignature, ImplFlags, ImplSignature, + StructFlags, StructSignature, TraitFlags, TraitSignature, UnionSignature, + }, }; use la_arena::Idx; use rustc_abi::{ReprFlags, ReprOptions}; @@ -548,7 +552,7 @@ impl AdtDef { let db = interner.db(); let (flags, variants, repr) = match def_id { AdtId::StructId(struct_id) => { - let data = db.struct_signature(struct_id); + let data = StructSignature::of(db, struct_id); let flags = AdtFlags { is_enum: false, @@ -775,15 +779,15 @@ impl fmt::Debug for AdtDef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { crate::with_attached_db(|db| match self.inner().id { AdtId::StructId(struct_id) => { - let data = db.struct_signature(struct_id); + let data = StructSignature::of(db, struct_id); f.write_str(data.name.as_str()) } AdtId::UnionId(union_id) => { - let data = db.union_signature(union_id); + let data = UnionSignature::of(db, union_id); f.write_str(data.name.as_str()) } AdtId::EnumId(enum_id) => { - let data = db.enum_signature(enum_id); + let data = EnumSignature::of(db, enum_id); f.write_str(data.name.as_str()) } }) @@ -1193,7 +1197,8 @@ impl<'db> Interner for DbInterner<'db> { | SolverDefId::ImplId(_) | SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) - | SolverDefId::InternedCoroutineId(_) => { + | SolverDefId::InternedCoroutineId(_) + | SolverDefId::AnonConstId(_) => { return VariancesOf::empty(self); } }; @@ -1230,7 +1235,7 @@ impl<'db> Interner for DbInterner<'db> { SolverDefId::InternedOpaqueTyId(_) => AliasTyKind::Opaque, SolverDefId::TypeAliasId(type_alias) => match type_alias.loc(self.db).container { ItemContainerId::ImplId(impl_) - if self.db.impl_signature(impl_).target_trait.is_none() => + if ImplSignature::of(self.db, impl_).target_trait.is_none() => { AliasTyKind::Inherent } @@ -1249,7 +1254,7 @@ impl<'db> Interner for DbInterner<'db> { SolverDefId::InternedOpaqueTyId(_) => AliasTermKind::OpaqueTy, SolverDefId::TypeAliasId(type_alias) => match type_alias.loc(self.db).container { ItemContainerId::ImplId(impl_) - if self.db.impl_signature(impl_).target_trait.is_none() => + if ImplSignature::of(self.db, impl_).target_trait.is_none() => { AliasTermKind::InherentTy } @@ -1260,7 +1265,9 @@ impl<'db> Interner for DbInterner<'db> { }, // rustc creates an `AnonConst` for consts, and evaluates them with CTFE (normalizing projections // via selection, similar to ours `find_matching_impl()`, and not with the trait solver), so mimic it. - SolverDefId::ConstId(_) => AliasTermKind::UnevaluatedConst, + SolverDefId::ConstId(_) | SolverDefId::AnonConstId(_) => { + AliasTermKind::UnevaluatedConst + } _ => unimplemented!("Unexpected alias: {:?}", alias.def_id), } } @@ -1308,22 +1315,10 @@ impl<'db> Interner for DbInterner<'db> { SolverDefId::TypeAliasId(it) => it.lookup(self.db()).container, SolverDefId::ConstId(it) => it.lookup(self.db()).container, SolverDefId::InternedClosureId(it) => { - return self - .db() - .lookup_intern_closure(it) - .0 - .as_generic_def_id(self.db()) - .unwrap() - .into(); + return self.db().lookup_intern_closure(it).0.generic_def(self.db()).into(); } SolverDefId::InternedCoroutineId(it) => { - return self - .db() - .lookup_intern_coroutine(it) - .0 - .as_generic_def_id(self.db()) - .unwrap() - .into(); + return self.db().lookup_intern_coroutine(it).0.generic_def(self.db()).into(); } SolverDefId::StaticId(_) | SolverDefId::AdtId(_) @@ -1332,7 +1327,8 @@ impl<'db> Interner for DbInterner<'db> { | SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::EnumVariantId(..) | SolverDefId::Ctor(..) - | SolverDefId::InternedOpaqueTyId(..) => panic!(), + | SolverDefId::InternedOpaqueTyId(..) + | SolverDefId::AnonConstId(_) => panic!(), }; match container { @@ -1361,8 +1357,8 @@ impl<'db> Interner for DbInterner<'db> { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); - let body = self.db.body(owner); - let expr = &body[expr_id]; + let store = ExpressionStore::of(self.db, owner); + let expr = &store[expr_id]; match *expr { hir_def::hir::Expr::Closure { closure_kind, .. } => match closure_kind { hir_def::hir::ClosureKind::Coroutine(movability) => match movability { @@ -1795,6 +1791,7 @@ impl<'db> Interner for DbInterner<'db> { | SolverDefId::InternedCoroutineId(_) | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::EnumVariantId(_) + | SolverDefId::AnonConstId(_) | SolverDefId::Ctor(_) => return None, }; module.block(self.db) @@ -1933,7 +1930,7 @@ impl<'db> Interner for DbInterner<'db> { fn impl_is_default(self, impl_def_id: Self::ImplId) -> bool { match impl_def_id { - AnyImplId::ImplId(impl_id) => self.db.impl_signature(impl_id).is_default(), + AnyImplId::ImplId(impl_id) => ImplSignature::of(self.db, impl_id).is_default(), AnyImplId::BuiltinDeriveImplId(_) => false, } } @@ -1960,7 +1957,7 @@ impl<'db> Interner for DbInterner<'db> { let AnyImplId::ImplId(impl_id) = impl_id else { return ImplPolarity::Positive; }; - let impl_data = self.db().impl_signature(impl_id); + let impl_data = ImplSignature::of(self.db(), impl_id); if impl_data.flags.contains(ImplFlags::NEGATIVE) { ImplPolarity::Negative } else { @@ -1969,12 +1966,12 @@ impl<'db> Interner for DbInterner<'db> { } fn trait_is_auto(self, trait_: Self::TraitId) -> bool { - let trait_data = self.db().trait_signature(trait_.0); + let trait_data = TraitSignature::of(self.db(), trait_.0); trait_data.flags.contains(TraitFlags::AUTO) } fn trait_is_alias(self, trait_: Self::TraitId) -> bool { - let trait_data = self.db().trait_signature(trait_.0); + let trait_data = TraitSignature::of(self.db(), trait_.0); trait_data.flags.contains(TraitFlags::ALIAS) } @@ -1983,7 +1980,7 @@ impl<'db> Interner for DbInterner<'db> { } fn trait_is_fundamental(self, trait_: Self::TraitId) -> bool { - let trait_data = self.db().trait_signature(trait_.0); + let trait_data = TraitSignature::of(self.db(), trait_.0); trait_data.flags.contains(TraitFlags::FUNDAMENTAL) } @@ -2006,9 +2003,9 @@ impl<'db> Interner for DbInterner<'db> { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); - let body = self.db.body(owner); + let store = ExpressionStore::of(self.db, owner); matches!( - body[expr_id], + store[expr_id], hir_def::hir::Expr::Closure { closure_kind: hir_def::hir::ClosureKind::Coroutine(_), .. @@ -2020,9 +2017,9 @@ impl<'db> Interner for DbInterner<'db> { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); - let body = self.db.body(owner); + let store = ExpressionStore::of(self.db, owner); matches!( - body[expr_id], + store[expr_id], hir_def::hir::Expr::Closure { closure_kind: hir_def::hir::ClosureKind::Async, .. } | hir_def::hir::Expr::Async { .. } ) @@ -2143,7 +2140,7 @@ impl<'db> Interner for DbInterner<'db> { crate::opaques::opaque_types_defined_by(self.db, def_id, &mut result); // Collect coroutines. - let body = self.db.body(def_id); + let body = Body::of(self.db, def_id); body.exprs().for_each(|(expr_id, expr)| { if matches!( expr, @@ -2154,8 +2151,10 @@ impl<'db> Interner for DbInterner<'db> { .. } ) { - let coroutine = - InternedCoroutineId::new(self.db, InternedCoroutine(def_id, expr_id)); + let coroutine = InternedCoroutineId::new( + self.db, + InternedCoroutine(ExpressionStoreOwnerId::Body(def_id), expr_id), + ); result.push(coroutine.into()); } }); @@ -2184,7 +2183,7 @@ impl<'db> Interner for DbInterner<'db> { CallableDefId::FunctionId(id) => id, _ => return false, }; - self.db().function_signature(id).flags.contains(FnFlags::CONST) + FunctionSignature::of(self.db(), id).flags.contains(FnFlags::CONST) } fn impl_is_const(self, _def_id: Self::ImplId) -> bool { @@ -2232,11 +2231,11 @@ impl<'db> Interner for DbInterner<'db> { } fn trait_is_coinductive(self, trait_: Self::TraitId) -> bool { - self.db().trait_signature(trait_.0).flags.contains(TraitFlags::COINDUCTIVE) + TraitSignature::of(self.db(), trait_.0).flags.contains(TraitFlags::COINDUCTIVE) } fn trait_is_unsafe(self, trait_: Self::TraitId) -> bool { - self.db().trait_signature(trait_.0).flags.contains(TraitFlags::UNSAFE) + TraitSignature::of(self.db(), trait_.0).flags.contains(TraitFlags::UNSAFE) } fn impl_self_is_guaranteed_unsized(self, _def_id: Self::ImplId) -> bool { diff --git a/crates/hir-ty/src/next_solver/ir_print.rs b/crates/hir-ty/src/next_solver/ir_print.rs index 998aab5a3f..e0732b3473 100644 --- a/crates/hir-ty/src/next_solver/ir_print.rs +++ b/crates/hir-ty/src/next_solver/ir_print.rs @@ -1,7 +1,6 @@ //! Things related to IR printing in the next-trait-solver. -use std::any::type_name_of_val; - +use hir_def::signatures::{TraitSignature, TypeAliasSignature}; use rustc_type_ir::{self as ty, ir_print::IrPrint}; use super::SolverDefId; @@ -16,7 +15,7 @@ impl<'db> IrPrint<ty::AliasTy<Self>> for DbInterner<'db> { crate::with_attached_db(|db| match t.def_id { SolverDefId::TypeAliasId(id) => fmt.write_str(&format!( "AliasTy({:?}[{:?}])", - db.type_alias_signature(id).name.as_str(), + TypeAliasSignature::of(db, id).name.as_str(), t.args )), SolverDefId::InternedOpaqueTyId(id) => { @@ -36,7 +35,7 @@ impl<'db> IrPrint<ty::AliasTerm<Self>> for DbInterner<'db> { crate::with_attached_db(|db| match t.def_id { SolverDefId::TypeAliasId(id) => fmt.write_str(&format!( "AliasTerm({:?}[{:?}])", - db.type_alias_signature(id).name.as_str(), + TypeAliasSignature::of(db, id).name.as_str(), t.args )), SolverDefId::InternedOpaqueTyId(id) => { @@ -60,13 +59,13 @@ impl<'db> IrPrint<ty::TraitRef<Self>> for DbInterner<'db> { fmt.write_str(&format!( "{:?}: {}", self_ty, - db.trait_signature(trait_).name.as_str() + TraitSignature::of(db, trait_).name.as_str() )) } else { fmt.write_str(&format!( "{:?}: {}<{:?}>", self_ty, - db.trait_signature(trait_).name.as_str(), + TraitSignature::of(db, trait_).name.as_str(), trait_args )) } @@ -82,7 +81,10 @@ impl<'db> IrPrint<ty::TraitPredicate<Self>> for DbInterner<'db> { t: &ty::TraitPredicate<Self>, fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - fmt.write_str(&format!("TODO: {:?}", type_name_of_val(t))) + match t.polarity { + ty::PredicatePolarity::Positive => write!(fmt, "{:?}", t.trait_ref), + ty::PredicatePolarity::Negative => write!(fmt, "!{:?}", t.trait_ref), + } } } impl<'db> IrPrint<rustc_type_ir::HostEffectPredicate<Self>> for DbInterner<'db> { @@ -97,7 +99,11 @@ impl<'db> IrPrint<rustc_type_ir::HostEffectPredicate<Self>> for DbInterner<'db> t: &rustc_type_ir::HostEffectPredicate<Self>, fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - fmt.write_str(&format!("TODO: {:?}", type_name_of_val(t))) + let prefix = match t.constness { + ty::BoundConstness::Const => "const", + ty::BoundConstness::Maybe => "[const]", + }; + write!(fmt, "{prefix} {:?}", t.trait_ref) } } impl<'db> IrPrint<ty::ExistentialTraitRef<Self>> for DbInterner<'db> { @@ -116,7 +122,7 @@ impl<'db> IrPrint<ty::ExistentialTraitRef<Self>> for DbInterner<'db> { let trait_ = t.def_id.0; fmt.write_str(&format!( "ExistentialTraitRef({:?}[{:?}])", - db.trait_signature(trait_).name.as_str(), + TraitSignature::of(db, trait_).name.as_str(), t.args )) }) @@ -141,7 +147,7 @@ impl<'db> IrPrint<ty::ExistentialProjection<Self>> for DbInterner<'db> { }; fmt.write_str(&format!( "ExistentialProjection(({:?}[{:?}]) -> {:?})", - db.type_alias_signature(id).name.as_str(), + TypeAliasSignature::of(db, id).name.as_str(), t.args, t.term )) @@ -167,7 +173,7 @@ impl<'db> IrPrint<ty::ProjectionPredicate<Self>> for DbInterner<'db> { }; fmt.write_str(&format!( "ProjectionPredicate(({:?}[{:?}]) -> {:?})", - db.type_alias_signature(id).name.as_str(), + TypeAliasSignature::of(db, id).name.as_str(), t.projection_term.args, t.term )) @@ -183,7 +189,7 @@ impl<'db> IrPrint<ty::NormalizesTo<Self>> for DbInterner<'db> { t: &ty::NormalizesTo<Self>, fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - fmt.write_str(&format!("TODO: {:?}", type_name_of_val(t))) + write!(fmt, "NormalizesTo({} -> {:?})", t.alias, t.term) } } impl<'db> IrPrint<ty::SubtypePredicate<Self>> for DbInterner<'db> { @@ -198,7 +204,7 @@ impl<'db> IrPrint<ty::SubtypePredicate<Self>> for DbInterner<'db> { t: &ty::SubtypePredicate<Self>, fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - fmt.write_str(&format!("TODO: {:?}", type_name_of_val(t))) + write!(fmt, "{:?} <: {:?}", t.a, t.b) } } impl<'db> IrPrint<ty::CoercePredicate<Self>> for DbInterner<'db> { @@ -210,7 +216,7 @@ impl<'db> IrPrint<ty::CoercePredicate<Self>> for DbInterner<'db> { t: &ty::CoercePredicate<Self>, fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - fmt.write_str(&format!("TODO: {:?}", type_name_of_val(t))) + write!(fmt, "CoercePredicate({:?} -> {:?})", t.a, t.b) } } impl<'db> IrPrint<ty::FnSig<Self>> for DbInterner<'db> { @@ -219,7 +225,9 @@ impl<'db> IrPrint<ty::FnSig<Self>> for DbInterner<'db> { } fn print_debug(t: &ty::FnSig<Self>, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - fmt.write_str(&format!("TODO: {:?}", type_name_of_val(t))) + let tys = t.inputs_and_output.as_slice(); + let (output, inputs) = tys.split_last().unwrap(); + write!(fmt, "fn({:?}) -> {:?}", inputs, output) } } @@ -235,6 +243,10 @@ impl<'db> IrPrint<rustc_type_ir::PatternKind<DbInterner<'db>>> for DbInterner<'d t: &rustc_type_ir::PatternKind<DbInterner<'db>>, fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - fmt.write_str(&format!("TODO: {:?}", type_name_of_val(t))) + match t { + ty::PatternKind::Range { start, end } => write!(fmt, "{:?}..={:?}", start, end), + ty::PatternKind::Or(list) => write!(fmt, "or({:?})", list), + ty::PatternKind::NotNull => fmt.write_str("!null"), + } } } diff --git a/crates/hir-ty/src/next_solver/predicate.rs b/crates/hir-ty/src/next_solver/predicate.rs index 6f4fae7073..8658d03a9e 100644 --- a/crates/hir-ty/src/next_solver/predicate.rs +++ b/crates/hir-ty/src/next_solver/predicate.rs @@ -714,9 +714,9 @@ impl<'db> rustc_type_ir::inherent::Predicate<DbInterner<'db>> for Predicate<'db> fn allow_normalization(self) -> bool { // TODO: this should probably live in rustc_type_ir match self.inner().as_ref().skip_binder() { - PredicateKind::Clause(ClauseKind::WellFormed(_)) - | PredicateKind::AliasRelate(..) - | PredicateKind::NormalizesTo(..) => false, + PredicateKind::Clause(ClauseKind::WellFormed(_)) | PredicateKind::AliasRelate(..) => { + false + } PredicateKind::Clause(ClauseKind::Trait(_)) | PredicateKind::Clause(ClauseKind::RegionOutlives(_)) | PredicateKind::Clause(ClauseKind::TypeOutlives(_)) @@ -729,6 +729,7 @@ impl<'db> rustc_type_ir::inherent::Predicate<DbInterner<'db>> for Predicate<'db> | PredicateKind::Coerce(_) | PredicateKind::Clause(ClauseKind::ConstEvaluatable(_)) | PredicateKind::ConstEquate(_, _) + | PredicateKind::NormalizesTo(..) | PredicateKind::Ambiguous => true, } } diff --git a/crates/hir-ty/src/next_solver/solver.rs b/crates/hir-ty/src/next_solver/solver.rs index 15d6e2e451..848bb110af 100644 --- a/crates/hir-ty/src/next_solver/solver.rs +++ b/crates/hir-ty/src/next_solver/solver.rs @@ -1,6 +1,9 @@ //! Defining `SolverContext` for next-trait-solver. -use hir_def::{AssocItemId, GeneralConstId}; +use hir_def::{ + AssocItemId, GeneralConstId, + signatures::{ConstSignature, TypeAliasSignature}, +}; use rustc_next_trait_solver::delegate::SolverDelegate; use rustc_type_ir::{ AliasTyKind, GenericArgKind, InferCtxtLike, Interner, PredicatePolarity, TypeFlags, @@ -18,7 +21,7 @@ use crate::next_solver::{ }; use super::{ - DbInterner, ErrorGuaranteed, GenericArg, SolverDefId, Span, + Const, DbInterner, ErrorGuaranteed, GenericArg, SolverDefId, Span, infer::{DbInternerInferExt, InferCtxt, canonical::instantiate::CanonicalExt}, }; @@ -181,52 +184,53 @@ impl<'db> SolverDelegate for SolverContext<'db> { return Ok(None); }; let impl_items = impl_id.impl_items(self.0.interner.db()); - let id = - match trait_assoc_def_id { - SolverDefId::TypeAliasId(trait_assoc_id) => { - let trait_assoc_data = self.0.interner.db.type_alias_signature(trait_assoc_id); - impl_items - .items - .iter() - .find_map(|(impl_assoc_name, impl_assoc_id)| { - if let AssocItemId::TypeAliasId(impl_assoc_id) = *impl_assoc_id - && *impl_assoc_name == trait_assoc_data.name - { - Some(impl_assoc_id) - } else { - None - } - }) - .or_else(|| { - if trait_assoc_data.ty.is_some() { Some(trait_assoc_id) } else { None } - }) - .map(SolverDefId::TypeAliasId) - } - SolverDefId::ConstId(trait_assoc_id) => { - let trait_assoc_data = self.0.interner.db.const_signature(trait_assoc_id); - let trait_assoc_name = trait_assoc_data - .name - .as_ref() - .expect("unnamed consts should not get passed to the solver"); - impl_items - .items - .iter() - .find_map(|(impl_assoc_name, impl_assoc_id)| { - if let AssocItemId::ConstId(impl_assoc_id) = *impl_assoc_id - && impl_assoc_name == trait_assoc_name - { - Some(impl_assoc_id) - } else { - None - } - }) - .or_else(|| { + let id = match trait_assoc_def_id { + SolverDefId::TypeAliasId(trait_assoc_id) => { + let trait_assoc_data = TypeAliasSignature::of(self.0.interner.db, trait_assoc_id); + impl_items + .items + .iter() + .find_map(|(impl_assoc_name, impl_assoc_id)| { + if let AssocItemId::TypeAliasId(impl_assoc_id) = *impl_assoc_id + && *impl_assoc_name == trait_assoc_data.name + { + Some(impl_assoc_id) + } else { + None + } + }) + .or_else(|| { + if trait_assoc_data.ty.is_some() { Some(trait_assoc_id) } else { None } + }) + .map(SolverDefId::TypeAliasId) + } + SolverDefId::ConstId(trait_assoc_id) => { + let trait_assoc_data = ConstSignature::of(self.0.interner.db, trait_assoc_id); + let trait_assoc_name = trait_assoc_data + .name + .as_ref() + .expect("unnamed consts should not get passed to the solver"); + impl_items + .items + .iter() + .find_map(|(impl_assoc_name, impl_assoc_id)| { + if let AssocItemId::ConstId(impl_assoc_id) = *impl_assoc_id + && impl_assoc_name == trait_assoc_name + { + Some(impl_assoc_id) + } else { + None + } + }) + .or_else( + || { if trait_assoc_data.has_body() { Some(trait_assoc_id) } else { None } - }) - .map(SolverDefId::ConstId) - } - _ => panic!("Unexpected SolverDefId"), - }; + }, + ) + .map(SolverDefId::ConstId) + } + _ => panic!("Unexpected SolverDefId"), + }; Ok(id) } @@ -256,6 +260,11 @@ impl<'db> SolverDelegate for SolverContext<'db> { let ec = self.cx().db.const_eval_static(c).ok()?; Some(ec) } + // TODO: Wire up const_eval_anon query in Phase 5. + // For now, return an error const so normalization resolves the + // unevaluated const to Error (matching the old behavior where + // complex expressions produced ConstKind::Error directly). + GeneralConstId::AnonConstId(_) => Some(Const::error(self.cx())), } } diff --git a/crates/hir-ty/src/next_solver/ty.rs b/crates/hir-ty/src/next_solver/ty.rs index 1173028a10..192cdb70ae 100644 --- a/crates/hir-ty/src/next_solver/ty.rs +++ b/crates/hir-ty/src/next_solver/ty.rs @@ -4,7 +4,7 @@ use std::ops::ControlFlow; use hir_def::{ AdtId, HasModule, TypeParamId, - hir::generics::{TypeOrConstParamData, TypeParamProvenance}, + hir::generics::{GenericParams, TypeOrConstParamData, TypeParamProvenance}, }; use hir_def::{TraitId, type_ref::Rawness}; use intern::{Interned, InternedRef, impl_internable}; @@ -690,7 +690,7 @@ impl<'db> Ty<'db> { ), TyKind::Param(param) => { // FIXME: We shouldn't use `param.id` here. - let generic_params = db.generic_params(param.id.parent()); + let generic_params = GenericParams::of(db, param.id.parent()); let param_data = &generic_params[param.id.local_id()]; match param_data { TypeOrConstParamData::TypeParamData(p) => match p.provenance { @@ -714,7 +714,7 @@ impl<'db> Ty<'db> { } TyKind::Coroutine(coroutine_id, _args) => { let InternedCoroutine(owner, _) = coroutine_id.0.loc(db); - let krate = owner.module(db).krate(db); + let krate = owner.krate(db); if let Some(future_trait) = hir_def::lang_item::lang_items(db, krate).Future { // This is only used by type walking. // Parameters will be walked outside, and projection predicate is not used. diff --git a/crates/hir-ty/src/next_solver/util.rs b/crates/hir-ty/src/next_solver/util.rs index 9a1b476976..c175062bda 100644 --- a/crates/hir-ty/src/next_solver/util.rs +++ b/crates/hir-ty/src/next_solver/util.rs @@ -13,13 +13,16 @@ use rustc_type_ir::{ solve::SizedTraitKind, }; -use crate::next_solver::{ - BoundConst, FxIndexMap, ParamEnv, Placeholder, PlaceholderConst, PlaceholderRegion, - PolyTraitRef, - infer::{ - InferCtxt, - traits::{Obligation, ObligationCause, PredicateObligation}, +use crate::{ + next_solver::{ + BoundConst, FxIndexMap, ParamEnv, Placeholder, PlaceholderConst, PlaceholderRegion, + PolyTraitRef, + infer::{ + InferCtxt, + traits::{Obligation, ObligationCause, PredicateObligation}, + }, }, + representability::Representability, }; use super::{ @@ -419,10 +422,18 @@ pub fn sizedness_constraint_for_ty<'db>( .next_back() .and_then(|ty| sizedness_constraint_for_ty(interner, sizedness, ty)), - Adt(adt, args) => adt.struct_tail_ty(interner).and_then(|tail_ty| { - let tail_ty = tail_ty.instantiate(interner, args); - sizedness_constraint_for_ty(interner, sizedness, tail_ty) - }), + Adt(adt, args) => { + if crate::representability::representability(interner.db, adt.def_id().0) + == Representability::Infinite + { + return None; + } + + adt.struct_tail_ty(interner).and_then(|tail_ty| { + let tail_ty = tail_ty.instantiate(interner, args); + sizedness_constraint_for_ty(interner, sizedness, tail_ty) + }) + } Placeholder(..) | Bound(..) | Infer(..) => { panic!("unexpected type `{ty:?}` in sizedness_constraint_for_ty") diff --git a/crates/hir-ty/src/opaques.rs b/crates/hir-ty/src/opaques.rs index 27ae5e39d5..ce93a33422 100644 --- a/crates/hir-ty/src/opaques.rs +++ b/crates/hir-ty/src/opaques.rs @@ -1,7 +1,8 @@ //! Handling of opaque types, detection of defining scope and hidden type. use hir_def::{ - AssocItemId, AssocItemLoc, DefWithBodyId, FunctionId, HasModule, ItemContainerId, TypeAliasId, + AssocItemId, AssocItemLoc, DefWithBodyId, ExpressionStoreOwnerId, FunctionId, GenericDefId, + HasModule, ItemContainerId, TypeAliasId, signatures::ImplSignature, }; use hir_expand::name::Name; use la_arena::ArenaMap; @@ -55,7 +56,7 @@ pub(crate) fn opaque_types_defined_by( }; let extend_with_atpit_from_container = |container| match container { ItemContainerId::ImplId(impl_id) => { - if db.impl_signature(impl_id).target_trait.is_some() { + if ImplSignature::of(db, impl_id).target_trait.is_some() { extend_with_atpit_from_assoc_items(&impl_id.impl_items(db).items); } } @@ -94,7 +95,7 @@ pub(crate) fn rpit_hidden_types<'db>( db: &'db dyn HirDatabase, function: FunctionId, ) -> ArenaMap<ImplTraitIdx, StoredEarlyBinder<StoredTy>> { - let infer = InferenceResult::for_body(db, function.into()); + let infer = InferenceResult::of(db, DefWithBodyId::from(function)); let mut result = ArenaMap::new(); for (opaque, hidden_type) in infer.return_position_impl_trait_types(db) { result.insert(opaque, StoredEarlyBinder::bind(hidden_type.store())); @@ -122,13 +123,14 @@ pub(crate) fn tait_hidden_types<'db>( let infcx = interner.infer_ctxt().build(TypingMode::non_body_analysis()); let mut ocx = ObligationCtxt::new(&infcx); let cause = ObligationCause::dummy(); - let param_env = db.trait_environment(type_alias.into()); + let param_env = + db.trait_environment(ExpressionStoreOwnerId::from(GenericDefId::from(type_alias))); let defining_bodies = tait_defining_bodies(db, &loc); let mut result = ArenaMap::with_capacity(taits_count); for defining_body in defining_bodies { - let infer = InferenceResult::for_body(db, defining_body); + let infer = InferenceResult::of(db, defining_body); for (&opaque, hidden_type) in &infer.type_of_opaque { let ImplTraitId::TypeAliasImplTrait(opaque_owner, opaque_idx) = opaque.loc(db) else { continue; @@ -195,7 +197,7 @@ fn tait_defining_bodies( }; match loc.container { ItemContainerId::ImplId(impl_id) => { - if db.impl_signature(impl_id).target_trait.is_some() { + if ImplSignature::of(db, impl_id).target_trait.is_some() { return from_assoc_items(&impl_id.impl_items(db).items); } } diff --git a/crates/hir-ty/src/representability.rs b/crates/hir-ty/src/representability.rs new file mode 100644 index 0000000000..bae204c4ef --- /dev/null +++ b/crates/hir-ty/src/representability.rs @@ -0,0 +1,131 @@ +//! Detecting whether a type is infinitely-sized. + +use hir_def::{AdtId, VariantId, hir::generics::GenericParams}; +use rustc_type_ir::inherent::{AdtDef, IntoKind}; + +use crate::{ + db::HirDatabase, + next_solver::{GenericArgKind, GenericArgs, Ty, TyKind}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum Representability { + Representable, + Infinite, +} + +macro_rules! rtry { + ($e:expr) => { + match $e { + e @ Representability::Infinite => return e, + Representability::Representable => {} + } + }; +} + +#[salsa::tracked(cycle_result = representability_cycle)] +pub(crate) fn representability(db: &dyn HirDatabase, id: AdtId) -> Representability { + match id { + AdtId::StructId(id) => variant_representability(db, id.into()), + AdtId::UnionId(id) => variant_representability(db, id.into()), + AdtId::EnumId(id) => { + for &(variant, ..) in &id.enum_variants(db).variants { + rtry!(variant_representability(db, variant.into())); + } + Representability::Representable + } + } +} + +pub(crate) fn representability_cycle( + _db: &dyn HirDatabase, + _: salsa::Id, + _id: AdtId, +) -> Representability { + Representability::Infinite +} + +fn variant_representability(db: &dyn HirDatabase, id: VariantId) -> Representability { + for ty in db.field_types(id).values() { + rtry!(representability_ty(db, ty.get().instantiate_identity())); + } + Representability::Representable +} + +fn representability_ty<'db>(db: &'db dyn HirDatabase, ty: Ty<'db>) -> Representability { + match ty.kind() { + TyKind::Adt(adt_id, args) => representability_adt_ty(db, adt_id.def_id().0, args), + // FIXME(#11924) allow zero-length arrays? + TyKind::Array(ty, _) => representability_ty(db, ty), + TyKind::Tuple(tys) => { + for ty in tys { + rtry!(representability_ty(db, ty)); + } + Representability::Representable + } + _ => Representability::Representable, + } +} + +fn representability_adt_ty<'db>( + db: &'db dyn HirDatabase, + def_id: AdtId, + args: GenericArgs<'db>, +) -> Representability { + rtry!(representability(db, def_id)); + + // At this point, we know that the item of the ADT type is representable; + // but the type parameters may cause a cycle with an upstream type + let params_in_repr = params_in_repr(db, def_id); + for (i, arg) in args.iter().enumerate() { + if let GenericArgKind::Type(ty) = arg.kind() + && params_in_repr[i] + { + rtry!(representability_ty(db, ty)); + } + } + Representability::Representable +} + +fn params_in_repr(db: &dyn HirDatabase, def_id: AdtId) -> Box<[bool]> { + let generics = GenericParams::of(db, def_id.into()); + let mut params_in_repr = (0..generics.len_lifetimes() + generics.len_type_or_consts()) + .map(|_| false) + .collect::<Box<[bool]>>(); + let mut handle_variant = |variant| { + for field in db.field_types(variant).values() { + params_in_repr_ty(db, field.get().instantiate_identity(), &mut params_in_repr); + } + }; + match def_id { + AdtId::StructId(def_id) => handle_variant(def_id.into()), + AdtId::UnionId(def_id) => handle_variant(def_id.into()), + AdtId::EnumId(def_id) => { + for &(variant, ..) in &def_id.enum_variants(db).variants { + handle_variant(variant.into()); + } + } + } + params_in_repr +} + +fn params_in_repr_ty<'db>(db: &'db dyn HirDatabase, ty: Ty<'db>, params_in_repr: &mut [bool]) { + match ty.kind() { + TyKind::Adt(adt, args) => { + let inner_params_in_repr = self::params_in_repr(db, adt.def_id().0); + for (i, arg) in args.iter().enumerate() { + if let GenericArgKind::Type(ty) = arg.kind() + && inner_params_in_repr[i] + { + params_in_repr_ty(db, ty, params_in_repr); + } + } + } + TyKind::Array(ty, _) => params_in_repr_ty(db, ty, params_in_repr), + TyKind::Tuple(tys) => tys.iter().for_each(|ty| params_in_repr_ty(db, ty, params_in_repr)), + TyKind::Param(param) => { + params_in_repr[param.index as usize] = true; + } + _ => {} + } +} diff --git a/crates/hir-ty/src/specialization.rs b/crates/hir-ty/src/specialization.rs index d97a35549c..90cbcfea6a 100644 --- a/crates/hir-ty/src/specialization.rs +++ b/crates/hir-ty/src/specialization.rs @@ -1,6 +1,9 @@ //! Impl specialization related things -use hir_def::{HasModule, ImplId, nameres::crate_def_map}; +use hir_def::{ + ExpressionStoreOwnerId, GenericDefId, HasModule, ImplId, nameres::crate_def_map, + signatures::ImplSignature, +}; use intern::sym; use tracing::debug; @@ -45,11 +48,13 @@ fn specializes_query( specializing_impl_def_id: ImplId, parent_impl_def_id: ImplId, ) -> bool { - let trait_env = db.trait_environment(specializing_impl_def_id.into()); + let trait_env = db.trait_environment(ExpressionStoreOwnerId::from(GenericDefId::from( + specializing_impl_def_id, + ))); let interner = DbInterner::new_with(db, specializing_impl_def_id.krate(db)); - let specializing_impl_signature = db.impl_signature(specializing_impl_def_id); - let parent_impl_signature = db.impl_signature(parent_impl_def_id); + let specializing_impl_signature = ImplSignature::of(db, specializing_impl_def_id); + let parent_impl_signature = ImplSignature::of(db, parent_impl_def_id); // We determine whether there's a subset relationship by: // diff --git a/crates/hir-ty/src/target_feature.rs b/crates/hir-ty/src/target_feature.rs index 2bd675ba12..29a933f922 100644 --- a/crates/hir-ty/src/target_feature.rs +++ b/crates/hir-ty/src/target_feature.rs @@ -217,6 +217,7 @@ const TARGET_FEATURE_IMPLICATIONS_RAW: &[(&str, &[&str])] = &[ ("relaxed-simd", &["simd128"]), // BPF ("alu32", &[]), + ("allows-misaligned-mem-access", &[]), // CSKY ("10e60", &["7e10"]), ("2e3", &["e2"]), diff --git a/crates/hir-ty/src/tests.rs b/crates/hir-ty/src/tests.rs index 67ab89f5ec..430a570444 100644 --- a/crates/hir-ty/src/tests.rs +++ b/crates/hir-ty/src/tests.rs @@ -16,9 +16,9 @@ mod traits; use base_db::{Crate, SourceDatabase}; use expect_test::Expect; use hir_def::{ - AssocItemId, DefWithBodyId, HasModule, Lookup, ModuleDefId, ModuleId, SyntheticSyntax, - db::DefDatabase, - expr_store::{Body, BodySourceMap}, + AssocItemId, DefWithBodyId, GenericDefId, HasModule, Lookup, ModuleDefId, ModuleId, + SyntheticSyntax, + expr_store::{Body, BodySourceMap, ExpressionStore, ExpressionStoreSourceMap}, hir::{ExprId, Pat, PatId}, item_scope::ItemScope, nameres::DefMap, @@ -34,7 +34,6 @@ use syntax::{ ast::{self, AstNode, HasName}, }; use test_fixture::WithFixture; -use triomphe::Arc; use crate::{ InferenceResult, @@ -146,15 +145,15 @@ fn check_impl( let mut unexpected_type_mismatches = String::new(); for (def, krate) in defs { let display_target = DisplayTarget::from_crate(&db, krate); - let (body, body_source_map) = db.body_with_source_map(def); - let inference_result = InferenceResult::for_body(&db, def); + let (body, body_source_map) = Body::with_source_map(&db, def); + let inference_result = InferenceResult::of(&db, def); for (pat, ty) in inference_result.type_of_pat.iter() { let mut ty = ty.as_ref(); if let Pat::Bind { id, .. } = body[pat] { ty = inference_result.type_of_binding[id].as_ref(); } - let node = match pat_node(&body_source_map, pat, &db) { + let node = match pat_node(body_source_map, pat, &db) { Some(value) => value, None => continue, }; @@ -171,7 +170,7 @@ fn check_impl( for (expr, ty) in inference_result.type_of_expr.iter() { let ty = ty.as_ref(); - let node = match expr_node(&body_source_map, expr, &db) { + let node = match expr_node(body_source_map, expr, &db) { Some(value) => value, None => continue, }; @@ -202,9 +201,9 @@ fn check_impl( for (expr_or_pat, mismatch) in inference_result.type_mismatches() { let Some(node) = (match expr_or_pat { hir_def::hir::ExprOrPatId::ExprId(expr) => { - expr_node(&body_source_map, expr, &db) + expr_node(body_source_map, expr, &db) } - hir_def::hir::ExprOrPatId::PatId(pat) => pat_node(&body_source_map, pat, &db), + hir_def::hir::ExprOrPatId::PatId(pat) => pat_node(body_source_map, pat, &db), }) else { continue; }; @@ -223,7 +222,7 @@ fn check_impl( } for (type_ref, ty) in inference_result.placeholder_types() { - let node = match type_node(&body_source_map, type_ref, &db) { + let node = match type_node(body_source_map, type_ref, &db) { Some(value) => value, None => continue, }; @@ -321,16 +320,20 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { let mut buf = String::new(); let mut infer_def = |inference_result: &InferenceResult, - body: Arc<Body>, - body_source_map: Arc<BodySourceMap>, + store: &ExpressionStore, + source_map: &ExpressionStoreSourceMap, + self_param: Option<( + hir_def::hir::BindingId, + Option<InFile<hir_def::expr_store::SelfParamPtr>>, + )>, krate: Crate| { let display_target = DisplayTarget::from_crate(&db, krate); let mut types: Vec<(InFile<SyntaxNode>, Ty<'_>)> = Vec::new(); let mut mismatches: Vec<(InFile<SyntaxNode>, &TypeMismatch)> = Vec::new(); - if let Some(self_param) = body.self_param { - let ty = &inference_result.type_of_binding[self_param]; - if let Some(syntax_ptr) = body_source_map.self_param_syntax() { + if let Some((binding_id, syntax_ptr)) = self_param { + let ty = &inference_result.type_of_binding[binding_id]; + if let Some(syntax_ptr) = syntax_ptr { let root = db.parse_or_expand(syntax_ptr.file_id); let node = syntax_ptr.map(|ptr| ptr.to_node(&root).syntax().clone()); types.push((node, ty.as_ref())); @@ -338,10 +341,10 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { } for (pat, mut ty) in inference_result.type_of_pat.iter() { - if let Pat::Bind { id, .. } = body[pat] { + if let Pat::Bind { id, .. } = store[pat] { ty = &inference_result.type_of_binding[id]; } - let node = match body_source_map.pat_syntax(pat) { + let node = match source_map.pat_syntax(pat) { Ok(sp) => { let root = db.parse_or_expand(sp.file_id); sp.map(|ptr| ptr.to_node(&root).syntax().clone()) @@ -355,7 +358,7 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { } for (expr, ty) in inference_result.type_of_expr.iter() { - let node = match body_source_map.expr_syntax(expr) { + let node = match source_map.expr_syntax(expr) { Ok(sp) => { let root = db.parse_or_expand(sp.file_id); sp.map(|ptr| ptr.to_node(&root).syntax().clone()) @@ -414,16 +417,56 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { let def_map = module.def_map(&db); let mut defs: Vec<(DefWithBodyId, Crate)> = Vec::new(); + let mut generic_defs: Vec<(GenericDefId, Crate)> = Vec::new(); visit_module(&db, def_map, module, &mut |it| { - let def = match it { - ModuleDefId::FunctionId(it) => it.into(), - ModuleDefId::EnumVariantId(it) => it.into(), - ModuleDefId::ConstId(it) => it.into(), - ModuleDefId::StaticId(it) => it.into(), - _ => return, - }; - defs.push((def, module.krate(&db))) + let krate = module.krate(&db); + match it { + ModuleDefId::FunctionId(it) => { + defs.push((it.into(), krate)); + generic_defs.push((it.into(), krate)); + } + ModuleDefId::EnumVariantId(it) => { + defs.push((it.into(), krate)); + } + ModuleDefId::ConstId(it) => { + defs.push((it.into(), krate)); + generic_defs.push((it.into(), krate)); + } + ModuleDefId::StaticId(it) => { + defs.push((it.into(), krate)); + generic_defs.push((it.into(), krate)); + } + ModuleDefId::AdtId(it) => { + generic_defs.push((it.into(), krate)); + } + ModuleDefId::TraitId(it) => { + generic_defs.push((it.into(), krate)); + } + ModuleDefId::TypeAliasId(it) => { + generic_defs.push((it.into(), krate)); + } + _ => {} + } }); + // Also collect impls + for impl_id in def_map[module].scope.impls() { + generic_defs.push((impl_id.into(), module.krate(&db))); + let impl_data = impl_id.impl_items(&db); + for &(_, item) in impl_data.items.iter() { + match item { + AssocItemId::FunctionId(it) => { + generic_defs.push((it.into(), module.krate(&db))); + } + AssocItemId::ConstId(it) => { + generic_defs.push((it.into(), module.krate(&db))); + } + AssocItemId::TypeAliasId(it) => { + generic_defs.push((it.into(), module.krate(&db))); + } + } + } + } + defs.sort_by_key(|(def, _)| match def { DefWithBodyId::FunctionId(it) => { let loc = it.lookup(&db); @@ -443,9 +486,22 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { } }); for (def, krate) in defs { - let (body, source_map) = db.body_with_source_map(def); - let infer = InferenceResult::for_body(&db, def); - infer_def(infer, body, source_map, krate); + let (body, source_map) = Body::with_source_map(&db, def); + let infer = InferenceResult::of(&db, def); + let self_param = body.self_param.map(|id| (id, source_map.self_param_syntax())); + infer_def(infer, body, source_map, self_param, krate); + } + + // Also infer signature const expressions (array lengths, const generic args, etc.) + generic_defs.dedup(); + for (def, krate) in generic_defs { + let (store, source_map) = ExpressionStore::with_source_map(&db, def.into()); + // Skip if there are no const expressions in the signature + if store.const_expr_origins().is_empty() { + continue; + } + let infer = InferenceResult::of(&db, def); + infer_def(infer, store, source_map, None, krate); } buf.truncate(buf.trim_end().len()); @@ -465,14 +521,14 @@ pub(crate) fn visit_module( for &(_, item) in impl_data.items.iter() { match item { AssocItemId::FunctionId(it) => { - let body = db.body(it.into()); + let body = Body::of(db, it.into()); cb(it.into()); - visit_body(db, &body, cb); + visit_body(db, body, cb); } AssocItemId::ConstId(it) => { - let body = db.body(it.into()); + let body = Body::of(db, it.into()); cb(it.into()); - visit_body(db, &body, cb); + visit_body(db, body, cb); } AssocItemId::TypeAliasId(it) => { cb(it.into()); @@ -491,22 +547,22 @@ pub(crate) fn visit_module( cb(decl); match decl { ModuleDefId::FunctionId(it) => { - let body = db.body(it.into()); - visit_body(db, &body, cb); + let body = Body::of(db, it.into()); + visit_body(db, body, cb); } ModuleDefId::ConstId(it) => { - let body = db.body(it.into()); - visit_body(db, &body, cb); + let body = Body::of(db, it.into()); + visit_body(db, body, cb); } ModuleDefId::StaticId(it) => { - let body = db.body(it.into()); - visit_body(db, &body, cb); + let body = Body::of(db, it.into()); + visit_body(db, body, cb); } ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => { it.enum_variants(db).variants.iter().for_each(|&(it, _, _)| { - let body = db.body(it.into()); + let body = Body::of(db, it.into()); cb(it.into()); - visit_body(db, &body, cb); + visit_body(db, body, cb); }); } ModuleDefId::TraitId(it) => { @@ -596,16 +652,14 @@ fn salsa_bug() { let module = db.module_for_file(pos.file_id.file_id(&db)); let crate_def_map = module.def_map(&db); visit_module(&db, crate_def_map, module, &mut |def| { - InferenceResult::for_body( - &db, - match def { - ModuleDefId::FunctionId(it) => it.into(), - ModuleDefId::EnumVariantId(it) => it.into(), - ModuleDefId::ConstId(it) => it.into(), - ModuleDefId::StaticId(it) => it.into(), - _ => return, - }, - ); + let body_def: DefWithBodyId = match def { + ModuleDefId::FunctionId(it) => it.into(), + ModuleDefId::EnumVariantId(it) => it.into(), + ModuleDefId::ConstId(it) => it.into(), + ModuleDefId::StaticId(it) => it.into(), + _ => return, + }; + InferenceResult::of(&db, body_def); }); }); @@ -640,16 +694,14 @@ fn salsa_bug() { let module = db.module_for_file(pos.file_id.file_id(&db)); let crate_def_map = module.def_map(&db); visit_module(&db, crate_def_map, module, &mut |def| { - InferenceResult::for_body( - &db, - match def { - ModuleDefId::FunctionId(it) => it.into(), - ModuleDefId::EnumVariantId(it) => it.into(), - ModuleDefId::ConstId(it) => it.into(), - ModuleDefId::StaticId(it) => it.into(), - _ => return, - }, - ); + let body_def: DefWithBodyId = match def { + ModuleDefId::FunctionId(it) => it.into(), + ModuleDefId::EnumVariantId(it) => it.into(), + ModuleDefId::ConstId(it) => it.into(), + ModuleDefId::StaticId(it) => it.into(), + _ => return, + }; + InferenceResult::of(&db, body_def); }); }) } diff --git a/crates/hir-ty/src/tests/closure_captures.rs b/crates/hir-ty/src/tests/closure_captures.rs index f089120cd7..9e68756821 100644 --- a/crates/hir-ty/src/tests/closure_captures.rs +++ b/crates/hir-ty/src/tests/closure_captures.rs @@ -1,5 +1,8 @@ use expect_test::{Expect, expect}; -use hir_def::db::DefDatabase; +use hir_def::{ + DefWithBodyId, + expr_store::{Body, ExpressionStore}, +}; use hir_expand::{HirFileId, files::InFileWrapper}; use itertools::Itertools; use span::TextRange; @@ -28,19 +31,20 @@ fn check_closure_captures(#[rust_analyzer::rust_fixture] ra_fixture: &str, expec let mut captures_info = Vec::new(); for def in defs { - let def = match def { + let def: DefWithBodyId = match def { hir_def::ModuleDefId::FunctionId(it) => it.into(), hir_def::ModuleDefId::EnumVariantId(it) => it.into(), hir_def::ModuleDefId::ConstId(it) => it.into(), hir_def::ModuleDefId::StaticId(it) => it.into(), _ => continue, }; - let infer = InferenceResult::for_body(&db, def); + let infer = InferenceResult::of(&db, def); let db = &db; captures_info.extend(infer.closure_info.iter().flat_map( |(closure_id, (captures, _))| { let closure = db.lookup_intern_closure(*closure_id); - let source_map = db.body_with_source_map(closure.0).1; + let body_owner = closure.0; + let source_map = ExpressionStore::with_source_map(db, body_owner).1; let closure_text_range = source_map .expr_syntax(closure.1) .expect("failed to map closure to SyntaxNode") @@ -56,7 +60,8 @@ fn check_closure_captures(#[rust_analyzer::rust_fixture] ra_fixture: &str, expec } // FIXME: Deduplicate this with hir::Local::sources(). - let (body, source_map) = db.body_with_source_map(closure.0); + let (body, source_map) = + Body::with_source_map(db, body_owner.as_def_with_body().unwrap()); let local_text_range = match body.self_param.zip(source_map.self_param_syntax()) { Some((param, source)) if param == capture.local() => { @@ -71,7 +76,7 @@ fn check_closure_captures(#[rust_analyzer::rust_fixture] ra_fixture: &str, expec .map(|it| format!("{it:?}")) .join(", "), }; - let place = capture.display_place(closure.0, db); + let place = capture.display_place(body_owner, db); let capture_ty = capture .ty .get() diff --git a/crates/hir-ty/src/tests/coercion.rs b/crates/hir-ty/src/tests/coercion.rs index 36630ab587..438699b409 100644 --- a/crates/hir-ty/src/tests/coercion.rs +++ b/crates/hir-ty/src/tests/coercion.rs @@ -608,7 +608,7 @@ trait Foo {} fn test(f: impl Foo, g: &(impl Foo + ?Sized)) { let _: &dyn Foo = &f; let _: &dyn Foo = g; - //^ expected &'? (dyn Foo + 'static), got &'? impl Foo + ?Sized + //^ expected &'? (dyn Foo + 'static), got &'? (impl Foo + ?Sized) } "#, ); diff --git a/crates/hir-ty/src/tests/display_source_code.rs b/crates/hir-ty/src/tests/display_source_code.rs index dc3869930d..37da7fc875 100644 --- a/crates/hir-ty/src/tests/display_source_code.rs +++ b/crates/hir-ty/src/tests/display_source_code.rs @@ -111,7 +111,7 @@ fn test( b; //^ impl Foo c; - //^ &impl Foo + ?Sized + //^ &(impl Foo + ?Sized) d; //^ S<impl Foo> ref_any; @@ -192,7 +192,7 @@ fn test( b; //^ fn(impl Foo) -> impl Foo c; -} //^ fn(&impl Foo + ?Sized) -> &impl Foo + ?Sized +} //^ fn(&(impl Foo + ?Sized)) -> &(impl Foo + ?Sized) "#, ); } diff --git a/crates/hir-ty/src/tests/incremental.rs b/crates/hir-ty/src/tests/incremental.rs index cf7ff6f7ec..e806999cb4 100644 --- a/crates/hir-ty/src/tests/incremental.rs +++ b/crates/hir-ty/src/tests/incremental.rs @@ -24,30 +24,31 @@ fn foo() -> i32 { let crate_def_map = module.def_map(&db); visit_module(&db, crate_def_map, module, &mut |def| { if let ModuleDefId::FunctionId(it) = def { - InferenceResult::for_body(&db, it.into()); + InferenceResult::of(&db, DefWithBodyId::from(it)); } }); }, &[("InferenceResult::for_body_", 1)], expect_test::expect![[r#" [ + "source_root_crates_shim", "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "InferenceResult::for_body_", - "function_signature_shim", - "function_signature_with_source_map_shim", + "FunctionSignature::of_", + "FunctionSignature::with_source_map_", "AttrFlags::query_", - "body_shim", - "body_with_source_map_shim", + "Body::of_", + "Body::with_source_map_", "trait_environment_query", "lang_items", "crate_lang_items", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", - "expr_scopes_shim", + "ExprScopes::body_expr_scopes_", ] "#]], ); @@ -68,7 +69,7 @@ fn foo() -> i32 { let crate_def_map = module.def_map(&db); visit_module(&db, crate_def_map, module, &mut |def| { if let ModuleDefId::FunctionId(it) = def { - InferenceResult::for_body(&db, it.into()); + InferenceResult::of(&db, DefWithBodyId::from(it)); } }); }, @@ -76,14 +77,14 @@ fn foo() -> i32 { expect_test::expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "AttrFlags::query_", - "function_signature_with_source_map_shim", - "function_signature_shim", - "body_with_source_map_shim", - "body_shim", + "FunctionSignature::with_source_map_", + "FunctionSignature::of_", + "Body::with_source_map_", + "Body::of_", ] "#]], ); @@ -111,50 +112,51 @@ fn baz() -> i32 { let crate_def_map = module.def_map(&db); visit_module(&db, crate_def_map, module, &mut |def| { if let ModuleDefId::FunctionId(it) = def { - InferenceResult::for_body(&db, it.into()); + InferenceResult::of(&db, DefWithBodyId::from(it)); } }); }, &[("InferenceResult::for_body_", 3)], expect_test::expect![[r#" [ + "source_root_crates_shim", "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "InferenceResult::for_body_", - "function_signature_shim", - "function_signature_with_source_map_shim", + "FunctionSignature::of_", + "FunctionSignature::with_source_map_", "AttrFlags::query_", - "body_shim", - "body_with_source_map_shim", + "Body::of_", + "Body::with_source_map_", "trait_environment_query", "lang_items", "crate_lang_items", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", - "expr_scopes_shim", + "ExprScopes::body_expr_scopes_", "InferenceResult::for_body_", - "function_signature_shim", - "function_signature_with_source_map_shim", + "FunctionSignature::of_", + "FunctionSignature::with_source_map_", "AttrFlags::query_", - "body_shim", - "body_with_source_map_shim", + "Body::of_", + "Body::with_source_map_", "trait_environment_query", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", - "expr_scopes_shim", + "ExprScopes::body_expr_scopes_", "InferenceResult::for_body_", - "function_signature_shim", - "function_signature_with_source_map_shim", + "FunctionSignature::of_", + "FunctionSignature::with_source_map_", "AttrFlags::query_", - "body_shim", - "body_with_source_map_shim", + "Body::of_", + "Body::with_source_map_", "trait_environment_query", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", - "expr_scopes_shim", + "ExprScopes::body_expr_scopes_", ] "#]], ); @@ -180,7 +182,7 @@ fn baz() -> i32 { let crate_def_map = module.def_map(&db); visit_module(&db, crate_def_map, module, &mut |def| { if let ModuleDefId::FunctionId(it) = def { - InferenceResult::for_body(&db, it.into()); + InferenceResult::of(&db, DefWithBodyId::from(it)); } }); }, @@ -188,26 +190,26 @@ fn baz() -> i32 { expect_test::expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "AttrFlags::query_", - "function_signature_with_source_map_shim", - "function_signature_shim", - "body_with_source_map_shim", - "body_shim", + "FunctionSignature::with_source_map_", + "FunctionSignature::of_", + "Body::with_source_map_", + "Body::of_", "AttrFlags::query_", - "function_signature_with_source_map_shim", - "function_signature_shim", - "body_with_source_map_shim", - "body_shim", + "FunctionSignature::with_source_map_", + "FunctionSignature::of_", + "Body::with_source_map_", + "Body::of_", "InferenceResult::for_body_", - "expr_scopes_shim", + "ExprScopes::body_expr_scopes_", "AttrFlags::query_", - "function_signature_with_source_map_shim", - "function_signature_shim", - "body_with_source_map_shim", - "body_shim", + "FunctionSignature::with_source_map_", + "FunctionSignature::of_", + "Body::with_source_map_", + "Body::of_", ] "#]], ); @@ -237,9 +239,10 @@ $0", &[("TraitImpls::for_crate_", 1)], expect_test::expect![[r#" [ + "source_root_crates_shim", "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "TraitImpls::for_crate_", @@ -276,7 +279,7 @@ pub struct NewStruct { expect_test::expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "crate_local_def_map", @@ -311,9 +314,10 @@ $0", &[("TraitImpls::for_crate_", 1)], expect_test::expect![[r#" [ + "source_root_crates_shim", "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "TraitImpls::for_crate_", @@ -351,7 +355,7 @@ pub enum SomeEnum { expect_test::expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "crate_local_def_map", @@ -386,9 +390,10 @@ $0", &[("TraitImpls::for_crate_", 1)], expect_test::expect![[r#" [ + "source_root_crates_shim", "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "TraitImpls::for_crate_", @@ -423,7 +428,7 @@ fn bar() -> f32 { expect_test::expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "crate_local_def_map", @@ -462,9 +467,10 @@ $0", &[("TraitImpls::for_crate_", 1)], expect_test::expect![[r#" [ + "source_root_crates_shim", "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "TraitImpls::for_crate_", @@ -507,7 +513,7 @@ impl SomeStruct { expect_test::expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "crate_local_def_map", @@ -556,31 +562,32 @@ fn main() { }); for def in defs { - let _inference_result = InferenceResult::for_body(&db, def); + let _inference_result = InferenceResult::of(&db, def); } }, &[("trait_solve_shim", 0)], expect_test::expect![[r#" [ + "source_root_crates_shim", "crate_local_def_map", "file_item_tree_query", - "ast_id_map_shim", + "ast_id_map", "parse_shim", "real_span_map_shim", "TraitItems::query_with_diagnostics_", - "body_shim", - "body_with_source_map_shim", + "Body::of_", + "Body::with_source_map_", "AttrFlags::query_", "ImplItems::of_", "InferenceResult::for_body_", - "trait_signature_shim", - "trait_signature_with_source_map_shim", + "TraitSignature::of_", + "TraitSignature::with_source_map_", "AttrFlags::query_", - "function_signature_shim", - "function_signature_with_source_map_shim", + "FunctionSignature::of_", + "FunctionSignature::with_source_map_", "AttrFlags::query_", - "body_shim", - "body_with_source_map_shim", + "Body::of_", + "Body::with_source_map_", "trait_environment_query", "lang_items", "crate_lang_items", @@ -588,14 +595,14 @@ fn main() { "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", "InferenceResult::for_body_", - "function_signature_shim", - "function_signature_with_source_map_shim", + "FunctionSignature::of_", + "FunctionSignature::with_source_map_", "trait_environment_query", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", - "expr_scopes_shim", - "struct_signature_shim", - "struct_signature_with_source_map_shim", + "ExprScopes::body_expr_scopes_", + "StructSignature::of_", + "StructSignature::with_source_map_", "AttrFlags::query_", "GenericPredicates::query_with_diagnostics_", "value_ty_query", @@ -604,8 +611,8 @@ fn main() { "TraitImpls::for_crate_and_deps_", "TraitImpls::for_crate_", "impl_trait_with_diagnostics_query", - "impl_signature_shim", - "impl_signature_with_source_map_shim", + "ImplSignature::of_", + "ImplSignature::with_source_map_", "impl_self_ty_with_diagnostics_query", "AttrFlags::query_", "GenericPredicates::query_with_diagnostics_", @@ -651,47 +658,47 @@ fn main() { }); for def in defs { - let _inference_result = InferenceResult::for_body(&db, def); + let _inference_result = InferenceResult::of(&db, def); } }, &[("trait_solve_shim", 0)], expect_test::expect![[r#" [ "parse_shim", - "ast_id_map_shim", + "ast_id_map", "file_item_tree_query", "real_span_map_shim", "crate_local_def_map", "TraitItems::query_with_diagnostics_", - "body_with_source_map_shim", + "Body::with_source_map_", "AttrFlags::query_", - "body_shim", + "Body::of_", "ImplItems::of_", "InferenceResult::for_body_", "AttrFlags::query_", - "trait_signature_with_source_map_shim", + "TraitSignature::with_source_map_", "AttrFlags::query_", - "function_signature_with_source_map_shim", - "function_signature_shim", - "body_with_source_map_shim", - "body_shim", + "FunctionSignature::with_source_map_", + "FunctionSignature::of_", + "Body::with_source_map_", + "Body::of_", "crate_lang_items", "GenericPredicates::query_with_diagnostics_", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", "InferenceResult::for_body_", - "function_signature_with_source_map_shim", + "FunctionSignature::with_source_map_", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", - "expr_scopes_shim", - "struct_signature_with_source_map_shim", + "ExprScopes::body_expr_scopes_", + "StructSignature::with_source_map_", "AttrFlags::query_", "GenericPredicates::query_with_diagnostics_", "InherentImpls::for_crate_", "callable_item_signature_query", "TraitImpls::for_crate_", - "impl_signature_with_source_map_shim", - "impl_signature_shim", + "ImplSignature::with_source_map_", + "ImplSignature::of_", "impl_trait_with_diagnostics_query", "impl_self_ty_with_diagnostics_query", "AttrFlags::query_", @@ -709,6 +716,7 @@ fn execute_assert_events( ) { crate::attach_db(db, || { let (executed, events) = db.log_executed(f); + expect.assert_debug_eq(&executed); for (event, count) in required { let n = executed.iter().filter(|it| it.contains(event)).count(); assert_eq!( @@ -724,6 +732,5 @@ fn execute_assert_events( .collect::<Vec<_>>(), ); } - expect.assert_debug_eq(&executed); }); } diff --git a/crates/hir-ty/src/tests/macros.rs b/crates/hir-ty/src/tests/macros.rs index 2f41de64cb..28a688d4a3 100644 --- a/crates/hir-ty/src/tests/macros.rs +++ b/crates/hir-ty/src/tests/macros.rs @@ -9,16 +9,16 @@ use super::{check_infer, check_types}; fn cfg_impl_def() { check_types( r#" -//- /main.rs crate:main deps:foo cfg:test +//- /main.rs crate:main deps:foo cfg:some_cfg use foo::S as T; struct S; -#[cfg(test)] +#[cfg(some_cfg)] impl S { fn foo1(&self) -> i32 { 0 } } -#[cfg(not(test))] +#[cfg(not(some_cfg))] impl S { fn foo2(&self) -> i32 { 0 } } @@ -31,12 +31,12 @@ fn test() { //- /foo.rs crate:foo pub struct S; -#[cfg(not(test))] +#[cfg(not(some_cfg))] impl S { pub fn foo3(&self) -> i32 { 0 } } -#[cfg(test)] +#[cfg(some_cfg)] impl S { pub fn foo4(&self) -> i32 { 0 } } diff --git a/crates/hir-ty/src/tests/never_type.rs b/crates/hir-ty/src/tests/never_type.rs index 4d68179a88..993293bb56 100644 --- a/crates/hir-ty/src/tests/never_type.rs +++ b/crates/hir-ty/src/tests/never_type.rs @@ -761,6 +761,79 @@ fn coerce_ref_binding() -> ! { } #[test] +fn diverging_place_match_ref_mut() { + check_infer_with_mismatches( + r#" +//- minicore: sized +fn coerce_ref_mut_binding() -> ! { + unsafe { + let x: *mut ! = 0 as _; + let ref mut _x: () = *x; + } +} +"#, + expect![[r#" + 33..120 '{ ... } }': ! + 39..118 'unsafe... }': ! + 60..61 'x': *mut ! + 72..73 '0': i32 + 72..78 '0 as _': *mut ! + 92..102 'ref mut _x': &'? mut () + 109..111 '*x': ! + 110..111 'x': *mut ! + 109..111: expected (), got ! + "#]], + ) +} + +#[test] +fn assign_never_place_no_mismatch() { + check_no_mismatches( + r#" +//- minicore: sized +fn foo() { + unsafe { + let p: *mut ! = 0 as _; + let mut x: () = (); + x = *p; + } +} +"#, + ); +} + +#[test] +fn binop_rhs_never_place_diverges() { + check_no_mismatches( + r#" +//- minicore: sized, add +fn foo() -> i32 { + unsafe { + let p: *mut ! = 0 as _; + let mut x: i32 = 0; + x += *p; + } +} +"#, + ); +} + +#[test] +fn binop_lhs_never_place_diverges() { + check_no_mismatches( + r#" +//- minicore: sized, add +fn foo() { + unsafe { + let p: *mut ! = 0 as _; + *p += 1; + } +} +"#, + ); +} + +#[test] fn never_place_isnt_diverging() { check_infer_with_mismatches( r#" @@ -813,3 +886,18 @@ fn foo() { "#]], ); } + +#[test] +fn never_coercion_in_struct_update_syntax() { + check_no_mismatches( + r#" +struct Struct { + field: i32, +} + +fn example() -> Struct { + Struct { ..loop {} } +} + "#, + ); +} diff --git a/crates/hir-ty/src/tests/patterns.rs b/crates/hir-ty/src/tests/patterns.rs index 8c7d29f993..42dc074309 100644 --- a/crates/hir-ty/src/tests/patterns.rs +++ b/crates/hir-ty/src/tests/patterns.rs @@ -421,6 +421,8 @@ fn infer_pattern_match_byte_string_literal() { 254..256 '&v': &'? [u8; 3] 255..256 'v': [u8; 3] 257..259 '{}': () + 199..200 '3': usize + 62..63 'N': usize "#]], ); } diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs index 4f1480c393..e4fc7e56c6 100644 --- a/crates/hir-ty/src/tests/regression.rs +++ b/crates/hir-ty/src/tests/regression.rs @@ -2645,3 +2645,199 @@ where "#, ); } + +#[test] +fn issue_21560() { + check_no_mismatches( + r#" +mod bindings { + use super::*; + pub type HRESULT = i32; +} +use bindings::*; + + +mod error { + use super::*; + pub fn nonzero_hresult(hr: HRESULT) -> crate::HRESULT { + hr + } +} +pub use error::*; + +mod hresult { + use super::*; + pub struct HRESULT(pub i32); +} +pub use hresult::HRESULT; + + "#, + ); +} + +#[test] +fn regression_21577() { + check_no_mismatches( + r#" +pub trait FilterT<F: FilterT<F, V = Self::V> = Self> { + type V; + + fn foo() {} +} + "#, + ); +} + +#[test] +fn regression_21605() { + check_infer( + r#" +//- minicore: fn, coerce_unsized, dispatch_from_dyn, iterator, iterators +pub struct Filter<'a, 'b, T> +where + T: 'b, + 'a: 'b, +{ + filter_fn: dyn Fn(&'a T) -> bool, + t: Option<T>, + b: &'b (), +} + +impl<'a, 'b, T> Filter<'a, 'b, T> +where + T: 'b, + 'a: 'b, +{ + pub fn new(filter_fn: dyn Fn(&T) -> bool) -> Self { + Self { + filter_fn: filter_fn, + t: None, + b: &(), + } + } +} + +pub trait FilterExt<T> { + type Output; + fn filter(&self, filter: &Filter<T>) -> Self::Output; +} + +impl<const N: usize, T> FilterExt<T> for [T; N] +where + T: IntoIterator, +{ + type Output = T; + fn filter(&self, filter: &Filter<T>) -> Self::Output { + let _ = self.into_iter().filter(filter.filter_fn); + loop {} + } +} +"#, + expect![[r#" + 214..223 'filter_fn': dyn Fn(&'? T) -> bool + 'static + 253..360 '{ ... }': Filter<'a, 'b, T> + 263..354 'Self {... }': Filter<'a, 'b, T> + 293..302 'filter_fn': dyn Fn(&'? T) -> bool + 'static + 319..323 'None': Option<T> + 340..343 '&()': &'? () + 341..343 '()': () + 421..425 'self': &'? Self + 427..433 'filter': &'? Filter<'?, '?, T> + 580..584 'self': &'? [T; N] + 586..592 'filter': &'? Filter<'?, '?, T> + 622..704 '{ ... }': T + 636..637 '_': Filter<Iter<'?, T>, dyn Fn(&'? T) -> bool + '?> + 640..644 'self': &'? [T; N] + 640..656 'self.i...iter()': Iter<'?, T> + 640..681 'self.i...er_fn)': Filter<Iter<'?, T>, dyn Fn(&'? T) -> bool + '?> + 664..670 'filter': &'? Filter<'?, '?, T> + 664..680 'filter...ter_fn': dyn Fn(&'? T) -> bool + 'static + 691..698 'loop {}': ! + 696..698 '{}': () + 512..513 'N': usize + "#]], + ); +} + +#[test] +fn extern_fns_cannot_have_param_patterns() { + check_no_mismatches( + r#" +pub(crate) struct Builder<'a>(&'a ()); + +unsafe extern "C" { + pub(crate) fn foo<'a>(Builder: &Builder<'a>); +} + "#, + ); +} + +#[test] +fn infinitely_sized_type() { + check_infer( + r#" +//- minicore: sized + +pub struct Recursive { + pub content: Recursive, +} + +fn is_sized<T: Sized>() {} + +fn foo() { + is_sized::<Recursive>(); +} + "#, + expect![[r#" + 79..81 '{}': () + 92..124 '{ ...>(); }': () + 98..119 'is_siz...rsive>': fn is_sized<Recursive>() + 98..121 'is_siz...ive>()': () + "#]], + ); +} + +#[test] +fn regression_21742() { + check_no_mismatches( + r#" +pub trait IntoIterator { + type Item; +} + +pub trait Collection: IntoIterator<Item = <Self as Collection>::Item> { + type Item; + fn contains(&self, item: &<Self as Collection>::Item); +} + +fn contains_0<S: Collection<Item = i32>>(points: &S) { + points.contains(&0) +} + "#, + ); +} + +#[test] +fn regression_21773() { + check_no_mismatches( + r#" +trait Neg { + type Output; +} + +trait Abs: Neg { + fn abs(&self) -> Self::Output; +} + +trait SelfAbs: Abs + Neg +where + Self::Output: Neg<Output = Self::Output> + Abs, +{ +} + +fn wrapped_abs<T: SelfAbs<Output = T>>(v: T) -> T { + v.abs() +} + "#, + ); +} diff --git a/crates/hir-ty/src/tests/regression/new_solver.rs b/crates/hir-ty/src/tests/regression/new_solver.rs index f47a26d429..e6b3244cda 100644 --- a/crates/hir-ty/src/tests/regression/new_solver.rs +++ b/crates/hir-ty/src/tests/regression/new_solver.rs @@ -34,6 +34,7 @@ impl Space for [u8; 1] { 223..227 'iter': IntoIter<u8> 230..231 'a': Vec<u8> 230..243 'a.into_iter()': IntoIter<u8> + 322..323 '1': usize "#]], ); } @@ -472,6 +473,8 @@ fn foo() { 249..257 'to_bytes': fn to_bytes() -> [u8; _] 249..259 'to_bytes()': [u8; _] 249..268 'to_byt..._vec()': Vec<<[u8; _] as Foo>::Item> + 205..206 '_': usize + 156..157 'N': usize "#]], ); } @@ -541,6 +544,11 @@ fn test_at_most() { 617..620 'num': Between<0, 1, char> 623..626 ''9'': char 623..641 ''9'.at...:<1>()': Between<0, 1, char> + 320..335 '{ Consts::MAX }': usize + 322..333 'Consts::MAX': usize + 421..422 '0': i32 + 144..159 '{ Consts::MAX }': usize + 146..157 'Consts::MAX': usize "#]], ); } diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 98503452d3..3ea21f8265 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -2152,10 +2152,11 @@ async fn main() { let z: core::ops::ControlFlow<(), _> = try { () }; let w = const { 92 }; let t = 'a: { 92 }; + let u = try bikeshed core::ops::ControlFlow<(), _> { () }; } "#, expect![[r#" - 16..193 '{ ...2 }; }': () + 16..256 '{ ...) }; }': () 26..27 'x': i32 30..43 'unsafe { 92 }': i32 39..41 '92': i32 @@ -2176,6 +2177,13 @@ async fn main() { 176..177 't': i32 180..190 ''a: { 92 }': i32 186..188 '92': i32 + 200..201 'u': ControlFlow<(), ()> + 204..253 'try bi...{ () }': ControlFlow<(), ()> + 204..253 'try bi...{ () }': fn from_output<ControlFlow<(), ()>>(<ControlFlow<(), ()> as Try>::Output) -> ControlFlow<(), ()> + 204..253 'try bi...{ () }': ControlFlow<(), ()> + 204..253 'try bi...{ () }': ControlFlow<(), ()> + 204..253 'try bi...{ () }': ControlFlow<(), ()> + 249..251 '()': () "#]], ) } @@ -3860,6 +3868,8 @@ fn main() { 208..209 'c': u8 213..214 'a': A 213..221 'a.into()': [u8; 2] + 33..34 '2': usize + 111..112 '3': usize "#]], ); } @@ -4053,6 +4063,88 @@ fn foo() { 248..282 'LazyLo..._LOCK)': &'? [u32; _] 264..281 '&VALUE...Y_LOCK': &'? LazyLock<[u32; _]> 265..281 'VALUES...Y_LOCK': LazyLock<[u32; _]> + 197..202 '{ 0 }': usize + 199..200 '0': usize + "#]], + ); +} + +#[test] +fn include_bytes_len_mismatch() { + check_no_mismatches( + r#" +//- minicore: include_bytes +static S: &[u8; 158] = include_bytes!("/foo/bar/baz.txt"); + "#, + ); +} + +#[test] +fn proc_macros_are_functions_inside_defining_crate_and_macros_outside() { + check_types( + r#" +//- /pm.rs crate:pm +#![crate_type = "proc-macro"] + +#[proc_macro_attribute] +pub fn proc_macro() {} + +fn foo() { + proc_macro; + // ^^^^^^^^^^ fn proc_macro() +} + +mod bar { + use super::proc_macro; + + fn baz() { + super::proc_macro; + // ^^^^^^^^^^^^^^^^^ fn proc_macro() + proc_macro; + // ^^^^^^^^^^ fn proc_macro() + } +} + +//- /lib.rs crate:lib deps:pm +fn foo() { + pm::proc_macro; + // ^^^^^^^^^^^^^^ {unknown} +} + "#, + ); +} + +#[test] +fn signature_inference() { + check_infer( + r#" +trait Trait<const A: u8> {} +struct S<T: Trait<2>, const C: f32 = 0.0> +where + (): Trait<2> +{ + field: [(); { C as usize }], + field2: *mut S<T, 5.0> +} + +struct S2<const C: u16>; + +type Alias = S2<0>; +impl S2<0> {} +enum E { + V(S2<0>) = 0, +} +union U { + field: S2<0> +} + "#, + expect![[r#" + 242..243 '0': isize + 46..47 '2': i32 + 65..68 '0.0': f32 + 90..91 '2': i32 + 200..201 '0': i32 + 212..213 '0': i32 "#]], ); } diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 390553c0d7..22359d8f1f 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -219,14 +219,16 @@ fn test() { #[test] fn infer_try_block() { - // FIXME: We should test more cases, but it currently doesn't work, since - // our labeled block type inference is broken. check_types( r#" -//- minicore: try, option +//- minicore: try, option, result, from fn test() { let x: Option<_> = try { Some(2)?; }; //^ Option<()> + let homogeneous = try { Ok::<(), u32>(())?; "hi" }; + //^^^^^^^^^^^ Result<&'? str, u32> + let heterogeneous = try bikeshed Result<_, u64> { 1 }; + //^^^^^^^^^^^^^ Result<i32, u64> } "#, ); @@ -1269,6 +1271,7 @@ fn bar() { 241..245 'R::B': fn B<(), i32>(i32) -> R<(), i32> 241..248 'R::B(7)': R<(), i32> 246..247 '7': i32 + 46..47 '2': usize "#]], ); } @@ -2216,6 +2219,40 @@ fn test() { } #[test] +fn tuple_struct_constructor_as_fn_trait() { + check_types( + r#" +//- minicore: fn +struct S(u32, u64); + +fn takes_fn<F: Fn(u32, u64) -> S>(f: F) -> S { f(1, 2) } + +fn test() { + takes_fn(S); + //^^^^^^^^^^^ S +} +"#, + ); +} + +#[test] +fn enum_variant_constructor_as_fn_trait() { + check_types( + r#" +//- minicore: fn +enum E { A(u32) } + +fn takes_fn<F: Fn(u32) -> E>(f: F) -> E { f(1) } + +fn test() { + takes_fn(E::A); + //^^^^^^^^^^^^^^ E +} +"#, + ); +} + +#[test] fn fn_item_fn_trait() { check_types( r#" @@ -3745,6 +3782,8 @@ fn main() { 371..373 'v4': usize 376..378 'v3': [u8; 4] 376..389 'v3.do_thing()': usize + 86..87 '4': usize + 192..193 '2': usize "#]], ) } @@ -3784,6 +3823,9 @@ fn main() { 240..242 'v2': [u8; 2] 245..246 'v': [u8; 2] 245..257 'v.do_thing()': [u8; 2] + 130..131 'L': usize + 102..103 'L': usize + 130..131 'L': usize "#]], ) } @@ -4819,7 +4861,7 @@ fn allowed3(baz: impl Baz<Assoc = Qux<impl Foo>>) {} 431..433 '{}': () 447..450 'baz': impl Baz<Assoc = impl Foo> 480..482 '{}': () - 500..503 'baz': impl Baz<Assoc = &'a impl Foo + 'a> + 500..503 'baz': impl Baz<Assoc = &'a (impl Foo + 'a)> 544..546 '{}': () 560..563 'baz': impl Baz<Assoc = Qux<impl Foo>> 598..600 '{}': () diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs index fb598fe5ac..878696c721 100644 --- a/crates/hir-ty/src/traits.rs +++ b/crates/hir-ty/src/traits.rs @@ -7,7 +7,11 @@ use hir_def::{ AdtId, AssocItemId, HasModule, ImplId, Lookup, TraitId, lang_item::LangItems, nameres::DefMap, - signatures::{ConstFlags, EnumFlags, FnFlags, StructFlags, TraitFlags, TypeAliasFlags}, + signatures::{ + ConstFlags, ConstSignature, EnumFlags, EnumSignature, FnFlags, FunctionSignature, + StructFlags, StructSignature, TraitFlags, TraitSignature, TypeAliasFlags, + TypeAliasSignature, UnionSignature, + }, }; use hir_expand::name::Name; use intern::sym; @@ -279,21 +283,18 @@ pub fn is_inherent_impl_coherent(db: &dyn HirDatabase, def_map: &DefMap, impl_id | TyKind::Float(_) => true, TyKind::Adt(adt_def, _) => match adt_def.def_id().0 { - hir_def::AdtId::StructId(id) => db - .struct_signature(id) + hir_def::AdtId::StructId(id) => StructSignature::of(db, id) .flags .contains(StructFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS), - hir_def::AdtId::UnionId(id) => db - .union_signature(id) + hir_def::AdtId::UnionId(id) => UnionSignature::of(db, id) .flags .contains(StructFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS), - hir_def::AdtId::EnumId(it) => db - .enum_signature(it) + hir_def::AdtId::EnumId(it) => EnumSignature::of(db, it) .flags .contains(EnumFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS), }, TyKind::Dynamic(it, _) => it.principal_def_id().is_some_and(|trait_id| { - db.trait_signature(trait_id.0) + TraitSignature::of(db, trait_id.0) .flags .contains(TraitFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS) }), @@ -304,14 +305,13 @@ pub fn is_inherent_impl_coherent(db: &dyn HirDatabase, def_map: &DefMap, impl_id rustc_has_incoherent_inherent_impls && !items.items.is_empty() && items.items.iter().all(|&(_, assoc)| match assoc { - AssocItemId::FunctionId(it) => { - db.function_signature(it).flags.contains(FnFlags::RUSTC_ALLOW_INCOHERENT_IMPL) - } - AssocItemId::ConstId(it) => { - db.const_signature(it).flags.contains(ConstFlags::RUSTC_ALLOW_INCOHERENT_IMPL) - } - AssocItemId::TypeAliasId(it) => db - .type_alias_signature(it) + AssocItemId::FunctionId(it) => FunctionSignature::of(db, it) + .flags + .contains(FnFlags::RUSTC_ALLOW_INCOHERENT_IMPL), + AssocItemId::ConstId(it) => ConstSignature::of(db, it) + .flags + .contains(ConstFlags::RUSTC_ALLOW_INCOHERENT_IMPL), + AssocItemId::TypeAliasId(it) => TypeAliasSignature::of(db, it) .flags .contains(TypeAliasFlags::RUSTC_ALLOW_INCOHERENT_IMPL), }) @@ -350,7 +350,7 @@ pub fn check_orphan_rules<'db>(db: &'db dyn HirDatabase, impl_: ImplId) -> bool let AdtId::StructId(s) = adt_def.def_id().0 else { break ty; }; - let struct_signature = db.struct_signature(s); + let struct_signature = StructSignature::of(db, s); if struct_signature.flags.contains(StructFlags::FUNDAMENTAL) { let next = subs.types().next(); match next { diff --git a/crates/hir-ty/src/upvars.rs b/crates/hir-ty/src/upvars.rs index ee864ab068..489895fe3c 100644 --- a/crates/hir-ty/src/upvars.rs +++ b/crates/hir-ty/src/upvars.rs @@ -44,10 +44,10 @@ pub fn upvars_mentioned( db: &dyn HirDatabase, owner: DefWithBodyId, ) -> Option<Box<FxHashMap<ExprId, Upvars>>> { - let body = db.body(owner); + let body = Body::of(db, owner); let mut resolver = owner.resolver(db); let mut result = FxHashMap::default(); - handle_expr_outside_closure(db, &mut resolver, owner, &body, body.body_expr, &mut result); + handle_expr_outside_closure(db, &mut resolver, owner, body, body.root_expr(), &mut result); return if result.is_empty() { None } else { @@ -198,7 +198,7 @@ fn resolve_maybe_upvar<'db>( #[cfg(test)] mod tests { use expect_test::{Expect, expect}; - use hir_def::{ModuleDefId, db::DefDatabase, nameres::crate_def_map}; + use hir_def::{ModuleDefId, expr_store::Body, nameres::crate_def_map}; use itertools::Itertools; use span::Edition; use test_fixture::WithFixture; @@ -219,7 +219,7 @@ mod tests { }) .exactly_one() .unwrap_or_else(|_| panic!("expected one function")); - let (body, source_map) = db.body_with_source_map(func.into()); + let (body, source_map) = Body::with_source_map(&db, func.into()); let Some(upvars) = upvars_mentioned(&db, func.into()) else { expectation.assert_eq(""); return; diff --git a/crates/hir-ty/src/utils.rs b/crates/hir-ty/src/utils.rs index 7dd73f1e7a..509109543c 100644 --- a/crates/hir-ty/src/utils.rs +++ b/crates/hir-ty/src/utils.rs @@ -1,27 +1,20 @@ //! Helper functions for working with def, which don't need to be a separate //! query, but can't be computed directly from `*Data` (ie, which need a `db`). -use std::cell::LazyCell; - use base_db::target::{self, TargetData}; use hir_def::{ - EnumId, EnumVariantId, FunctionId, Lookup, TraitId, - attrs::AttrFlags, - db::DefDatabase, - hir::generics::WherePredicate, - lang_item::LangItems, - resolver::{HasResolver, TypeNs}, - type_ref::{TraitBoundModifier, TypeRef}, + EnumId, EnumVariantId, FunctionId, Lookup, TraitId, attrs::AttrFlags, lang_item::LangItems, + signatures::FunctionSignature, }; use intern::sym; use rustc_abi::TargetDataLayout; -use smallvec::{SmallVec, smallvec}; use span::Edition; use crate::{ TargetFeatures, db::HirDatabase, layout::{Layout, TagEncoding}, + lower::SupertraitsInfo, mir::pad16, }; @@ -50,65 +43,13 @@ pub(crate) fn fn_traits(lang_items: &LangItems) -> impl Iterator<Item = TraitId> } /// Returns an iterator over the direct super traits (including the trait itself). -pub fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> SmallVec<[TraitId; 4]> { - let mut result = smallvec![trait_]; - direct_super_traits_cb(db, trait_, |tt| { - if !result.contains(&tt) { - result.push(tt); - } - }); - result +pub fn direct_super_traits(db: &dyn HirDatabase, trait_: TraitId) -> &[TraitId] { + &SupertraitsInfo::query(db, trait_).direct_supertraits } -/// Returns an iterator over the whole super trait hierarchy (including the -/// trait itself). -pub fn all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> SmallVec<[TraitId; 4]> { - // we need to take care a bit here to avoid infinite loops in case of cycles - // (i.e. if we have `trait A: B; trait B: A;`) - - let mut result = smallvec![trait_]; - let mut i = 0; - while let Some(&t) = result.get(i) { - // yeah this is quadratic, but trait hierarchies should be flat - // enough that this doesn't matter - direct_super_traits_cb(db, t, |tt| { - if !result.contains(&tt) { - result.push(tt); - } - }); - i += 1; - } - result -} - -fn direct_super_traits_cb(db: &dyn DefDatabase, trait_: TraitId, cb: impl FnMut(TraitId)) { - let resolver = LazyCell::new(|| trait_.resolver(db)); - let (generic_params, store) = db.generic_params_and_store(trait_.into()); - let trait_self = generic_params.trait_self_param(); - generic_params - .where_predicates() - .iter() - .filter_map(|pred| match pred { - WherePredicate::ForLifetime { target, bound, .. } - | WherePredicate::TypeBound { target, bound } => { - let is_trait = match &store[*target] { - TypeRef::Path(p) => p.is_self_type(), - TypeRef::TypeParam(p) => Some(p.local_id()) == trait_self, - _ => false, - }; - match is_trait { - true => bound.as_path(&store), - false => None, - } - } - WherePredicate::Lifetime { .. } => None, - }) - .filter(|(_, bound_modifier)| matches!(bound_modifier, TraitBoundModifier::None)) - .filter_map(|(path, _)| match resolver.resolve_path_in_type_ns_fully(db, path) { - Some(TypeNs::TraitId(t)) => Some(t), - _ => None, - }) - .for_each(cb); +/// Returns the whole super trait hierarchy (including the trait itself). +pub fn all_super_traits(db: &dyn HirDatabase, trait_: TraitId) -> &[TraitId] { + &SupertraitsInfo::query(db, trait_).all_supertraits } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -139,7 +80,7 @@ pub fn is_fn_unsafe_to_call( call_edition: Edition, target_feature_is_safe: TargetFeatureIsSafeInTarget, ) -> Unsafety { - let data = db.function_signature(func); + let data = FunctionSignature::of(db, func); if data.is_unsafe() { return Unsafety::Unsafe; } diff --git a/crates/hir-ty/src/variance.rs b/crates/hir-ty/src/variance.rs index 6f415a5289..1945b04bb3 100644 --- a/crates/hir-ty/src/variance.rs +++ b/crates/hir-ty/src/variance.rs @@ -13,7 +13,10 @@ //! by the next salsa version. If not, we will likely have to adapt and go with the rustc approach //! while installing firewall per item queries to prevent invalidation issues. -use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId, signatures::StructFlags}; +use hir_def::{ + AdtId, GenericDefId, GenericParamId, VariantId, + signatures::{StructFlags, StructSignature}, +}; use rustc_ast_ir::Mutability; use rustc_type_ir::{ Variance, @@ -45,7 +48,7 @@ fn variances_of_query(db: &dyn HirDatabase, def: GenericDefId) -> StoredVariance GenericDefId::FunctionId(_) => (), GenericDefId::AdtId(adt) => { if let AdtId::StructId(id) = adt { - let flags = &db.struct_signature(id).flags; + let flags = &StructSignature::of(db, id).flags; let types = || crate::next_solver::default_types(db); if flags.contains(StructFlags::IS_UNSAFE_CELL) { return types().one_invariant.store(); @@ -113,7 +116,7 @@ pub(crate) fn variances_of_cycle_initial( struct Context<'db> { db: &'db dyn HirDatabase, - generics: Generics, + generics: Generics<'db>, variances: Box<[Variance]>, } diff --git a/crates/hir/src/attrs.rs b/crates/hir/src/attrs.rs index cfb95e07c3..27e7985146 100644 --- a/crates/hir/src/attrs.rs +++ b/crates/hir/src/attrs.rs @@ -7,6 +7,7 @@ use hir_def::{ TraitId, TypeOrConstParamId, attrs::{AttrFlags, Docs, IsInnerDoc}, expr_store::path::Path, + hir::generics::GenericParams, item_scope::ItemInNs, per_ns::Namespace, resolver::{HasResolver, Resolver, TypeNs}, @@ -26,9 +27,9 @@ use intern::Symbol; use stdx::never; use crate::{ - Adt, AsAssocItem, AssocItem, BuiltinType, Const, ConstParam, DocLinkDef, Enum, ExternCrateDecl, - Field, Function, GenericParam, HasCrate, Impl, LangItem, LifetimeParam, Macro, Module, - ModuleDef, Static, Struct, Trait, Type, TypeAlias, TypeParam, Union, Variant, VariantDef, + Adt, AsAssocItem, AssocItem, BuiltinType, Const, ConstParam, DocLinkDef, Enum, EnumVariant, + ExternCrateDecl, Field, Function, GenericParam, HasCrate, Impl, LangItem, LifetimeParam, Macro, + Module, ModuleDef, Static, Struct, Trait, Type, TypeAlias, TypeParam, Union, Variant, }; #[derive(Debug, Clone, Copy)] @@ -199,7 +200,7 @@ macro_rules! impl_has_attrs { } impl_has_attrs![ - (Variant, EnumVariantId), + (EnumVariant, EnumVariantId), (Static, StaticId), (Const, ConstId), (Trait, TraitId), @@ -377,7 +378,7 @@ fn resolve_assoc_or_field( let ty = match base_def { TypeNs::SelfType(id) => Impl::from(id).self_ty(db), TypeNs::GenericParam(param) => { - let generic_params = db.generic_params(param.parent()); + let generic_params = GenericParams::of(db, param.parent()); if generic_params[param.local_id()].is_trait_self() { // `Self::assoc` in traits should refer to the trait itself. let parent_trait = |container| match container { @@ -406,7 +407,7 @@ fn resolve_assoc_or_field( TypeNs::AdtId(id) | TypeNs::AdtSelfType(id) => Adt::from(id).ty(db), TypeNs::EnumVariantId(id) => { // Enum variants don't have path candidates. - let variant = Variant::from(id); + let variant = EnumVariant::from(id); return resolve_field(db, variant.into(), name, ns); } TypeNs::TypeAliasId(id) => { @@ -443,7 +444,7 @@ fn resolve_assoc_or_field( .id .enum_variants(db) .variant(&name) - .map(|variant| DocLinkDef::ModuleDef(ModuleDef::Variant(variant.into()))); + .map(|variant| DocLinkDef::ModuleDef(ModuleDef::EnumVariant(variant.into()))); } }; resolve_field(db, variant_def, name, ns) @@ -505,7 +506,7 @@ fn resolve_impl_trait_item<'db>( fn resolve_field( db: &dyn HirDatabase, - def: VariantDef, + def: Variant, name: Name, ns: Option<Namespace>, ) -> Option<DocLinkDef> { diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs index 050777a480..7f672a697c 100644 --- a/crates/hir/src/diagnostics.rs +++ b/crates/hir/src/diagnostics.rs @@ -51,20 +51,6 @@ macro_rules! diagnostics { )* }; } -// FIXME Accept something like the following in the macro call instead -// diagnostics![ -// pub struct BreakOutsideOfLoop { -// pub expr: InFile<AstPtr<ast::Expr>>, -// pub is_break: bool, -// pub bad_value_break: bool, -// }, ... -// or more concisely -// BreakOutsideOfLoop { -// expr: InFile<AstPtr<ast::Expr>>, -// is_break: bool, -// bad_value_break: bool, -// }, ... -// ] diagnostics![AnyDiagnostic<'db> -> AwaitOutsideOfAsync, diff --git a/crates/hir/src/display.rs b/crates/hir/src/display.rs index 1f9af564c3..4bfdd239f9 100644 --- a/crates/hir/src/display.rs +++ b/crates/hir/src/display.rs @@ -4,10 +4,13 @@ use either::Either; use hir_def::{ AdtId, BuiltinDeriveImplId, FunctionId, GenericDefId, ImplId, ItemContainerId, builtin_derive::BuiltinDeriveImplMethod, - expr_store::ExpressionStore, + expr_store::{Body, ExpressionStore}, hir::generics::{GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate}, item_tree::FieldsShape, - signatures::{StaticFlags, TraitFlags}, + signatures::{ + ConstSignature, FunctionSignature, ImplSignature, StaticFlags, StaticSignature, TraitFlags, + TraitSignature, TypeAliasSignature, + }, type_ref::{TypeBound, TypeRef, TypeRefId}, }; use hir_expand::name::Name; @@ -26,9 +29,9 @@ use rustc_type_ir::inherent::IntoKind; use crate::{ Adt, AnyFunctionId, AsAssocItem, AssocItem, AssocItemContainer, Const, ConstParam, Crate, Enum, - ExternCrateDecl, Field, Function, GenericParam, HasCrate, HasVisibility, Impl, LifetimeParam, - Macro, Module, SelfParam, Static, Struct, StructKind, Trait, TraitRef, TupleField, Type, - TypeAlias, TypeNs, TypeOrConstParam, TypeParam, Union, Variant, + EnumVariant, ExternCrateDecl, Field, Function, GenericParam, HasCrate, HasVisibility, Impl, + LifetimeParam, Macro, Module, SelfParam, Static, Struct, StructKind, Trait, TraitRef, + TupleField, Type, TypeAlias, TypeNs, TypeOrConstParam, TypeParam, Union, }; fn write_builtin_derive_impl_method<'db>( @@ -38,7 +41,7 @@ fn write_builtin_derive_impl_method<'db>( ) -> Result { let db = f.db; let loc = impl_.loc(db); - let (adt_params, _adt_params_store) = db.generic_params_and_store(loc.adt.into()); + let adt_params = GenericParams::of(db, loc.adt.into()); if f.show_container_bounds() && !adt_params.is_empty() { f.write_str("impl")?; @@ -94,22 +97,22 @@ impl<'db> HirDisplay<'db> for Function { // Write container (trait or impl) let container_params = match container { ItemContainerId::TraitId(trait_) => { - let (params, params_store) = f.db.generic_params_and_store(trait_.into()); + let (params, params_store) = GenericParams::with_store(f.db, trait_.into()); if f.show_container_bounds() && !params.is_empty() { write_trait_header(trait_.into(), f)?; f.write_char('\n')?; - has_disaplayable_predicates(f.db, ¶ms, ¶ms_store) + has_disaplayable_predicates(f.db, params, params_store) .then_some((params, params_store)) } else { None } } ItemContainerId::ImplId(impl_) => { - let (params, params_store) = f.db.generic_params_and_store(impl_.into()); + let (params, params_store) = GenericParams::with_store(f.db, impl_.into()); if f.show_container_bounds() && !params.is_empty() { write_impl_header(impl_, f)?; f.write_char('\n')?; - has_disaplayable_predicates(f.db, ¶ms, ¶ms_store) + has_disaplayable_predicates(f.db, params, params_store) .then_some((params, params_store)) } else { None @@ -131,7 +134,7 @@ impl<'db> HirDisplay<'db> for Function { _ => unreachable!(), }; write!(f, "\n // Bounds from {container_name}:",)?; - write_where_predicates(&container_params, &container_params_store, f)?; + write_where_predicates(container_params, container_params_store, f)?; } Ok(()) } @@ -140,7 +143,7 @@ impl<'db> HirDisplay<'db> for Function { fn write_function<'db>(f: &mut HirFormatter<'_, 'db>, func_id: FunctionId) -> Result<bool> { let db = f.db; let func = Function::from(func_id); - let data = db.function_signature(func_id); + let data = FunctionSignature::of(db, func_id); let mut module = func.module(db); // Block-local impls are "hoisted" to the nearest (non-block) module. @@ -172,8 +175,13 @@ fn write_function<'db>(f: &mut HirFormatter<'_, 'db>, func_id: FunctionId) -> Re write_generic_params(GenericDefId::FunctionId(func_id), f)?; + let too_long_param = data.params.len() > 4; f.write_char('(')?; + if too_long_param { + f.write_str("\n ")?; + } + let mut first = true; let mut skip_self = 0; if let Some(self_param) = func.self_param(db) { @@ -182,11 +190,12 @@ fn write_function<'db>(f: &mut HirFormatter<'_, 'db>, func_id: FunctionId) -> Re skip_self = 1; } + let comma = if too_long_param { ",\n " } else { ", " }; // FIXME: Use resolved `param.ty` once we no longer discard lifetimes - let body = db.body(func_id.into()); + let body = Body::of(db, func_id.into()); for (type_ref, param) in data.params.iter().zip(func.assoc_fn_params(db)).skip(skip_self) { if !first { - f.write_str(", ")?; + f.write_str(comma)?; } else { first = false; } @@ -201,11 +210,14 @@ fn write_function<'db>(f: &mut HirFormatter<'_, 'db>, func_id: FunctionId) -> Re if data.is_varargs() { if !first { - f.write_str(", ")?; + f.write_str(comma)?; } f.write_str("...")?; } + if too_long_param { + f.write_char('\n')?; + } f.write_char(')')?; // `FunctionData::ret_type` will be `::core::future::Future<Output = ...>` for async fns. @@ -259,7 +271,7 @@ fn write_impl_header<'db>(impl_: ImplId, f: &mut HirFormatter<'_, 'db>) -> Resul let def_id = GenericDefId::ImplId(impl_); write_generic_params(def_id, f)?; - let impl_data = db.impl_signature(impl_); + let impl_data = ImplSignature::of(db, impl_); if let Some(target_trait) = &impl_data.target_trait { f.write_char(' ')?; hir_display_with_store(&impl_data.store[target_trait.path], &impl_data.store).hir_fmt(f)?; @@ -288,7 +300,7 @@ impl<'db> HirDisplay<'db> for SelfParam { } }, }; - let data = f.db.function_signature(func); + let data = FunctionSignature::of(f.db, func); let param = *data.params.first().unwrap(); match &data.store[param] { TypeRef::Path(p) if p.is_self_type() => f.write_str("self"), @@ -431,7 +443,7 @@ fn write_fields<'db>( } fn write_variants<'db>( - variants: &[Variant], + variants: &[EnumVariant], has_where_clause: bool, limit: usize, f: &mut HirFormatter<'_, 'db>, @@ -485,7 +497,7 @@ impl<'db> HirDisplay<'db> for TupleField { } } -impl<'db> HirDisplay<'db> for Variant { +impl<'db> HirDisplay<'db> for EnumVariant { fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result { write!(f, "{}", self.name(f.db).display(f.db, f.edition()))?; let data = self.id.fields(f.db); @@ -560,7 +572,7 @@ impl<'db> HirDisplay<'db> for TypeOrConstParam { impl<'db> HirDisplay<'db> for TypeParam { fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result { - let params = f.db.generic_params(self.id.parent()); + let params = GenericParams::of(f.db, self.id.parent()); let param_data = ¶ms[self.id.local_id()]; let krate = self.id.parent().krate(f.db).id; let ty = self.ty(f.db).ty; @@ -587,6 +599,7 @@ impl<'db> HirDisplay<'db> for TypeParam { Either::Left(ty), &predicates, SizedByDefault::Sized { anchor: krate }, + false, ); } }, @@ -614,6 +627,7 @@ impl<'db> HirDisplay<'db> for TypeParam { Either::Left(ty), &predicates, default_sized, + false, )?; } Ok(()) @@ -646,7 +660,7 @@ fn write_generic_params_or_args<'db>( f: &mut HirFormatter<'_, 'db>, include_defaults: bool, ) -> Result { - let (params, store) = f.db.generic_params_and_store(def); + let (params, store) = GenericParams::with_store(f.db, def); if params.iter_lt().next().is_none() && params.iter_type_or_consts().all(|it| it.1.const_param().is_none()) && params @@ -682,17 +696,17 @@ fn write_generic_params_or_args<'db>( write!(f, "{}", name.display(f.db, f.edition()))?; if include_defaults && let Some(default) = &ty.default { f.write_str(" = ")?; - default.hir_fmt(f, &store)?; + default.hir_fmt(f, store)?; } } TypeOrConstParamData::ConstParamData(c) => { delim(f)?; write!(f, "const {}: ", name.display(f.db, f.edition()))?; - c.ty.hir_fmt(f, &store)?; + c.ty.hir_fmt(f, store)?; if include_defaults && let Some(default) = &c.default { f.write_str(" = ")?; - default.hir_fmt(f, &store)?; + default.hir_fmt(f, store)?; } } } @@ -704,13 +718,13 @@ fn write_generic_params_or_args<'db>( } fn write_where_clause<'db>(def: GenericDefId, f: &mut HirFormatter<'_, 'db>) -> Result<bool> { - let (params, store) = f.db.generic_params_and_store(def); - if !has_disaplayable_predicates(f.db, ¶ms, &store) { + let (params, store) = GenericParams::with_store(f.db, def); + if !has_disaplayable_predicates(f.db, params, store) { return Ok(false); } f.write_str("\nwhere")?; - write_where_predicates(¶ms, &store, f)?; + write_where_predicates(params, store, f)?; Ok(true) } @@ -725,7 +739,7 @@ fn has_disaplayable_predicates( pred, WherePredicate::TypeBound { target, .. } if matches!(store[*target], - TypeRef::TypeParam(id) if db.generic_params(id.parent())[id.local_id()].name().is_none() + TypeRef::TypeParam(id) if GenericParams::of(db,id.parent())[id.local_id()].name().is_none() ) ) }) @@ -741,7 +755,7 @@ fn write_where_predicates<'db>( // unnamed type targets are displayed inline with the argument itself, e.g. `f: impl Y`. let is_unnamed_type_target = |target: TypeRefId| { matches!(store[target], - TypeRef::TypeParam(id) if f.db.generic_params(id.parent())[id.local_id()].name().is_none() + TypeRef::TypeParam(id) if GenericParams::of(f.db,id.parent())[id.local_id()].name().is_none() ) }; @@ -805,7 +819,7 @@ impl<'db> HirDisplay<'db> for Const { module = module.nearest_non_block_module(db); } write_visibility(module.id, self.visibility(db), f)?; - let data = db.const_signature(self.id); + let data = ConstSignature::of(db, self.id); f.write_str("const ")?; match &data.name { Some(name) => write!(f, "{}: ", name.display(f.db, f.edition()))?, @@ -819,7 +833,7 @@ impl<'db> HirDisplay<'db> for Const { impl<'db> HirDisplay<'db> for Static { fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result { write_visibility(self.module(f.db).id, self.visibility(f.db), f)?; - let data = f.db.static_signature(self.id); + let data = StaticSignature::of(f.db, self.id); f.write_str("static ")?; if data.flags.contains(StaticFlags::MUTABLE) { f.write_str("mut ")?; @@ -878,7 +892,7 @@ impl<'db> HirDisplay<'db> for Trait { fn write_trait_header<'db>(trait_: Trait, f: &mut HirFormatter<'_, 'db>) -> Result { write_visibility(trait_.module(f.db).id, trait_.visibility(f.db), f)?; - let data = f.db.trait_signature(trait_.id); + let data = TraitSignature::of(f.db, trait_.id); if data.flags.contains(TraitFlags::UNSAFE) { f.write_str("unsafe ")?; } @@ -893,7 +907,7 @@ fn write_trait_header<'db>(trait_: Trait, f: &mut HirFormatter<'_, 'db>) -> Resu impl<'db> HirDisplay<'db> for TypeAlias { fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result { write_visibility(self.module(f.db).id, self.visibility(f.db), f)?; - let data = f.db.type_alias_signature(self.id); + let data = TypeAliasSignature::of(f.db, self.id); write!(f, "type {}", data.name.display(f.db, f.edition()))?; let def_id = GenericDefId::TypeAliasId(self.id); write_generic_params(def_id, f)?; diff --git a/crates/hir/src/from_id.rs b/crates/hir/src/from_id.rs index fc20f4b46b..0a48be5473 100644 --- a/crates/hir/src/from_id.rs +++ b/crates/hir/src/from_id.rs @@ -4,15 +4,15 @@ //! are splitting the hir. use hir_def::{ - AdtId, AssocItemId, BuiltinDeriveImplId, DefWithBodyId, EnumVariantId, FieldId, GenericDefId, - GenericParamId, ModuleDefId, VariantId, + AdtId, AssocItemId, BuiltinDeriveImplId, DefWithBodyId, EnumVariantId, ExpressionStoreOwnerId, + FieldId, FunctionId, GenericDefId, GenericParamId, ImplId, ModuleDefId, VariantId, hir::{BindingId, LabelId}, }; use hir_ty::next_solver::AnyImplId; use crate::{ - Adt, AnyFunctionId, AssocItem, BuiltinType, DefWithBody, Field, GenericDef, GenericParam, - ItemInNs, Label, Local, ModuleDef, Variant, VariantDef, + Adt, AnyFunctionId, AssocItem, BuiltinType, DefWithBody, EnumVariant, ExpressionStoreOwner, + Field, Function, GenericDef, GenericParam, Impl, ItemInNs, Label, Local, ModuleDef, Variant, }; macro_rules! from_id { @@ -71,6 +71,15 @@ impl From<Adt> for AdtId { } } +impl From<VariantId> for Variant { + fn from(v: VariantId) -> Self { + match v { + VariantId::EnumVariantId(it) => Variant::EnumVariant(it.into()), + VariantId::StructId(it) => Variant::Struct(it.into()), + VariantId::UnionId(it) => Variant::Union(it.into()), + } + } +} impl From<GenericParamId> for GenericParam { fn from(id: GenericParamId) -> Self { match id { @@ -91,14 +100,14 @@ impl From<GenericParam> for GenericParamId { } } -impl From<EnumVariantId> for Variant { +impl From<EnumVariantId> for EnumVariant { fn from(id: EnumVariantId) -> Self { - Variant { id } + EnumVariant { id } } } -impl From<Variant> for EnumVariantId { - fn from(def: Variant) -> Self { +impl From<EnumVariant> for EnumVariantId { + fn from(def: EnumVariant) -> Self { def.id } } @@ -109,7 +118,7 @@ impl From<ModuleDefId> for ModuleDef { ModuleDefId::ModuleId(it) => ModuleDef::Module(it.into()), ModuleDefId::FunctionId(it) => ModuleDef::Function(it.into()), ModuleDefId::AdtId(it) => ModuleDef::Adt(it.into()), - ModuleDefId::EnumVariantId(it) => ModuleDef::Variant(it.into()), + ModuleDefId::EnumVariantId(it) => ModuleDef::EnumVariant(it.into()), ModuleDefId::ConstId(it) => ModuleDef::Const(it.into()), ModuleDefId::StaticId(it) => ModuleDef::Static(it.into()), ModuleDefId::TraitId(it) => ModuleDef::Trait(it.into()), @@ -130,7 +139,7 @@ impl TryFrom<ModuleDef> for ModuleDefId { AnyFunctionId::BuiltinDeriveImplMethod { .. } => return Err(()), }, ModuleDef::Adt(it) => ModuleDefId::AdtId(it.into()), - ModuleDef::Variant(it) => ModuleDefId::EnumVariantId(it.into()), + ModuleDef::EnumVariant(it) => ModuleDefId::EnumVariantId(it.into()), ModuleDef::Const(it) => ModuleDefId::ConstId(it.into()), ModuleDef::Static(it) => ModuleDefId::StaticId(it.into()), ModuleDef::Trait(it) => ModuleDefId::TraitId(it.into()), @@ -151,7 +160,7 @@ impl TryFrom<DefWithBody> for DefWithBodyId { }, DefWithBody::Static(it) => DefWithBodyId::StaticId(it.id), DefWithBody::Const(it) => DefWithBodyId::ConstId(it.id), - DefWithBody::Variant(it) => DefWithBodyId::VariantId(it.into()), + DefWithBody::EnumVariant(it) => DefWithBodyId::VariantId(it.into()), }) } } @@ -162,7 +171,7 @@ impl From<DefWithBodyId> for DefWithBody { DefWithBodyId::FunctionId(it) => DefWithBody::Function(it.into()), DefWithBodyId::StaticId(it) => DefWithBody::Static(it.into()), DefWithBodyId::ConstId(it) => DefWithBody::Const(it.into()), - DefWithBodyId::VariantId(it) => DefWithBody::Variant(it.into()), + DefWithBodyId::VariantId(it) => DefWithBody::EnumVariant(it.into()), } } } @@ -209,22 +218,12 @@ impl From<Adt> for GenericDefId { } } -impl From<VariantId> for VariantDef { - fn from(def: VariantId) -> Self { - match def { - VariantId::StructId(it) => VariantDef::Struct(it.into()), - VariantId::EnumVariantId(it) => VariantDef::Variant(it.into()), - VariantId::UnionId(it) => VariantDef::Union(it.into()), - } - } -} - -impl From<VariantDef> for VariantId { - fn from(def: VariantDef) -> Self { +impl From<Variant> for VariantId { + fn from(def: Variant) -> Self { match def { - VariantDef::Struct(it) => VariantId::StructId(it.id), - VariantDef::Variant(it) => VariantId::EnumVariantId(it.into()), - VariantDef::Union(it) => VariantId::UnionId(it.id), + Variant::Struct(it) => VariantId::StructId(it.id), + Variant::EnumVariant(it) => VariantId::EnumVariantId(it.into()), + Variant::Union(it) => VariantId::UnionId(it.id), } } } @@ -255,14 +254,19 @@ impl TryFrom<AssocItem> for GenericDefId { } } +impl From<(ExpressionStoreOwnerId, BindingId)> for Local { + fn from((parent, binding_id): (ExpressionStoreOwnerId, BindingId)) -> Self { + Local { parent, binding_id } + } +} impl From<(DefWithBodyId, BindingId)> for Local { fn from((parent, binding_id): (DefWithBodyId, BindingId)) -> Self { - Local { parent, binding_id } + Local { parent: parent.into(), binding_id } } } -impl From<(DefWithBodyId, LabelId)> for Label { - fn from((parent, label_id): (DefWithBodyId, LabelId)) -> Self { +impl From<(ExpressionStoreOwnerId, LabelId)> for Label { + fn from((parent, label_id): (ExpressionStoreOwnerId, LabelId)) -> Self { Label { parent, label_id } } } @@ -317,3 +321,43 @@ impl From<hir_def::FunctionId> for crate::Function { crate::Function { id: AnyFunctionId::FunctionId(value) } } } + +impl TryFrom<ExpressionStoreOwner> for ExpressionStoreOwnerId { + type Error = (); + + fn try_from(v: ExpressionStoreOwner) -> Result<Self, Self::Error> { + match v { + ExpressionStoreOwner::Signature(generic_def_id) => { + Ok(Self::Signature(generic_def_id.try_into()?)) + } + ExpressionStoreOwner::Body(def_with_body_id) => { + Ok(Self::Body(def_with_body_id.try_into()?)) + } + ExpressionStoreOwner::VariantFields(variant_id) => { + Ok(Self::VariantFields(variant_id.into())) + } + } + } +} + +impl TryFrom<Function> for FunctionId { + type Error = (); + + fn try_from(v: Function) -> Result<Self, Self::Error> { + match v.id { + AnyFunctionId::FunctionId(id) => Ok(id), + _ => Err(()), + } + } +} + +impl TryFrom<Impl> for ImplId { + type Error = (); + + fn try_from(v: Impl) -> Result<Self, Self::Error> { + match v.id { + AnyImplId::ImplId(id) => Ok(id), + _ => Err(()), + } + } +} diff --git a/crates/hir/src/has_source.rs b/crates/hir/src/has_source.rs index e032a16989..f9badc0b79 100644 --- a/crates/hir/src/has_source.rs +++ b/crates/hir/src/has_source.rs @@ -3,6 +3,7 @@ use either::Either; use hir_def::{ CallableDefId, Lookup, MacroId, VariantId, + expr_store::ExpressionStore, nameres::{ModuleOrigin, ModuleSource}, src::{HasChildSource, HasSource as _}, }; @@ -12,9 +13,9 @@ use syntax::{AstNode, ast}; use tt::TextRange; use crate::{ - Adt, AnyFunctionId, Callee, Const, Enum, ExternCrateDecl, Field, FieldSource, Function, Impl, - InlineAsmOperand, Label, LifetimeParam, LocalSource, Macro, Module, Param, SelfParam, Static, - Struct, Trait, TypeAlias, TypeOrConstParam, Union, Variant, VariantDef, db::HirDatabase, + Adt, AnyFunctionId, Callee, Const, Enum, EnumVariant, ExternCrateDecl, Field, FieldSource, + Function, Impl, InlineAsmOperand, Label, LifetimeParam, LocalSource, Macro, Module, Param, + SelfParam, Static, Struct, Trait, TypeAlias, TypeOrConstParam, Union, Variant, db::HirDatabase, }; pub trait HasSource: Sized { @@ -123,13 +124,13 @@ impl HasSource for Adt { } } } -impl HasSource for VariantDef { +impl HasSource for Variant { type Ast = ast::VariantDef; fn source(self, db: &dyn HirDatabase) -> Option<InFile<Self::Ast>> { match self { - VariantDef::Struct(s) => Some(s.source(db)?.map(ast::VariantDef::Struct)), - VariantDef::Union(u) => Some(u.source(db)?.map(ast::VariantDef::Union)), - VariantDef::Variant(v) => Some(v.source(db)?.map(ast::VariantDef::Variant)), + Variant::Struct(s) => Some(s.source(db)?.map(ast::VariantDef::Struct)), + Variant::Union(u) => Some(u.source(db)?.map(ast::VariantDef::Union)), + Variant::EnumVariant(v) => Some(v.source(db)?.map(ast::VariantDef::Variant)), } } } @@ -151,7 +152,7 @@ impl HasSource for Enum { Some(self.id.lookup(db).source(db)) } } -impl HasSource for Variant { +impl HasSource for EnumVariant { type Ast = ast::Variant; fn source(self, db: &dyn HirDatabase) -> Option<InFile<ast::Variant>> { Some(self.id.lookup(db).source(db)) @@ -293,7 +294,7 @@ impl HasSource for Param<'_> { } Callee::Closure(closure, _) => { let InternedClosure(owner, expr_id) = db.lookup_intern_closure(closure); - let (_, source_map) = db.body_with_source_map(owner); + let (_, source_map) = ExpressionStore::with_source_map(db, owner); let ast @ InFile { file_id, value } = source_map.expr_syntax(expr_id).ok()?; let root = db.parse_or_expand(file_id); match value.to_node(&root) { @@ -327,8 +328,7 @@ impl HasSource for Label { type Ast = ast::Label; fn source(self, db: &dyn HirDatabase) -> Option<InFile<Self::Ast>> { - let (_body, source_map) = db.body_with_source_map(self.parent); - let src = source_map.label_syntax(self.label_id); + let src = ExpressionStore::with_source_map(db, self.parent).1.label_syntax(self.label_id); let root = src.file_syntax(db); src.map(|ast| ast.to_node(&root).left()).transpose() } @@ -345,7 +345,7 @@ impl HasSource for ExternCrateDecl { impl HasSource for InlineAsmOperand { type Ast = ast::AsmOperandNamed; fn source(self, db: &dyn HirDatabase) -> Option<InFile<Self::Ast>> { - let source_map = db.body_with_source_map(self.owner).1; + let (_, source_map) = ExpressionStore::with_source_map(db, self.owner); if let Ok(src) = source_map.expr_syntax(self.expr) { let root = src.file_syntax(db); return src diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 252d71fb80..bc5e164830 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -49,16 +49,16 @@ use base_db::{CrateDisplayName, CrateOrigin, LangCrateOrigin}; use either::Either; use hir_def::{ AdtId, AssocItemId, AssocItemLoc, BuiltinDeriveImplId, CallableDefId, ConstId, ConstParamId, - DefWithBodyId, EnumId, EnumVariantId, ExternBlockId, ExternCrateId, FunctionId, GenericDefId, - HasModule, ImplId, ItemContainerId, LifetimeParamId, LocalFieldId, Lookup, MacroExpander, - MacroId, StaticId, StructId, SyntheticSyntax, TupleId, TypeAliasId, TypeOrConstParamId, - TypeParamId, UnionId, + DefWithBodyId, EnumId, EnumVariantId, ExpressionStoreOwnerId, ExternBlockId, ExternCrateId, + FunctionId, GenericDefId, HasModule, ImplId, ItemContainerId, LifetimeParamId, LocalFieldId, + Lookup, MacroExpander, MacroId, StaticId, StructId, SyntheticSyntax, TupleId, TypeAliasId, + TypeOrConstParamId, TypeParamId, UnionId, attrs::AttrFlags, builtin_derive::BuiltinDeriveImplMethod, - expr_store::{ExpressionStoreDiagnostics, ExpressionStoreSourceMap}, + expr_store::{ExpressionStore, ExpressionStoreDiagnostics, ExpressionStoreSourceMap}, hir::{ BindingAnnotation, BindingId, Expr, ExprId, ExprOrPatId, LabelId, Pat, - generics::{LifetimeParamData, TypeOrConstParamData, TypeParamProvenance}, + generics::{GenericParams, LifetimeParamData, TypeOrConstParamData, TypeParamProvenance}, }, item_tree::ImportAlias, lang_item::LangItemTarget, @@ -69,7 +69,11 @@ use hir_def::{ }, per_ns::PerNs, resolver::{HasResolver, Resolver}, - signatures::{EnumSignature, ImplFlags, StaticFlags, StructFlags, TraitFlags, VariantFields}, + signatures::{ + ConstSignature, EnumSignature, FunctionSignature, ImplFlags, ImplSignature, StaticFlags, + StaticSignature, StructFlags, StructSignature, TraitFlags, TraitSignature, + TypeAliasSignature, UnionSignature, VariantFields, + }, src::HasSource as _, visibility::visibility_from_ast, }; @@ -141,6 +145,7 @@ pub use { Complete, FindPathConfig, attrs::{Docs, IsInnerDoc}, + expr_store::Body, find_path::PrefixKind, import_map, lang_item::{LangItemEnum as LangItem, crate_lang_items}, @@ -351,8 +356,7 @@ pub enum ModuleDef { Function(Function), Adt(Adt), // Can't be directly declared, but can be imported. - // FIXME: Rename to `EnumVariant` - Variant(Variant), + EnumVariant(EnumVariant), Const(Const), Static(Static), Trait(Trait), @@ -364,7 +368,7 @@ impl_from!( Module, Function, Adt(Struct, Enum, Union), - Variant, + EnumVariant, Const, Static, Trait, @@ -374,12 +378,12 @@ impl_from!( for ModuleDef ); -impl From<VariantDef> for ModuleDef { - fn from(var: VariantDef) -> Self { +impl From<Variant> for ModuleDef { + fn from(var: Variant) -> Self { match var { - VariantDef::Struct(t) => Adt::from(t).into(), - VariantDef::Union(t) => Adt::from(t).into(), - VariantDef::Variant(t) => t.into(), + Variant::Struct(t) => Adt::from(t).into(), + Variant::Union(t) => Adt::from(t).into(), + Variant::EnumVariant(t) => t.into(), } } } @@ -390,7 +394,7 @@ impl ModuleDef { ModuleDef::Module(it) => it.parent(db), ModuleDef::Function(it) => Some(it.module(db)), ModuleDef::Adt(it) => Some(it.module(db)), - ModuleDef::Variant(it) => Some(it.module(db)), + ModuleDef::EnumVariant(it) => Some(it.module(db)), ModuleDef::Const(it) => Some(it.module(db)), ModuleDef::Static(it) => Some(it.module(db)), ModuleDef::Trait(it) => Some(it.module(db)), @@ -423,7 +427,7 @@ impl ModuleDef { ModuleDef::Adt(it) => it.name(db), ModuleDef::Trait(it) => it.name(db), ModuleDef::Function(it) => it.name(db), - ModuleDef::Variant(it) => it.name(db), + ModuleDef::EnumVariant(it) => it.name(db), ModuleDef::TypeAlias(it) => it.name(db), ModuleDef::Static(it) => it.name(db), ModuleDef::Macro(it) => it.name(db), @@ -452,7 +456,7 @@ impl ModuleDef { ModuleDef::Module(it) => it.id.into(), ModuleDef::Const(it) => it.id.into(), ModuleDef::Static(it) => it.id.into(), - ModuleDef::Variant(it) => it.id.into(), + ModuleDef::EnumVariant(it) => it.id.into(), ModuleDef::BuiltinType(_) | ModuleDef::Macro(_) => return Vec::new(), }; @@ -481,7 +485,7 @@ impl ModuleDef { ModuleDef::Function(it) => Some(it.into()), ModuleDef::Const(it) => Some(it.into()), ModuleDef::Static(it) => Some(it.into()), - ModuleDef::Variant(it) => Some(it.into()), + ModuleDef::EnumVariant(it) => Some(it.into()), ModuleDef::Module(_) | ModuleDef::Adt(_) @@ -500,7 +504,7 @@ impl ModuleDef { ModuleDef::Trait(it) => Some(it.into()), ModuleDef::TypeAlias(it) => Some(it.into()), ModuleDef::Module(_) - | ModuleDef::Variant(_) + | ModuleDef::EnumVariant(_) | ModuleDef::Static(_) | ModuleDef::Const(_) | ModuleDef::BuiltinType(_) @@ -508,12 +512,27 @@ impl ModuleDef { } } + pub fn as_generic_def(self) -> Option<GenericDef> { + match self { + ModuleDef::Function(it) => Some(it.into()), + ModuleDef::Adt(it) => Some(it.into()), + ModuleDef::Trait(it) => Some(it.into()), + ModuleDef::TypeAlias(it) => Some(it.into()), + ModuleDef::Static(it) => Some(it.into()), + ModuleDef::Const(it) => Some(it.into()), + ModuleDef::EnumVariant(_) + | ModuleDef::Module(_) + | ModuleDef::BuiltinType(_) + | ModuleDef::Macro(_) => None, + } + } + pub fn attrs(&self, db: &dyn HirDatabase) -> Option<AttrsWithOwner> { Some(match self { ModuleDef::Module(it) => it.attrs(db), ModuleDef::Function(it) => HasAttrs::attrs(*it, db), ModuleDef::Adt(it) => it.attrs(db), - ModuleDef::Variant(it) => it.attrs(db), + ModuleDef::EnumVariant(it) => it.attrs(db), ModuleDef::Const(it) => it.attrs(db), ModuleDef::Static(it) => it.attrs(db), ModuleDef::Trait(it) => it.attrs(db), @@ -543,7 +562,7 @@ impl HasVisibility for ModuleDef { ModuleDef::Static(it) => it.visibility(db), ModuleDef::Trait(it) => it.visibility(db), ModuleDef::TypeAlias(it) => it.visibility(db), - ModuleDef::Variant(it) => it.visibility(db), + ModuleDef::EnumVariant(it) => it.visibility(db), ModuleDef::Macro(it) => it.visibility(db), ModuleDef::BuiltinType(_) => Visibility::Public, } @@ -714,8 +733,8 @@ impl Module { ModuleDef::Adt(adt) => { match adt { Adt::Struct(s) => { - let source_map = db.struct_signature_with_source_map(s.id).1; - expr_store_diagnostics(db, acc, &source_map); + let source_map = &StructSignature::with_source_map(db, s.id).1; + expr_store_diagnostics(db, acc, source_map); let source_map = &s.id.fields_with_source_map(db).1; expr_store_diagnostics(db, acc, source_map); push_ty_diagnostics( @@ -726,8 +745,8 @@ impl Module { ); } Adt::Union(u) => { - let source_map = db.union_signature_with_source_map(u.id).1; - expr_store_diagnostics(db, acc, &source_map); + let source_map = &UnionSignature::with_source_map(db, u.id).1; + expr_store_diagnostics(db, acc, source_map); let source_map = &u.id.fields_with_source_map(db).1; expr_store_diagnostics(db, acc, source_map); push_ty_diagnostics( @@ -738,8 +757,8 @@ impl Module { ); } Adt::Enum(e) => { - let source_map = db.enum_signature_with_source_map(e.id).1; - expr_store_diagnostics(db, acc, &source_map); + let source_map = &EnumSignature::with_source_map(db, e.id).1; + expr_store_diagnostics(db, acc, source_map); let (variants, diagnostics) = e.id.enum_variants_with_diagnostics(db); let file = e.id.lookup(db).id.file_id; let ast_id_map = db.ast_id_map(file); @@ -774,13 +793,13 @@ impl Module { } ModuleDef::Macro(m) => emit_macro_def_diagnostics(db, acc, m), ModuleDef::TypeAlias(type_alias) => { - let source_map = db.type_alias_signature_with_source_map(type_alias.id).1; - expr_store_diagnostics(db, acc, &source_map); + let source_map = &TypeAliasSignature::with_source_map(db, type_alias.id).1; + expr_store_diagnostics(db, acc, source_map); push_ty_diagnostics( db, acc, db.type_for_type_alias_with_diagnostics(type_alias.id).1, - &source_map, + source_map, ); acc.extend(def.diagnostics(db, style_lints)); } @@ -800,8 +819,8 @@ impl Module { continue; }; let loc = impl_id.lookup(db); - let (impl_signature, source_map) = db.impl_signature_with_source_map(impl_id); - expr_store_diagnostics(db, acc, &source_map); + let (impl_signature, source_map) = ImplSignature::with_source_map(db, impl_id); + expr_store_diagnostics(db, acc, source_map); let file_id = loc.id.file_id; if file_id.macro_file().is_some_and(|it| it.kind(db) == MacroKind::DeriveBuiltIn) { @@ -873,9 +892,9 @@ impl Module { if let (false, Some(trait_)) = (impl_is_negative, trait_) { let items = &trait_.id.trait_items(db).items; let required_items = items.iter().filter(|&(_, assoc)| match *assoc { - AssocItemId::FunctionId(it) => !db.function_signature(it).has_body(), - AssocItemId::ConstId(id) => !db.const_signature(id).has_body(), - AssocItemId::TypeAliasId(it) => db.type_alias_signature(it).ty.is_none(), + AssocItemId::FunctionId(it) => !FunctionSignature::of(db, it).has_body(), + AssocItemId::ConstId(id) => !ConstSignature::of(db, id).has_body(), + AssocItemId::TypeAliasId(it) => TypeAliasSignature::of(db, it).ty.is_none(), }); impl_assoc_items_scratch.extend(impl_id.impl_items(db).items.iter().cloned()); @@ -913,7 +932,7 @@ impl Module { let self_ty = structurally_normalize_ty( &infcx, self_ty, - db.trait_environment(impl_id.into()), + db.trait_environment(GenericDefId::from(impl_id).into()), ); let self_ty_is_guaranteed_unsized = matches!( self_ty.kind(), @@ -968,7 +987,7 @@ impl Module { continue; } - if db.function_signature(*fn_).is_default() { + if FunctionSignature::of(db, *fn_).is_default() { return false; } } @@ -992,12 +1011,12 @@ impl Module { impl_assoc_items_scratch.clear(); } - push_ty_diagnostics(db, acc, db.impl_self_ty_with_diagnostics(impl_id).1, &source_map); + push_ty_diagnostics(db, acc, db.impl_self_ty_with_diagnostics(impl_id).1, source_map); push_ty_diagnostics( db, acc, db.impl_trait_with_diagnostics(impl_id).and_then(|it| it.1), - &source_map, + source_map, ); for &(_, item) in impl_id.impl_items(db).items.iter() { @@ -1091,7 +1110,7 @@ fn macro_call_diagnostics<'db>( let file_id = loc.kind.file_id(); let mut range = precise_macro_call_location(&loc.kind, db, loc.krate); let RenderedExpandError { message, error, kind } = err.render_to_string(db); - if Some(err.span().anchor.file_id) == file_id.file_id().map(|it| it.editioned_file_id(db)) { + if Some(err.span().anchor.file_id) == file_id.file_id().map(|it| it.span_file_id(db)) { range.value = err.span().range + db.ast_id_map(file_id).get_erased(err.span().anchor.ast_id).text_range().start(); } @@ -1280,7 +1299,7 @@ impl HasVisibility for Module { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Field { - pub(crate) parent: VariantDef, + pub(crate) parent: Variant, pub(crate) id: LocalFieldId, } @@ -1304,7 +1323,7 @@ impl<'db> InstantiatedField<'db> { #[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)] pub struct TupleField { - pub owner: DefWithBodyId, + pub owner: ExpressionStoreOwnerId, pub tuple: TupleId, pub index: u32, } @@ -1316,7 +1335,7 @@ impl TupleField { pub fn ty<'db>(&self, db: &'db dyn HirDatabase) -> Type<'db> { let interner = DbInterner::new_no_crate(db); - let ty = InferenceResult::for_body(db, self.owner) + let ty = InferenceResult::of(db, self.owner) .tuple_field_access_type(self.tuple) .as_slice() .get(self.index as usize) @@ -1386,9 +1405,9 @@ impl Field { ) -> Type<'db> { let var_id = self.parent.into(); let def_id: AdtId = match self.parent { - VariantDef::Struct(it) => it.id.into(), - VariantDef::Union(it) => it.id.into(), - VariantDef::Variant(it) => it.parent_enum(db).id.into(), + Variant::Struct(it) => it.id.into(), + Variant::Union(it) => it.id.into(), + Variant::EnumVariant(it) => it.parent_enum(db).id.into(), }; let interner = DbInterner::new_no_crate(db); let args = generic_args_from_tys(interner, def_id.into(), generics.map(|ty| ty.ty)); @@ -1414,7 +1433,7 @@ impl Field { .map(|layout| Layout(layout, db.target_data_layout(self.krate(db).into()).unwrap())) } - pub fn parent_def(&self, _db: &dyn HirDatabase) -> VariantDef { + pub fn parent_def(&self, _db: &dyn HirDatabase) -> Variant { self.parent } } @@ -1440,7 +1459,7 @@ impl Struct { } pub fn name(self, db: &dyn HirDatabase) -> Name { - db.struct_signature(self.id).name.clone() + StructSignature::of(db, self.id).name.clone() } pub fn fields(self, db: &dyn HirDatabase) -> Vec<Field> { @@ -1533,7 +1552,7 @@ pub struct Union { impl Union { pub fn name(self, db: &dyn HirDatabase) -> Name { - db.union_signature(self.id).name.clone() + UnionSignature::of(db, self.id).name.clone() } pub fn module(self, db: &dyn HirDatabase) -> Module { @@ -1592,11 +1611,11 @@ impl Enum { } pub fn name(self, db: &dyn HirDatabase) -> Name { - db.enum_signature(self.id).name.clone() + EnumSignature::of(db, self.id).name.clone() } - pub fn variants(self, db: &dyn HirDatabase) -> Vec<Variant> { - self.id.enum_variants(db).variants.iter().map(|&(id, _, _)| Variant { id }).collect() + pub fn variants(self, db: &dyn HirDatabase) -> Vec<EnumVariant> { + self.id.enum_variants(db).variants.iter().map(|&(id, _, _)| EnumVariant { id }).collect() } pub fn num_variants(self, db: &dyn HirDatabase) -> usize { @@ -1688,19 +1707,18 @@ impl<'db> InstantiatedEnum<'db> { } } -impl From<&Variant> for DefWithBodyId { - fn from(&v: &Variant) -> Self { +impl From<&EnumVariant> for DefWithBodyId { + fn from(&v: &EnumVariant) -> Self { DefWithBodyId::VariantId(v.into()) } } -// FIXME: Rename to `EnumVariant` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Variant { +pub struct EnumVariant { pub(crate) id: EnumVariantId, } -impl Variant { +impl EnumVariant { pub fn module(self, db: &dyn HirDatabase) -> Module { Module { id: self.id.module(db) } } @@ -1737,7 +1755,7 @@ impl Variant { } pub fn value(self, db: &dyn HirDatabase) -> Option<ast::Expr> { - self.source(db)?.value.expr() + self.source(db)?.value.const_arg()?.expr() } pub fn eval(self, db: &dyn HirDatabase) -> Result<i128, ConstEvalError> { @@ -1774,7 +1792,7 @@ impl Variant { // FIXME: Rename to `EnumVariant` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct InstantiatedVariant<'db> { - pub(crate) inner: Variant, + pub(crate) inner: EnumVariant, pub(crate) args: GenericArgs<'db>, } @@ -1805,7 +1823,7 @@ pub enum StructKind { } /// Variants inherit visibility from the parent enum. -impl HasVisibility for Variant { +impl HasVisibility for EnumVariant { fn visibility(&self, db: &dyn HirDatabase) -> Visibility { self.parent_enum(db).visibility(db) } @@ -1914,35 +1932,78 @@ impl HasVisibility for Adt { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub enum VariantDef { +pub enum Variant { Struct(Struct), Union(Union), - Variant(Variant), + EnumVariant(EnumVariant), } -impl_from!(Struct, Union, Variant for VariantDef); +impl_from!(Struct, Union, EnumVariant for Variant); -impl VariantDef { +impl Variant { pub fn fields(self, db: &dyn HirDatabase) -> Vec<Field> { match self { - VariantDef::Struct(it) => it.fields(db), - VariantDef::Union(it) => it.fields(db), - VariantDef::Variant(it) => it.fields(db), + Variant::Struct(it) => it.fields(db), + Variant::Union(it) => it.fields(db), + Variant::EnumVariant(it) => it.fields(db), } } pub fn module(self, db: &dyn HirDatabase) -> Module { match self { - VariantDef::Struct(it) => it.module(db), - VariantDef::Union(it) => it.module(db), - VariantDef::Variant(it) => it.module(db), + Variant::Struct(it) => it.module(db), + Variant::Union(it) => it.module(db), + Variant::EnumVariant(it) => it.module(db), } } pub fn name(&self, db: &dyn HirDatabase) -> Name { match self { - VariantDef::Struct(s) => (*s).name(db), - VariantDef::Union(u) => (*u).name(db), - VariantDef::Variant(e) => (*e).name(db), + Variant::Struct(s) => (*s).name(db), + Variant::Union(u) => (*u).name(db), + Variant::EnumVariant(e) => (*e).name(db), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpressionStoreOwner { + Body(DefWithBody), + Signature(GenericDef), + VariantFields(Variant), +} + +impl From<GenericDef> for ExpressionStoreOwner { + fn from(v: GenericDef) -> Self { + Self::Signature(v) + } +} + +impl From<DefWithBody> for ExpressionStoreOwner { + fn from(v: DefWithBody) -> Self { + Self::Body(v) + } +} + +impl From<ExpressionStoreOwnerId> for ExpressionStoreOwner { + fn from(v: ExpressionStoreOwnerId) -> Self { + match v { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + Self::Signature(generic_def_id.into()) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => Self::Body(def_with_body_id.into()), + ExpressionStoreOwnerId::VariantFields(variant_id) => { + Self::VariantFields(variant_id.into()) + } + } + } +} + +impl ExpressionStoreOwner { + pub fn module(self, db: &dyn HirDatabase) -> Module { + match self { + Self::Body(body) => body.module(db), + Self::Signature(generic_def) => generic_def.module(db), + Self::VariantFields(variant) => variant.module(db), } } } @@ -1953,9 +2014,9 @@ pub enum DefWithBody { Function(Function), Static(Static), Const(Const), - Variant(Variant), + EnumVariant(EnumVariant), } -impl_from!(Function, Const, Static, Variant for DefWithBody); +impl_from!(Function, Const, Static, EnumVariant for DefWithBody); impl DefWithBody { pub fn module(self, db: &dyn HirDatabase) -> Module { @@ -1963,7 +2024,7 @@ impl DefWithBody { DefWithBody::Const(c) => c.module(db), DefWithBody::Function(f) => f.module(db), DefWithBody::Static(s) => s.module(db), - DefWithBody::Variant(v) => v.module(db), + DefWithBody::EnumVariant(v) => v.module(db), } } @@ -1972,7 +2033,7 @@ impl DefWithBody { DefWithBody::Function(f) => Some(f.name(db)), DefWithBody::Static(s) => Some(s.name(db)), DefWithBody::Const(c) => c.name(db), - DefWithBody::Variant(v) => Some(v.name(db)), + DefWithBody::EnumVariant(v) => Some(v.name(db)), } } @@ -1982,7 +2043,7 @@ impl DefWithBody { DefWithBody::Function(it) => it.ret_type(db), DefWithBody::Static(it) => it.ty(db), DefWithBody::Const(it) => it.ty(db), - DefWithBody::Variant(it) => it.parent_enum(db).variant_body_ty(db), + DefWithBody::EnumVariant(it) => it.parent_enum(db).variant_body_ty(db), } } @@ -1994,7 +2055,7 @@ impl DefWithBody { }, DefWithBody::Static(it) => it.id.into(), DefWithBody::Const(it) => it.id.into(), - DefWithBody::Variant(it) => it.into(), + DefWithBody::EnumVariant(it) => it.into(), }) } @@ -2003,7 +2064,7 @@ impl DefWithBody { let Some(id) = self.id() else { return String::new(); }; - let body = db.body(id); + let body = Body::of(db, id); body.pretty_print(db, id, Edition::CURRENT) } @@ -2030,17 +2091,17 @@ impl DefWithBody { }; let krate = self.module(db).id.krate(db); - let (body, source_map) = db.body_with_source_map(id); + let (body, source_map) = Body::with_source_map(db, id); let sig_source_map = match self { DefWithBody::Function(id) => match id.id { - AnyFunctionId::FunctionId(id) => db.function_signature_with_source_map(id).1, + AnyFunctionId::FunctionId(id) => &FunctionSignature::with_source_map(db, id).1, AnyFunctionId::BuiltinDeriveImplMethod { .. } => return, }, - DefWithBody::Static(id) => db.static_signature_with_source_map(id.into()).1, - DefWithBody::Const(id) => db.const_signature_with_source_map(id.into()).1, - DefWithBody::Variant(variant) => { + DefWithBody::Static(id) => &StaticSignature::with_source_map(db, id.into()).1, + DefWithBody::Const(id) => &ConstSignature::with_source_map(db, id.into()).1, + DefWithBody::EnumVariant(variant) => { let enum_id = variant.parent_enum(db).id; - db.enum_signature_with_source_map(enum_id).1 + &EnumSignature::with_source_map(db, enum_id).1 } }; @@ -2048,17 +2109,11 @@ impl DefWithBody { Module { id: def_map.root_module_id() }.diagnostics(db, acc, style_lints); } - expr_store_diagnostics(db, acc, &source_map); + expr_store_diagnostics(db, acc, source_map); - let infer = InferenceResult::for_body(db, id); + let infer = InferenceResult::of(db, id); for d in infer.diagnostics() { - acc.extend(AnyDiagnostic::inference_diagnostic( - db, - id, - d, - &source_map, - &sig_source_map, - )); + acc.extend(AnyDiagnostic::inference_diagnostic(db, id, d, source_map, sig_source_map)); } for (pat_or_expr, mismatch) in infer.type_mismatches() { @@ -2180,7 +2235,7 @@ impl DefWithBody { { need_mut = &mir::MutabilityReason::Not; } - let local = Local { parent: id, binding_id }; + let local = Local { parent: id.into(), binding_id }; let is_mut = body[binding_id].mode == BindingAnnotation::Mutable; match (need_mut, is_mut) { @@ -2237,13 +2292,46 @@ impl DefWithBody { } for diagnostic in BodyValidationDiagnostic::collect(db, id, style_lints) { - acc.extend(AnyDiagnostic::body_validation_diagnostic(db, diagnostic, &source_map)); + acc.extend(AnyDiagnostic::body_validation_diagnostic(db, diagnostic, source_map)); } for diag in hir_ty::diagnostics::incorrect_case(db, id.into()) { acc.push(diag.into()) } } + + /// Returns an iterator over the inferred types of all expressions in this body. + pub fn expression_types<'db>( + self, + db: &'db dyn HirDatabase, + ) -> impl Iterator<Item = Type<'db>> { + self.id().into_iter().flat_map(move |def_id| { + let infer = InferenceResult::of(db, def_id); + let resolver = def_id.resolver(db); + + infer.expression_types().map(move |(_, ty)| Type::new_with_resolver(db, &resolver, ty)) + }) + } + + /// Returns an iterator over the inferred types of all patterns in this body. + pub fn pattern_types<'db>(self, db: &'db dyn HirDatabase) -> impl Iterator<Item = Type<'db>> { + self.id().into_iter().flat_map(move |def_id| { + let infer = InferenceResult::of(db, def_id); + let resolver = def_id.resolver(db); + + infer.pattern_types().map(move |(_, ty)| Type::new_with_resolver(db, &resolver, ty)) + }) + } + + /// Returns an iterator over the inferred types of all bindings in this body. + pub fn binding_types<'db>(self, db: &'db dyn HirDatabase) -> impl Iterator<Item = Type<'db>> { + self.id().into_iter().flat_map(move |def_id| { + let infer = InferenceResult::of(db, def_id); + let resolver = def_id.resolver(db); + + infer.binding_types().map(move |(_, ty)| Type::new_with_resolver(db, &resolver, ty)) + }) + } } fn expr_store_diagnostics<'db>( @@ -2306,7 +2394,7 @@ impl Function { pub fn name(self, db: &dyn HirDatabase) -> Name { match self.id { - AnyFunctionId::FunctionId(id) => db.function_signature(id).name.clone(), + AnyFunctionId::FunctionId(id) => FunctionSignature::of(db, id).name.clone(), AnyFunctionId::BuiltinDeriveImplMethod { method, .. } => { Name::new_symbol_root(method.name()) } @@ -2508,7 +2596,7 @@ impl Function { pub fn has_self_param(self, db: &dyn HirDatabase) -> bool { match self.id { - AnyFunctionId::FunctionId(id) => db.function_signature(id).has_self_param(), + AnyFunctionId::FunctionId(id) => FunctionSignature::of(db, id).has_self_param(), AnyFunctionId::BuiltinDeriveImplMethod { method, .. } => match method { BuiltinDeriveImplMethod::clone | BuiltinDeriveImplMethod::fmt @@ -2543,7 +2631,7 @@ impl Function { pub fn num_params(self, db: &dyn HirDatabase) -> usize { match self.id { - AnyFunctionId::FunctionId(id) => db.function_signature(id).params.len(), + AnyFunctionId::FunctionId(id) => FunctionSignature::of(db, id).params.len(), AnyFunctionId::BuiltinDeriveImplMethod { .. } => { self.fn_sig(db).1.skip_binder().inputs().len() } @@ -2587,21 +2675,21 @@ impl Function { pub fn is_const(self, db: &dyn HirDatabase) -> bool { match self.id { - AnyFunctionId::FunctionId(id) => db.function_signature(id).is_const(), + AnyFunctionId::FunctionId(id) => FunctionSignature::of(db, id).is_const(), AnyFunctionId::BuiltinDeriveImplMethod { .. } => false, } } pub fn is_async(self, db: &dyn HirDatabase) -> bool { match self.id { - AnyFunctionId::FunctionId(id) => db.function_signature(id).is_async(), + AnyFunctionId::FunctionId(id) => FunctionSignature::of(db, id).is_async(), AnyFunctionId::BuiltinDeriveImplMethod { .. } => false, } } pub fn is_varargs(self, db: &dyn HirDatabase) -> bool { match self.id { - AnyFunctionId::FunctionId(id) => db.function_signature(id).is_varargs(), + AnyFunctionId::FunctionId(id) => FunctionSignature::of(db, id).is_varargs(), AnyFunctionId::BuiltinDeriveImplMethod { .. } => false, } } @@ -2654,7 +2742,7 @@ impl Function { AnyFunctionId::FunctionId(id) => { self.exported_main(db) || self.module(db).is_crate_root(db) - && db.function_signature(id).name == sym::main + && FunctionSignature::of(db, id).name == sym::main } AnyFunctionId::BuiltinDeriveImplMethod { .. } => false, } @@ -2731,7 +2819,7 @@ impl Function { /// This is false in the case of required (not provided) trait methods. pub fn has_body(self, db: &dyn HirDatabase) -> bool { match self.id { - AnyFunctionId::FunctionId(id) => db.function_signature(id).has_body(), + AnyFunctionId::FunctionId(id) => FunctionSignature::of(db, id).has_body(), AnyFunctionId::BuiltinDeriveImplMethod { .. } => true, } } @@ -2759,7 +2847,7 @@ impl Function { id.into(), GenericArgs::empty(interner).store(), ParamEnvAndCrate { - param_env: db.trait_environment(id.into()), + param_env: db.trait_environment(GenericDefId::from(id).into()), krate: id.module(db).krate(db), } .store(), @@ -2845,24 +2933,26 @@ impl<'db> Param<'db> { match self.func { Callee::Def(CallableDefId::FunctionId(it)) => { let parent = DefWithBodyId::FunctionId(it); - let body = db.body(parent); + let body = Body::of(db, parent); if let Some(self_param) = body.self_param.filter(|_| self.idx == 0) { - Some(Local { parent, binding_id: self_param }) + Some(Local { parent: parent.into(), binding_id: self_param }) } else if let Pat::Bind { id, .. } = &body[body.params[self.idx - body.self_param.is_some() as usize]] { - Some(Local { parent, binding_id: *id }) + Some(Local { parent: parent.into(), binding_id: *id }) } else { None } } Callee::Closure(closure, _) => { let c = db.lookup_intern_closure(closure); - let body = db.body(c.0); - if let Expr::Closure { args, .. } = &body[c.1] - && let Pat::Bind { id, .. } = &body[args[self.idx]] + let body_owner = c.0; + let store = ExpressionStore::of(db, c.0); + + if let Expr::Closure { args, .. } = &store[c.1] + && let Pat::Bind { id, .. } = &store[args[self.idx]] { - return Some(Local { parent: c.0, binding_id: *id }); + return Some(Local { parent: body_owner, binding_id: *id }); } None } @@ -2884,7 +2974,7 @@ impl SelfParam { pub fn access(self, db: &dyn HirDatabase) -> Access { match self.func.id { AnyFunctionId::FunctionId(id) => { - let func_data = db.function_signature(id); + let func_data = FunctionSignature::of(db, id); func_data .params .first() @@ -3013,7 +3103,7 @@ impl Const { } pub fn name(self, db: &dyn HirDatabase) -> Option<Name> { - db.const_signature(self.id).name.clone() + ConstSignature::of(db, self.id).name.clone() } pub fn value(self, db: &dyn HirDatabase) -> Option<ast::Expr> { @@ -3086,11 +3176,11 @@ impl Static { } pub fn name(self, db: &dyn HirDatabase) -> Name { - db.static_signature(self.id).name.clone() + StaticSignature::of(db, self.id).name.clone() } pub fn is_mut(self, db: &dyn HirDatabase) -> bool { - db.static_signature(self.id).flags.contains(StaticFlags::MUTABLE) + StaticSignature::of(db, self.id).flags.contains(StaticFlags::MUTABLE) } pub fn value(self, db: &dyn HirDatabase) -> Option<ast::Expr> { @@ -3146,7 +3236,7 @@ impl Trait { } pub fn name(self, db: &dyn HirDatabase) -> Name { - db.trait_signature(self.id).name.clone() + TraitSignature::of(db, self.id).name.clone() } pub fn direct_supertraits(self, db: &dyn HirDatabase) -> Vec<Trait> { @@ -3176,11 +3266,11 @@ impl Trait { } pub fn is_auto(self, db: &dyn HirDatabase) -> bool { - db.trait_signature(self.id).flags.contains(TraitFlags::AUTO) + TraitSignature::of(db, self.id).flags.contains(TraitFlags::AUTO) } pub fn is_unsafe(&self, db: &dyn HirDatabase) -> bool { - db.trait_signature(self.id).flags.contains(TraitFlags::UNSAFE) + TraitSignature::of(db, self.id).flags.contains(TraitFlags::UNSAFE) } pub fn type_or_const_param_count( @@ -3188,7 +3278,7 @@ impl Trait { db: &dyn HirDatabase, count_required_only: bool, ) -> usize { - db.generic_params(self.id.into()) + GenericParams::of(db,self.id.into()) .iter_type_or_consts() .filter(|(_, ty)| !matches!(ty, TypeOrConstParamData::TypeParamData(ty) if ty.provenance != TypeParamProvenance::TypeParamList)) .filter(|(_, ty)| !count_required_only || !ty.has_default()) @@ -3256,7 +3346,7 @@ impl TypeAlias { } pub fn name(self, db: &dyn HirDatabase) -> Name { - db.type_alias_signature(self.id).name.clone() + TypeAliasSignature::of(db, self.id).name.clone() } } @@ -3698,7 +3788,18 @@ impl AsAssocItem for DefWithBody { match self { DefWithBody::Function(it) => it.as_assoc_item(db), DefWithBody::Const(it) => it.as_assoc_item(db), - DefWithBody::Static(_) | DefWithBody::Variant(_) => None, + DefWithBody::Static(_) | DefWithBody::EnumVariant(_) => None, + } + } +} + +impl AsAssocItem for GenericDef { + fn as_assoc_item(self, db: &dyn HirDatabase) -> Option<AssocItem> { + match self { + GenericDef::Function(it) => it.as_assoc_item(db), + GenericDef::Const(it) => it.as_assoc_item(db), + GenericDef::TypeAlias(it) => it.as_assoc_item(db), + _ => None, } } } @@ -3885,7 +3986,7 @@ impl AssocItem { db, acc, db.type_for_type_alias_with_diagnostics(type_alias.id).1, - &db.type_alias_signature_with_source_map(type_alias.id).1, + &TypeAliasSignature::with_source_map(db, type_alias.id).1, ); for diag in hir_ty::diagnostics::incorrect_case(db, type_alias.id.into()) { acc.push(diag.into()); @@ -3938,12 +4039,36 @@ impl_from!( ); impl GenericDef { + pub fn name(self, db: &dyn HirDatabase) -> Option<Name> { + match self { + GenericDef::Function(it) => Some(it.name(db)), + GenericDef::Adt(it) => Some(it.name(db)), + GenericDef::Trait(it) => Some(it.name(db)), + GenericDef::TypeAlias(it) => Some(it.name(db)), + GenericDef::Impl(_) => None, + GenericDef::Const(it) => it.name(db), + GenericDef::Static(it) => Some(it.name(db)), + } + } + + pub fn module(self, db: &dyn HirDatabase) -> Module { + match self { + GenericDef::Function(it) => it.module(db), + GenericDef::Adt(it) => it.module(db), + GenericDef::Trait(it) => it.module(db), + GenericDef::TypeAlias(it) => it.module(db), + GenericDef::Impl(it) => it.module(db), + GenericDef::Const(it) => it.module(db), + GenericDef::Static(it) => it.module(db), + } + } + pub fn params(self, db: &dyn HirDatabase) -> Vec<GenericParam> { let Ok(id) = self.try_into() else { // Let's pretend builtin derive impls don't have generic parameters. return Vec::new(); }; - let generics = db.generic_params(id); + let generics = GenericParams::of(db, id); let ty_params = generics.iter_type_or_consts().map(|(local_id, _)| { let toc = TypeOrConstParam { id: TypeOrConstParamId { parent: id, local_id } }; match toc.split(db) { @@ -3963,7 +4088,7 @@ impl GenericDef { // Let's pretend builtin derive impls don't have generic parameters. return Vec::new(); }; - let generics = db.generic_params(id); + let generics = GenericParams::of(db, id); generics .iter_lt() .map(|(local_id, _)| LifetimeParam { id: LifetimeParamId { parent: id, local_id } }) @@ -3975,7 +4100,7 @@ impl GenericDef { // Let's pretend builtin derive impls don't have generic parameters. return Vec::new(); }; - let generics = db.generic_params(id); + let generics = GenericParams::of(db, id); generics .iter_type_or_consts() .map(|(local_id, _)| TypeOrConstParam { @@ -4005,31 +4130,31 @@ impl GenericDef { pub fn diagnostics<'db>(self, db: &'db dyn HirDatabase, acc: &mut Vec<AnyDiagnostic<'db>>) { let Some(def) = self.id() else { return }; - let generics = db.generic_params(def); + let generics = GenericParams::of(db, def); if generics.is_empty() && generics.has_no_predicates() { return; } let source_map = match def { - GenericDefId::AdtId(AdtId::EnumId(it)) => db.enum_signature_with_source_map(it).1, - GenericDefId::AdtId(AdtId::StructId(it)) => db.struct_signature_with_source_map(it).1, - GenericDefId::AdtId(AdtId::UnionId(it)) => db.union_signature_with_source_map(it).1, + GenericDefId::AdtId(AdtId::EnumId(it)) => &EnumSignature::with_source_map(db, it).1, + GenericDefId::AdtId(AdtId::StructId(it)) => &StructSignature::with_source_map(db, it).1, + GenericDefId::AdtId(AdtId::UnionId(it)) => &UnionSignature::with_source_map(db, it).1, GenericDefId::ConstId(_) => return, - GenericDefId::FunctionId(it) => db.function_signature_with_source_map(it).1, - GenericDefId::ImplId(it) => db.impl_signature_with_source_map(it).1, + GenericDefId::FunctionId(it) => &FunctionSignature::with_source_map(db, it).1, + GenericDefId::ImplId(it) => &ImplSignature::with_source_map(db, it).1, GenericDefId::StaticId(_) => return, - GenericDefId::TraitId(it) => db.trait_signature_with_source_map(it).1, - GenericDefId::TypeAliasId(it) => db.type_alias_signature_with_source_map(it).1, + GenericDefId::TraitId(it) => &TraitSignature::with_source_map(db, it).1, + GenericDefId::TypeAliasId(it) => &TypeAliasSignature::with_source_map(db, it).1, }; - expr_store_diagnostics(db, acc, &source_map); - push_ty_diagnostics(db, acc, db.generic_defaults_with_diagnostics(def).1, &source_map); + expr_store_diagnostics(db, acc, source_map); + push_ty_diagnostics(db, acc, db.generic_defaults_with_diagnostics(def).1, source_map); push_ty_diagnostics( db, acc, GenericPredicates::query_with_diagnostics(db, def).1.clone(), - &source_map, + source_map, ); for (param_id, param) in generics.iter_type_or_consts() { if let TypeOrConstParamData::ConstParamData(_) = param { @@ -4040,7 +4165,7 @@ impl GenericDef { TypeOrConstParamId { parent: def, local_id: param_id }, )) .1, - &source_map, + source_map, ); } } @@ -4101,7 +4226,7 @@ impl<'db> GenericSubstitution<'db> { _ => None, }) .map(|container| { - db.generic_params(container) + GenericParams::of(db, container) .iter_type_or_consts() .filter_map(|param| match param.1 { TypeOrConstParamData::TypeParamData(param) => Some(param.name.clone()), @@ -4109,7 +4234,7 @@ impl<'db> GenericSubstitution<'db> { }) .collect::<Vec<_>>() }); - let generics = db.generic_params(self.def); + let generics = GenericParams::of(db, self.def); let type_params = generics.iter_type_or_consts().filter_map(|param| match param.1 { TypeOrConstParamData::TypeParamData(param) => Some(param.name.clone()), TypeOrConstParamData::ConstParamData(_) => None, @@ -4137,7 +4262,7 @@ impl<'db> GenericSubstitution<'db> { /// A single local definition. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct Local { - pub(crate) parent: DefWithBodyId, + pub(crate) parent: ExpressionStoreOwnerId, pub(crate) binding_id: BindingId, } @@ -4199,7 +4324,7 @@ impl Local { pub fn as_self_param(self, db: &dyn HirDatabase) -> Option<SelfParam> { match self.parent { - DefWithBodyId::FunctionId(func) if self.is_self(db) => { + ExpressionStoreOwnerId::Body(DefWithBodyId::FunctionId(func)) if self.is_self(db) => { Some(SelfParam { func: func.into() }) } _ => None, @@ -4207,8 +4332,7 @@ impl Local { } pub fn name(self, db: &dyn HirDatabase) -> Name { - let body = db.body(self.parent); - body[self.binding_id].name.clone() + ExpressionStore::of(db, self.parent)[self.binding_id].name.clone() } pub fn is_self(self, db: &dyn HirDatabase) -> bool { @@ -4216,16 +4340,17 @@ impl Local { } pub fn is_mut(self, db: &dyn HirDatabase) -> bool { - let body = db.body(self.parent); - body[self.binding_id].mode == BindingAnnotation::Mutable + ExpressionStore::of(db, self.parent)[self.binding_id].mode == BindingAnnotation::Mutable } pub fn is_ref(self, db: &dyn HirDatabase) -> bool { - let body = db.body(self.parent); - matches!(body[self.binding_id].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) + matches!( + ExpressionStore::of(db, self.parent)[self.binding_id].mode, + BindingAnnotation::Ref | BindingAnnotation::RefMut + ) } - pub fn parent(self, _db: &dyn HirDatabase) -> DefWithBody { + pub fn parent(self, _db: &dyn HirDatabase) -> ExpressionStoreOwner { self.parent.into() } @@ -4233,69 +4358,97 @@ impl Local { self.parent(db).module(db) } + pub fn as_id(self) -> u32 { + self.binding_id.into_raw().into_u32() + } + pub fn ty(self, db: &dyn HirDatabase) -> Type<'_> { let def = self.parent; - let infer = InferenceResult::for_body(db, def); + let infer = InferenceResult::of(db, def); let ty = infer.binding_ty(self.binding_id); Type::new(db, def, ty) } /// All definitions for this local. Example: `let (a$0, _) | (_, a$0) = it;` pub fn sources(self, db: &dyn HirDatabase) -> Vec<LocalSource> { - let (body, source_map) = db.body_with_source_map(self.parent); - match body.self_param.zip(source_map.self_param_syntax()) { - Some((param, source)) if param == self.binding_id => { - let root = source.file_syntax(db); - vec![LocalSource { - local: self, - source: source.map(|ast| Either::Right(ast.to_node(&root))), - }] + let b; + let (_, source_map) = match self.parent { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + ExpressionStore::with_source_map(db, generic_def_id.into()) } - _ => source_map - .patterns_for_binding(self.binding_id) - .iter() - .map(|&definition| { - let src = source_map.pat_syntax(definition).unwrap(); // Hmm... - let root = src.file_syntax(db); - LocalSource { + ExpressionStoreOwnerId::Body(def_with_body_id) => { + b = Body::with_source_map(db, def_with_body_id); + if let Some((param, source)) = b.0.self_param.zip(b.1.self_param_syntax()) + && param == self.binding_id + { + let root = source.file_syntax(db); + return vec![LocalSource { local: self, - source: src.map(|ast| match ast.to_node(&root) { - Either::Right(ast::Pat::IdentPat(it)) => Either::Left(it), - _ => unreachable!("local with non ident-pattern"), - }), - } - }) - .collect(), - } + source: source.map(|ast| Either::Right(ast.to_node(&root))), + }]; + } + (&b.0.store, &b.1.store) + } + ExpressionStoreOwnerId::VariantFields(def) => { + ExpressionStore::with_source_map(db, def.into()) + } + }; + source_map + .patterns_for_binding(self.binding_id) + .iter() + .map(|&definition| { + let src = source_map.pat_syntax(definition).unwrap(); // Hmm... + let root = src.file_syntax(db); + LocalSource { + local: self, + source: src.map(|ast| match ast.to_node(&root) { + Either::Right(ast::Pat::IdentPat(it)) => Either::Left(it), + _ => unreachable!("local with non ident-pattern"), + }), + } + }) + .collect() } /// The leftmost definition for this local. Example: `let (a$0, _) | (_, a) = it;` pub fn primary_source(self, db: &dyn HirDatabase) -> LocalSource { - let (body, source_map) = db.body_with_source_map(self.parent); - match body.self_param.zip(source_map.self_param_syntax()) { - Some((param, source)) if param == self.binding_id => { - let root = source.file_syntax(db); + let b; + let (_, source_map) = match self.parent { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + ExpressionStore::with_source_map(db, generic_def_id.into()) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => { + b = Body::with_source_map(db, def_with_body_id); + if let Some((param, source)) = b.0.self_param.zip(b.1.self_param_syntax()) + && param == self.binding_id + { + let root = source.file_syntax(db); + return LocalSource { + local: self, + source: source.map(|ast| Either::Right(ast.to_node(&root))), + }; + } + (&b.0.store, &b.1.store) + } + ExpressionStoreOwnerId::VariantFields(def) => { + ExpressionStore::with_source_map(db, def.into()) + } + }; + source_map + .patterns_for_binding(self.binding_id) + .first() + .map(|&definition| { + let src = source_map.pat_syntax(definition).unwrap(); // Hmm... + let root = src.file_syntax(db); LocalSource { local: self, - source: source.map(|ast| Either::Right(ast.to_node(&root))), + source: src.map(|ast| match ast.to_node(&root) { + Either::Right(ast::Pat::IdentPat(it)) => Either::Left(it), + _ => unreachable!("local with non ident-pattern"), + }), } - } - _ => source_map - .patterns_for_binding(self.binding_id) - .first() - .map(|&definition| { - let src = source_map.pat_syntax(definition).unwrap(); // Hmm... - let root = src.file_syntax(db); - LocalSource { - local: self, - source: src.map(|ast| match ast.to_node(&root) { - Either::Right(ast::Pat::IdentPat(it)) => Either::Left(it), - _ => unreachable!("local with non ident-pattern"), - }), - } - }) - .unwrap(), - } + }) + .unwrap() } } @@ -4380,7 +4533,7 @@ impl ToolModule { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct Label { - pub(crate) parent: DefWithBodyId, + pub(crate) parent: ExpressionStoreOwnerId, pub(crate) label_id: LabelId, } @@ -4389,13 +4542,12 @@ impl Label { self.parent(db).module(db) } - pub fn parent(self, _db: &dyn HirDatabase) -> DefWithBody { + pub fn parent(self, _db: &dyn HirDatabase) -> ExpressionStoreOwner { self.parent.into() } pub fn name(self, db: &dyn HirDatabase) -> Name { - let body = db.body(self.parent); - body[self.label_id].name.clone() + ExpressionStore::of(db, self.parent)[self.label_id].name.clone() } } @@ -4506,7 +4658,7 @@ impl TypeParam { /// Is this type parameter implicitly introduced (eg. `Self` in a trait or an `impl Trait` /// argument)? pub fn is_implicit(self, db: &dyn HirDatabase) -> bool { - let params = db.generic_params(self.id.parent()); + let params = GenericParams::of(db, self.id.parent()); let data = ¶ms[self.id.local_id()]; match data.type_param().unwrap().provenance { TypeParamProvenance::TypeParamList => false, @@ -4561,7 +4713,7 @@ pub struct LifetimeParam { impl LifetimeParam { pub fn name(self, db: &dyn HirDatabase) -> Name { - let params = db.generic_params(self.id.parent); + let params = GenericParams::of(db, self.id.parent); params[self.id.local_id].name.clone() } @@ -4585,7 +4737,7 @@ impl ConstParam { } pub fn name(self, db: &dyn HirDatabase) -> Name { - let params = db.generic_params(self.id.parent()); + let params = GenericParams::of(db, self.id.parent()); match params[self.id.local_id()].name() { Some(it) => it.clone(), None => { @@ -4632,7 +4784,7 @@ pub struct TypeOrConstParam { impl TypeOrConstParam { pub fn name(self, db: &dyn HirDatabase) -> Name { - let params = db.generic_params(self.id.parent); + let params = GenericParams::of(db, self.id.parent); match params[self.id.local_id].name() { Some(n) => n.clone(), _ => Name::missing(), @@ -4648,7 +4800,7 @@ impl TypeOrConstParam { } pub fn split(self, db: &dyn HirDatabase) -> Either<ConstParam, TypeParam> { - let params = db.generic_params(self.id.parent); + let params = GenericParams::of(db, self.id.parent); match ¶ms[self.id.local_id] { TypeOrConstParamData::TypeParamData(_) => { Either::Right(TypeParam { id: TypeParamId::from_unchecked(self.id) }) @@ -4667,7 +4819,7 @@ impl TypeOrConstParam { } pub fn as_type_param(self, db: &dyn HirDatabase) -> Option<TypeParam> { - let params = db.generic_params(self.id.parent); + let params = GenericParams::of(db, self.id.parent); match ¶ms[self.id.local_id] { TypeOrConstParamData::TypeParamData(_) => { Some(TypeParam { id: TypeParamId::from_unchecked(self.id) }) @@ -4677,7 +4829,7 @@ impl TypeOrConstParam { } pub fn as_const_param(self, db: &dyn HirDatabase) -> Option<ConstParam> { - let params = db.generic_params(self.id.parent); + let params = GenericParams::of(db, self.id.parent); match ¶ms[self.id.local_id] { TypeOrConstParamData::TypeParamData(_) => None, TypeOrConstParamData::ConstParamData(_) => { @@ -4704,7 +4856,7 @@ impl Impl { result.extend(module.scope.builtin_derive_impls().map(Impl::from)); for unnamed_const in module.scope.unnamed_consts() { - for (_, block_def_map) in db.body(unnamed_const.into()).blocks(db) { + for (_, block_def_map) in Body::of(db, unnamed_const.into()).blocks(db) { extend_with_def_map(db, block_def_map, result); } } @@ -4861,14 +5013,14 @@ impl Impl { pub fn is_negative(self, db: &dyn HirDatabase) -> bool { match self.id { - AnyImplId::ImplId(id) => db.impl_signature(id).flags.contains(ImplFlags::NEGATIVE), + AnyImplId::ImplId(id) => ImplSignature::of(db, id).flags.contains(ImplFlags::NEGATIVE), AnyImplId::BuiltinDeriveImplId(_) => false, } } pub fn is_unsafe(self, db: &dyn HirDatabase) -> bool { match self.id { - AnyImplId::ImplId(id) => db.impl_signature(id).flags.contains(ImplFlags::UNSAFE), + AnyImplId::ImplId(id) => ImplSignature::of(db, id).flags.contains(ImplFlags::UNSAFE), AnyImplId::BuiltinDeriveImplId(_) => false, } } @@ -4975,7 +5127,7 @@ impl<'db> Closure<'db> { return Vec::new(); }; let owner = db.lookup_intern_closure(id).0; - let infer = InferenceResult::for_body(db, owner); + let infer = InferenceResult::of(db, owner); let info = infer.closure_info(id); info.0 .iter() @@ -4995,9 +5147,12 @@ impl<'db> Closure<'db> { return Vec::new(); }; let owner = db.lookup_intern_closure(id).0; - let infer = InferenceResult::for_body(db, owner); + let Some(body_owner) = owner.as_def_with_body() else { + return Vec::new(); + }; + let infer = InferenceResult::of(db, body_owner); let (captures, _) = infer.closure_info(id); - let env = body_param_env_from_has_crate(db, owner); + let env = body_param_env_from_has_crate(db, body_owner); captures.iter().map(|capture| Type { env, ty: capture.ty(db, self.subst) }).collect() } @@ -5005,7 +5160,10 @@ impl<'db> Closure<'db> { match self.id { AnyClosureId::ClosureId(id) => { let owner = db.lookup_intern_closure(id).0; - let infer = InferenceResult::for_body(db, owner); + let Some(body_owner) = owner.as_def_with_body() else { + return FnTrait::FnOnce; + }; + let infer = InferenceResult::of(db, body_owner); let info = infer.closure_info(id); info.1.into() } @@ -5088,7 +5246,7 @@ impl FnTrait { #[derive(Clone, Debug, PartialEq, Eq)] pub struct ClosureCapture<'db> { - owner: DefWithBodyId, + owner: ExpressionStoreOwnerId, closure: InternedClosureId, capture: hir_ty::CapturedItem, _marker: PhantomCovariantLifetime<'db>, @@ -5147,17 +5305,16 @@ pub enum CaptureKind { #[derive(Debug, Clone)] pub struct CaptureUsages { - parent: DefWithBodyId, + parent: ExpressionStoreOwnerId, spans: SmallVec<[mir::MirSpan; 3]>, } impl CaptureUsages { pub fn sources(&self, db: &dyn HirDatabase) -> Vec<CaptureUsageSource> { - let (body, source_map) = db.body_with_source_map(self.parent); - + let (body, source_map) = ExpressionStore::with_source_map(db, self.parent); let mut result = Vec::with_capacity(self.spans.len()); for &span in self.spans.iter() { - let is_ref = span.is_ref_span(&body); + let is_ref = span.is_ref_span(body); match span { mir::MirSpan::ExprId(expr) => { if let Ok(expr) = source_map.expr_syntax(expr) { @@ -5325,7 +5482,7 @@ impl<'db> Type<'db> { fn is_phantom_data(db: &dyn HirDatabase, adt_id: AdtId) -> bool { match adt_id { AdtId::StructId(s) => { - let flags = db.struct_signature(s).flags; + let flags = StructSignature::of(db, s).flags; flags.contains(StructFlags::IS_PHANTOM_DATA) } AdtId::UnionId(_) | AdtId::EnumId(_) => false, @@ -5915,7 +6072,16 @@ impl<'db> Type<'db> { ) -> R { let module = resolver.module(); let interner = DbInterner::new_with(db, module.krate(db)); - let infcx = interner.infer_ctxt().build(TypingMode::PostAnalysis); + // Most IDE operations want to operate in PostAnalysis mode, revealing opaques. This makes + // for a nicer IDE experience. However, method resolution is always done on real code (either + // existing code or code to be inserted), and there using PostAnalysis is dangerous - we may + // suggest invalid methods. So we're using the TypingMode of the body we're in. + let typing_mode = if let Some(store_owner) = resolver.expression_store_owner() { + TypingMode::analysis_in_body(interner, store_owner.into()) + } else { + TypingMode::non_body_analysis() + }; + let infcx = interner.infer_ctxt().build(typing_mode); let unstable_features = MethodResolutionUnstableFeatures::from_def_map(resolver.top_level_def_map()); let environment = param_env_from_resolver(db, resolver); @@ -6064,11 +6230,7 @@ impl<'db> Type<'db> { match name { Some(name) => { - match ctx.probe_for_name( - method_resolution::Mode::MethodCall, - name.clone(), - self_ty, - ) { + match ctx.probe_for_name(method_resolution::Mode::Path, name.clone(), self_ty) { Ok(candidate) | Err(method_resolution::MethodError::PrivateMatch(candidate)) => { let id = candidate.item.into(); @@ -6116,6 +6278,13 @@ impl<'db> Type<'db> { Some(adt.into()) } + /// Holes in the args can come from lifetime/const params. + pub fn as_adt_with_args(&self) -> Option<(Adt, Vec<Option<Type<'db>>>)> { + let (adt, args) = self.ty.as_adt()?; + let args = args.iter().map(|arg| Some(self.derived(arg.ty()?))).collect(); + Some((adt.into(), args)) + } + pub fn as_builtin(&self) -> Option<BuiltinType> { self.ty.as_builtin().map(|inner| BuiltinType { inner }) } @@ -6134,6 +6303,7 @@ impl<'db> Type<'db> { self.autoderef_(db) .filter_map(|ty| ty.dyn_trait()) .flat_map(move |dyn_trait_id| hir_ty::all_super_traits(db, dyn_trait_id)) + .copied() .map(Trait::from) } @@ -6151,6 +6321,7 @@ impl<'db> Type<'db> { _ => None, }) .flat_map(|t| hir_ty::all_super_traits(db, t)) + .copied() }) .map(Trait::from) } @@ -6311,18 +6482,19 @@ impl<'db> TypeNs<'db> { #[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)] pub struct InlineAsmOperand { - owner: DefWithBodyId, + owner: ExpressionStoreOwnerId, expr: ExprId, index: usize, } impl InlineAsmOperand { - pub fn parent(self, _db: &dyn HirDatabase) -> DefWithBody { + pub fn parent(self, _db: &dyn HirDatabase) -> ExpressionStoreOwner { self.owner.into() } pub fn name(&self, db: &dyn HirDatabase) -> Option<Name> { - match &db.body(self.owner)[self.expr] { + let body = ExpressionStore::of(db, self.owner); + match &body[self.expr] { hir_def::hir::Expr::InlineAsm(e) => e.operands.get(self.index)?.0.clone(), _ => None, } @@ -6352,7 +6524,7 @@ enum Callee<'db> { pub enum CallableKind<'db> { Function(Function), TupleStruct(Struct), - TupleEnumVariant(Variant), + TupleEnumVariant(EnumVariant), Closure(Closure<'db>), FnPtr, FnImpl(FnTrait), @@ -6702,7 +6874,7 @@ impl HasCrate for Field { } } -impl HasCrate for Variant { +impl HasCrate for EnumVariant { fn krate(&self, db: &dyn HirDatabase) -> Crate { self.module(db).krate(db) } @@ -6871,9 +7043,9 @@ impl_has_name!( Struct, Union, Enum, - Variant, + EnumVariant, Adt, - VariantDef, + Variant, DefWithBody, Function, ExternCrateDecl, @@ -7062,7 +7234,7 @@ fn generic_args_from_tys<'db>( } fn has_non_default_type_params(db: &dyn HirDatabase, generic_def: GenericDefId) -> bool { - let params = db.generic_params(generic_def); + let params = GenericParams::of(db, generic_def); let defaults = db.generic_defaults(generic_def); params .iter_type_or_consts() @@ -7083,7 +7255,7 @@ fn param_env_from_resolver<'db>( ParamEnvAndCrate { param_env: resolver .generic_def() - .map_or_else(ParamEnv::empty, |generic_def| db.trait_environment(generic_def)), + .map_or_else(ParamEnv::empty, |generic_def| db.trait_environment(generic_def.into())), krate: resolver.krate(), } } @@ -7092,14 +7264,14 @@ fn param_env_from_has_crate<'db>( db: &'db dyn HirDatabase, id: impl hir_def::HasModule + Into<GenericDefId> + Copy, ) -> ParamEnvAndCrate<'db> { - ParamEnvAndCrate { param_env: db.trait_environment(id.into()), krate: id.krate(db) } + ParamEnvAndCrate { param_env: db.trait_environment(id.into().into()), krate: id.krate(db) } } fn body_param_env_from_has_crate<'db>( db: &'db dyn HirDatabase, - id: impl hir_def::HasModule + Into<DefWithBodyId> + Copy, + id: impl hir_def::HasModule + Into<ExpressionStoreOwnerId> + Copy, ) -> ParamEnvAndCrate<'db> { - ParamEnvAndCrate { param_env: db.trait_environment_for_body(id.into()), krate: id.krate(db) } + ParamEnvAndCrate { param_env: db.trait_environment(id.into()), krate: id.krate(db) } } fn empty_param_env<'db>(krate: base_db::Crate) -> ParamEnvAndCrate<'db> { diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs index 4bc757da44..4e9e3c44be 100644 --- a/crates/hir/src/semantics.rs +++ b/crates/hir/src/semantics.rs @@ -13,9 +13,10 @@ use std::{ use base_db::FxIndexSet; use either::Either; use hir_def::{ - BuiltinDeriveImplId, DefWithBodyId, HasModule, MacroId, StructId, TraitId, VariantId, + BuiltinDeriveImplId, DefWithBodyId, ExpressionStoreOwnerId, HasModule, MacroId, StructId, + TraitId, VariantId, attrs::parse_extra_crate_attrs, - expr_store::{Body, ExprOrPatSource, HygieneId, path::Path}, + expr_store::{Body, ExprOrPatSource, ExpressionStore, HygieneId, path::Path}, hir::{BindingId, Expr, ExprId, ExprOrPatId, Pat}, nameres::{ModuleOrigin, crate_def_map}, resolver::{self, HasResolver, Resolver, TypeNs, ValueNs}, @@ -31,7 +32,7 @@ use hir_expand::{ }; use hir_ty::{ InferenceResult, - diagnostics::{unsafe_operations, unsafe_operations_for_body}, + diagnostics::unsafe_operations, infer_query_with_inspect, next_solver::{ AnyImplId, DbInterner, Span, @@ -54,10 +55,10 @@ use syntax::{ use crate::{ Adjust, Adjustment, Adt, AnyFunctionId, AutoBorrow, BindingMode, BuiltinAttr, Callable, Const, - ConstParam, Crate, DefWithBody, DeriveHelper, Enum, Field, Function, GenericSubstitution, - HasSource, Impl, InFile, InlineAsmOperand, ItemInNs, Label, LifetimeParam, Local, Macro, - Module, ModuleDef, Name, OverloadedDeref, ScopeDef, Static, Struct, ToolModule, Trait, - TupleField, Type, TypeAlias, TypeParam, Union, Variant, VariantDef, + ConstParam, Crate, DeriveHelper, Enum, EnumVariant, ExpressionStoreOwner, Field, Function, + GenericSubstitution, HasSource, Impl, InFile, InlineAsmOperand, ItemInNs, Label, LifetimeParam, + Local, Macro, Module, ModuleDef, Name, OverloadedDeref, ScopeDef, Static, Struct, ToolModule, + Trait, TupleField, Type, TypeAlias, TypeParam, Union, Variant, db::HirDatabase, semantics::source_to_def::{ChildContainer, SourceToDefCache, SourceToDefCtx}, source_analyzer::{SourceAnalyzer, resolve_hir_path}, @@ -90,7 +91,7 @@ impl PathResolution { } PathResolution::Def( ModuleDef::Const(_) - | ModuleDef::Variant(_) + | ModuleDef::EnumVariant(_) | ModuleDef::Macro(_) | ModuleDef::Function(_) | ModuleDef::Module(_) @@ -367,8 +368,8 @@ impl<DB: HirDatabase + ?Sized> Semantics<'_, DB> { self.imp.resolve_try_expr(try_expr) } - pub fn resolve_variant(&self, record_lit: ast::RecordExpr) -> Option<VariantDef> { - self.imp.resolve_variant(record_lit).map(VariantDef::from) + pub fn resolve_variant(&self, record_lit: ast::RecordExpr) -> Option<Variant> { + self.imp.resolve_variant(record_lit).map(Variant::from) } pub fn file_to_module_def(&self, file: impl Into<FileId>) -> Option<Module> { @@ -409,7 +410,7 @@ impl<DB: HirDatabase + ?Sized> Semantics<'_, DB> { self.imp.to_def(e) } - pub fn to_enum_variant_def(&self, v: &ast::Variant) -> Option<Variant> { + pub fn to_enum_variant_def(&self, v: &ast::Variant) -> Option<EnumVariant> { self.imp.to_def(v) } @@ -472,12 +473,12 @@ impl<'db> SemanticsImpl<'db> { pub fn attach_first_edition_opt(&self, file: FileId) -> Option<EditionedFileId> { let krate = self.file_to_module_defs(file).next()?.krate(self.db); - Some(EditionedFileId::new(self.db, file, krate.edition(self.db), krate.id)) + Some(EditionedFileId::new(self.db, file, krate.edition(self.db))) } pub fn attach_first_edition(&self, file: FileId) -> EditionedFileId { self.attach_first_edition_opt(file) - .unwrap_or_else(|| EditionedFileId::current_edition_guess_origin(self.db, file)) + .unwrap_or_else(|| EditionedFileId::current_edition(self.db, file)) } pub fn parse_guess_edition(&self, file_id: FileId) -> ast::SourceFile { @@ -785,16 +786,21 @@ impl<'db> SemanticsImpl<'db> { /// Checks if renaming `renamed` to `new_name` may introduce conflicts with other locals, /// and returns the conflicting locals. pub fn rename_conflicts(&self, to_be_renamed: &Local, new_name: &Name) -> Vec<Local> { - let body = self.db.body(to_be_renamed.parent); + // FIXME: signatures + let Some(def) = to_be_renamed.parent.as_def_with_body() else { + return Vec::new(); + }; + let body = Body::of(self.db, def); let resolver = to_be_renamed.parent.resolver(self.db); - let starting_expr = body.binding_owner(to_be_renamed.binding_id).unwrap_or(body.body_expr); + let starting_expr = + body.binding_owner(to_be_renamed.binding_id).unwrap_or(body.root_expr()); let mut visitor = RenameConflictsVisitor { - body: &body, + body, conflicts: FxHashSet::default(), db: self.db, new_name: new_name.symbol().clone(), old_name: to_be_renamed.name(self.db).symbol().clone(), - owner: to_be_renamed.parent, + owner: def, to_be_renamed: to_be_renamed.binding_id, resolver, }; @@ -1819,6 +1825,28 @@ impl<'db> SemanticsImpl<'db> { self.analyze(try_expr.syntax())?.resolve_try_expr(self.db, try_expr) } + /// The type that the associated `try` block, closure or function expects. + pub fn try_expr_returned_type(&self, try_expr: &ast::TryExpr) -> Option<Type<'db>> { + self.ancestors_with_macros(try_expr.syntax().clone()).find_map(|parent| { + if let Some(try_block) = ast::BlockExpr::cast(parent.clone()) + && try_block.try_block_modifier().is_some() + { + Some(self.type_of_expr(&try_block.into())?.original) + } else if let Some(closure) = ast::ClosureExpr::cast(parent.clone()) { + Some( + self.type_of_expr(&closure.into())? + .original + .as_callable(self.db)? + .return_type(), + ) + } else if let Some(function) = ast::Fn::cast(parent) { + Some(self.to_def(&function)?.ret_type(self.db)) + } else { + None + } + }) + } + // This does not resolve the method call to the correct trait impl! // We should probably fix that. pub fn resolve_method_call_as_callable( @@ -1891,36 +1919,32 @@ impl<'db> SemanticsImpl<'db> { self.db.parse_macro_expansion(file_id).value.1.matched_arm } - pub fn get_unsafe_ops(&self, def: DefWithBody) -> FxHashSet<ExprOrPatSource> { - let Ok(def) = DefWithBodyId::try_from(def) else { - return FxHashSet::default(); - }; - let (body, source_map) = self.db.body_with_source_map(def); - let infer = InferenceResult::for_body(self.db, def); + pub fn get_unsafe_ops(&self, def: ExpressionStoreOwner) -> FxHashSet<ExprOrPatSource> { + let Ok(def) = ExpressionStoreOwnerId::try_from(def) else { return Default::default() }; + let (body, source_map) = ExpressionStore::with_source_map(self.db, def); + let infer = InferenceResult::of(self.db, def); let mut res = FxHashSet::default(); - unsafe_operations_for_body(self.db, infer, def, &body, &mut |node| { - if let Ok(node) = source_map.expr_or_pat_syntax(node) { - res.insert(node); - } - }); + for root in body.expr_roots() { + unsafe_operations(self.db, infer, def, body, root, &mut |node, _| { + if let Ok(node) = source_map.expr_or_pat_syntax(node) { + res.insert(node); + } + }); + } res } pub fn get_unsafe_ops_for_unsafe_block(&self, block: ast::BlockExpr) -> Vec<ExprOrPatSource> { always!(block.unsafe_token().is_some()); + let Some(sa) = self.analyze(block.syntax()) else { return vec![] }; + let Some((def, store, sm, Some(infer))) = sa.def() else { return vec![] }; let block = self.wrap_node_infile(ast::Expr::from(block)); - let Some(def) = self.body_for(block.syntax()) else { return Vec::new() }; - let Ok(def) = def.try_into() else { - return Vec::new(); - }; - let (body, source_map) = self.db.body_with_source_map(def); - let infer = InferenceResult::for_body(self.db, def); - let Some(ExprOrPatId::ExprId(block)) = source_map.node_expr(block.as_ref()) else { + let Some(ExprOrPatId::ExprId(block)) = sm.node_expr(block.as_ref()) else { return Vec::new(); }; let mut res = Vec::default(); - unsafe_operations(self.db, infer, def, &body, block, &mut |node, _| { - if let Ok(node) = source_map.expr_or_pat_syntax(node) { + unsafe_operations(self.db, infer, def, store, block, &mut |node, _| { + if let Ok(node) = sm.expr_or_pat_syntax(node) { res.push(node); } }); @@ -1972,7 +1996,7 @@ impl<'db> SemanticsImpl<'db> { pub fn resolve_offset_of_field( &self, name_ref: &ast::NameRef, - ) -> Option<(Either<Variant, Field>, GenericSubstitution<'db>)> { + ) -> Option<(Either<EnumVariant, Field>, GenericSubstitution<'db>)> { self.analyze_no_infer(name_ref.syntax())?.resolve_offset_of_field(self.db, name_ref) } @@ -2092,13 +2116,9 @@ impl<'db> SemanticsImpl<'db> { Some(res) } - pub fn body_for(&self, node: InFile<&SyntaxNode>) -> Option<DefWithBody> { + pub fn store_owner_for(&self, node: InFile<&SyntaxNode>) -> Option<ExpressionStoreOwner> { let container = self.with_ctx(|ctx| ctx.find_container(node))?; - - match container { - ChildContainer::DefWithBodyId(def) => Some(def.into()), - _ => None, - } + container.as_expression_store_owner().map(|id| id.into()) } /// Returns none if the file of the node is not part of a crate. @@ -2127,7 +2147,7 @@ impl<'db> SemanticsImpl<'db> { node: InFile<&SyntaxNode>, offset: Option<TextSize>, // replace this, just make the inference result a `LazyCell` - infer_body: bool, + infer: bool, ) -> Option<SourceAnalyzer<'db>> { let _p = tracing::info_span!("SemanticsImpl::analyze_impl").entered(); @@ -2135,26 +2155,42 @@ impl<'db> SemanticsImpl<'db> { let resolver = match container { ChildContainer::DefWithBodyId(def) => { - return Some(if infer_body { + return Some(if infer { SourceAnalyzer::new_for_body(self.db, def, node, offset) } else { SourceAnalyzer::new_for_body_no_infer(self.db, def, node, offset) }); } ChildContainer::VariantId(def) => { - return Some(SourceAnalyzer::new_variant_body(self.db, def, node, offset)); + return Some(SourceAnalyzer::new_variant_body(self.db, def, node, offset, infer)); } ChildContainer::TraitId(it) => { - return Some(SourceAnalyzer::new_generic_def(self.db, it.into(), node, offset)); + return Some(if infer { + SourceAnalyzer::new_generic_def(self.db, it.into(), node, offset) + } else { + SourceAnalyzer::new_generic_def_no_infer(self.db, it.into(), node, offset) + }); } ChildContainer::ImplId(it) => { - return Some(SourceAnalyzer::new_generic_def(self.db, it.into(), node, offset)); + return Some(if infer { + SourceAnalyzer::new_generic_def(self.db, it.into(), node, offset) + } else { + SourceAnalyzer::new_generic_def_no_infer(self.db, it.into(), node, offset) + }); } ChildContainer::EnumId(it) => { - return Some(SourceAnalyzer::new_generic_def(self.db, it.into(), node, offset)); + return Some(if infer { + SourceAnalyzer::new_generic_def(self.db, it.into(), node, offset) + } else { + SourceAnalyzer::new_generic_def_no_infer(self.db, it.into(), node, offset) + }); } ChildContainer::GenericDefId(it) => { - return Some(SourceAnalyzer::new_generic_def(self.db, it, node, offset)); + return Some(if infer { + SourceAnalyzer::new_generic_def(self.db, it, node, offset) + } else { + SourceAnalyzer::new_generic_def_no_infer(self.db, it, node, offset) + }); } ChildContainer::ModuleId(it) => it.resolver(self.db), }; @@ -2237,7 +2273,7 @@ impl<'db> SemanticsImpl<'db> { let Some(def) = def else { return false }; let enclosing_node = enclosing_item.as_ref().either(|i| i.syntax(), |v| v.syntax()); - let (body, source_map) = self.db.body_with_source_map(def); + let (body, source_map) = Body::with_source_map(self.db, def); let file_id = self.find_file(expr.syntax()).file_id; @@ -2288,7 +2324,7 @@ impl<'db> SemanticsImpl<'db> { let sa = self.analyze(element.either(|e| e.syntax(), |s| s.syntax()))?; let store = sa.store()?; let mut resolver = sa.resolver.clone(); - let def = resolver.body_owner()?; + let def = resolver.expression_store_owner()?; let is_not_generated = |path: &Path| { !path.mod_path().and_then(|path| path.as_ident()).is_some_and(Name::is_generated) @@ -2479,7 +2515,7 @@ to_def_impls![ (crate::Function, ast::Fn, fn_to_def), (crate::Field, ast::RecordField, record_field_to_def), (crate::Field, ast::TupleField, tuple_field_to_def), - (crate::Variant, ast::Variant, enum_variant_to_def), + (crate::EnumVariant, ast::Variant, enum_variant_to_def), (crate::TypeParam, ast::TypeParam, type_param_to_def), (crate::LifetimeParam, ast::LifetimeParam, lifetime_param_to_def), (crate::ConstParam, ast::ConstParam, const_param_to_def), @@ -2538,13 +2574,18 @@ impl<'db> SemanticsScope<'db> { Crate { id: self.resolver.krate() } } + // FIXME: This is a weird function, we shouldn't have this? pub fn containing_function(&self) -> Option<Function> { - self.resolver.body_owner().and_then(|owner| match owner { - DefWithBodyId::FunctionId(id) => Some(id.into()), + self.resolver.expression_store_owner().and_then(|owner| match owner { + ExpressionStoreOwnerId::Body(DefWithBodyId::FunctionId(id)) => Some(id.into()), _ => None, }) } + pub fn expression_store_owner(&self) -> Option<ExpressionStoreOwner> { + self.resolver.expression_store_owner().map(Into::into) + } + pub(crate) fn resolver(&self) -> &Resolver<'db> { &self.resolver } @@ -2566,14 +2607,18 @@ impl<'db> SemanticsScope<'db> { resolver::ScopeDef::ImplSelfType(it) => ScopeDef::ImplSelfType(it.into()), resolver::ScopeDef::AdtSelfType(it) => ScopeDef::AdtSelfType(it.into()), resolver::ScopeDef::GenericParam(id) => ScopeDef::GenericParam(id.into()), - resolver::ScopeDef::Local(binding_id) => match self.resolver.body_owner() { - Some(parent) => ScopeDef::Local(Local { parent, binding_id }), - None => continue, - }, - resolver::ScopeDef::Label(label_id) => match self.resolver.body_owner() { - Some(parent) => ScopeDef::Label(Label { parent, label_id }), - None => continue, - }, + resolver::ScopeDef::Local(binding_id) => { + match self.resolver.expression_store_owner() { + Some(parent) => ScopeDef::Local(Local { parent, binding_id }), + None => continue, + } + } + resolver::ScopeDef::Label(label_id) => { + match self.resolver.expression_store_owner() { + Some(parent) => ScopeDef::Label(Label { parent, label_id }), + None => continue, + } + } }; f(name.clone(), def) } diff --git a/crates/hir/src/semantics/child_by_source.rs b/crates/hir/src/semantics/child_by_source.rs index c1f72debe5..f6d1bec575 100644 --- a/crates/hir/src/semantics/child_by_source.rs +++ b/crates/hir/src/semantics/child_by_source.rs @@ -18,8 +18,10 @@ use hir_def::{ DynMap, keys::{self, Key}, }, + expr_store::Body, hir::generics::GenericParams, item_scope::ItemScope, + signatures::{EnumSignature, ImplSignature, TraitSignature}, src::{HasChildSource, HasSource}, }; @@ -49,7 +51,7 @@ impl ChildBySource for TraitId { data.items.iter().for_each(|&(_, item)| { add_assoc_item(db, res, file_id, item); }); - let (_, source_map) = db.trait_signature_with_source_map(*self); + let (_, source_map) = TraitSignature::with_source_map(db, *self); source_map.expansions().filter(|(ast, _)| ast.file_id == file_id).for_each( |(ast, &exp_id)| { res[keys::MACRO_CALL].insert(ast.value, exp_id); @@ -74,7 +76,7 @@ impl ChildBySource for ImplId { data.items.iter().for_each(|&(_, item)| { add_assoc_item(db, res, file_id, item); }); - let (_, source_map) = db.impl_signature_with_source_map(*self); + let (_, source_map) = ImplSignature::with_source_map(db, *self); source_map.expansions().filter(|(ast, _)| ast.file_id == file_id).for_each( |(ast, &exp_id)| { res[keys::MACRO_CALL].insert(ast.value, exp_id); @@ -93,7 +95,6 @@ impl ChildBySource for ModuleId { impl ChildBySource for ItemScope { fn child_by_source_to(&self, db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId) { - let krate = file_id.krate(db); self.declarations().for_each(|item| add_module_def(db, res, file_id, item)); self.impls().for_each(|imp| insert_item_loc(db, res, file_id, imp, keys::IMPL)); self.extern_blocks().for_each(|extern_block| { @@ -123,6 +124,8 @@ impl ChildBySource for ItemScope { |(ast_id, calls)| { let adt = ast_id.to_node(db); calls.for_each(|(attr_id, call_id, calls)| { + // FIXME: Is this the right crate? + let krate = call_id.lookup(db).krate; // FIXME: Fix cfg_attr handling. let (attr, _, _, _) = attr_id.find_attr_range_with_source(db, krate, &adt); res[keys::DERIVE_MACRO_CALL] @@ -203,7 +206,7 @@ impl ChildBySource for EnumId { self.enum_variants(db).variants.iter().for_each(|&(variant, _, _)| { res[keys::ENUM_VARIANT].insert(ast_id_map.get(variant.lookup(db).id.value), variant); }); - let (_, source_map) = db.enum_signature_with_source_map(*self); + let (_, source_map) = EnumSignature::with_source_map(db, *self); source_map .expansions() .filter(|(ast, _)| ast.file_id == file_id) @@ -213,7 +216,7 @@ impl ChildBySource for EnumId { impl ChildBySource for DefWithBodyId { fn child_by_source_to(&self, db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId) { - let (body, sm) = db.body_with_source_map(*self); + let (body, sm) = Body::with_source_map(db, *self); if let &DefWithBodyId::VariantId(v) = self { VariantId::EnumVariantId(v).child_by_source_to(db, res, file_id) } @@ -238,8 +241,7 @@ impl ChildBySource for GenericDefId { return; } - let (generic_params, _, source_map) = - GenericParams::generic_params_and_store_and_source_map(db, *self); + let (generic_params, _, source_map) = GenericParams::with_source_map(db, *self); let mut toc_idx_iter = generic_params.iter_type_or_consts().map(|(idx, _)| idx); let lts_idx_iter = generic_params.iter_lt().map(|(idx, _)| idx); diff --git a/crates/hir/src/semantics/source_to_def.rs b/crates/hir/src/semantics/source_to_def.rs index d222c3dc7e..a9a779a287 100644 --- a/crates/hir/src/semantics/source_to_def.rs +++ b/crates/hir/src/semantics/source_to_def.rs @@ -88,13 +88,14 @@ use either::Either; use hir_def::{ AdtId, BlockId, BuiltinDeriveImplId, ConstId, ConstParamId, DefWithBodyId, EnumId, - EnumVariantId, ExternBlockId, ExternCrateId, FieldId, FunctionId, GenericDefId, GenericParamId, - ImplId, LifetimeParamId, Lookup, MacroId, ModuleId, StaticId, StructId, TraitId, TypeAliasId, - TypeParamId, UnionId, UseId, VariantId, + EnumVariantId, ExpressionStoreOwnerId, ExternBlockId, ExternCrateId, FieldId, FunctionId, + GenericDefId, GenericParamId, ImplId, LifetimeParamId, Lookup, MacroId, ModuleId, StaticId, + StructId, TraitId, TypeAliasId, TypeParamId, UnionId, UseId, VariantId, dyn_map::{ DynMap, keys::{self, Key}, }, + expr_store::{Body, ExpressionStore}, hir::{BindingId, Expr, LabelId}, nameres::{block_def_map, crate_def_map}, }; @@ -334,8 +335,8 @@ impl SourceToDefCtx<'_, '_> { _ => None, }) .position(|it| it == *src.value)?; - let container = self.find_pat_or_label_container(src.syntax_ref())?; - let source_map = self.db.body_with_source_map(container).1; + let container = self.find_container(src.syntax_ref())?.as_expression_store_owner()?; + let (_, source_map) = ExpressionStore::with_source_map(self.db, container); let expr = source_map.node_expr(src.with_value(&ast::Expr::AsmExpr(asm)))?.as_expr()?; Some(InlineAsmOperand { owner: container, expr, index }) } @@ -343,13 +344,13 @@ impl SourceToDefCtx<'_, '_> { pub(super) fn bind_pat_to_def( &mut self, src: InFile<&ast::IdentPat>, - ) -> Option<(DefWithBodyId, BindingId)> { - let container = self.find_pat_or_label_container(src.syntax_ref())?; - let (body, source_map) = self.db.body_with_source_map(container); + ) -> Option<(ExpressionStoreOwnerId, BindingId)> { + let container = self.find_container(src.syntax_ref())?.as_expression_store_owner()?; + let (store, source_map) = ExpressionStore::with_source_map(self.db, container); let src = src.cloned().map(ast::Pat::from); let pat_id = source_map.node_pat(src.as_ref())?; // the pattern could resolve to a constant, verify that this is not the case - if let crate::Pat::Bind { id, .. } = body[pat_id.as_pat()?] { + if let crate::Pat::Bind { id, .. } = store[pat_id.as_pat()?] { Some((container, id)) } else { None @@ -359,17 +360,19 @@ impl SourceToDefCtx<'_, '_> { &mut self, src: InFile<&ast::SelfParam>, ) -> Option<(DefWithBodyId, BindingId)> { - let container = self.find_pat_or_label_container(src.syntax_ref())?; - let body = self.db.body(container); + let container = self + .find_container(src.syntax_ref())? + .as_expression_store_owner()? + .as_def_with_body()?; + let body = Body::of(self.db, container); Some((container, body.self_param?)) } pub(super) fn label_to_def( &mut self, src: InFile<&ast::Label>, - ) -> Option<(DefWithBodyId, LabelId)> { - let container = self.find_pat_or_label_container(src.syntax_ref())?; - let source_map = self.db.body_with_source_map(container).1; - + ) -> Option<(ExpressionStoreOwnerId, LabelId)> { + let container = self.find_container(src.syntax_ref())?.as_expression_store_owner()?; + let (_, source_map) = ExpressionStore::with_source_map(self.db, container); let label_id = source_map.node_label(src)?; Some((container, label_id)) } @@ -377,13 +380,14 @@ impl SourceToDefCtx<'_, '_> { pub(super) fn label_ref_to_def( &mut self, src: InFile<&ast::Lifetime>, - ) -> Option<(DefWithBodyId, LabelId)> { + ) -> Option<(ExpressionStoreOwnerId, LabelId)> { let break_or_continue = ast::Expr::cast(src.value.syntax().parent()?)?; - let container = self.find_pat_or_label_container(src.syntax_ref())?; - let (body, source_map) = self.db.body_with_source_map(container); + let container = self.find_container(src.syntax_ref())?.as_expression_store_owner()?; + let (store, source_map) = ExpressionStore::with_source_map(self.db, container); let break_or_continue = source_map.node_expr(src.with_value(&break_or_continue))?.as_expr()?; - let (Expr::Break { label, .. } | Expr::Continue { label }) = body[break_or_continue] else { + let (Expr::Break { label, .. } | Expr::Continue { label }) = store[break_or_continue] + else { return None; }; Some((container, label?)) @@ -557,29 +561,6 @@ impl SourceToDefCtx<'_, '_> { }) } - // FIXME: Remove this when we do inference in signatures - fn find_pat_or_label_container(&mut self, src: InFile<&SyntaxNode>) -> Option<DefWithBodyId> { - self.parent_ancestors_with_macros(src, |this, InFile { file_id, value }, _| { - let item = match ast::Item::cast(value.clone()) { - Some(it) => it, - None => { - let variant = ast::Variant::cast(value)?; - return this - .enum_variant_to_def(InFile::new(file_id, &variant)) - .map(Into::into); - } - }; - match &item { - ast::Item::Fn(it) => this.fn_to_def(InFile::new(file_id, it)).map(Into::into), - ast::Item::Const(it) => this.const_to_def(InFile::new(file_id, it)).map(Into::into), - ast::Item::Static(it) => { - this.static_to_def(InFile::new(file_id, it)).map(Into::into) - } - _ => None, - } - }) - } - /// Skips the attributed item that caused the macro invocation we are climbing up fn parent_ancestors_with_macros<T>( &mut self, @@ -756,4 +737,22 @@ impl ChildContainer { ChildContainer::GenericDefId(it) => it.child_by_source(db, file_id), } } + + pub(crate) fn as_expression_store_owner(self) -> Option<ExpressionStoreOwnerId> { + match self { + ChildContainer::DefWithBodyId(it) => Some(it.into()), + ChildContainer::ModuleId(_) => None, + ChildContainer::TraitId(it) => { + Some(ExpressionStoreOwnerId::Signature(GenericDefId::TraitId(it))) + } + ChildContainer::EnumId(it) => { + Some(ExpressionStoreOwnerId::Signature(GenericDefId::AdtId(it.into()))) + } + ChildContainer::ImplId(it) => { + Some(ExpressionStoreOwnerId::Signature(GenericDefId::ImplId(it))) + } + ChildContainer::VariantId(_) => None, + ChildContainer::GenericDefId(it) => Some(it.into()), + } + } } diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index c6f2d151f5..1a34fa9134 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -9,18 +9,18 @@ use std::iter::{self, once}; use either::Either; use hir_def::{ - AdtId, AssocItemId, CallableDefId, ConstId, DefWithBodyId, FieldId, FunctionId, GenericDefId, - LocalFieldId, ModuleDefId, StructId, TraitId, VariantId, + AdtId, AssocItemId, CallableDefId, ConstId, DefWithBodyId, ExpressionStoreOwnerId, FieldId, + FunctionId, GenericDefId, LocalFieldId, ModuleDefId, StructId, TraitId, VariantId, expr_store::{ Body, BodySourceMap, ExpressionStore, ExpressionStoreSourceMap, HygieneId, lower::ExprCollector, path::Path, scope::{ExprScopes, ScopeId}, }, - hir::{BindingId, Expr, ExprId, ExprOrPatId, Pat, PatId}, + hir::{BindingId, Expr, ExprId, ExprOrPatId, Pat, PatId, generics::GenericParams}, lang_item::LangItems, nameres::MacroSubNs, - resolver::{HasResolver, Resolver, TypeNs, ValueNs, resolver_for_scope}, + resolver::{Resolver, TypeNs, ValueNs, resolver_for_scope}, type_ref::{Mutability, TypeRef, TypeRefId}, }; use hir_expand::{ @@ -55,12 +55,11 @@ use syntax::{ SyntaxKind, SyntaxNode, TextRange, TextSize, ast::{self, AstNode, RangeItem, RangeOp}, }; -use triomphe::Arc; use crate::{ Adt, AnyFunctionId, AssocItem, BindingMode, BuiltinAttr, BuiltinType, Callable, Const, - DeriveHelper, Field, Function, GenericSubstitution, Local, Macro, ModuleDef, Static, Struct, - ToolModule, Trait, TupleField, Type, TypeAlias, Variant, + DeriveHelper, EnumVariant, Field, Function, GenericSubstitution, Local, Macro, ModuleDef, + Static, Struct, ToolModule, Trait, TupleField, Type, TypeAlias, db::HirDatabase, semantics::{PathResolution, PathResolutionPerNs}, }; @@ -78,21 +77,23 @@ pub(crate) struct SourceAnalyzer<'db> { pub(crate) enum BodyOrSig<'db> { Body { def: DefWithBodyId, - body: Arc<Body>, - source_map: Arc<BodySourceMap>, + body: &'db Body, + source_map: &'db BodySourceMap, infer: Option<&'db InferenceResult>, }, - // To be folded into body once it is considered one VariantFields { def: VariantId, - store: Arc<ExpressionStore>, - source_map: Arc<ExpressionStoreSourceMap>, + store: &'db ExpressionStore, + source_map: &'db ExpressionStoreSourceMap, + infer: Option<&'db InferenceResult>, }, Sig { def: GenericDefId, - store: Arc<ExpressionStore>, - source_map: Arc<ExpressionStoreSourceMap>, - // infer: Option<Arc<InferenceResult>>, + store: &'db ExpressionStore, + source_map: &'db ExpressionStoreSourceMap, + infer: Option<&'db InferenceResult>, + #[expect(dead_code)] + generics: &'db GenericParams, }, } @@ -103,7 +104,7 @@ impl<'db> SourceAnalyzer<'db> { node: InFile<&SyntaxNode>, offset: Option<TextSize>, ) -> SourceAnalyzer<'db> { - Self::new_for_body_(db, def, node, offset, Some(InferenceResult::for_body(db, def))) + Self::new_for_body_(db, def, node, offset, Some(InferenceResult::of(db, def))) } pub(crate) fn new_for_body_no_infer( @@ -122,10 +123,10 @@ impl<'db> SourceAnalyzer<'db> { offset: Option<TextSize>, infer: Option<&'db InferenceResult>, ) -> SourceAnalyzer<'db> { - let (body, source_map) = db.body_with_source_map(def); - let scopes = db.expr_scopes(def); + let (body, source_map) = Body::with_source_map(db, def); + let scopes = ExprScopes::of(db, def); let scope = match offset { - None => scope_for(db, &scopes, &source_map, node), + None => scope_for(db, scopes, source_map, node), Some(offset) => { debug_assert!( node.text_range().contains_inclusive(offset), @@ -133,7 +134,7 @@ impl<'db> SourceAnalyzer<'db> { offset, node.text_range() ); - scope_for_offset(db, &scopes, &source_map, node.file_id, offset) + scope_for_offset(db, scopes, source_map, node.file_id, offset) } }; let resolver = resolver_for_scope(db, def, scope); @@ -147,14 +148,47 @@ impl<'db> SourceAnalyzer<'db> { pub(crate) fn new_generic_def( db: &'db dyn HirDatabase, def: GenericDefId, - InFile { file_id, .. }: InFile<&SyntaxNode>, - _offset: Option<TextSize>, + node: InFile<&SyntaxNode>, + offset: Option<TextSize>, + ) -> SourceAnalyzer<'db> { + Self::new_generic_def_(db, def, node, offset, true) + } + + pub(crate) fn new_generic_def_no_infer( + db: &'db dyn HirDatabase, + def: GenericDefId, + node: InFile<&SyntaxNode>, + offset: Option<TextSize>, ) -> SourceAnalyzer<'db> { - let (_params, store, source_map) = db.generic_params_and_store_and_source_map(def); - let resolver = def.resolver(db); + Self::new_generic_def_(db, def, node, offset, false) + } + + pub(crate) fn new_generic_def_( + db: &'db dyn HirDatabase, + def: GenericDefId, + node @ InFile { file_id, .. }: InFile<&SyntaxNode>, + offset: Option<TextSize>, + infer: bool, + ) -> SourceAnalyzer<'db> { + let (generics, store, source_map) = GenericParams::with_source_map(db, def); + let scopes = ExprScopes::of(db, def); + let scope = match offset { + None => scope_for(db, scopes, source_map, node), + Some(offset) => { + debug_assert!( + node.text_range().contains_inclusive(offset), + "{:?} not in {:?}", + offset, + node.text_range() + ); + scope_for_offset(db, scopes, source_map, node.file_id, offset) + } + }; + let resolver = resolver_for_scope(db, def, scope); + let infer = if infer { Some(InferenceResult::of(db, def)) } else { None }; SourceAnalyzer { resolver, - body_or_sig: Some(BodyOrSig::Sig { def, store, source_map }), + body_or_sig: Some(BodyOrSig::Sig { def, store, source_map, generics, infer }), file_id, } } @@ -162,17 +196,33 @@ impl<'db> SourceAnalyzer<'db> { pub(crate) fn new_variant_body( db: &'db dyn HirDatabase, def: VariantId, - InFile { file_id, .. }: InFile<&SyntaxNode>, - _offset: Option<TextSize>, + node @ InFile { file_id, .. }: InFile<&SyntaxNode>, + offset: Option<TextSize>, + infer: bool, ) -> SourceAnalyzer<'db> { let (fields, source_map) = def.fields_with_source_map(db); - let resolver = def.resolver(db); + let scopes = ExprScopes::of(db, def); + let scope = match offset { + None => scope_for(db, scopes, source_map, node), + Some(offset) => { + debug_assert!( + node.text_range().contains_inclusive(offset), + "{:?} not in {:?}", + offset, + node.text_range() + ); + scope_for_offset(db, scopes, source_map, node.file_id, offset) + } + }; + let resolver = resolver_for_scope(db, def, scope); + let infer = if infer { Some(InferenceResult::of(db, def)) } else { None }; SourceAnalyzer { resolver, body_or_sig: Some(BodyOrSig::VariantFields { def, - store: fields.store.clone(), - source_map: source_map.clone(), + store: &fields.store, + source_map, + infer, }), file_id, } @@ -185,29 +235,40 @@ impl<'db> SourceAnalyzer<'db> { SourceAnalyzer { resolver, body_or_sig: None, file_id: node.file_id } } - // FIXME: Remove this - fn body_(&self) -> Option<(DefWithBodyId, &Body, &BodySourceMap, Option<&InferenceResult>)> { - self.body_or_sig.as_ref().and_then(|it| match it { - BodyOrSig::Body { def, body, source_map, infer } => { - Some((*def, &**body, &**source_map, infer.as_deref())) - } - _ => None, + fn owner(&self) -> Option<ExpressionStoreOwnerId> { + self.body_or_sig.as_ref().map(|it| match *it { + BodyOrSig::VariantFields { def, .. } => def.into(), + BodyOrSig::Sig { def, .. } => def.into(), + BodyOrSig::Body { def, .. } => def.into(), }) } fn infer(&self) -> Option<&InferenceResult> { self.body_or_sig.as_ref().and_then(|it| match it { - BodyOrSig::Sig { .. } => None, - BodyOrSig::VariantFields { .. } => None, - BodyOrSig::Body { infer, .. } => infer.as_deref(), + BodyOrSig::VariantFields { infer, .. } + | BodyOrSig::Sig { infer, .. } + | BodyOrSig::Body { infer, .. } => infer.as_deref(), }) } - fn body(&self) -> Option<&Body> { - self.body_or_sig.as_ref().and_then(|it| match it { - BodyOrSig::Sig { .. } => None, - BodyOrSig::VariantFields { .. } => None, - BodyOrSig::Body { body, .. } => Some(&**body), + pub(crate) fn def( + &self, + ) -> Option<( + ExpressionStoreOwnerId, + &ExpressionStore, + &ExpressionStoreSourceMap, + Option<&InferenceResult>, + )> { + self.body_or_sig.as_ref().map(|it| match *it { + BodyOrSig::VariantFields { def, store, source_map, infer, .. } => { + (def.into(), store, source_map, infer) + } + BodyOrSig::Sig { def, store, source_map, infer, .. } => { + (def.into(), store, source_map, infer) + } + BodyOrSig::Body { def, body, source_map, infer, .. } => { + (def.into(), &body.store, &source_map.store, infer) + } }) } @@ -232,11 +293,13 @@ impl<'db> SourceAnalyzer<'db> { } fn trait_environment(&self, db: &'db dyn HirDatabase) -> ParamEnvAndCrate<'db> { - self.param_and( - self.body_() - .map(|(def, ..)| def) - .map_or_else(ParamEnv::empty, |def| db.trait_environment_for_body(def)), - ) + self.param_and(self.body_or_sig.as_ref().map_or_else(ParamEnv::empty, |body_or_sig| { + match *body_or_sig { + BodyOrSig::Body { def, .. } => db.trait_environment(def.into()), + BodyOrSig::VariantFields { def, .. } => db.trait_environment(def.into()), + BodyOrSig::Sig { def, .. } => db.trait_environment(def.into()), + } + })) } pub(crate) fn expr_id(&self, expr: ast::Expr) -> Option<ExprOrPatId> { @@ -371,7 +434,10 @@ impl<'db> SourceAnalyzer<'db> { db: &'db dyn HirDatabase, _param: &ast::SelfParam, ) -> Option<Type<'db>> { - let binding = self.body()?.self_param?; + let binding = match self.body_or_sig.as_ref()? { + BodyOrSig::Sig { .. } | BodyOrSig::VariantFields { .. } => return None, + BodyOrSig::Body { body, .. } => body.self_param?, + }; let ty = self.infer()?.binding_ty(binding); Some(Type::new_with_resolver(db, &self.resolver, ty)) } @@ -472,7 +538,7 @@ impl<'db> SourceAnalyzer<'db> { &self, field: &ast::FieldExpr, ) -> Option<Either<Field, TupleField>> { - let (def, ..) = self.body_()?; + let def = self.owner()?; let expr_id = self.expr_id(field.clone().into())?.as_expr()?; self.infer()?.field_resolution(expr_id).map(|it| { it.map_either(Into::into, |f| TupleField { owner: def, tuple: f.tuple, index: f.index }) @@ -499,7 +565,7 @@ impl<'db> SourceAnalyzer<'db> { field: &ast::FieldExpr, ) -> Option<(Either<Either<Field, TupleField>, Function>, Option<GenericSubstitution<'db>>)> { - let (def, ..) = self.body_()?; + let def = self.owner()?; let expr_id = self.expr_id(field.clone().into())?.as_expr()?; let inference_result = self.infer()?; match inference_result.field_resolution(expr_id) { @@ -771,7 +837,7 @@ impl<'db> SourceAnalyzer<'db> { name_hygiene(db, InFile::new(self.file_id, ast_name.syntax())), ) { Some(ValueNs::LocalBinding(binding_id)) => { - Some(Local { binding_id, parent: self.resolver.body_owner()? }) + Some(Local { binding_id, parent: self.resolver.expression_store_owner()? }) } _ => None, } @@ -831,8 +897,8 @@ impl<'db> SourceAnalyzer<'db> { }, }; - let body_owner = self.resolver.body_owner(); - let res = resolve_hir_value_path(db, &self.resolver, body_owner, path, HygieneId::ROOT)?; + let store_owner = self.resolver.expression_store_owner(); + let res = resolve_hir_value_path(db, &self.resolver, store_owner, path, HygieneId::ROOT)?; match res { PathResolution::Def(def) => Some(def), _ => None, @@ -843,7 +909,7 @@ impl<'db> SourceAnalyzer<'db> { let name = name.as_name(); self.resolver .all_generic_params() - .find_map(|(params, parent)| params.find_type_by_name(&name, *parent)) + .find_map(|(params, parent)| params.find_type_by_name(&name, parent)) .map(crate::TypeParam::from) } @@ -851,7 +917,7 @@ impl<'db> SourceAnalyzer<'db> { &self, db: &'db dyn HirDatabase, name_ref: &ast::NameRef, - ) -> Option<(Either<crate::Variant, crate::Field>, GenericSubstitution<'db>)> { + ) -> Option<(Either<crate::EnumVariant, crate::Field>, GenericSubstitution<'db>)> { let offset_of_expr = ast::OffsetOfExpr::cast(name_ref.syntax().parent()?)?; let container = offset_of_expr.ty()?; let container = self.type_of_type(db, &container)?; @@ -902,7 +968,7 @@ impl<'db> SourceAnalyzer<'db> { let variants = id.enum_variants(db); let variant = variants.variant(&field_name.as_name())?; container = Either::Left((variant, subst)); - (Either::Left(Variant { id: variant }), id.into(), subst) + (Either::Left(EnumVariant { id: variant }), id.into(), subst) } }, _ => return None, @@ -982,7 +1048,10 @@ impl<'db> SourceAnalyzer<'db> { if let Some(VariantId::EnumVariantId(variant)) = infer.variant_resolution_for_expr_or_pat(expr_id) { - return Some((PathResolution::Def(ModuleDef::Variant(variant.into())), None)); + return Some(( + PathResolution::Def(ModuleDef::EnumVariant(variant.into())), + None, + )); } prefer_value_ns = true; } else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) { @@ -1014,14 +1083,20 @@ impl<'db> SourceAnalyzer<'db> { if let Some(VariantId::EnumVariantId(variant)) = infer.variant_resolution_for_expr_or_pat(expr_or_pat_id) { - return Some((PathResolution::Def(ModuleDef::Variant(variant.into())), None)); + return Some(( + PathResolution::Def(ModuleDef::EnumVariant(variant.into())), + None, + )); } } else if let Some(rec_lit) = parent().and_then(ast::RecordExpr::cast) { let expr_id = self.expr_id(rec_lit.into())?; if let Some(VariantId::EnumVariantId(variant)) = infer.variant_resolution_for_expr_or_pat(expr_id) { - return Some((PathResolution::Def(ModuleDef::Variant(variant.into())), None)); + return Some(( + PathResolution::Def(ModuleDef::EnumVariant(variant.into())), + None, + )); } } else { let record_pat = parent().and_then(ast::RecordPat::cast).map(ast::Pat::from); @@ -1032,7 +1107,7 @@ impl<'db> SourceAnalyzer<'db> { let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id.as_pat()?); if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat { return Some(( - PathResolution::Def(ModuleDef::Variant(variant.into())), + PathResolution::Def(ModuleDef::EnumVariant(variant.into())), None, )); } @@ -1045,7 +1120,7 @@ impl<'db> SourceAnalyzer<'db> { } // FIXME: collectiong here shouldnt be necessary? - let mut collector = ExprCollector::new(db, self.resolver.module(), self.file_id); + let mut collector = ExprCollector::body(db, self.resolver.module(), self.file_id); let hir_path = collector.lower_path(path.clone(), &mut ExprCollector::impl_trait_error_allocator)?; let parent_hir_path = path @@ -1253,7 +1328,7 @@ impl<'db> SourceAnalyzer<'db> { db: &dyn HirDatabase, path: &ast::Path, ) -> Option<PathResolutionPerNs> { - let mut collector = ExprCollector::new(db, self.resolver.module(), self.file_id); + let mut collector = ExprCollector::body(db, self.resolver.module(), self.file_id); let hir_path = collector.lower_path(path.clone(), &mut ExprCollector::impl_trait_error_allocator)?; let (store, _) = collector.store.finish(); @@ -1366,7 +1441,7 @@ impl<'db> SourceAnalyzer<'db> { db: &'db dyn HirDatabase, macro_expr: InFile<&ast::MacroExpr>, ) -> bool { - if let Some((def, body, sm, Some(infer))) = self.body_() + if let Some((def, body, sm, Some(infer))) = self.def() && let Some(expanded_expr) = sm.macro_expansion_expr(macro_expr) { let mut is_unsafe = false; @@ -1400,7 +1475,7 @@ impl<'db> SourceAnalyzer<'db> { resolve_hir_value_path( db, &self.resolver, - self.resolver.body_owner(), + self.resolver.expression_store_owner(), &Path::from_known_path_with_no_generic(ModPath::from_segments( PathKind::Plain, Some(name.clone()), @@ -1416,9 +1491,9 @@ impl<'db> SourceAnalyzer<'db> { asm: InFile<&ast::AsmExpr>, line: usize, offset: TextSize, - ) -> Option<(DefWithBodyId, (ExprId, TextRange, usize))> { - let (def, _, body_source_map, _) = self.body_()?; - let (expr, args) = body_source_map.asm_template_args(asm)?; + ) -> Option<(ExpressionStoreOwnerId, (ExprId, TextRange, usize))> { + let (def, _, sm, _) = self.def()?; + let (expr, args) = sm.asm_template_args(asm)?; Some(def).zip( args.get(line)? .iter() @@ -1439,7 +1514,7 @@ impl<'db> SourceAnalyzer<'db> { resolve_hir_value_path( db, &self.resolver, - self.resolver.body_owner(), + self.resolver.expression_store_owner(), &Path::from_known_path_with_no_generic(ModPath::from_segments( PathKind::Plain, Some(name.clone()), @@ -1453,9 +1528,9 @@ impl<'db> SourceAnalyzer<'db> { pub(crate) fn as_asm_parts( &self, asm: InFile<&ast::AsmExpr>, - ) -> Option<(DefWithBodyId, (ExprId, &[Vec<(TextRange, usize)>]))> { - let (def, _, body_source_map, _) = self.body_()?; - Some(def).zip(body_source_map.asm_template_args(asm)) + ) -> Option<(ExpressionStoreOwnerId, (ExprId, &[Vec<(TextRange, usize)>]))> { + let (def, _, sm, _) = self.def()?; + Some(def).zip(sm.asm_template_args(asm)) } fn resolve_impl_method_or_trait_def( @@ -1473,11 +1548,11 @@ impl<'db> SourceAnalyzer<'db> { func: FunctionId, substs: GenericArgs<'db>, ) -> (Function, GenericArgs<'db>) { - let owner = match self.resolver.body_owner() { + let owner = match self.resolver.expression_store_owner() { Some(it) => it, None => return (func.into(), substs), }; - let env = self.param_and(db.trait_environment_for_body(owner)); + let env = self.param_and(db.trait_environment(owner)); let (func, args) = db.lookup_impl_method(env, func, substs); match func { Either::Left(func) => (func.into(), args), @@ -1493,11 +1568,11 @@ impl<'db> SourceAnalyzer<'db> { const_id: ConstId, subs: GenericArgs<'db>, ) -> (ConstId, GenericArgs<'db>) { - let owner = match self.resolver.body_owner() { + let owner = match self.resolver.expression_store_owner() { Some(it) => it, None => return (const_id, subs), }; - let env = self.param_and(db.trait_environment_for_body(owner)); + let env = self.param_and(db.trait_environment(owner)); let interner = DbInterner::new_with(db, env.krate); let infcx = interner.infer_ctxt().build(TypingMode::PostAnalysis); method_resolution::lookup_impl_const(&infcx, env.param_env, const_id, subs) @@ -1526,7 +1601,7 @@ impl<'db> SourceAnalyzer<'db> { fn scope_for( db: &dyn HirDatabase, scopes: &ExprScopes, - source_map: &BodySourceMap, + source_map: &ExpressionStoreSourceMap, node: InFile<&SyntaxNode>, ) -> Option<ScopeId> { node.ancestors_with_macros(db) @@ -1545,7 +1620,7 @@ fn scope_for( fn scope_for_offset( db: &dyn HirDatabase, scopes: &ExprScopes, - source_map: &BodySourceMap, + source_map: &ExpressionStoreSourceMap, from_file: HirFileId, offset: TextSize, ) -> Option<ScopeId> { @@ -1579,7 +1654,7 @@ fn scope_for_offset( fn adjust( db: &dyn HirDatabase, scopes: &ExprScopes, - source_map: &BodySourceMap, + source_map: &ExpressionStoreSourceMap, expr_range: TextRange, from_file: HirFileId, offset: TextSize, @@ -1684,7 +1759,7 @@ fn resolve_hir_path_( TypeNs::AdtSelfType(it) | TypeNs::AdtId(it) => { PathResolution::Def(Adt::from(it).into()) } - TypeNs::EnumVariantId(it) => PathResolution::Def(Variant::from(it).into()), + TypeNs::EnumVariantId(it) => PathResolution::Def(EnumVariant::from(it).into()), TypeNs::TypeAliasId(it) => PathResolution::Def(TypeAlias::from(it).into()), TypeNs::BuiltinType(it) => PathResolution::Def(BuiltinType::from(it).into()), TypeNs::TraitId(it) => PathResolution::Def(Trait::from(it).into()), @@ -1708,7 +1783,7 @@ fn resolve_hir_path_( } }; - let body_owner = resolver.body_owner(); + let body_owner = resolver.expression_store_owner(); let values = || resolve_hir_value_path(db, resolver, body_owner, path, hygiene); let items = || { @@ -1754,21 +1829,21 @@ fn resolve_hir_path_( fn resolve_hir_value_path( db: &dyn HirDatabase, resolver: &Resolver<'_>, - body_owner: Option<DefWithBodyId>, + store_owner: Option<ExpressionStoreOwnerId>, path: &Path, hygiene: HygieneId, ) -> Option<PathResolution> { resolver.resolve_path_in_value_ns_fully(db, path, hygiene).and_then(|val| { let res = match val { ValueNs::LocalBinding(binding_id) => { - let var = Local { parent: body_owner?, binding_id }; + let var = Local { parent: store_owner?, binding_id }; PathResolution::Local(var) } ValueNs::FunctionId(it) => PathResolution::Def(Function::from(it).into()), ValueNs::ConstId(it) => PathResolution::Def(Const::from(it).into()), ValueNs::StaticId(it) => PathResolution::Def(Static::from(it).into()), ValueNs::StructId(it) => PathResolution::Def(Struct::from(it).into()), - ValueNs::EnumVariantId(it) => PathResolution::Def(Variant::from(it).into()), + ValueNs::EnumVariantId(it) => PathResolution::Def(EnumVariant::from(it).into()), ValueNs::ImplSelf(impl_id) => PathResolution::SelfType(impl_id.into()), ValueNs::GenericParam(id) => PathResolution::ConstParam(id.into()), }; @@ -1833,7 +1908,7 @@ fn resolve_hir_path_qualifier( TypeNs::AdtSelfType(it) | TypeNs::AdtId(it) => { PathResolution::Def(Adt::from(it).into()) } - TypeNs::EnumVariantId(it) => PathResolution::Def(Variant::from(it).into()), + TypeNs::EnumVariantId(it) => PathResolution::Def(EnumVariant::from(it).into()), TypeNs::TypeAliasId(it) => PathResolution::Def(TypeAlias::from(it).into()), TypeNs::BuiltinType(it) => PathResolution::Def(BuiltinType::from(it).into()), TypeNs::TraitId(it) => PathResolution::Def(Trait::from(it).into()), diff --git a/crates/hir/src/symbols.rs b/crates/hir/src/symbols.rs index c088f3aa0c..ff56544d82 100644 --- a/crates/hir/src/symbols.rs +++ b/crates/hir/src/symbols.rs @@ -8,9 +8,11 @@ use hir_def::{ AdtId, AssocItemId, AstIdLoc, Complete, DefWithBodyId, ExternCrateId, HasModule, ImplId, Lookup, MacroId, ModuleDefId, ModuleId, TraitId, db::DefDatabase, + expr_store::Body, item_scope::{ImportId, ImportOrExternCrate, ImportOrGlob}, nameres::crate_def_map, per_ns::Item, + signatures::{EnumSignature, ImplSignature, TraitSignature}, src::{HasChildSource, HasSource}, visibility::{Visibility, VisibilityExplicitness}, }; @@ -185,7 +187,7 @@ impl<'a> SymbolCollector<'a> { } ModuleDefId::AdtId(AdtId::EnumId(id)) => { this.push_decl(id, name, false, None); - let enum_name = Symbol::intern(this.db.enum_signature(id).name.as_str()); + let enum_name = Symbol::intern(EnumSignature::of(this.db, id).name.as_str()); this.with_container_name(Some(enum_name), |this| { let variants = id.enum_variants(this.db); for (variant_id, variant_name, _) in &variants.variants { @@ -386,7 +388,7 @@ impl<'a> SymbolCollector<'a> { return; } let body_id = body_id.into(); - let body = self.db.body(body_id); + let body = Body::of(self.db, body_id); // Descend into the blocks and enqueue collection of all modules within. for (_, def_map) in body.blocks(self.db) { @@ -397,7 +399,7 @@ impl<'a> SymbolCollector<'a> { } fn collect_from_impl(&mut self, impl_id: ImplId) { - let impl_data = self.db.impl_signature(impl_id); + let impl_data = ImplSignature::of(self.db, impl_id); let impl_name = Some( hir_display_with_store(impl_data.self_ty, &impl_data.store) .display( @@ -419,7 +421,7 @@ impl<'a> SymbolCollector<'a> { } fn collect_from_trait(&mut self, trait_id: TraitId, trait_do_not_complete: Complete) { - let trait_data = self.db.trait_signature(trait_id); + let trait_data = TraitSignature::of(self.db, trait_id); self.with_container_name(Some(Symbol::intern(trait_data.name.as_str())), |s| { for &(ref name, assoc_item_id) in &trait_id.trait_items(self.db).items { s.push_assoc_item(assoc_item_id, name, Some(trait_do_not_complete)); diff --git a/crates/hir/src/term_search/expr.rs b/crates/hir/src/term_search/expr.rs index e56f9e91e3..e3d0121e49 100644 --- a/crates/hir/src/term_search/expr.rs +++ b/crates/hir/src/term_search/expr.rs @@ -10,8 +10,8 @@ use itertools::Itertools; use span::Edition; use crate::{ - Adt, AsAssocItem, AssocItemContainer, Const, ConstParam, Field, Function, Local, ModuleDef, - SemanticsScope, Static, Struct, StructKind, Trait, Type, Variant, + Adt, AsAssocItem, AssocItemContainer, Const, ConstParam, EnumVariant, Field, Function, Local, + ModuleDef, SemanticsScope, Static, Struct, StructKind, Trait, Type, }; /// Helper function to get path to `ModuleDef` @@ -80,7 +80,7 @@ pub enum Expr<'db> { params: Vec<Expr<'db>>, }, /// Enum variant construction - Variant { variant: Variant, generics: Vec<Type<'db>>, params: Vec<Expr<'db>> }, + Variant { variant: EnumVariant, generics: Vec<Type<'db>>, params: Vec<Expr<'db>> }, /// Struct construction Struct { strukt: Struct, generics: Vec<Type<'db>>, params: Vec<Expr<'db>> }, /// Tuple construction @@ -222,7 +222,7 @@ impl<'db> Expr<'db> { StructKind::Unit => String::new(), }; - let prefix = mod_item_path_str(sema_scope, &ModuleDef::Variant(*variant))?; + let prefix = mod_item_path_str(sema_scope, &ModuleDef::EnumVariant(*variant))?; Ok(format!("{prefix}{inner}")) } Expr::Struct { strukt, params, .. } => { diff --git a/crates/hir/src/term_search/tactics.rs b/crates/hir/src/term_search/tactics.rs index 05a89e7652..c7ef4e5d5d 100644 --- a/crates/hir/src/term_search/tactics.rs +++ b/crates/hir/src/term_search/tactics.rs @@ -18,7 +18,6 @@ use hir_ty::{ use itertools::Itertools; use rustc_hash::FxHashSet; use rustc_type_ir::inherent::Ty as _; -use span::Edition; use crate::{ Adt, AssocItem, GenericDef, GenericParam, HasAttrs, HasVisibility, Impl, ModuleDef, ScopeDef, @@ -54,7 +53,7 @@ pub(super) fn trivial<'a, 'lt, 'db, DB: HirDatabase>( ScopeDef::GenericParam(GenericParam::ConstParam(it)) => Some(Expr::ConstParam(*it)), ScopeDef::Local(it) => { if ctx.config.enable_borrowcheck { - let borrowck = db.borrowck(it.parent).ok()?; + let borrowck = db.borrowck(it.parent.as_def_with_body()?).ok()?; let invalid = borrowck.iter().any(|b| { b.partially_moved.iter().any(|moved| { @@ -367,7 +366,11 @@ pub(super) fn free_function<'a, 'lt, 'db, DB: HirDatabase>( let ret_ty = it.ret_type_with_args(db, generics.iter().cloned()); // Filter out private and unsafe functions if !it.is_visible_from(db, module) - || it.is_unsafe_to_call(db, None, Edition::CURRENT_FIXME) + || it.is_unsafe_to_call( + db, + None, + crate::Crate::from(ctx.scope.resolver().krate()).edition(db), + ) || it.is_unstable(db) || ctx.config.enable_borrowcheck && ret_ty.contains_reference(db) || ret_ty.is_raw_ptr() @@ -473,7 +476,11 @@ pub(super) fn impl_method<'a, 'lt, 'db, DB: HirDatabase>( // Filter out private and unsafe functions if !it.is_visible_from(db, module) - || it.is_unsafe_to_call(db, None, Edition::CURRENT_FIXME) + || it.is_unsafe_to_call( + db, + None, + crate::Crate::from(ctx.scope.resolver().krate()).edition(db), + ) || it.is_unstable(db) { return None; @@ -667,7 +674,11 @@ pub(super) fn impl_static_method<'a, 'lt, 'db, DB: HirDatabase>( // Filter out private and unsafe functions if !it.is_visible_from(db, module) - || it.is_unsafe_to_call(db, None, Edition::CURRENT_FIXME) + || it.is_unsafe_to_call( + db, + None, + crate::Crate::from(ctx.scope.resolver().krate()).edition(db), + ) || it.is_unstable(db) { return None; diff --git a/crates/ide-assists/src/handlers/add_braces.rs b/crates/ide-assists/src/handlers/add_braces.rs index 99ee50fa58..da1322de4b 100644 --- a/crates/ide-assists/src/handlers/add_braces.rs +++ b/crates/ide-assists/src/handlers/add_braces.rs @@ -84,6 +84,7 @@ fn get_replacement_node(ctx: &AssistContext<'_>) -> Option<(ParentType, ast::Exp match parent { ast::LetStmt(it) => it.initializer()?, ast::LetExpr(it) => it.expr()?, + ast::BinExpr(it) => it.rhs()?, ast::Static(it) => it.body()?, ast::Const(it) => it.body()?, _ => return None, @@ -176,6 +177,70 @@ fn foo() { } "#, ); + + check_assist( + add_braces, + r#" +fn foo() { + let x; + x =$0 n + 100; +} +"#, + r#" +fn foo() { + let x; + x = { + n + 100 + }; +} +"#, + ); + + check_assist( + add_braces, + r#" +fn foo() { + if let x =$0 n + 100 {} +} +"#, + r#" +fn foo() { + if let x = { + n + 100 + } {} +} +"#, + ); + } + + #[test] + fn suggest_add_braces_for_const_initializer() { + check_assist( + add_braces, + r#" +const X: i32 =$0 1 + 2; +"#, + r#" +const X: i32 = { + 1 + 2 +}; +"#, + ); + } + + #[test] + fn suggest_add_braces_for_static_initializer() { + check_assist( + add_braces, + r#" +static X: i32 $0= 1 + 2; +"#, + r#" +static X: i32 = { + 1 + 2 +}; +"#, + ); } #[test] diff --git a/crates/ide-assists/src/handlers/add_explicit_enum_discriminant.rs b/crates/ide-assists/src/handlers/add_explicit_enum_discriminant.rs index 7960373e61..75c5f84b85 100644 --- a/crates/ide-assists/src/handlers/add_explicit_enum_discriminant.rs +++ b/crates/ide-assists/src/handlers/add_explicit_enum_discriminant.rs @@ -47,7 +47,7 @@ pub(crate) fn add_explicit_enum_discriminant( // Don't offer the assist if the enum has no variants or if all variants already have an // explicit discriminant. - if variant_list.variants().all(|variant_node| variant_node.expr().is_some()) { + if variant_list.variants().all(|variant_node| variant_node.const_arg().is_some()) { return None; } @@ -72,7 +72,9 @@ fn add_variant_discriminant( variant_node: &ast::Variant, radix: &mut Radix, ) { - if let Some(expr) = variant_node.expr() { + if let Some(expr) = variant_node.const_arg() + && let Some(expr) = expr.expr() + { *radix = expr_radix(&expr).unwrap_or(*radix); return; } diff --git a/crates/ide-assists/src/handlers/add_label_to_loop.rs b/crates/ide-assists/src/handlers/add_label_to_loop.rs index b84ad24cfc..6a408e5254 100644 --- a/crates/ide-assists/src/handlers/add_label_to_loop.rs +++ b/crates/ide-assists/src/handlers/add_label_to_loop.rs @@ -3,11 +3,7 @@ use ide_db::{ }; use syntax::{ SyntaxToken, T, - ast::{ - self, AstNode, HasLoopBody, - make::{self, tokens}, - syntax_factory::SyntaxFactory, - }, + ast::{self, AstNode, HasLoopBody, syntax_factory::SyntaxFactory}, syntax_editor::{Position, SyntaxEditor}, }; @@ -35,9 +31,9 @@ use crate::{AssistContext, AssistId, Assists}; // } // ``` pub(crate) fn add_label_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { - let loop_kw = ctx.find_token_syntax_at_offset(T![loop])?; - let loop_expr = loop_kw.parent().and_then(ast::LoopExpr::cast)?; - if loop_expr.label().is_some() { + let loop_expr = ctx.find_node_at_offset::<ast::AnyHasLoopBody>()?; + let loop_kw = loop_token(&loop_expr)?; + if loop_expr.label().is_some() || !loop_kw.text_range().contains_inclusive(ctx.offset()) { return None; } @@ -52,8 +48,8 @@ pub(crate) fn add_label_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) -> O let label = make.lifetime("'l"); let elements = vec![ label.syntax().clone().into(), - make::token(T![:]).into(), - tokens::single_space().into(), + make.token(T![:]).into(), + make.whitespace(" ").into(), ]; editor.insert_all(Position::before(&loop_kw), elements); @@ -80,6 +76,14 @@ pub(crate) fn add_label_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) -> O ) } +fn loop_token(loop_expr: &ast::AnyHasLoopBody) -> Option<syntax::SyntaxToken> { + loop_expr + .syntax() + .children_with_tokens() + .filter_map(|it| it.into_token()) + .find(|it| matches!(it.kind(), T![for] | T![loop] | T![while])) +} + fn insert_label_after_token( editor: &mut SyntaxEditor, make: &SyntaxFactory, @@ -88,7 +92,7 @@ fn insert_label_after_token( builder: &mut SourceChangeBuilder, ) { let label = make.lifetime("'l"); - let elements = vec![tokens::single_space().into(), label.syntax().clone().into()]; + let elements = vec![make.whitespace(" ").into(), label.syntax().clone().into()]; editor.insert_all(Position::after(token), elements); if let Some(cap) = ctx.config.snippet_cap { @@ -124,6 +128,48 @@ fn main() { } #[test] + fn add_label_to_while_expr() { + check_assist( + add_label_to_loop, + r#" +fn main() { + while$0 true { + break; + continue; + } +}"#, + r#" +fn main() { + ${1:'l}: while true { + break ${2:'l}; + continue ${0:'l}; + } +}"#, + ); + } + + #[test] + fn add_label_to_for_expr() { + check_assist( + add_label_to_loop, + r#" +fn main() { + for$0 _ in 0..5 { + break; + continue; + } +}"#, + r#" +fn main() { + ${1:'l}: for _ in 0..5 { + break ${2:'l}; + continue ${0:'l}; + } +}"#, + ); + } + + #[test] fn add_label_to_outer_loop() { check_assist( add_label_to_loop, @@ -194,4 +240,29 @@ fn main() { }"#, ); } + + #[test] + fn do_not_add_label_if_outside_keyword() { + check_assist_not_applicable( + add_label_to_loop, + r#" +fn main() { + 'l: loop {$0 + break 'l; + continue 'l; + } +}"#, + ); + + check_assist_not_applicable( + add_label_to_loop, + r#" +fn main() { + 'l: while true {$0 + break 'l; + continue 'l; + } +}"#, + ); + } } diff --git a/crates/ide-assists/src/handlers/add_lifetime_to_type.rs b/crates/ide-assists/src/handlers/add_lifetime_to_type.rs index 27dbdcf2c4..265ee3d2d4 100644 --- a/crates/ide-assists/src/handlers/add_lifetime_to_type.rs +++ b/crates/ide-assists/src/handlers/add_lifetime_to_type.rs @@ -1,4 +1,7 @@ -use syntax::ast::{self, AstNode, HasGenericParams, HasName}; +use syntax::{ + SyntaxKind, SyntaxNode, SyntaxToken, + ast::{self, AstNode, HasGenericParams, HasName}, +}; use crate::{AssistContext, AssistId, Assists}; @@ -21,7 +24,7 @@ use crate::{AssistContext, AssistId, Assists}; // ``` pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let ref_type_focused = ctx.find_node_at_offset::<ast::RefType>()?; - if ref_type_focused.lifetime().is_some() { + if ref_type_focused.lifetime().is_some_and(|lifetime| lifetime.text() != "'_") { return None; } @@ -34,10 +37,10 @@ pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_>) - return None; } - let ref_types = fetch_borrowed_types(&node)?; + let changes = fetch_borrowed_types(&node)?; let target = node.syntax().text_range(); - acc.add(AssistId::generate("add_lifetime_to_type"), "Add lifetime", target, |builder| { + acc.add(AssistId::quick_fix("add_lifetime_to_type"), "Add lifetime", target, |builder| { match node.generic_param_list() { Some(gen_param) => { if let Some(left_angle) = gen_param.l_angle_token() { @@ -51,16 +54,21 @@ pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_>) - } } - for ref_type in ref_types { - if let Some(amp_token) = ref_type.amp_token() { - builder.insert(amp_token.text_range().end(), "'a "); + for change in changes { + match change { + Change::Replace(it) => { + builder.replace(it.text_range(), "'a"); + } + Change::Insert(it) => { + builder.insert(it.text_range().end(), "'a "); + } } } }) } -fn fetch_borrowed_types(node: &ast::Adt) -> Option<Vec<ast::RefType>> { - let ref_types: Vec<ast::RefType> = match node { +fn fetch_borrowed_types(node: &ast::Adt) -> Option<Vec<Change>> { + let ref_types: Vec<_> = match node { ast::Adt::Enum(enum_) => { let variant_list = enum_.variant_list()?; variant_list @@ -79,55 +87,50 @@ fn fetch_borrowed_types(node: &ast::Adt) -> Option<Vec<ast::RefType>> { } ast::Adt::Union(un) => { let record_field_list = un.record_field_list()?; - record_field_list - .fields() - .filter_map(|r_field| { - if let ast::Type::RefType(ref_type) = r_field.ty()? - && ref_type.lifetime().is_none() - { - return Some(ref_type); - } - - None - }) - .collect() + find_ref_types_from_field_list(&record_field_list.into())? } }; if ref_types.is_empty() { None } else { Some(ref_types) } } -fn find_ref_types_from_field_list(field_list: &ast::FieldList) -> Option<Vec<ast::RefType>> { - let ref_types: Vec<ast::RefType> = match field_list { - ast::FieldList::RecordFieldList(record_list) => record_list - .fields() - .filter_map(|f| { - if let ast::Type::RefType(ref_type) = f.ty()? - && ref_type.lifetime().is_none() - { - return Some(ref_type); - } - - None - }) - .collect(), - ast::FieldList::TupleFieldList(tuple_field_list) => tuple_field_list - .fields() - .filter_map(|f| { - if let ast::Type::RefType(ref_type) = f.ty()? - && ref_type.lifetime().is_none() - { - return Some(ref_type); - } - - None - }) - .collect(), +fn find_ref_types_from_field_list(field_list: &ast::FieldList) -> Option<Vec<Change>> { + let ref_types: Vec<_> = match field_list { + ast::FieldList::RecordFieldList(record_list) => { + record_list.fields().flat_map(|f| infer_lifetimes(f.syntax())).collect() + } + ast::FieldList::TupleFieldList(tuple_field_list) => { + tuple_field_list.fields().flat_map(|f| infer_lifetimes(f.syntax())).collect() + } }; if ref_types.is_empty() { None } else { Some(ref_types) } } +enum Change { + Replace(SyntaxToken), + Insert(SyntaxToken), +} + +fn infer_lifetimes(node: &SyntaxNode) -> Vec<Change> { + node.children() + .filter(|it| !matches!(it.kind(), SyntaxKind::FN_PTR_TYPE | SyntaxKind::TYPE_BOUND_LIST)) + .flat_map(|it| { + infer_lifetimes(&it) + .into_iter() + .chain(ast::Lifetime::cast(it.clone()).and_then(|lt| { + lt.lifetime_ident_token().filter(|lt| lt.text() == "'_").map(Change::Replace) + })) + .chain( + ast::RefType::cast(it) + .filter(|ty| ty.lifetime().is_none()) + .and_then(|ty| ty.amp_token()) + .map(Change::Insert), + ) + }) + .collect() +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -165,6 +168,24 @@ mod tests { } #[test] + fn add_lifetime_to_nested_types() { + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &$0i32, b: &(&i32, fn(&str) -> &str) }"#, + r#"struct Foo<'a> { a: &'a i32, b: &'a (&'a i32, fn(&str) -> &str) }"#, + ); + } + + #[test] + fn add_lifetime_to_explicit_infer_lifetime() { + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &'_ $0i32, b: &'_ (&'_ i32, fn(&str) -> &str) }"#, + r#"struct Foo<'a> { a: &'a i32, b: &'a (&'a i32, fn(&str) -> &str) }"#, + ); + } + + #[test] fn add_lifetime_to_enum() { check_assist( add_lifetime_to_type, diff --git a/crates/ide-assists/src/handlers/add_missing_impl_members.rs b/crates/ide-assists/src/handlers/add_missing_impl_members.rs index 65ca1ceae1..afdced4215 100644 --- a/crates/ide-assists/src/handlers/add_missing_impl_members.rs +++ b/crates/ide-assists/src/handlers/add_missing_impl_members.rs @@ -2537,4 +2537,84 @@ impl Test for () { "#, ); } + + #[test] + fn test_param_name_not_qualified() { + check_assist( + add_missing_impl_members, + r#" +mod ptr { + pub struct NonNull<T>(T); +} +mod alloc { + use super::ptr::NonNull; + pub trait Allocator { + unsafe fn deallocate(&self, ptr: NonNull<u8>); + } +} + +struct System; + +unsafe impl alloc::Allocator for System { + $0 +} +"#, + r#" +mod ptr { + pub struct NonNull<T>(T); +} +mod alloc { + use super::ptr::NonNull; + pub trait Allocator { + unsafe fn deallocate(&self, ptr: NonNull<u8>); + } +} + +struct System; + +unsafe impl alloc::Allocator for System { + unsafe fn deallocate(&self, ptr: ptr::NonNull<u8>) { + ${0:todo!()} + } +} +"#, + ); + } + + #[test] + fn test_param_name_shadows_module() { + check_assist( + add_missing_impl_members, + r#" +mod m { } +use m as p; + +pub trait Allocator { + fn deallocate(&self, p: u8); +} + +struct System; + +impl Allocator for System { + $0 +} +"#, + r#" +mod m { } +use m as p; + +pub trait Allocator { + fn deallocate(&self, p: u8); +} + +struct System; + +impl Allocator for System { + fn deallocate(&self, p: u8) { + ${0:todo!()} + } +} +"#, + ); + } } diff --git a/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/crates/ide-assists/src/handlers/add_missing_match_arms.rs index 10c3ff0e4d..b063e5ffce 100644 --- a/crates/ide-assists/src/handlers/add_missing_match_arms.rs +++ b/crates/ide-assists/src/handlers/add_missing_match_arms.rs @@ -3,14 +3,14 @@ use std::iter::{self, Peekable}; use either::Either; use hir::{Adt, AsAssocItem, Crate, FindPathConfig, HasAttrs, ModuleDef, Semantics}; use ide_db::RootDatabase; -use ide_db::assists::ExprFillDefaultMode; use ide_db::syntax_helpers::suggest_name; use ide_db::{famous_defs::FamousDefs, helpers::mod_path_to_ast}; use itertools::Itertools; -use syntax::ToSmolStr; -use syntax::ast::edit::{AstNodeEdit, IndentLevel}; +use syntax::ast::edit::IndentLevel; use syntax::ast::syntax_factory::SyntaxFactory; use syntax::ast::{self, AstNode, MatchArmList, MatchExpr, Pat, make}; +use syntax::syntax_editor::{Position, SyntaxEditor}; +use syntax::{SyntaxKind, SyntaxNode, ToSmolStr}; use crate::{AssistContext, AssistId, Assists, utils}; @@ -44,8 +44,7 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) let arm_list_range = ctx.sema.original_range_opt(match_arm_list.syntax())?; if cursor_at_trivial_match_arm_list(ctx, &match_expr, &match_arm_list).is_none() { - let arm_list_range = ctx.sema.original_range(match_arm_list.syntax()).range; - let cursor_in_range = arm_list_range.contains_range(ctx.selection_trimmed()); + let cursor_in_range = arm_list_range.range.contains_range(ctx.selection_trimmed()); if cursor_in_range { cov_mark::hit!(not_applicable_outside_of_range_right); return None; @@ -81,9 +80,16 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) let module = scope.module(); let cfg = ctx.config.find_path_config(ctx.sema.is_nightly(scope.krate())); let self_ty = if ctx.config.prefer_self_ty { - scope - .containing_function() - .and_then(|function| function.as_assoc_item(ctx.db())?.implementing_ty(ctx.db())) + scope.expression_store_owner().and_then(|def| { + match def { + hir::ExpressionStoreOwner::Body(def_with_body) => { + def_with_body.as_assoc_item(ctx.db()) + } + hir::ExpressionStoreOwner::Signature(def) => def.as_assoc_item(ctx.db()), + hir::ExpressionStoreOwner::VariantFields(_) => None, + }? + .implementing_ty(ctx.db()) + }) } else { None }; @@ -225,68 +231,26 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) // having any hidden variants means that we need a catch-all arm needs_catch_all_arm |= has_hidden_variants; - let missing_arms = missing_pats + let mut missing_arms = missing_pats .filter(|(_, hidden)| { // filter out hidden patterns because they're handled by the catch-all arm !hidden }) - .map(|(pat, _)| { - make.match_arm( - pat, - None, - match ctx.config.expr_fill_default { - ExprFillDefaultMode::Todo => make::ext::expr_todo(), - ExprFillDefaultMode::Underscore => make::ext::expr_underscore(), - ExprFillDefaultMode::Default => make::ext::expr_todo(), - }, - ) - }); - - let mut arms: Vec<_> = match_arm_list - .arms() - .filter(|arm| { - if matches!(arm.pat(), Some(ast::Pat::WildcardPat(_))) { - let is_empty_expr = arm.expr().is_none_or(|e| match e { - ast::Expr::BlockExpr(b) => { - b.statements().next().is_none() && b.tail_expr().is_none() - } - ast::Expr::TupleExpr(t) => t.fields().next().is_none(), - _ => false, - }); - if is_empty_expr { - false - } else { - cov_mark::hit!(add_missing_match_arms_empty_expr); - true - } - } else { - true - } - }) - .map(|arm| arm.reset_indent().indent(IndentLevel(1))) - .collect(); - - let first_new_arm_idx = arms.len(); - arms.extend(missing_arms); + .map(|(pat, _)| make.match_arm(pat, None, utils::expr_fill_default(ctx.config))) + .collect::<Vec<_>>(); if needs_catch_all_arm && !has_catch_all_arm { cov_mark::hit!(added_wildcard_pattern); let arm = make.match_arm( make.wildcard_pat().into(), None, - match ctx.config.expr_fill_default { - ExprFillDefaultMode::Todo => make::ext::expr_todo(), - ExprFillDefaultMode::Underscore => make::ext::expr_underscore(), - ExprFillDefaultMode::Default => make::ext::expr_todo(), - }, + utils::expr_fill_default(ctx.config), ); - arms.push(arm); + missing_arms.push(arm); } - let new_match_arm_list = make.match_arm_list(arms); - // FIXME: Hack for syntax trees not having great support for macros - // Just replace the element that the original range came from + // Just edit the element that the original range came from let old_place = { // Find the original element let file = ctx.sema.parse(arm_list_range.file_id); @@ -303,25 +267,27 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) }; let mut editor = builder.make_editor(&old_place); - let new_match_arm_list = new_match_arm_list.indent(IndentLevel::from_node(&old_place)); - editor.replace(old_place, new_match_arm_list.syntax()); + let mut arms_edit = ArmsEdit { match_arm_list, place: old_place, last_arm: None }; + + arms_edit.remove_wildcard_arms(ctx, &mut editor); + arms_edit.add_comma_after_last_arm(ctx, &make, &mut editor); + arms_edit.append_arms(&missing_arms, &make, &mut editor); if let Some(cap) = ctx.config.snippet_cap { - if let Some(it) = new_match_arm_list - .arms() - .nth(first_new_arm_idx) + if let Some(it) = missing_arms + .first() .and_then(|arm| arm.syntax().descendants().find_map(ast::WildcardPat::cast)) { editor.add_annotation(it.syntax(), builder.make_placeholder_snippet(cap)); } - for arm in new_match_arm_list.arms().skip(first_new_arm_idx) { + for arm in &missing_arms { if let Some(expr) = arm.expr() { editor.add_annotation(expr.syntax(), builder.make_placeholder_snippet(cap)); } } - if let Some(arm) = new_match_arm_list.arms().skip(first_new_arm_idx).last() { + if let Some(arm) = missing_arms.last() { editor.add_annotation(arm.syntax(), builder.make_tabstop_after(cap)); } } @@ -348,12 +314,24 @@ fn cursor_at_trivial_match_arm_list( // $0 // } if let Some(last_arm) = match_arm_list.arms().last() { - let last_arm_range = last_arm.syntax().text_range(); - let match_expr_range = match_expr.syntax().text_range(); - if last_arm_range.end() <= ctx.offset() && ctx.offset() < match_expr_range.end() { + let last_node = match last_arm.expr() { + Some(expr) => expr.syntax().clone(), + None => last_arm.syntax().clone(), + }; + let last_node_range = ctx.sema.original_range_opt(&last_node)?.range; + let match_expr_range = ctx.sema.original_range_opt(match_expr.syntax())?.range; + if last_node_range.end() <= ctx.offset() && ctx.offset() < match_expr_range.end() { cov_mark::hit!(add_missing_match_arms_end_of_last_arm); return Some(()); } + + if ast::Expr::cast(last_node.clone()).is_some_and(is_empty_expr) + && last_node_range.contains(ctx.offset()) + && !last_node.text().contains_char('\n') + { + cov_mark::hit!(add_missing_match_arms_end_of_last_empty_arm); + return Some(()); + } } // match { _$0 => {...} } @@ -368,10 +346,113 @@ fn cursor_at_trivial_match_arm_list( None } +struct ArmsEdit { + match_arm_list: MatchArmList, + place: SyntaxNode, + last_arm: Option<ast::MatchArm>, +} + +impl ArmsEdit { + fn remove_wildcard_arms(&mut self, ctx: &AssistContext<'_>, editor: &mut SyntaxEditor) { + for arm in self.match_arm_list.arms() { + if !matches!(arm.pat(), Some(Pat::WildcardPat(_))) { + self.last_arm = Some(arm); + continue; + } + if !arm.expr().is_none_or(is_empty_expr) { + cov_mark::hit!(add_missing_match_arms_empty_expr); + self.last_arm = Some(arm); + continue; + } + let Some(range) = self.cover_edit_range(ctx, &arm) else { continue }; + + let prev = match range.start() { + syntax::NodeOrToken::Node(node) => { + node.first_token().and_then(|it| it.prev_token()) + } + syntax::NodeOrToken::Token(tok) => tok.prev_token(), + }; + if let Some(prev) = prev + && prev.kind() == SyntaxKind::WHITESPACE + { + editor.delete(prev); + } + + editor.delete_all(range); + } + } + + fn append_arms(&self, arms: &[ast::MatchArm], make: &SyntaxFactory, editor: &mut SyntaxEditor) { + let Some(mut before) = self.place.last_token() else { + stdx::never!("match arm list not contain any token"); + return; + }; + if let Some(prev) = before.prev_token() + && prev.kind() == SyntaxKind::WHITESPACE + { + before = prev; + } + let open_curly = + !self.place.text().contains_char('\n') || before.kind() == SyntaxKind::WHITESPACE; + let indent = IndentLevel::from_node(&self.place); + let arm_indent = indent + 1; + let indent = make.whitespace(&format!("\n{indent}")); + let arm_indent = make.whitespace(&format!("\n{arm_indent}")); + let elements = arms + .iter() + .flat_map(|arm| [arm_indent.clone().into(), arm.syntax().clone().into()]) + .chain(open_curly.then(|| indent.clone().into())) + .collect(); + + if before.kind() == SyntaxKind::WHITESPACE { + editor.replace_with_many(before, elements); + } else { + editor.insert_all(Position::before(before), elements); + } + } + + fn add_comma_after_last_arm( + &self, + ctx: &AssistContext<'_>, + make: &SyntaxFactory, + editor: &mut SyntaxEditor, + ) { + if let Some(last_arm) = &self.last_arm + && last_arm.comma_token().is_none() + && last_arm.expr().is_none_or(|it| !it.is_block_like()) + && let Some(range) = self.cover_edit_range(ctx, last_arm) + { + editor.insert(Position::after(range.end()), make.token(syntax::T![,])); + } + } + + fn cover_edit_range( + &self, + ctx: &AssistContext<'_>, + node: &impl AstNode, + ) -> Option<std::ops::RangeInclusive<syntax::SyntaxElement>> { + let range = ctx.sema.original_range_opt(node.syntax())?; + + if !self.place.text_range().contains_range(range.range) { + return None; + } + + Some(utils::cover_edit_range(&self.place, range.range)) + } +} + fn is_variant_missing(existing_pats: &[Pat], var: &Pat) -> bool { !existing_pats.iter().any(|pat| does_pat_match_variant(pat, var)) } +fn is_empty_expr(e: ast::Expr) -> bool { + match e { + ast::Expr::BlockExpr(b) => b.statements().next().is_none() && b.tail_expr().is_none(), + ast::Expr::TupleExpr(t) => t.fields().next().is_none(), + _ => false, + } +} + // Fixme: this is still somewhat limited, use hir_ty::diagnostics::match_check? fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool { match (pat, var) { @@ -397,7 +478,7 @@ enum ExtendedEnum { enum ExtendedVariant { True, False, - Variant { variant: hir::Variant, use_self: bool }, + Variant { variant: hir::EnumVariant, use_self: bool }, } impl ExtendedVariant { @@ -1067,7 +1148,7 @@ fn main() { #[test] fn add_missing_match_arms_end_of_last_arm() { - cov_mark::check!(add_missing_match_arms_end_of_last_arm); + cov_mark::check_count!(add_missing_match_arms_end_of_last_arm, 2); check_assist( add_missing_match_arms, r#" @@ -1098,6 +1179,103 @@ fn main() { } "#, ); + + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => 2$0, + } +} +"#, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => 2, + (A::One, B::One) => ${1:todo!()}, + (A::One, B::Two) => ${2:todo!()}, + (A::Two, B::Two) => ${3:todo!()},$0 + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_end_of_last_empty_arm() { + cov_mark::check_count!(add_missing_match_arms_end_of_last_empty_arm, 2); + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => {$0} + } +} +"#, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => {} + (A::One, B::One) => ${1:todo!()}, + (A::One, B::Two) => ${2:todo!()}, + (A::Two, B::Two) => ${3:todo!()},$0 + } +} +"#, + ); + + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => ($0) + } +} +"#, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => (), + (A::One, B::One) => ${1:todo!()}, + (A::One, B::Two) => ${2:todo!()}, + (A::Two, B::Two) => ${3:todo!()},$0 + } +} +"#, + ); } #[test] @@ -1614,6 +1792,38 @@ fn foo(t: Test) { }); }"#, ); + + check_assist( + add_missing_match_arms, + r#" +macro_rules! m { ($expr:expr) => {$expr}} +enum Test { + A, + B, + C, +} + +fn foo(t: Test) { + m!(match t { + Test::A => (), + $0}); +}"#, + r#" +macro_rules! m { ($expr:expr) => {$expr}} +enum Test { + A, + B, + C, +} + +fn foo(t: Test) { + m!(match t { + Test::A => (), + Test::B => ${1:todo!()}, + Test::C => ${2:todo!()},$0 + }); +}"#, + ); } #[test] @@ -2047,6 +2257,35 @@ fn foo(t: E) { } #[test] + fn keep_comments() { + check_assist( + add_missing_match_arms, + r#" +enum E { A, B, C } + +fn foo(t: E) -> i32 { + match $0t { + // variant a + E::A => 2 + // comment on end + } +}"#, + r#" +enum E { A, B, C } + +fn foo(t: E) -> i32 { + match t { + // variant a + E::A => 2, + // comment on end + E::B => ${1:todo!()}, + E::C => ${2:todo!()},$0 + } +}"#, + ); + } + + #[test] fn not_applicable_when_match_arm_list_cannot_be_upmapped() { check_assist_not_applicable( add_missing_match_arms, diff --git a/crates/ide-assists/src/handlers/add_turbo_fish.rs b/crates/ide-assists/src/handlers/add_turbo_fish.rs index be13b04873..c5e722d87e 100644 --- a/crates/ide-assists/src/handlers/add_turbo_fish.rs +++ b/crates/ide-assists/src/handlers/add_turbo_fish.rs @@ -2,7 +2,7 @@ use either::Either; use ide_db::defs::{Definition, NameRefClass}; use syntax::{ AstNode, - ast::{self, HasArgList, HasGenericArgs, make, syntax_factory::SyntaxFactory}, + ast::{self, HasArgList, HasGenericArgs, syntax_factory::SyntaxFactory}, syntax_editor::Position, }; @@ -94,20 +94,21 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti ident.text_range(), |builder| { let mut editor = builder.make_editor(let_stmt.syntax()); + let make = SyntaxFactory::without_mappings(); if let_stmt.semicolon_token().is_none() { editor.insert( Position::last_child_of(let_stmt.syntax()), - make::tokens::semicolon(), + make.token(syntax::SyntaxKind::SEMICOLON), ); } - let placeholder_ty = make::ty_placeholder().clone_for_update(); + let placeholder_ty = make.ty_placeholder(); if let Some(pat) = let_stmt.pat() { let elements = vec![ - make::token(syntax::SyntaxKind::COLON).into(), - make::token(syntax::SyntaxKind::WHITESPACE).into(), + make.token(syntax::SyntaxKind::COLON).into(), + make.whitespace(" ").into(), placeholder_ty.syntax().clone().into(), ]; editor.insert_all(Position::after(pat.syntax()), elements); @@ -188,7 +189,7 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti /// This will create a turbofish generic arg list corresponding to the number of arguments fn get_fish_head(make: &SyntaxFactory, number_of_arguments: usize) -> ast::GenericArgList { - let args = (0..number_of_arguments).map(|_| make::type_arg(make::ty_placeholder()).into()); + let args = (0..number_of_arguments).map(|_| make.type_arg(make.ty_placeholder()).into()); make.generic_arg_list(args, true) } diff --git a/crates/ide-assists/src/handlers/apply_demorgan.rs b/crates/ide-assists/src/handlers/apply_demorgan.rs index 80d0a6da12..4ee4970248 100644 --- a/crates/ide-assists/src/handlers/apply_demorgan.rs +++ b/crates/ide-assists/src/handlers/apply_demorgan.rs @@ -197,7 +197,7 @@ pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_> let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?; let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None }; - let closure_body = closure_expr.body()?.clone_for_update(); + let closure_body = closure_expr.body()?; let op_range = method_call.syntax().text_range(); let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str()); diff --git a/crates/ide-assists/src/handlers/auto_import.rs b/crates/ide-assists/src/handlers/auto_import.rs index cc2bf81749..de5dfdf4d9 100644 --- a/crates/ide-assists/src/handlers/auto_import.rs +++ b/crates/ide-assists/src/handlers/auto_import.rs @@ -155,6 +155,7 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option< &scope, mod_path_to_ast(&import_path, edition), &ctx.config.insert_use, + edition, ); }, ); @@ -267,7 +268,7 @@ pub(crate) fn relevance_score( hir::Adt::Union(it) => it.ty(ctx.db()), hir::Adt::Enum(it) => it.ty(ctx.db()), }), - hir::ModuleDef::Variant(variant) => Some(variant.constructor_ty(ctx.db())), + hir::ModuleDef::EnumVariant(variant) => Some(variant.constructor_ty(ctx.db())), hir::ModuleDef::Const(it) => Some(it.ty(ctx.db())), hir::ModuleDef::Static(it) => Some(it.ty(ctx.db())), hir::ModuleDef::TypeAlias(it) => Some(it.ty(ctx.db())), @@ -351,7 +352,7 @@ mod tests { let config = TEST_CONFIG; let ctx = AssistContext::new(sema, &config, frange); let mut acc = Assists::new(&ctx, AssistResolveStrategy::All); - auto_import(&mut acc, &ctx); + hir::attach_db(&db, || auto_import(&mut acc, &ctx)); let assists = acc.finish(); let labels = assists.iter().map(|assist| assist.label.to_string()).collect::<Vec<_>>(); @@ -1896,4 +1897,35 @@ fn foo(_: S) {} "#, ); } + + #[test] + fn with_after_segments() { + let before = r#" +mod foo { + pub mod wanted { + pub fn abc() {} + } +} + +mod bar { + pub mod wanted {} +} + +mod baz { + pub fn wanted() {} +} + +mod quux { + pub struct wanted; +} +impl quux::wanted { + fn abc() {} +} + +fn f() { + wanted$0::abc; +} + "#; + check_auto_import_order(before, &["Import `foo::wanted`", "Import `quux::wanted`"]); + } } diff --git a/crates/ide-assists/src/handlers/bind_unused_param.rs b/crates/ide-assists/src/handlers/bind_unused_param.rs index 771e80bb92..0e85a77822 100644 --- a/crates/ide-assists/src/handlers/bind_unused_param.rs +++ b/crates/ide-assists/src/handlers/bind_unused_param.rs @@ -2,7 +2,7 @@ use crate::assist_context::{AssistContext, Assists}; use ide_db::{LineIndexDatabase, assists::AssistId, defs::Definition}; use syntax::{ AstNode, - ast::{self, HasName, edit_in_place::Indent}, + ast::{self, HasName, edit::AstNodeEdit}, }; // Assist: bind_unused_param diff --git a/crates/ide-assists/src/handlers/convert_bool_then.rs b/crates/ide-assists/src/handlers/convert_bool_then.rs index 91cee59ad8..b3bfe5b8c4 100644 --- a/crates/ide-assists/src/handlers/convert_bool_then.rs +++ b/crates/ide-assists/src/handlers/convert_bool_then.rs @@ -102,6 +102,7 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> ast::Expr::BlockExpr(block) => unwrap_trivial_block(block), e => e, }; + let cond = if invert_cond { invert_boolean_expression(&make, cond) } else { cond }; let parenthesize = matches!( cond, @@ -123,11 +124,7 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> | ast::Expr::WhileExpr(_) | ast::Expr::YieldExpr(_) ); - let cond = if invert_cond { - invert_boolean_expression(&make, cond) - } else { - cond.clone_for_update() - }; + let cond = if parenthesize { make.expr_paren(cond).into() } else { cond }; let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into())); let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list); @@ -240,7 +237,7 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_> fn option_variants( sema: &Semantics<'_, RootDatabase>, expr: &SyntaxNode, -) -> Option<(hir::Variant, hir::Variant)> { +) -> Option<(hir::EnumVariant, hir::EnumVariant)> { let fam = FamousDefs(sema, sema.scope(expr)?.krate()); let option_variants = fam.core_option_Option()?.variants(sema.db); match &*option_variants { @@ -257,7 +254,7 @@ fn option_variants( /// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call. fn is_invalid_body( sema: &Semantics<'_, RootDatabase>, - some_variant: hir::Variant, + some_variant: hir::EnumVariant, expr: &ast::Expr, ) -> bool { let mut invalid = false; @@ -280,7 +277,7 @@ fn is_invalid_body( && let Some(ast::Expr::PathExpr(p)) = call.expr() { let res = p.path().and_then(|p| sema.resolve_path(&p)); - if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res { + if let Some(hir::PathResolution::Def(hir::ModuleDef::EnumVariant(v))) = res { return invalid |= v != some_variant; } } @@ -293,11 +290,11 @@ fn is_invalid_body( fn block_is_none_variant( sema: &Semantics<'_, RootDatabase>, block: &ast::BlockExpr, - none_variant: hir::Variant, + none_variant: hir::EnumVariant, ) -> bool { block_as_lone_tail(block).and_then(|e| match e { ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? { - hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v), + hir::PathResolution::Def(hir::ModuleDef::EnumVariant(v)) => Some(v), _ => None, }, _ => None, @@ -591,4 +588,23 @@ fn main() { ", ); } + #[test] + fn convert_if_to_bool_then_invert_method_call() { + check_assist( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + let test = &[()]; + let value = if$0 test.is_empty() { None } else { Some(()) }; +} +", + r" +fn main() { + let test = &[()]; + let value = (!test.is_empty()).then(|| ()); +} +", + ); + } } diff --git a/crates/ide-assists/src/handlers/convert_bool_to_enum.rs b/crates/ide-assists/src/handlers/convert_bool_to_enum.rs index 1ae5f64917..e88778a62e 100644 --- a/crates/ide-assists/src/handlers/convert_bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/convert_bool_to_enum.rs @@ -11,9 +11,11 @@ use ide_db::{ source_change::SourceChangeBuilder, }; use itertools::Itertools; +use syntax::ast::edit::AstNodeEdit; +use syntax::ast::syntax_factory::SyntaxFactory; use syntax::{ AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T, - ast::{self, HasName, edit::IndentLevel, edit_in_place::Indent, make}, + ast::{self, HasName, edit::IndentLevel}, }; use crate::{ @@ -61,19 +63,28 @@ pub(crate) fn convert_bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) - "Convert boolean to enum", target, |edit| { + let make = SyntaxFactory::without_mappings(); if let Some(ty) = &ty_annotation { cov_mark::hit!(replaces_ty_annotation); edit.replace(ty.syntax().text_range(), "Bool"); } if let Some(initializer) = initializer { - replace_bool_expr(edit, initializer); + replace_bool_expr(edit, initializer, &make); } let usages = definition.usages(&ctx.sema).all(); - add_enum_def(edit, ctx, &usages, target_node, &target_module); + add_enum_def(edit, ctx, &usages, target_node, &target_module, &make); let mut delayed_mutations = Vec::new(); - replace_usages(edit, ctx, usages, definition, &target_module, &mut delayed_mutations); + replace_usages( + edit, + ctx, + usages, + definition, + &target_module, + &mut delayed_mutations, + &make, + ); for (scope, path) in delayed_mutations { insert_use(&scope, path, &ctx.config.insert_use); } @@ -167,16 +178,16 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> { } } -fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) { +fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr, make: &SyntaxFactory) { let expr_range = expr.syntax().text_range(); - let enum_expr = bool_expr_to_enum_expr(expr); + let enum_expr = bool_expr_to_enum_expr(expr, make); edit.replace(expr_range, enum_expr.syntax().text()) } /// Converts an expression of type `bool` to one of the new enum type. -fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr { - let true_expr = make::expr_path(make::path_from_text("Bool::True")); - let false_expr = make::expr_path(make::path_from_text("Bool::False")); +fn bool_expr_to_enum_expr(expr: ast::Expr, make: &SyntaxFactory) -> ast::Expr { + let true_expr = make.expr_path(make.path_from_text("Bool::True")); + let false_expr = make.expr_path(make.path_from_text("Bool::False")); if let ast::Expr::Literal(literal) = &expr { match literal.kind() { @@ -185,10 +196,10 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr { _ => expr, } } else { - make::expr_if( + make.expr_if( expr, - make::tail_only_block_expr(true_expr), - Some(ast::ElseBranch::Block(make::tail_only_block_expr(false_expr))), + make.tail_only_block_expr(true_expr), + Some(ast::ElseBranch::Block(make.tail_only_block_expr(false_expr))), ) .into() } @@ -202,11 +213,13 @@ fn replace_usages( target_definition: Definition, target_module: &hir::Module, delayed_mutations: &mut Vec<(ImportScope, ast::Path)>, + make: &SyntaxFactory, ) { for (file_id, references) in usages { edit.edit_file(file_id.file_id(ctx.db())); - let refs_with_imports = augment_references_with_imports(ctx, references, target_module); + let refs_with_imports = + augment_references_with_imports(ctx, references, target_module, make); refs_with_imports.into_iter().rev().for_each( |FileReferenceWithImport { range, name, import_data }| { @@ -223,12 +236,13 @@ fn replace_usages( target_definition, target_module, delayed_mutations, + make, ) } } else if let Some(initializer) = find_assignment_usage(&name) { cov_mark::hit!(replaces_assignment); - replace_bool_expr(edit, initializer); + replace_bool_expr(edit, initializer, make); } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) { cov_mark::hit!(replaces_negation); @@ -246,7 +260,7 @@ fn replace_usages( { cov_mark::hit!(replaces_record_expr); - let enum_expr = bool_expr_to_enum_expr(initializer); + let enum_expr = bool_expr_to_enum_expr(initializer, make); utils::replace_record_field_expr(ctx, edit, record_field, enum_expr); } else if let Some(pat) = find_record_pat_field_usage(&name) { match pat { @@ -262,6 +276,7 @@ fn replace_usages( target_definition, target_module, delayed_mutations, + make, ) } } @@ -271,14 +286,14 @@ fn replace_usages( if let Some(expr) = literal_pat.literal().and_then(|literal| { literal.syntax().ancestors().find_map(ast::Expr::cast) }) { - replace_bool_expr(edit, expr); + replace_bool_expr(edit, expr, make); } } _ => (), } } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) { edit.replace(ty_annotation.syntax().text_range(), "Bool"); - replace_bool_expr(edit, initializer); + replace_bool_expr(edit, initializer, make); } else if let Some(receiver) = find_method_call_expr_usage(&name) { edit.replace( receiver.syntax().text_range(), @@ -295,10 +310,10 @@ fn replace_usages( ctx, edit, record_field, - make::expr_bin_op( + make.expr_bin_op( expr, ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }), - make::expr_path(make::path_from_text("Bool::True")), + make.expr_path(make.path_from_text("Bool::True")), ), ); } else { @@ -326,6 +341,7 @@ fn augment_references_with_imports( ctx: &AssistContext<'_>, references: Vec<FileReference>, target_module: &hir::Module, + make: &SyntaxFactory, ) -> Vec<FileReferenceWithImport> { let mut visited_modules = FxHashSet::default(); @@ -356,9 +372,9 @@ fn augment_references_with_imports( cfg, ) .map(|mod_path| { - make::path_concat( + make.path_concat( mod_path_to_ast(&mod_path, edition), - make::path_from_text("Bool"), + make.path_from_text("Bool"), ) })?; @@ -457,6 +473,7 @@ fn add_enum_def( usages: &UsageSearchResult, target_node: SyntaxNode, target_module: &hir::Module, + make: &SyntaxFactory, ) -> Option<()> { let insert_before = node_to_insert_before(target_node); @@ -479,10 +496,9 @@ fn add_enum_def( ctx.sema.scope(name.syntax()).map(|scope| scope.module()) }) .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); - let enum_def = make_bool_enum(make_enum_pub); let indent = IndentLevel::from_node(&insert_before); - enum_def.reindent_to(indent); + let enum_def = make_bool_enum(make_enum_pub, make).reset_indent().indent(indent); edit.insert( insert_before.text_range().start(), @@ -504,31 +520,30 @@ fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode { .unwrap_or(target_node) } -fn make_bool_enum(make_pub: bool) -> ast::Enum { - let derive_eq = make::attr_outer(make::meta_token_tree( - make::ext::ident_path("derive"), - make::token_tree( +fn make_bool_enum(make_pub: bool, make: &SyntaxFactory) -> ast::Enum { + let derive_eq = make.attr_outer(make.meta_token_tree( + make.ident_path("derive"), + make.token_tree( T!['('], vec![ - NodeOrToken::Token(make::tokens::ident("PartialEq")), - NodeOrToken::Token(make::token(T![,])), - NodeOrToken::Token(make::tokens::single_space()), - NodeOrToken::Token(make::tokens::ident("Eq")), + NodeOrToken::Token(make.ident("PartialEq")), + NodeOrToken::Token(make.token(T![,])), + NodeOrToken::Token(make.whitespace(" ")), + NodeOrToken::Token(make.ident("Eq")), ], ), )); - make::enum_( + make.item_enum( [derive_eq], - if make_pub { Some(make::visibility_pub()) } else { None }, - make::name("Bool"), + if make_pub { Some(make.visibility_pub()) } else { None }, + make.name("Bool"), None, None, - make::variant_list(vec![ - make::variant(None, make::name("True"), None, None), - make::variant(None, make::name("False"), None, None), + make.variant_list(vec![ + make.variant(None, make.name("True"), None, None), + make.variant(None, make.name("False"), None, None), ]), ) - .clone_for_update() } #[cfg(test)] diff --git a/crates/ide-assists/src/handlers/convert_closure_to_fn.rs b/crates/ide-assists/src/handlers/convert_closure_to_fn.rs index ca142332d9..9f9ced98d7 100644 --- a/crates/ide-assists/src/handlers/convert_closure_to_fn.rs +++ b/crates/ide-assists/src/handlers/convert_closure_to_fn.rs @@ -132,7 +132,7 @@ pub(crate) fn convert_closure_to_fn(acc: &mut Assists, ctx: &AssistContext<'_>) ); } - if block.try_token().is_none() + if block.try_block_modifier().is_none() && block.unsafe_token().is_none() && block.label().is_none() && block.const_token().is_none() @@ -220,7 +220,7 @@ pub(crate) fn convert_closure_to_fn(acc: &mut Assists, ctx: &AssistContext<'_>) } let body = if wrap_body_in_block { - make::block_expr([], Some(body)) + make::block_expr([], Some(body.reset_indent().indent(1.into()))) } else { ast::BlockExpr::cast(body.syntax().clone()).unwrap() }; @@ -971,6 +971,32 @@ fn foo() { } "#, ); + check_assist( + convert_closure_to_fn, + r#" +//- minicore: copy +fn foo() { + { + let closure = |$0| match () { + () => {}, + }; + closure(); + } +} +"#, + r#" +fn foo() { + { + fn closure() { + match () { + () => {}, + } + } + closure(); + } +} +"#, + ); } #[test] diff --git a/crates/ide-assists/src/handlers/convert_for_to_while_let.rs b/crates/ide-assists/src/handlers/convert_for_to_while_let.rs index d64e9ceda2..15f324eff3 100644 --- a/crates/ide-assists/src/handlers/convert_for_to_while_let.rs +++ b/crates/ide-assists/src/handlers/convert_for_to_while_let.rs @@ -2,7 +2,7 @@ use hir::{Name, sym}; use ide_db::{famous_defs::FamousDefs, syntax_helpers::suggest_name}; use syntax::{ AstNode, - ast::{self, HasAttrs, HasLoopBody, edit::IndentLevel, make, syntax_factory::SyntaxFactory}, + ast::{self, HasAttrs, HasLoopBody, edit::IndentLevel, syntax_factory::SyntaxFactory}, syntax_editor::Position, }; @@ -24,8 +24,8 @@ use crate::{AssistContext, AssistId, Assists}; // ``` // fn main() { // let x = vec![1, 2, 3]; -// let mut tmp = x.into_iter(); -// while let Some(v) = tmp.next() { +// let mut iter = x.into_iter(); +// while let Some(v) = iter.next() { // let y = v * 2; // }; // } @@ -57,13 +57,13 @@ pub(crate) fn convert_for_loop_to_while_let( { (expr, Some(make.name_ref(method.as_str()))) } else if let ast::Expr::RefExpr(_) = iterable { - (make::expr_paren(iterable).into(), Some(make.name_ref("into_iter"))) + (make.expr_paren(iterable).into(), Some(make.name_ref("into_iter"))) } else { (iterable, Some(make.name_ref("into_iter"))) }; let iterable = if let Some(method) = method { - make::expr_method_call(iterable, method, make::arg_list([])).into() + make.expr_method_call(iterable, method, make.arg_list([])).into() } else { iterable }; @@ -71,7 +71,7 @@ pub(crate) fn convert_for_loop_to_while_let( let mut new_name = suggest_name::NameGenerator::new_from_scope_locals( ctx.sema.scope(for_loop.syntax()), ); - let tmp_var = new_name.suggest_name("tmp"); + let tmp_var = new_name.suggest_name("iter"); let mut_expr = make.let_stmt( make.ident_pat(false, true, make.name(&tmp_var)).into(), @@ -89,17 +89,18 @@ pub(crate) fn convert_for_loop_to_while_let( for_loop.syntax(), &mut editor, for_loop.attrs().map(|it| it.clone_for_update()), + &make, ); editor.insert( Position::before(for_loop.syntax()), - make::tokens::whitespace(format!("\n{indent}").as_str()), + make.whitespace(format!("\n{indent}").as_str()), ); editor.insert(Position::before(for_loop.syntax()), mut_expr.syntax()); - let opt_pat = make.tuple_struct_pat(make::ext::ident_path("Some"), [pat]); + let opt_pat = make.tuple_struct_pat(make.ident_path("Some"), [pat]); let iter_next_expr = make.expr_method_call( - make.expr_path(make::ext::ident_path(&tmp_var)), + make.expr_path(make.ident_path(&tmp_var)), make.name_ref("next"), make.arg_list([]), ); @@ -187,8 +188,8 @@ fn main() { r" fn main() { let mut x = vec![1, 2, 3]; - let mut tmp = x.into_iter(); - while let Some(v) = tmp.next() { + let mut iter = x.into_iter(); + while let Some(v) = iter.next() { v *= 2; }; }", @@ -210,8 +211,8 @@ fn main() { r" fn main() { let mut x = vec![1, 2, 3]; - let mut tmp = x.into_iter(); - 'a: while let Some(v) = tmp.next() { + let mut iter = x.into_iter(); + 'a: while let Some(v) = iter.next() { v *= 2; break 'a; }; @@ -235,10 +236,10 @@ fn main() { r" fn main() { let mut x = vec![1, 2, 3]; - let mut tmp = x.into_iter(); + let mut iter = x.into_iter(); #[allow(unused)] #[deny(unsafe_code)] - while let Some(v) = tmp.next() { + while let Some(v) = iter.next() { v *= 2; }; }", @@ -274,8 +275,8 @@ impl<T> core::iter::Iterator for core::ops::Range<T> { } fn main() { - let mut tmp = 0..92; - while let Some(x) = tmp.next() { + let mut iter = 0..92; + while let Some(x) = iter.next() { print!("{}", x); } }"#, @@ -329,8 +330,8 @@ impl S { fn main() { let x = S; - let mut tmp = x.iter(); - while let Some(v) = tmp.next() { + let mut iter = x.iter(); + while let Some(v) = iter.next() { let a = v * 2; } } @@ -355,8 +356,8 @@ fn main() { struct NoIterMethod; fn main() { let x = NoIterMethod; - let mut tmp = (&x).into_iter(); - while let Some(v) = tmp.next() { + let mut iter = (&x).into_iter(); + while let Some(v) = iter.next() { let a = v * 2; } } @@ -381,8 +382,8 @@ fn main() { struct NoIterMethod; fn main() { let x = NoIterMethod; - let mut tmp = (&mut x).into_iter(); - while let Some(v) = tmp.next() { + let mut iter = (&mut x).into_iter(); + while let Some(v) = iter.next() { let a = v * 2; } } @@ -422,8 +423,8 @@ impl S { fn main() { let x = S; - let mut tmp = x.iter_mut(); - while let Some(v) = tmp.next() { + let mut iter = x.iter_mut(); + while let Some(v) = iter.next() { let a = v * 2; } } @@ -447,8 +448,8 @@ fn main() { fn main() { let mut x = vec![1, 2, 3]; let y = &mut x; - let mut tmp = y.into_iter(); - while let Some(v) = tmp.next() { + let mut iter = y.into_iter(); + while let Some(v) = iter.next() { *v *= 2; } }", @@ -470,8 +471,8 @@ fn main() { "#, r#" fn main() { - let mut tmp = core::iter::repeat(92).take(1); - while let Some(a) = tmp.next() { + let mut iter = core::iter::repeat(92).take(1); + while let Some(a) = iter.next() { println!("{}", a); } } diff --git a/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs b/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs index 6a74d21451..66ccd2a4e4 100644 --- a/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs +++ b/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs @@ -1,6 +1,6 @@ use ide_db::{famous_defs::FamousDefs, traits::resolve_target_trait}; use syntax::ast::edit::IndentLevel; -use syntax::ast::{self, AstNode, HasGenericArgs, HasName, make}; +use syntax::ast::{self, AstNode, HasGenericArgs, HasName, syntax_factory::SyntaxFactory}; use syntax::syntax_editor::{Element, Position}; use crate::{AssistContext, AssistId, Assists}; @@ -74,36 +74,25 @@ pub(crate) fn convert_from_to_tryfrom(acc: &mut Assists, ctx: &AssistContext<'_> "Convert From to TryFrom", impl_.syntax().text_range(), |builder| { + let make = SyntaxFactory::with_mappings(); let mut editor = builder.make_editor(impl_.syntax()); - editor.replace( - trait_ty.syntax(), - make::ty(&format!("TryFrom<{from_type}>")).syntax().clone_for_update(), - ); + + editor.replace(trait_ty.syntax(), make.ty(&format!("TryFrom<{from_type}>")).syntax()); editor.replace( from_fn_return_type.syntax(), - make::ty("Result<Self, Self::Error>").syntax().clone_for_update(), - ); - editor - .replace(from_fn_name.syntax(), make::name("try_from").syntax().clone_for_update()); - editor.replace( - tail_expr.syntax(), - wrap_ok(tail_expr.clone()).syntax().clone_for_update(), + make.ty("Result<Self, Self::Error>").syntax(), ); + editor.replace(from_fn_name.syntax(), make.name("try_from").syntax()); + editor.replace(tail_expr.syntax(), wrap_ok(&make, tail_expr.clone()).syntax()); for r in return_exprs { - let t = r.expr().unwrap_or_else(make::ext::expr_unit); - editor.replace(t.syntax(), wrap_ok(t.clone()).syntax().clone_for_update()); + let t = r.expr().unwrap_or_else(|| make.expr_unit()); + editor.replace(t.syntax(), wrap_ok(&make, t.clone()).syntax()); } - let error_type = ast::AssocItem::TypeAlias(make::ty_alias( - None, - "Error", - None, - None, - None, - Some((make::ty_unit(), None)), - )) - .clone_for_update(); + let error_type_alias = + make.ty_alias(None, "Error", None, None, None, Some((make.ty("()"), None))); + let error_type = ast::AssocItem::TypeAlias(error_type_alias); if let Some(cap) = ctx.config.snippet_cap && let ast::AssocItem::TypeAlias(type_alias) = &error_type @@ -117,22 +106,19 @@ pub(crate) fn convert_from_to_tryfrom(acc: &mut Assists, ctx: &AssistContext<'_> editor.insert_all( Position::after(associated_l_curly), vec![ - make::tokens::whitespace(&format!("\n{indent}")).syntax_element(), + make.whitespace(&format!("\n{indent}")).syntax_element(), error_type.syntax().syntax_element(), - make::tokens::whitespace("\n").syntax_element(), + make.whitespace("\n").syntax_element(), ], ); + editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } -fn wrap_ok(expr: ast::Expr) -> ast::Expr { - make::expr_call( - make::expr_path(make::ext::ident_path("Ok")), - make::arg_list(std::iter::once(expr)), - ) - .into() +fn wrap_ok(make: &SyntaxFactory, expr: ast::Expr) -> ast::Expr { + make.expr_call(make.expr_path(make.path_from_text("Ok")), make.arg_list([expr])).into() } #[cfg(test)] diff --git a/crates/ide-assists/src/handlers/convert_let_else_to_match.rs b/crates/ide-assists/src/handlers/convert_let_else_to_match.rs index ebfed9f9ca..d2336a4a5d 100644 --- a/crates/ide-assists/src/handlers/convert_let_else_to_match.rs +++ b/crates/ide-assists/src/handlers/convert_let_else_to_match.rs @@ -1,7 +1,6 @@ use syntax::T; use syntax::ast::RangeItem; -use syntax::ast::edit::IndentLevel; -use syntax::ast::edit_in_place::Indent; +use syntax::ast::edit::AstNodeEdit; use syntax::ast::syntax_factory::SyntaxFactory; use syntax::ast::{self, AstNode, HasName, LetStmt, Pat}; @@ -93,7 +92,8 @@ pub(crate) fn convert_let_else_to_match(acc: &mut Assists, ctx: &AssistContext<' ); let else_arm = make.match_arm(make.wildcard_pat().into(), None, else_expr); let match_ = make.expr_match(init, make.match_arm_list([binding_arm, else_arm])); - match_.reindent_to(IndentLevel::from_node(let_stmt.syntax())); + let match_ = match_.reset_indent(); + let match_ = match_.indent(let_stmt.indent_level()); if bindings.is_empty() { editor.replace(let_stmt.syntax(), match_.syntax()); diff --git a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs index e518c39dab..aaf727058c 100644 --- a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs +++ b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs @@ -1,14 +1,18 @@ use either::Either; use ide_db::{defs::Definition, search::FileReference}; -use itertools::Itertools; use syntax::{ - SyntaxKind, - ast::{self, AstNode, HasAttrs, HasGenericParams, HasVisibility}, + NodeOrToken, SyntaxKind, SyntaxNode, T, + algo::next_non_trivia_token, + ast::{ + self, AstNode, HasAttrs, HasGenericParams, HasVisibility, syntax_factory::SyntaxFactory, + }, match_ast, - syntax_editor::{Position, SyntaxEditor}, + syntax_editor::{Element, Position, SyntaxEditor}, }; -use crate::{AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder}; +use crate::{ + AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder, utils::cover_edit_range, +}; // Assist: convert_named_struct_to_tuple_struct // @@ -81,17 +85,17 @@ pub(crate) fn convert_named_struct_to_tuple_struct( AssistId::refactor_rewrite("convert_named_struct_to_tuple_struct"), "Convert to tuple struct", strukt_or_variant.syntax().text_range(), - |edit| { - edit_field_references(ctx, edit, record_fields.fields()); - edit_struct_references(ctx, edit, strukt_def); - edit_struct_def(ctx, edit, &strukt_or_variant, record_fields); + |builder| { + edit_field_references(ctx, builder, record_fields.fields()); + edit_struct_references(ctx, builder, strukt_def); + edit_struct_def(ctx, builder, &strukt_or_variant, record_fields); }, ) } fn edit_struct_def( ctx: &AssistContext<'_>, - edit: &mut SourceChangeBuilder, + builder: &mut SourceChangeBuilder, strukt: &Either<ast::Struct, ast::Variant>, record_fields: ast::RecordFieldList, ) { @@ -108,24 +112,23 @@ fn edit_struct_def( let field = ast::TupleField::cast(field_syntax)?; Some(field) }); - let tuple_fields = ast::make::tuple_field_list(tuple_fields); - let record_fields_text_range = record_fields.syntax().text_range(); - edit.edit_file(ctx.vfs_file_id()); - edit.replace(record_fields_text_range, tuple_fields.syntax().text()); + let make = SyntaxFactory::without_mappings(); + let mut edit = builder.make_editor(strukt.syntax()); + + let tuple_fields = make.tuple_field_list(tuple_fields); + let mut elements = vec![tuple_fields.syntax().clone().into()]; if let Either::Left(strukt) = strukt { if let Some(w) = strukt.where_clause() { - let mut where_clause = w.to_string(); - if where_clause.ends_with(',') { - where_clause.pop(); - } - where_clause.push(';'); + edit.delete(w.syntax()); - edit.delete(w.syntax().text_range()); - edit.insert(record_fields_text_range.end(), ast::make::tokens::single_newline().text()); - edit.insert(record_fields_text_range.end(), where_clause); - edit.insert(record_fields_text_range.end(), ast::make::tokens::single_newline().text()); + elements.extend([ + make.whitespace("\n").into(), + remove_trailing_comma(w).into(), + make.token(T![;]).into(), + make.whitespace("\n").into(), + ]); if let Some(tok) = strukt .generic_param_list() @@ -133,45 +136,51 @@ fn edit_struct_def( .and_then(|tok| tok.next_token()) .filter(|tok| tok.kind() == SyntaxKind::WHITESPACE) { - edit.delete(tok.text_range()); + edit.delete(tok); } } else { - edit.insert(record_fields_text_range.end(), ";"); + elements.push(make.token(T![;]).into()); } } + edit.replace_with_many(record_fields.syntax(), elements); if let Some(tok) = record_fields .l_curly_token() .and_then(|tok| tok.prev_token()) .filter(|tok| tok.kind() == SyntaxKind::WHITESPACE) { - edit.delete(tok.text_range()) + edit.delete(tok) } + + builder.add_file_edits(ctx.vfs_file_id(), edit); } fn edit_struct_references( ctx: &AssistContext<'_>, - edit: &mut SourceChangeBuilder, - strukt: Either<hir::Struct, hir::Variant>, + builder: &mut SourceChangeBuilder, + strukt: Either<hir::Struct, hir::EnumVariant>, ) { let strukt_def = match strukt { Either::Left(s) => Definition::Adt(hir::Adt::Struct(s)), - Either::Right(v) => Definition::Variant(v), + Either::Right(v) => Definition::EnumVariant(v), }; let usages = strukt_def.usages(&ctx.sema).include_self_refs().all(); for (file_id, refs) in usages { - edit.edit_file(file_id.file_id(ctx.db())); + let source = ctx.sema.parse(file_id); + let mut edit = builder.make_editor(source.syntax()); for r in refs { - process_struct_name_reference(ctx, r, edit); + process_struct_name_reference(ctx, r, &mut edit, &source); } + builder.add_file_edits(file_id.file_id(ctx.db()), edit); } } fn process_struct_name_reference( ctx: &AssistContext<'_>, r: FileReference, - edit: &mut SourceChangeBuilder, + edit: &mut SyntaxEditor, + source: &ast::SourceFile, ) -> Option<()> { // First check if it's the last semgnet of a path that directly belongs to a record // expression/pattern. @@ -192,36 +201,26 @@ fn process_struct_name_reference( match_ast! { match parent { ast::RecordPat(record_struct_pat) => { - // When we failed to get the original range for the whole struct expression node, + // When we failed to get the original range for the whole struct pattern node, // we can't provide any reasonable edit. Leave it untouched. - let file_range = ctx.sema.original_range_opt(record_struct_pat.syntax())?; - edit.replace( - file_range.range, - ast::make::tuple_struct_pat( - record_struct_pat.path()?, - record_struct_pat - .record_pat_field_list()? - .fields() - .filter_map(|pat| pat.pat()) - .chain(record_struct_pat.record_pat_field_list()? - .rest_pat() - .map(Into::into)) - ) - .to_string() + record_to_tuple_struct_like( + ctx, + source, + edit, + record_struct_pat.record_pat_field_list()?, + |it| it.fields().filter_map(|it| it.name_ref()), ); }, ast::RecordExpr(record_expr) => { - // When we failed to get the original range for the whole struct pattern node, + // When we failed to get the original range for the whole struct expression node, // we can't provide any reasonable edit. Leave it untouched. - let file_range = ctx.sema.original_range_opt(record_expr.syntax())?; - let path = record_expr.path()?; - let args = record_expr - .record_expr_field_list()? - .fields() - .filter_map(|f| f.expr()) - .join(", "); - - edit.replace(file_range.range, format!("{path}({args})")); + record_to_tuple_struct_like( + ctx, + source, + edit, + record_expr.record_expr_field_list()?, + |it| it.fields().filter_map(|it| it.name_ref()), + ); }, _ => {} } @@ -230,11 +229,67 @@ fn process_struct_name_reference( Some(()) } +fn record_to_tuple_struct_like<T, I>( + ctx: &AssistContext<'_>, + source: &ast::SourceFile, + edit: &mut SyntaxEditor, + field_list: T, + fields: impl FnOnce(&T) -> I, +) -> Option<()> +where + T: AstNode, + I: IntoIterator<Item = ast::NameRef>, +{ + let make = SyntaxFactory::without_mappings(); + let orig = ctx.sema.original_range_opt(field_list.syntax())?; + let list_range = cover_edit_range(source.syntax(), orig.range); + + let l_curly = match list_range.start() { + NodeOrToken::Node(node) => node.first_token()?, + NodeOrToken::Token(t) => t.clone(), + }; + let r_curly = match list_range.end() { + NodeOrToken::Node(node) => node.last_token()?, + NodeOrToken::Token(t) => t.clone(), + }; + + if l_curly.kind() == T!['{'] { + delete_whitespace(edit, l_curly.prev_token()); + delete_whitespace(edit, l_curly.next_token()); + edit.replace(l_curly, make.token(T!['('])); + } + if r_curly.kind() == T!['}'] { + delete_whitespace(edit, r_curly.prev_token()); + edit.replace(r_curly, make.token(T![')'])); + } + + for name_ref in fields(&field_list) { + let Some(orig) = ctx.sema.original_range_opt(name_ref.syntax()) else { continue }; + let name_range = cover_edit_range(source.syntax(), orig.range); + + if let Some(colon) = next_non_trivia_token(name_range.end().clone()) + && colon.kind() == T![:] + { + edit.delete(&colon); + edit.delete_all(name_range); + + if let Some(next) = next_non_trivia_token(colon.clone()) + && next.kind() != T!['}'] + { + // Avoid overlapping delete whitespace on `{ field: }` + delete_whitespace(edit, colon.next_token()); + } + } + } + Some(()) +} + fn edit_field_references( ctx: &AssistContext<'_>, - edit: &mut SourceChangeBuilder, + builder: &mut SourceChangeBuilder, fields: impl Iterator<Item = ast::RecordField>, ) { + let make = SyntaxFactory::without_mappings(); for (index, field) in fields.enumerate() { let field = match ctx.sema.to_def(&field) { Some(it) => it, @@ -243,19 +298,46 @@ fn edit_field_references( let def = Definition::Field(field); let usages = def.usages(&ctx.sema).all(); for (file_id, refs) in usages { - edit.edit_file(file_id.file_id(ctx.db())); + let source = ctx.sema.parse(file_id); + let mut edit = builder.make_editor(source.syntax()); + for r in refs { if let Some(name_ref) = r.name.as_name_ref() { // Only edit the field reference if it's part of a `.field` access if name_ref.syntax().parent().and_then(ast::FieldExpr::cast).is_some() { - edit.replace(r.range, index.to_string()); + edit.replace_all( + cover_edit_range(source.syntax(), r.range), + vec![make.name_ref(&index.to_string()).syntax().clone().into()], + ); } } } + + builder.add_file_edits(file_id.file_id(ctx.db()), edit); } } } +fn delete_whitespace(edit: &mut SyntaxEditor, whitespace: Option<impl Element>) { + let Some(whitespace) = whitespace else { return }; + let NodeOrToken::Token(token) = whitespace.syntax_element() else { return }; + + if token.kind() == SyntaxKind::WHITESPACE && !token.text().contains('\n') { + edit.delete(token); + } +} + +fn remove_trailing_comma(w: ast::WhereClause) -> SyntaxNode { + let w = w.syntax().clone_subtree(); + let mut editor = SyntaxEditor::new(w.clone()); + if let Some(last) = w.last_child_or_token() + && last.kind() == T![,] + { + editor.delete(last); + } + editor.finish().new_root().clone() +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -678,6 +760,102 @@ where } #[test] + fn convert_constructor_expr_uses_self() { + // regression test for #21595 + check_assist( + convert_named_struct_to_tuple_struct, + r#" +struct $0Foo { field1: u32 } +impl Foo { + fn clone(&self) -> Self { + Self { field1: self.field1 } + } +}"#, + r#" +struct Foo(u32); +impl Foo { + fn clone(&self) -> Self { + Self(self.0) + } +}"#, + ); + + check_assist( + convert_named_struct_to_tuple_struct, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +struct $0Foo { field1: u32 } +impl Foo { + fn clone(&self) -> Self { + id!(Self { field1: self.field1 }) + } +}"#, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +struct Foo(u32); +impl Foo { + fn clone(&self) -> Self { + id!(Self(self.0)) + } +}"#, + ); + } + + #[test] + fn convert_pat_uses_self() { + // regression test for #21595 + check_assist( + convert_named_struct_to_tuple_struct, + r#" +enum Foo { + $0Value { field: &'static Foo }, + Nil, +} +fn foo(foo: &Foo) { + if let Foo::Value { field: Foo::Value { field } } = foo {} +}"#, + r#" +enum Foo { + Value(&'static Foo), + Nil, +} +fn foo(foo: &Foo) { + if let Foo::Value(Foo::Value(field)) = foo {} +}"#, + ); + + check_assist( + convert_named_struct_to_tuple_struct, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +enum Foo { + $0Value { field: &'static Foo }, + Nil, +} +fn foo(foo: &Foo) { + if let id!(Foo::Value { field: Foo::Value { field } }) = foo {} +}"#, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +enum Foo { + Value(&'static Foo), + Nil, +} +fn foo(foo: &Foo) { + if let id!(Foo::Value(Foo::Value(field))) = foo {} +}"#, + ); + } + + #[test] fn not_applicable_other_than_record_variant() { check_assist_not_applicable( convert_named_struct_to_tuple_struct, @@ -1042,7 +1220,9 @@ struct Struct(i32); fn test() { id! { - let s = Struct(42); + let s = Struct( + 42, + ); let Struct(value) = s; let Struct(inner) = s; } diff --git a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs index ea5c1637b7..db45916792 100644 --- a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs +++ b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs @@ -10,14 +10,14 @@ use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - make, + syntax_factory::SyntaxFactory, }, }; use crate::{ AssistId, assist_context::{AssistContext, Assists}, - utils::{invert_boolean_expression_legacy, is_never_block}, + utils::{invert_boolean_expression, is_never_block}, }; // Assist: convert_to_guarded_return @@ -69,6 +69,7 @@ fn if_expr_to_guarded_return( acc: &mut Assists, ctx: &AssistContext<'_>, ) -> Option<()> { + let make = SyntaxFactory::without_mappings(); let else_block = match if_expr.else_branch() { Some(ast::ElseBranch::Block(block_expr)) if is_never_block(&ctx.sema, &block_expr) => { Some(block_expr) @@ -88,7 +89,7 @@ fn if_expr_to_guarded_return( return None; } - let let_chains = flat_let_chain(cond); + let let_chains = flat_let_chain(cond, &make); let then_branch = if_expr.then_branch()?; let then_block = then_branch.stmt_list()?; @@ -110,7 +111,8 @@ fn if_expr_to_guarded_return( let early_expression = else_block .or_else(|| { - early_expression(parent_container, &ctx.sema).map(ast::make::tail_only_block_expr) + early_expression(parent_container, &ctx.sema, &make) + .map(ast::make::tail_only_block_expr) })? .reset_indent(); @@ -133,6 +135,7 @@ fn if_expr_to_guarded_return( "Convert to guarded return", target, |edit| { + let make = SyntaxFactory::without_mappings(); let if_indent_level = IndentLevel::from_node(if_expr.syntax()); let replacement = let_chains.into_iter().map(|expr| { if let ast::Expr::LetExpr(let_expr) = &expr @@ -140,15 +143,15 @@ fn if_expr_to_guarded_return( { // If-let. let let_else_stmt = - make::let_else_stmt(pat, None, expr, early_expression.clone()); + make.let_else_stmt(pat, None, expr, early_expression.clone()); let let_else_stmt = let_else_stmt.indent(if_indent_level); let_else_stmt.syntax().clone() } else { // If. let new_expr = { - let then_branch = clean_stmt_block(&early_expression); - let cond = invert_boolean_expression_legacy(expr); - make::expr_if(cond, then_branch, None).indent(if_indent_level) + let then_branch = clean_stmt_block(&early_expression, &make); + let cond = invert_boolean_expression(&make, expr); + make.expr_if(cond, then_branch, None).indent(if_indent_level) }; new_expr.syntax().clone() } @@ -159,7 +162,7 @@ fn if_expr_to_guarded_return( .enumerate() .flat_map(|(i, node)| { (i != 0) - .then(|| make::tokens::whitespace(newline).into()) + .then(|| make.whitespace(newline).into()) .into_iter() .chain(node.children_with_tokens()) }) @@ -201,12 +204,13 @@ fn let_stmt_to_guarded_return( let happy_pattern = try_enum.happy_pattern(pat); let target = let_stmt.syntax().text_range(); + let make = SyntaxFactory::without_mappings(); let early_expression: ast::Expr = { let parent_block = let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?; let parent_container = parent_block.syntax().parent()?; - early_expression(parent_container, &ctx.sema)? + early_expression(parent_container, &ctx.sema, &make)? }; acc.add( @@ -215,9 +219,10 @@ fn let_stmt_to_guarded_return( target, |edit| { let let_indent_level = IndentLevel::from_node(let_stmt.syntax()); + let make = SyntaxFactory::without_mappings(); let replacement = { - let let_else_stmt = make::let_else_stmt( + let let_else_stmt = make.let_else_stmt( happy_pattern, let_stmt.ty(), expr.reset_indent(), @@ -228,6 +233,7 @@ fn let_stmt_to_guarded_return( }; let mut editor = edit.make_editor(let_stmt.syntax()); editor.replace(let_stmt.syntax(), replacement); + editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -236,38 +242,39 @@ fn let_stmt_to_guarded_return( fn early_expression( parent_container: SyntaxNode, sema: &Semantics<'_, RootDatabase>, + make: &SyntaxFactory, ) -> Option<ast::Expr> { let return_none_expr = || { - let none_expr = make::expr_path(make::ext::ident_path("None")); - make::expr_return(Some(none_expr)) + let none_expr = make.expr_path(make.ident_path("None")); + make.expr_return(Some(none_expr)) }; if let Some(fn_) = ast::Fn::cast(parent_container.clone()) && let Some(fn_def) = sema.to_def(&fn_) && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db)) { - return Some(return_none_expr()); + return Some(return_none_expr().into()); } if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body()) && let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original) && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty) { - return Some(return_none_expr()); + return Some(return_none_expr().into()); } Some(match parent_container.kind() { - WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None), - FN | CLOSURE_EXPR => make::expr_return(None), + WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make.expr_continue(None).into(), + FN | CLOSURE_EXPR => make.expr_return(None).into(), _ => return None, }) } -fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> { +fn flat_let_chain(mut expr: ast::Expr, make: &SyntaxFactory) -> Vec<ast::Expr> { let mut chains = vec![]; let mut reduce_cond = |rhs| { if !matches!(rhs, ast::Expr::LetExpr(_)) && let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) { - chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last)); + chains.push(make.expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last)); } else { chains.push(rhs); } @@ -286,12 +293,12 @@ fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> { chains } -fn clean_stmt_block(block: &ast::BlockExpr) -> ast::BlockExpr { +fn clean_stmt_block(block: &ast::BlockExpr, make: &SyntaxFactory) -> ast::BlockExpr { if block.statements().next().is_none() && let Some(tail_expr) = block.tail_expr() && block.modifier().is_none() { - make::block_expr(once(make::expr_stmt(tail_expr).into()), None) + make.block_expr(once(make.expr_stmt(tail_expr).into()), None) } else { block.clone() } @@ -942,6 +949,32 @@ fn main() { } #[test] + fn convert_let_ref_stmt_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" +//- minicore: option +fn foo() -> &'static Option<i32> { + &None +} + +fn main() { + let x$0 = foo(); +} +"#, + r#" +fn foo() -> &'static Option<i32> { + &None +} + +fn main() { + let Some(x) = foo() else { return }; +} +"#, + ); + } + + #[test] fn convert_let_stmt_inside_fn_return_option() { check_assist( convert_to_guarded_return, diff --git a/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs b/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs index 0e5e6185d0..1740cd024a 100644 --- a/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs +++ b/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs @@ -5,15 +5,20 @@ use ide_db::{ assists::AssistId, defs::Definition, helpers::mod_path_to_ast, - imports::insert_use::{ImportScope, insert_use}, + imports::insert_use::{ImportScope, insert_use_with_editor}, search::{FileReference, UsageSearchResult}, source_change::SourceChangeBuilder, syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, }; use syntax::{ AstNode, SyntaxNode, - ast::{self, HasName, edit::IndentLevel, edit_in_place::Indent, make}, - match_ast, ted, + ast::{ + self, HasName, + edit::{AstNodeEdit, IndentLevel}, + syntax_factory::SyntaxFactory, + }, + match_ast, + syntax_editor::SyntaxEditor, }; use crate::assist_context::{AssistContext, Assists}; @@ -67,14 +72,15 @@ pub(crate) fn convert_tuple_return_type_to_struct( "Convert tuple return type to tuple struct", target, move |edit| { - let ret_type = edit.make_mut(ret_type); - let fn_ = edit.make_mut(fn_); + let mut syntax_editor = edit.make_editor(ret_type.syntax()); + let syntax_factory = SyntaxFactory::with_mappings(); let usages = Definition::Function(fn_def).usages(&ctx.sema).all(); let struct_name = format!("{}Result", stdx::to_camel_case(&fn_name.to_string())); let parent = fn_.syntax().ancestors().find_map(<Either<ast::Impl, ast::Trait>>::cast); add_tuple_struct_def( edit, + &syntax_factory, ctx, &usages, parent.as_ref().map(|it| it.syntax()).unwrap_or(fn_.syntax()), @@ -83,15 +89,23 @@ pub(crate) fn convert_tuple_return_type_to_struct( &target_module, ); - ted::replace( + syntax_editor.replace( ret_type.syntax(), - make::ret_type(make::ty(&struct_name)).syntax().clone_for_update(), + syntax_factory.ret_type(syntax_factory.ty(&struct_name)).syntax(), ); if let Some(fn_body) = fn_.body() { - replace_body_return_values(ast::Expr::BlockExpr(fn_body), &struct_name); + replace_body_return_values( + &mut syntax_editor, + &syntax_factory, + ast::Expr::BlockExpr(fn_body), + &struct_name, + ); } + syntax_editor.add_mappings(syntax_factory.finish_with_mappings()); + edit.add_file_edits(ctx.vfs_file_id(), syntax_editor); + replace_usages(edit, ctx, &usages, &struct_name, &target_module); }, ) @@ -106,24 +120,37 @@ fn replace_usages( target_module: &hir::Module, ) { for (file_id, references) in usages.iter() { - edit.edit_file(file_id.file_id(ctx.db())); + let Some(first_ref) = references.first() else { continue }; + + let mut editor = edit.make_editor(first_ref.name.syntax().as_node().unwrap()); + let syntax_factory = SyntaxFactory::with_mappings(); - let refs_with_imports = - augment_references_with_imports(edit, ctx, references, struct_name, target_module); + let refs_with_imports = augment_references_with_imports( + &syntax_factory, + ctx, + references, + struct_name, + target_module, + ); refs_with_imports.into_iter().rev().for_each(|(name, import_data)| { if let Some(fn_) = name.syntax().parent().and_then(ast::Fn::cast) { cov_mark::hit!(replace_trait_impl_fns); if let Some(ret_type) = fn_.ret_type() { - ted::replace( + editor.replace( ret_type.syntax(), - make::ret_type(make::ty(struct_name)).syntax().clone_for_update(), + syntax_factory.ret_type(syntax_factory.ty(struct_name)).syntax(), ); } if let Some(fn_body) = fn_.body() { - replace_body_return_values(ast::Expr::BlockExpr(fn_body), struct_name); + replace_body_return_values( + &mut editor, + &syntax_factory, + ast::Expr::BlockExpr(fn_body), + struct_name, + ); } } else { // replace tuple patterns @@ -143,22 +170,30 @@ fn replace_usages( _ => None, }); for tuple_pat in tuple_pats { - ted::replace( + editor.replace( tuple_pat.syntax(), - make::tuple_struct_pat( - make::path_from_text(struct_name), - tuple_pat.fields(), - ) - .clone_for_update() - .syntax(), + syntax_factory + .tuple_struct_pat( + syntax_factory.path_from_text(struct_name), + tuple_pat.fields(), + ) + .syntax(), ); } } - // add imports across modules where needed if let Some((import_scope, path)) = import_data { - insert_use(&import_scope, path, &ctx.config.insert_use); + insert_use_with_editor( + &import_scope, + path, + &ctx.config.insert_use, + &mut editor, + &syntax_factory, + ); } - }) + }); + + editor.add_mappings(syntax_factory.finish_with_mappings()); + edit.add_file_edits(file_id.file_id(ctx.db()), editor); } } @@ -176,7 +211,7 @@ fn node_to_pats(node: SyntaxNode) -> Option<Vec<ast::Pat>> { } fn augment_references_with_imports( - edit: &mut SourceChangeBuilder, + syntax_factory: &SyntaxFactory, ctx: &AssistContext<'_>, references: &[FileReference], struct_name: &str, @@ -191,8 +226,6 @@ fn augment_references_with_imports( ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module())) }) .map(|(name, ref_module)| { - let new_name = edit.make_mut(name); - // if the referenced module is not the same as the target one and has not been seen before, add an import let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module && !visited_modules.contains(&ref_module) @@ -201,8 +234,7 @@ fn augment_references_with_imports( let cfg = ctx.config.find_path_config(ctx.sema.is_nightly(ref_module.krate(ctx.sema.db))); - let import_scope = - ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema); + let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema); let path = ref_module .find_use_path( ctx.sema.db, @@ -211,12 +243,12 @@ fn augment_references_with_imports( cfg, ) .map(|mod_path| { - make::path_concat( + syntax_factory.path_concat( mod_path_to_ast( &mod_path, target_module.krate(ctx.db()).edition(ctx.db()), ), - make::path_from_text(struct_name), + syntax_factory.path_from_text(struct_name), ) }); @@ -225,7 +257,7 @@ fn augment_references_with_imports( None }; - (new_name, import_data) + (name, import_data) }) .collect() } @@ -233,6 +265,7 @@ fn augment_references_with_imports( // Adds the definition of the tuple struct before the parent function. fn add_tuple_struct_def( edit: &mut SourceChangeBuilder, + syntax_factory: &SyntaxFactory, ctx: &AssistContext<'_>, usages: &UsageSearchResult, parent: &SyntaxNode, @@ -248,22 +281,27 @@ fn add_tuple_struct_def( ctx.sema.scope(name.syntax()).map(|scope| scope.module()) }) .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); - let visibility = if make_struct_pub { Some(make::visibility_pub()) } else { None }; + let visibility = if make_struct_pub { Some(syntax_factory.visibility_pub()) } else { None }; - let field_list = ast::FieldList::TupleFieldList(make::tuple_field_list( - tuple_ty.fields().map(|ty| make::tuple_field(visibility.clone(), ty)), + let field_list = ast::FieldList::TupleFieldList(syntax_factory.tuple_field_list( + tuple_ty.fields().map(|ty| syntax_factory.tuple_field(visibility.clone(), ty)), )); - let struct_name = make::name(struct_name); - let struct_def = make::struct_(visibility, struct_name, None, field_list).clone_for_update(); + let struct_name = syntax_factory.name(struct_name); + let struct_def = syntax_factory.struct_(visibility, struct_name, None, field_list); let indent = IndentLevel::from_node(parent); - struct_def.reindent_to(indent); + let struct_def = struct_def.indent(indent); edit.insert(parent.text_range().start(), format!("{struct_def}\n\n{indent}")); } /// Replaces each returned tuple in `body` with the constructor of the tuple struct named `struct_name`. -fn replace_body_return_values(body: ast::Expr, struct_name: &str) { +fn replace_body_return_values( + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, + body: ast::Expr, + struct_name: &str, +) { let mut exprs_to_wrap = Vec::new(); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); @@ -278,12 +316,11 @@ fn replace_body_return_values(body: ast::Expr, struct_name: &str) { for ret_expr in exprs_to_wrap { if let ast::Expr::TupleExpr(tuple_expr) = &ret_expr { - let struct_constructor = make::expr_call( - make::expr_path(make::ext::ident_path(struct_name)), - make::arg_list(tuple_expr.fields()), - ) - .clone_for_update(); - ted::replace(ret_expr.syntax(), struct_constructor.syntax()); + let struct_constructor = syntax_factory.expr_call( + syntax_factory.expr_path(syntax_factory.ident_path(struct_name)), + syntax_factory.arg_list(tuple_expr.fields()), + ); + syntax_editor.replace(ret_expr.syntax(), struct_constructor.syntax()); } } } diff --git a/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs b/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs index f8b9bb68db..ae41e6c015 100644 --- a/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs +++ b/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs @@ -1,17 +1,21 @@ use either::Either; -use hir::FileRangeWrapper; -use ide_db::defs::{Definition, NameRefClass}; -use std::ops::RangeInclusive; +use ide_db::{ + defs::{Definition, NameRefClass}, + search::FileReference, +}; use syntax::{ - SyntaxElement, SyntaxKind, SyntaxNode, T, TextSize, + SyntaxKind, T, ast::{ - self, AstNode, HasAttrs, HasGenericParams, HasVisibility, syntax_factory::SyntaxFactory, + self, AstNode, HasArgList, HasAttrs, HasGenericParams, HasVisibility, + syntax_factory::SyntaxFactory, }, match_ast, syntax_editor::{Element, Position, SyntaxEditor}, }; -use crate::{AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder}; +use crate::{ + AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder, utils::cover_edit_range, +}; // Assist: convert_tuple_struct_to_named_struct // @@ -138,102 +142,130 @@ fn edit_struct_def( fn edit_struct_references( ctx: &AssistContext<'_>, edit: &mut SourceChangeBuilder, - strukt: Either<hir::Struct, hir::Variant>, + strukt: Either<hir::Struct, hir::EnumVariant>, names: &[ast::Name], ) { let strukt_def = match strukt { Either::Left(s) => Definition::Adt(hir::Adt::Struct(s)), - Either::Right(v) => Definition::Variant(v), + Either::Right(v) => Definition::EnumVariant(v), }; let usages = strukt_def.usages(&ctx.sema).include_self_refs().all(); - let edit_node = |node: SyntaxNode| -> Option<SyntaxNode> { - let make = SyntaxFactory::without_mappings(); - match_ast! { - match node { - ast::TupleStructPat(tuple_struct_pat) => { - Some(make.record_pat_with_fields( - tuple_struct_pat.path()?, - generate_record_pat_list(&tuple_struct_pat, names), - ).syntax().clone()) - }, - // for tuple struct creations like Foo(42) - ast::CallExpr(call_expr) => { - let path = call_expr.syntax().descendants().find_map(ast::PathExpr::cast).and_then(|expr| expr.path())?; - - // this also includes method calls like Foo::new(42), we should skip them - if let Some(name_ref) = path.segment().and_then(|s| s.name_ref()) { - match NameRefClass::classify(&ctx.sema, &name_ref) { - Some(NameRefClass::Definition(Definition::SelfType(_), _)) => {}, - Some(NameRefClass::Definition(def, _)) if def == strukt_def => {}, - _ => return None, - }; - } + for (file_id, refs) in usages { + let source = ctx.sema.parse(file_id); + let mut editor = edit.make_editor(source.syntax()); - let arg_list = call_expr.syntax().descendants().find_map(ast::ArgList::cast)?; - Some( - make.record_expr( - path, - ast::make::record_expr_field_list(arg_list.args().zip(names).map( - |(expr, name)| { - ast::make::record_expr_field( - ast::make::name_ref(&name.to_string()), - Some(expr), - ) - }, - )), - ).syntax().clone() - ) - }, - _ => None, - } + for r in refs { + process_struct_name_reference(ctx, r, &mut editor, &source, &strukt_def, names); } - }; - for (file_id, refs) in usages { - let source = ctx.sema.parse(file_id); - let source = source.syntax(); - - let mut editor = edit.make_editor(source); - for r in refs.iter().rev() { - if let Some((old_node, new_node)) = r - .name - .syntax() - .ancestors() - .find_map(|node| Some((node.clone(), edit_node(node.clone())?))) - { - if let Some(old_node) = ctx.sema.original_syntax_node_rooted(&old_node) { - editor.replace(old_node, new_node); - } else { - let FileRangeWrapper { file_id: _, range } = ctx.sema.original_range(&old_node); - let parent = source.covering_element(range); - match parent { - SyntaxElement::Token(token) => { - editor.replace(token, new_node.syntax_element()); - } - SyntaxElement::Node(parent_node) => { - // replace the part of macro - // ``` - // foo!(a, Test::A(0)); - // ^^^^^^^^^^^^^^^ // parent_node - // ^^^^^^^^^^ // replace_range - // ``` - let start = parent_node - .children_with_tokens() - .find(|t| t.text_range().contains(range.start())); - let end = parent_node - .children_with_tokens() - .find(|t| t.text_range().contains(range.end() - TextSize::new(1))); - if let (Some(start), Some(end)) = (start, end) { - let replace_range = RangeInclusive::new(start, end); - editor.replace_all(replace_range, vec![new_node.into()]); - } - } + edit.add_file_edits(file_id.file_id(ctx.db()), editor); + } +} + +fn process_struct_name_reference( + ctx: &AssistContext<'_>, + r: FileReference, + editor: &mut SyntaxEditor, + source: &ast::SourceFile, + strukt_def: &Definition, + names: &[ast::Name], +) -> Option<()> { + let make = SyntaxFactory::without_mappings(); + let name_ref = r.name.as_name_ref()?; + let path_segment = name_ref.syntax().parent().and_then(ast::PathSegment::cast)?; + let full_path = path_segment.syntax().parent().and_then(ast::Path::cast)?.top_path(); + + if full_path.segment()?.name_ref()? != *name_ref { + // `name_ref` isn't the last segment of the path, so `full_path` doesn't point to the + // struct we want to edit. + return None; + } + + let parent = full_path.syntax().parent()?; + match_ast! { + match parent { + ast::TupleStructPat(tuple_struct_pat) => { + let range = ctx.sema.original_range_opt(tuple_struct_pat.syntax())?.range; + let new = make.record_pat_with_fields( + full_path, + generate_record_pat_list(&tuple_struct_pat, names), + ); + editor.replace_all(cover_edit_range(source.syntax(), range), vec![new.syntax().clone().into()]); + }, + ast::PathExpr(path_expr) => { + let call_expr = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; + + // this also includes method calls like Foo::new(42), we should skip them + match NameRefClass::classify(&ctx.sema, name_ref) { + Some(NameRefClass::Definition(Definition::SelfType(_), _)) => {}, + Some(NameRefClass::Definition(def, _)) if def == *strukt_def => {}, + _ => return None, + } + + let arg_list = call_expr.arg_list()?; + let mut first_insert = vec![]; + for (expr, name) in arg_list.args().zip(names) { + let range = ctx.sema.original_range_opt(expr.syntax())?.range; + let place = cover_edit_range(source.syntax(), range); + let elements = vec![ + make.name_ref(&name.text()).syntax().clone().into(), + make.token(T![:]).into(), + make.whitespace(" ").into(), + ]; + if first_insert.is_empty() { + // XXX: SyntaxEditor cannot insert after deleted element + first_insert = elements; + } else { + editor.insert_all(Position::before(place.start()), elements); } } - } + process_delimiter(ctx, source, editor, &arg_list, first_insert); + }, + _ => {} } - edit.add_file_edits(file_id.file_id(ctx.db()), editor); + } + Some(()) +} + +fn process_delimiter( + ctx: &AssistContext<'_>, + source: &ast::SourceFile, + editor: &mut SyntaxEditor, + list: &impl AstNode, + first_insert: Vec<syntax::SyntaxElement>, +) { + let Some(range) = ctx.sema.original_range_opt(list.syntax()) else { return }; + let place = cover_edit_range(source.syntax(), range.range); + + let l_paren = match place.start() { + syntax::NodeOrToken::Node(node) => node.first_token(), + syntax::NodeOrToken::Token(t) => Some(t.clone()), + }; + let r_paren = match place.end() { + syntax::NodeOrToken::Node(node) => node.last_token(), + syntax::NodeOrToken::Token(t) => Some(t.clone()), + }; + + let make = SyntaxFactory::without_mappings(); + if let Some(l_paren) = l_paren + && l_paren.kind() == T!['('] + { + let mut open_delim = vec![ + make.whitespace(" ").into(), + make.token(T!['{']).into(), + make.whitespace(" ").into(), + ]; + open_delim.extend(first_insert); + editor.replace_with_many(l_paren, open_delim); + } + if let Some(r_paren) = r_paren + && r_paren.kind() == T![')'] + { + editor.replace_with_many( + r_paren, + vec![make.whitespace(" ").into(), make.token(T!['}']).into()], + ); } } @@ -252,13 +284,15 @@ fn edit_field_references( let usages = def.usages(&ctx.sema).all(); for (file_id, refs) in usages { let source = ctx.sema.parse(file_id); - let source = source.syntax(); - let mut editor = edit.make_editor(source); + let mut editor = edit.make_editor(source.syntax()); for r in refs { if let Some(name_ref) = r.name.as_name_ref() - && let Some(original) = ctx.sema.original_ast_node(name_ref.clone()) + && let Some(original) = ctx.sema.original_range_opt(name_ref.syntax()) { - editor.replace(original.syntax(), name.syntax()); + editor.replace_all( + cover_edit_range(source.syntax(), original.range), + vec![name.syntax().clone().into()], + ); } } edit.add_file_edits(file_id.file_id(ctx.db()), editor); @@ -739,6 +773,64 @@ where "#, ); } + + #[test] + fn convert_expr_uses_self() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +struct T$0(u8); +fn test(t: T) { + T(t.0); + id!(T(t.0)); +}"#, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +struct T { field1: u8 } +fn test(t: T) { + T { field1: t.field1 }; + id!(T { field1: t.field1 }); +}"#, + ); + } + + #[test] + #[ignore = "FIXME overlap edits in nested uses self"] + fn convert_pat_uses_self() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +enum T { + $0Value(&'static T), + Nil, +} +fn test(t: T) { + if let T::Value(T::Value(t)) = t {} + if let id!(T::Value(T::Value(t))) = t {} +}"#, + r#" +macro_rules! id { + ($($t:tt)*) => { $($t)* } +} +enum T { + Value { field1: &'static T }, + Nil, +} +fn test(t: T) { + if let T::Value { field1: T::Value { field1: t } } = t {} + if let id!(T::Value { field1: T::Value { field1: t } }) = t {} +}"#, + ); + } + #[test] fn not_applicable_other_than_tuple_variant() { check_assist_not_applicable( diff --git a/crates/ide-assists/src/handlers/convert_while_to_loop.rs b/crates/ide-assists/src/handlers/convert_while_to_loop.rs index 9fd8b4b315..f8215d6723 100644 --- a/crates/ide-assists/src/handlers/convert_while_to_loop.rs +++ b/crates/ide-assists/src/handlers/convert_while_to_loop.rs @@ -6,7 +6,7 @@ use syntax::{ ast::{ self, HasLoopBody, edit::{AstNodeEdit, IndentLevel}, - make, + syntax_factory::SyntaxFactory, }, syntax_editor::{Element, Position}, }; @@ -14,7 +14,7 @@ use syntax::{ use crate::{ AssistId, assist_context::{AssistContext, Assists}, - utils::invert_boolean_expression_legacy, + utils::invert_boolean_expression, }; // Assist: convert_while_to_loop @@ -52,44 +52,47 @@ pub(crate) fn convert_while_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) "Convert while to loop", target, |builder| { + let make = SyntaxFactory::without_mappings(); let mut edit = builder.make_editor(while_expr.syntax()); let while_indent_level = IndentLevel::from_node(while_expr.syntax()); - let break_block = make::block_expr( - iter::once(make::expr_stmt(make::expr_break(None, None)).into()), - None, - ) - .indent(IndentLevel(1)); + let break_block = make + .block_expr( + iter::once(make.expr_stmt(make.expr_break(None, None).into()).into()), + None, + ) + .indent(IndentLevel(1)); edit.replace_all( while_kw.syntax_element()..=while_cond.syntax().syntax_element(), - vec![make::token(T![loop]).syntax_element()], + vec![make.token(T![loop]).syntax_element()], ); if is_pattern_cond(while_cond.clone()) { let then_branch = while_body.reset_indent().indent(IndentLevel(1)); - let if_expr = make::expr_if(while_cond, then_branch, Some(break_block.into())); - let stmts = iter::once(make::expr_stmt(if_expr.into()).into()); - let block_expr = make::block_expr(stmts, None); + let if_expr = make.expr_if(while_cond, then_branch, Some(break_block.into())); + let stmts = iter::once(make.expr_stmt(if_expr.into()).into()); + let block_expr = make.block_expr(stmts, None); edit.replace(while_body.syntax(), block_expr.indent(while_indent_level).syntax()); } else { - let if_cond = invert_boolean_expression_legacy(while_cond); - let if_expr = make::expr_if(if_cond, break_block, None).indent(while_indent_level); + let if_cond = invert_boolean_expression(&make, while_cond); + let if_expr = make.expr_if(if_cond, break_block, None).indent(while_indent_level); if !while_body.syntax().text().contains_char('\n') { edit.insert( Position::after(&l_curly), - make::tokens::whitespace(&format!("\n{while_indent_level}")), + make.whitespace(&format!("\n{while_indent_level}")), ); } edit.insert_all( Position::after(&l_curly), vec![ - make::tokens::whitespace(&format!("\n{}", while_indent_level + 1)).into(), + make.whitespace(&format!("\n{}", while_indent_level + 1)).into(), if_expr.syntax().syntax_element(), ], ); }; + edit.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), edit); }, ) diff --git a/crates/ide-assists/src/handlers/destructure_struct_binding.rs b/crates/ide-assists/src/handlers/destructure_struct_binding.rs index bb5d112210..ec4a83b642 100644 --- a/crates/ide-assists/src/handlers/destructure_struct_binding.rs +++ b/crates/ide-assists/src/handlers/destructure_struct_binding.rs @@ -1,19 +1,23 @@ -use hir::HasVisibility; +use hir::{HasVisibility, Semantics}; use ide_db::{ - FxHashMap, FxHashSet, + FxHashMap, FxHashSet, RootDatabase, assists::AssistId, defs::Definition, helpers::mod_path_to_ast, search::{FileReference, SearchScope}, }; use itertools::Itertools; -use syntax::ast::{HasName, syntax_factory::SyntaxFactory}; use syntax::syntax_editor::SyntaxEditor; use syntax::{AstNode, Edition, SmolStr, SyntaxNode, ToSmolStr, ast}; +use syntax::{ + SyntaxToken, + ast::{HasName, edit::IndentLevel, syntax_factory::SyntaxFactory}, + syntax_editor::Position, +}; use crate::{ assist_context::{AssistContext, Assists, SourceChangeBuilder}, - utils::ref_field_expr::determine_ref_and_parens, + utils::{cover_edit_range, ref_field_expr::determine_ref_and_parens}, }; // Assist: destructure_struct_binding @@ -44,33 +48,90 @@ use crate::{ // } // ``` pub(crate) fn destructure_struct_binding(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { - let ident_pat = ctx.find_node_at_offset::<ast::IdentPat>()?; - let data = collect_data(ident_pat, ctx)?; + let target = ctx.find_node_at_offset::<Target>()?; + let data = collect_data(target, ctx)?; acc.add( AssistId::refactor_rewrite("destructure_struct_binding"), "Destructure struct binding", - data.ident_pat.syntax().text_range(), + data.target.syntax().text_range(), |edit| destructure_struct_binding_impl(ctx, edit, &data), ); Some(()) } +enum Target { + IdentPat(ast::IdentPat), + SelfParam { param: ast::SelfParam, insert_after: SyntaxToken }, +} + +impl Target { + fn ty<'db>(&self, sema: &Semantics<'db, RootDatabase>) -> Option<hir::Type<'db>> { + match self { + Target::IdentPat(pat) => sema.type_of_binding_in_pat(pat), + Target::SelfParam { param, .. } => sema.type_of_self(param), + } + } + + fn is_ref(&self) -> bool { + match self { + Target::IdentPat(ident_pat) => ident_pat.ref_token().is_some(), + Target::SelfParam { .. } => false, + } + } + + fn is_mut(&self) -> bool { + match self { + Target::IdentPat(ident_pat) => ident_pat.mut_token().is_some(), + Target::SelfParam { param, .. } => { + param.mut_token().is_some() && param.amp_token().is_none() + } + } + } +} + +impl HasName for Target {} + +impl AstNode for Target { + fn cast(node: SyntaxNode) -> Option<Self> { + if ast::IdentPat::can_cast(node.kind()) { + ast::IdentPat::cast(node).map(Self::IdentPat) + } else { + let param = ast::SelfParam::cast(node)?; + let param_list = param.syntax().parent().and_then(ast::ParamList::cast)?; + let block = param_list.syntax().parent()?.children().find_map(ast::BlockExpr::cast)?; + let insert_after = block.stmt_list()?.l_curly_token()?; + Some(Self::SelfParam { param, insert_after }) + } + } + + fn can_cast(kind: syntax::SyntaxKind) -> bool { + ast::IdentPat::can_cast(kind) || ast::SelfParam::can_cast(kind) + } + + fn syntax(&self) -> &SyntaxNode { + match self { + Target::IdentPat(ident_pat) => ident_pat.syntax(), + Target::SelfParam { param, .. } => param.syntax(), + } + } +} + fn destructure_struct_binding_impl( ctx: &AssistContext<'_>, builder: &mut SourceChangeBuilder, data: &StructEditData, ) { let field_names = generate_field_names(ctx, data); - let mut editor = builder.make_editor(data.ident_pat.syntax()); + let mut editor = builder.make_editor(data.target.syntax()); destructure_pat(ctx, &mut editor, data, &field_names); update_usages(ctx, &mut editor, data, &field_names.into_iter().collect()); builder.add_file_edits(ctx.vfs_file_id(), editor); } struct StructEditData { - ident_pat: ast::IdentPat, + target: Target, name: ast::Name, kind: hir::StructKind, struct_def_path: hir::ModPath, @@ -83,11 +144,44 @@ struct StructEditData { edition: Edition, } -fn collect_data(ident_pat: ast::IdentPat, ctx: &AssistContext<'_>) -> Option<StructEditData> { - let ty = ctx.sema.type_of_binding_in_pat(&ident_pat)?; +impl StructEditData { + fn apply_to_destruct( + &self, + new_pat: ast::Pat, + editor: &mut SyntaxEditor, + make: &SyntaxFactory, + ) { + match &self.target { + Target::IdentPat(pat) => { + // If the binding is nested inside a record, we need to wrap the new + // destructured pattern in a non-shorthand record field + if self.need_record_field_name { + let new_pat = + make.record_pat_field(make.name_ref(&self.name.to_string()), new_pat); + editor.replace(pat.syntax(), new_pat.syntax()) + } else { + editor.replace(pat.syntax(), new_pat.syntax()) + } + } + Target::SelfParam { insert_after, .. } => { + let indent = IndentLevel::from_token(insert_after) + 1; + let newline = make.whitespace(&format!("\n{indent}")); + let initializer = make.expr_path(make.ident_path("self")); + let let_stmt = make.let_stmt(new_pat, None, Some(initializer)); + editor.insert_all( + Position::after(insert_after), + vec![newline.into(), let_stmt.syntax().clone().into()], + ); + } + } + } +} + +fn collect_data(target: Target, ctx: &AssistContext<'_>) -> Option<StructEditData> { + let ty = target.ty(&ctx.sema)?; let hir::Adt::Struct(struct_type) = ty.strip_references().as_adt()? else { return None }; - let module = ctx.sema.scope(ident_pat.syntax())?.module(); + let module = ctx.sema.scope(target.syntax())?.module(); let cfg = ctx.config.find_path_config(ctx.sema.is_nightly(module.krate(ctx.db()))); let struct_def = hir::ModuleDef::from(struct_type); let kind = struct_type.kind(ctx.db()); @@ -116,15 +210,17 @@ fn collect_data(ident_pat: ast::IdentPat, ctx: &AssistContext<'_>) -> Option<Str } let is_ref = ty.is_reference(); - let need_record_field_name = ident_pat + let need_record_field_name = target .syntax() .parent() .and_then(ast::RecordPatField::cast) .is_some_and(|field| field.colon_token().is_none()); - let usages = ctx - .sema - .to_def(&ident_pat) + let def = match &target { + Target::IdentPat(pat) => ctx.sema.to_def(pat), + Target::SelfParam { param, .. } => ctx.sema.to_def(param), + }; + let usages = def .and_then(|def| { Definition::Local(def) .usages(&ctx.sema) @@ -136,11 +232,11 @@ fn collect_data(ident_pat: ast::IdentPat, ctx: &AssistContext<'_>) -> Option<Str }) .unwrap_or_default(); - let names_in_scope = get_names_in_scope(ctx, &ident_pat, &usages).unwrap_or_default(); + let names_in_scope = get_names_in_scope(ctx, &target, &usages).unwrap_or_default(); Some(StructEditData { - name: ident_pat.name()?, - ident_pat, + name: target.name()?, + target, kind, struct_def_path, usages, @@ -155,7 +251,7 @@ fn collect_data(ident_pat: ast::IdentPat, ctx: &AssistContext<'_>) -> Option<Str fn get_names_in_scope( ctx: &AssistContext<'_>, - ident_pat: &ast::IdentPat, + target: &Target, usages: &[FileReference], ) -> Option<FxHashSet<SmolStr>> { fn last_usage(usages: &[FileReference]) -> Option<SyntaxNode> { @@ -165,7 +261,7 @@ fn get_names_in_scope( // If available, find names visible to the last usage of the binding // else, find names visible to the binding itself let last_usage = last_usage(usages); - let node = last_usage.as_ref().unwrap_or(ident_pat.syntax()); + let node = last_usage.as_ref().unwrap_or(target.syntax()); let scope = ctx.sema.scope(node)?; let mut names = FxHashSet::default(); @@ -183,12 +279,9 @@ fn destructure_pat( data: &StructEditData, field_names: &[(SmolStr, SmolStr)], ) { - let ident_pat = &data.ident_pat; - let name = &data.name; - let struct_path = mod_path_to_ast(&data.struct_def_path, data.edition); - let is_ref = ident_pat.ref_token().is_some(); - let is_mut = ident_pat.mut_token().is_some(); + let is_ref = data.target.is_ref(); + let is_mut = data.target.is_mut(); let make = SyntaxFactory::with_mappings(); let new_pat = match data.kind { @@ -221,16 +314,8 @@ fn destructure_pat( hir::StructKind::Unit => make.path_pat(struct_path), }; - // If the binding is nested inside a record, we need to wrap the new - // destructured pattern in a non-shorthand record field - let destructured_pat = if data.need_record_field_name { - make.record_pat_field(make.name_ref(&name.to_string()), new_pat).syntax().clone() - } else { - new_pat.syntax().clone() - }; - + data.apply_to_destruct(new_pat, editor, &make); editor.add_mappings(make.finish_with_mappings()); - editor.replace(data.ident_pat.syntax(), destructured_pat); } fn generate_field_names(ctx: &AssistContext<'_>, data: &StructEditData) -> Vec<(SmolStr, SmolStr)> { @@ -273,6 +358,7 @@ fn update_usages( data: &StructEditData, field_names: &FxHashMap<SmolStr, SmolStr>, ) { + let source = ctx.source_file().syntax(); let make = SyntaxFactory::with_mappings(); let edits = data .usages @@ -281,7 +367,9 @@ fn update_usages( .collect_vec(); editor.add_mappings(make.finish_with_mappings()); for (old, new) in edits { - editor.replace(old, new); + if let Some(range) = ctx.sema.original_range_opt(&old) { + editor.replace_all(cover_edit_range(source, range.range), vec![new.into()]); + } } } @@ -296,23 +384,20 @@ fn build_usage_edit( Some(field_expr) => Some({ let field_name: SmolStr = field_expr.name_ref()?.to_string().into(); let new_field_name = field_names.get(&field_name)?; - let new_expr = ast::make::expr_path(ast::make::ext::ident_path(new_field_name)); + let new_expr = make.expr_path(make.ident_path(new_field_name)); // If struct binding is a reference, we might need to deref field usages if data.is_ref { let (replace_expr, ref_data) = determine_ref_and_parens(ctx, &field_expr); - ( - replace_expr.syntax().clone_for_update(), - ref_data.wrap_expr(new_expr).syntax().clone_for_update(), - ) + (replace_expr.syntax().clone(), ref_data.wrap_expr(new_expr, make).syntax().clone()) } else { - (field_expr.syntax().clone(), new_expr.syntax().clone_for_update()) + (field_expr.syntax().clone(), new_expr.syntax().clone()) } }), None => Some(( usage.name.syntax().as_node().unwrap().clone(), make.expr_macro( - ast::make::ext::ident_path("todo"), + make.ident_path("todo"), make.token_tree(syntax::SyntaxKind::L_PAREN, []), ) .syntax() @@ -699,6 +784,84 @@ mod tests { } #[test] + fn mut_self_param() { + check_assist( + destructure_struct_binding, + r#" + struct Foo { bar: i32, baz: i32 } + + impl Foo { + fn foo(mut $0self) { + self.bar = 5; + } + } + "#, + r#" + struct Foo { bar: i32, baz: i32 } + + impl Foo { + fn foo(mut self) { + let Foo { mut bar, mut baz } = self; + bar = 5; + } + } + "#, + ) + } + + #[test] + fn ref_mut_self_param() { + check_assist( + destructure_struct_binding, + r#" + struct Foo { bar: i32, baz: i32 } + + impl Foo { + fn foo(&mut $0self) { + self.bar = 5; + } + } + "#, + r#" + struct Foo { bar: i32, baz: i32 } + + impl Foo { + fn foo(&mut self) { + let Foo { bar, baz } = self; + *bar = 5; + } + } + "#, + ) + } + + #[test] + fn ref_self_param() { + check_assist( + destructure_struct_binding, + r#" + struct Foo { bar: i32, baz: i32 } + + impl Foo { + fn foo(&$0self) -> &i32 { + &self.bar + } + } + "#, + r#" + struct Foo { bar: i32, baz: i32 } + + impl Foo { + fn foo(&self) -> &i32 { + let Foo { bar, baz } = self; + bar + } + } + "#, + ) + } + + #[test] fn ref_not_add_parenthesis_and_deref_record() { check_assist( destructure_struct_binding, @@ -846,4 +1009,33 @@ mod tests { "#, ) } + + #[test] + fn record_struct_usage_in_macro_call() { + // exact repro from #20716: struct field access inside write! must not panic + check_assist( + destructure_struct_binding, + r#" +//- minicore: write, fmt +use core::fmt::Write; +struct Foo { y: i8 } + +fn main() { + let mut s = String::new(); + let $0x = Foo { y: 8 }; + write!(s, "{}", x.y).unwrap(); +} +"#, + r#" +use core::fmt::Write; +struct Foo { y: i8 } + +fn main() { + let mut s = String::new(); + let Foo { y } = Foo { y: 8 }; + write!(s, "{}", y).unwrap(); +} +"#, + ) + } } diff --git a/crates/ide-assists/src/handlers/destructure_tuple_binding.rs b/crates/ide-assists/src/handlers/destructure_tuple_binding.rs index e2afc0bf13..23c11b258c 100644 --- a/crates/ide-assists/src/handlers/destructure_tuple_binding.rs +++ b/crates/ide-assists/src/handlers/destructure_tuple_binding.rs @@ -8,13 +8,13 @@ use ide_db::{ use itertools::Itertools; use syntax::{ T, - ast::{self, AstNode, FieldExpr, HasName, IdentPat, make}, - ted, + ast::{self, AstNode, FieldExpr, HasName, IdentPat, syntax_factory::SyntaxFactory}, + syntax_editor::{Position, SyntaxEditor}, }; use crate::{ assist_context::{AssistContext, Assists, SourceChangeBuilder}, - utils::ref_field_expr::determine_ref_and_parens, + utils::{cover_edit_range, ref_field_expr::determine_ref_and_parens}, }; // Assist: destructure_tuple_binding @@ -89,13 +89,22 @@ fn destructure_tuple_edit_impl( data: &TupleData, in_sub_pattern: bool, ) { - let assignment_edit = edit_tuple_assignment(ctx, edit, data, in_sub_pattern); - let current_file_usages_edit = edit_tuple_usages(data, edit, ctx, in_sub_pattern); + let mut syntax_editor = edit.make_editor(data.ident_pat.syntax()); + let syntax_factory = SyntaxFactory::with_mappings(); - assignment_edit.apply(); + let assignment_edit = + edit_tuple_assignment(ctx, edit, &mut syntax_editor, &syntax_factory, data, in_sub_pattern); + let current_file_usages_edit = edit_tuple_usages(data, ctx, &syntax_factory, in_sub_pattern); + + assignment_edit.apply(&mut syntax_editor, &syntax_factory); if let Some(usages_edit) = current_file_usages_edit { - usages_edit.into_iter().for_each(|usage_edit| usage_edit.apply(edit)) + usages_edit + .into_iter() + .for_each(|usage_edit| usage_edit.apply(ctx, edit, &mut syntax_editor)) } + + syntax_editor.add_mappings(syntax_factory.finish_with_mappings()); + edit.add_file_edits(ctx.vfs_file_id(), syntax_editor); } fn collect_data(ident_pat: IdentPat, ctx: &AssistContext<'_>) -> Option<TupleData> { @@ -157,6 +166,7 @@ enum RefType { Mutable, } struct TupleData { + // FIXME: After removing ted, it may be possible to reuse destructure_struct_binding::Target ident_pat: IdentPat, ref_type: Option<RefType>, field_names: Vec<String>, @@ -165,11 +175,11 @@ struct TupleData { fn edit_tuple_assignment( ctx: &AssistContext<'_>, edit: &mut SourceChangeBuilder, + editor: &mut SyntaxEditor, + make: &SyntaxFactory, data: &TupleData, in_sub_pattern: bool, ) -> AssignmentEdit { - let ident_pat = edit.make_mut(data.ident_pat.clone()); - let tuple_pat = { let original = &data.ident_pat; let is_ref = original.ref_token().is_some(); @@ -177,10 +187,11 @@ fn edit_tuple_assignment( let fields = data .field_names .iter() - .map(|name| ast::Pat::from(make::ident_pat(is_ref, is_mut, make::name(name)))); - make::tuple_pat(fields).clone_for_update() + .map(|name| ast::Pat::from(make.ident_pat(is_ref, is_mut, make.name(name)))); + make.tuple_pat(fields) }; - let is_shorthand_field = ident_pat + let is_shorthand_field = data + .ident_pat .name() .as_ref() .and_then(ast::RecordPatField::for_field_name) @@ -189,14 +200,20 @@ fn edit_tuple_assignment( if let Some(cap) = ctx.config.snippet_cap { // place cursor on first tuple name if let Some(ast::Pat::IdentPat(first_pat)) = tuple_pat.fields().next() { - edit.add_tabstop_before( - cap, - first_pat.name().expect("first ident pattern should have a name"), - ) + let annotation = edit.make_tabstop_before(cap); + editor.add_annotation( + first_pat.name().expect("first ident pattern should have a name").syntax(), + annotation, + ); } } - AssignmentEdit { ident_pat, tuple_pat, in_sub_pattern, is_shorthand_field } + AssignmentEdit { + ident_pat: data.ident_pat.clone(), + tuple_pat, + in_sub_pattern, + is_shorthand_field, + } } struct AssignmentEdit { ident_pat: ast::IdentPat, @@ -206,23 +223,30 @@ struct AssignmentEdit { } impl AssignmentEdit { - fn apply(self) { + fn apply(self, syntax_editor: &mut SyntaxEditor, syntax_mapping: &SyntaxFactory) { // with sub_pattern: keep original tuple and add subpattern: `tup @ (_0, _1)` if self.in_sub_pattern { - self.ident_pat.set_pat(Some(self.tuple_pat.into())) + self.ident_pat.set_pat_with_editor( + Some(self.tuple_pat.into()), + syntax_editor, + syntax_mapping, + ) } else if self.is_shorthand_field { - ted::insert(ted::Position::after(self.ident_pat.syntax()), self.tuple_pat.syntax()); - ted::insert_raw(ted::Position::after(self.ident_pat.syntax()), make::token(T![:])); + syntax_editor.insert(Position::after(self.ident_pat.syntax()), self.tuple_pat.syntax()); + syntax_editor + .insert(Position::after(self.ident_pat.syntax()), syntax_mapping.whitespace(" ")); + syntax_editor + .insert(Position::after(self.ident_pat.syntax()), syntax_mapping.token(T![:])); } else { - ted::replace(self.ident_pat.syntax(), self.tuple_pat.syntax()) + syntax_editor.replace(self.ident_pat.syntax(), self.tuple_pat.syntax()) } } } fn edit_tuple_usages( data: &TupleData, - edit: &mut SourceChangeBuilder, ctx: &AssistContext<'_>, + make: &SyntaxFactory, in_sub_pattern: bool, ) -> Option<Vec<EditTupleUsage>> { // We need to collect edits first before actually applying them @@ -238,20 +262,20 @@ fn edit_tuple_usages( .as_ref()? .as_slice() .iter() - .filter_map(|r| edit_tuple_usage(ctx, edit, r, data, in_sub_pattern)) + .filter_map(|r| edit_tuple_usage(ctx, make, r, data, in_sub_pattern)) .collect_vec(); Some(edits) } fn edit_tuple_usage( ctx: &AssistContext<'_>, - builder: &mut SourceChangeBuilder, + make: &SyntaxFactory, usage: &FileReference, data: &TupleData, in_sub_pattern: bool, ) -> Option<EditTupleUsage> { match detect_tuple_index(usage, data) { - Some(index) => Some(edit_tuple_field_usage(ctx, builder, data, index)), + Some(index) => Some(edit_tuple_field_usage(ctx, make, data, index)), None if in_sub_pattern => { cov_mark::hit!(destructure_tuple_call_with_subpattern); None @@ -262,20 +286,18 @@ fn edit_tuple_usage( fn edit_tuple_field_usage( ctx: &AssistContext<'_>, - builder: &mut SourceChangeBuilder, + make: &SyntaxFactory, data: &TupleData, index: TupleIndex, ) -> EditTupleUsage { let field_name = &data.field_names[index.index]; - let field_name = make::expr_path(make::ext::ident_path(field_name)); + let field_name = make.expr_path(make.ident_path(field_name)); if data.ref_type.is_some() { let (replace_expr, ref_data) = determine_ref_and_parens(ctx, &index.field_expr); - let replace_expr = builder.make_mut(replace_expr); - EditTupleUsage::ReplaceExpr(replace_expr, ref_data.wrap_expr(field_name)) + EditTupleUsage::ReplaceExpr(replace_expr, ref_data.wrap_expr_with_factory(field_name, make)) } else { - let field_expr = builder.make_mut(index.field_expr); - EditTupleUsage::ReplaceExpr(field_expr.into(), field_name) + EditTupleUsage::ReplaceExpr(index.field_expr.into(), field_name) } } enum EditTupleUsage { @@ -291,14 +313,25 @@ enum EditTupleUsage { } impl EditTupleUsage { - fn apply(self, edit: &mut SourceChangeBuilder) { + fn apply( + self, + ctx: &AssistContext<'_>, + edit: &mut SourceChangeBuilder, + syntax_editor: &mut SyntaxEditor, + ) { match self { EditTupleUsage::NoIndex(range) => { edit.insert(range.start(), "/*"); edit.insert(range.end(), "*/"); } EditTupleUsage::ReplaceExpr(target_expr, replace_with) => { - ted::replace(target_expr.syntax(), replace_with.clone_for_update().syntax()) + if let Some(range) = ctx.sema.original_range_opt(target_expr.syntax()) { + let source = ctx.source_file().syntax(); + syntax_editor.replace_all( + cover_edit_range(source, range.range), + vec![replace_with.syntax().clone().into()], + ); + } } } } @@ -329,24 +362,6 @@ fn detect_tuple_index(usage: &FileReference, data: &TupleData) -> Option<TupleIn if let Some(field_expr) = ast::FieldExpr::cast(node) { let idx = field_expr.name_ref()?.as_tuple_field()?; if idx < data.field_names.len() { - // special case: in macro call -> range of `field_expr` in applied macro, NOT range in actual file! - if field_expr.syntax().ancestors().any(|a| ast::MacroStmts::can_cast(a.kind())) { - cov_mark::hit!(destructure_tuple_macro_call); - - // issue: cannot differentiate between tuple index passed into macro or tuple index as result of macro: - // ```rust - // macro_rules! m { - // ($t1:expr, $t2:expr) => { $t1; $t2.0 } - // } - // let t = (1,2); - // m!(t.0, t) - // ``` - // -> 2 tuple index usages detected! - // - // -> only handle `t` - return None; - } - Some(TupleIndex { index: idx, field_expr }) } else { // tuple index out of range @@ -1417,7 +1432,6 @@ fn main() { #[test] fn detect_macro_call() { - cov_mark::check!(destructure_tuple_macro_call); check_in_place_assist( r#" macro_rules! m { @@ -1436,7 +1450,7 @@ macro_rules! m { fn main() { let ($0_0, _1) = (1,2); - m!(/*t*/.0); + m!(_0); } "#, ) @@ -1528,7 +1542,6 @@ fn main() { m!(t.0); } "#, - // FIXME: replace `t.0` with `_0` (cannot detect range of tuple index in macro call) r#" macro_rules! m { ($e:expr) => { "foo"; $e }; @@ -1536,10 +1549,9 @@ macro_rules! m { fn main() { let ($0_0, _1) = (1,2); - m!(/*t*/.0); + m!(_0); } "#, - // FIXME: replace `t.0` with `_0` r#" macro_rules! m { ($e:expr) => { "foo"; $e }; @@ -1547,7 +1559,7 @@ macro_rules! m { fn main() { let t @ ($0_0, _1) = (1,2); - m!(t.0); + m!(_0); } "#, ) @@ -1566,7 +1578,6 @@ fn main() { m!((t).0); } "#, - // FIXME: replace `(t).0` with `_0` r#" macro_rules! m { ($e:expr) => { "foo"; $e }; @@ -1574,10 +1585,9 @@ macro_rules! m { fn main() { let ($0_0, _1) = (1,2); - m!((/*t*/).0); + m!(_0); } "#, - // FIXME: replace `(t).0` with `_0` r#" macro_rules! m { ($e:expr) => { "foo"; $e }; @@ -1585,7 +1595,7 @@ macro_rules! m { fn main() { let t @ ($0_0, _1) = (1,2); - m!((t).0); + m!(_0); } "#, ) @@ -1633,7 +1643,6 @@ fn main() { m!(t, t.0); } "#, - // FIXME: replace `t.0` in macro call (not IN macro) with `_0` r#" macro_rules! m { ($t:expr, $i:expr) => { $t.0 + $i }; @@ -1641,10 +1650,9 @@ macro_rules! m { fn main() { let ($0_0, _1) = (1,2); - m!(/*t*/, /*t*/.0); + m!(t, _0); } "#, - // FIXME: replace `t.0` in macro call with `_0` r#" macro_rules! m { ($t:expr, $i:expr) => { $t.0 + $i }; @@ -1652,13 +1660,41 @@ macro_rules! m { fn main() { let t @ ($0_0, _1) = (1,2); - m!(t, t.0); + m!(t, _0); } "#, ) } } + mod in_macro_expr { + use super::assist::*; + + // exact repro from #20716: tuple index inside write! must not panic + #[test] + fn tuple_index_in_write_macro() { + check_in_place_assist( + r#" +//- minicore: write, fmt +use core::fmt::Write; +fn main() { + let mut s = String::new(); + let $0x = (2i32, 3i32); + write!(s, "{}", x.0).unwrap(); +} +"#, + r#" +use core::fmt::Write; +fn main() { + let mut s = String::new(); + let ($0_0, _1) = (2i32, 3i32); + write!(s, "{}", _0).unwrap(); +} +"#, + ) + } + } + mod refs { use super::assist::*; diff --git a/crates/ide-assists/src/handlers/desugar_try_expr.rs b/crates/ide-assists/src/handlers/desugar_try_expr.rs index 9976e34e73..865dc86221 100644 --- a/crates/ide-assists/src/handlers/desugar_try_expr.rs +++ b/crates/ide-assists/src/handlers/desugar_try_expr.rs @@ -1,15 +1,11 @@ use std::iter; -use ide_db::{ - assists::{AssistId, ExprFillDefaultMode}, - ty_filter::TryEnum, -}; +use ide_db::{assists::AssistId, ty_filter::TryEnum}; use syntax::{ AstNode, T, ast::{ self, edit::{AstNodeEdit, IndentLevel}, - make, syntax_factory::SyntaxFactory, }, }; @@ -68,41 +64,39 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op AssistId::refactor_rewrite("desugar_try_expr_match"), "Replace try expression with match", target, - |edit| { + |builder| { + let make = SyntaxFactory::with_mappings(); + let mut editor = builder.make_editor(try_expr.syntax()); + let sad_pat = match try_enum { - TryEnum::Option => make::path_pat(make::ext::ident_path("None")), - TryEnum::Result => make::tuple_struct_pat( - make::ext::ident_path("Err"), - iter::once(make::path_pat(make::ext::ident_path("err"))), - ) - .into(), - }; - let sad_expr = match try_enum { - TryEnum::Option => { - make::expr_return(Some(make::expr_path(make::ext::ident_path("None")))) - } - TryEnum::Result => make::expr_return(Some( - make::expr_call( - make::expr_path(make::ext::ident_path("Err")), - make::arg_list(iter::once(make::expr_path(make::ext::ident_path("err")))), + TryEnum::Option => make.path_pat(make.ident_path("None")), + TryEnum::Result => make + .tuple_struct_pat( + make.ident_path("Err"), + iter::once(make.path_pat(make.ident_path("err"))), ) .into(), - )), }; + let sad_expr = make.expr_return(Some(sad_expr(try_enum, &make, || { + make.expr_path(make.ident_path("err")) + }))); - let happy_arm = make::match_arm( - try_enum.happy_pattern(make::ident_pat(false, false, make::name("it")).into()), + let happy_arm = make.match_arm( + try_enum.happy_pattern(make.ident_pat(false, false, make.name("it")).into()), None, - make::expr_path(make::ext::ident_path("it")), + make.expr_path(make.ident_path("it")), ); - let sad_arm = make::match_arm(sad_pat, None, sad_expr); + let sad_arm = make.match_arm(sad_pat, None, sad_expr.into()); - let match_arm_list = make::match_arm_list([happy_arm, sad_arm]); + let match_arm_list = make.match_arm_list([happy_arm, sad_arm]); - let expr_match = make::expr_match(expr.clone(), match_arm_list) + let expr_match = make + .expr_match(expr.clone(), match_arm_list) .indent(IndentLevel::from_node(try_expr.syntax())); - edit.replace_ast::<ast::Expr>(try_expr.clone().into(), expr_match.into()); + editor.replace(try_expr.syntax(), expr_match.syntax()); + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); @@ -119,48 +113,18 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let mut editor = builder.make_editor(let_stmt.syntax()); let indent_level = IndentLevel::from_node(let_stmt.syntax()); + let fill_expr = || crate::utils::expr_fill_default(ctx.config); let new_let_stmt = make.let_else_stmt( try_enum.happy_pattern(pat), - let_stmt.ty(), + let_stmt.ty().map(|ty| match try_enum { + TryEnum::Option => make.ty_option(ty).into(), + TryEnum::Result => make.ty_result(ty, make.ty_infer().into()).into(), + }), expr, make.block_expr( iter::once( make.expr_stmt( - make.expr_return(Some(match try_enum { - TryEnum::Option => make.expr_path(make.ident_path("None")), - TryEnum::Result => make - .expr_call( - make.expr_path(make.ident_path("Err")), - make.arg_list(iter::once( - match ctx.config.expr_fill_default { - ExprFillDefaultMode::Todo => make - .expr_macro( - make.ident_path("todo"), - make.token_tree( - syntax::SyntaxKind::L_PAREN, - [], - ), - ) - .into(), - ExprFillDefaultMode::Underscore => { - make.expr_underscore().into() - } - ExprFillDefaultMode::Default => make - .expr_macro( - make.ident_path("todo"), - make.token_tree( - syntax::SyntaxKind::L_PAREN, - [], - ), - ) - .into(), - }, - )), - ) - .into(), - })) - .indent(indent_level + 1) - .into(), + make.expr_return(Some(sad_expr(try_enum, &make, fill_expr))).into(), ) .into(), ), @@ -177,6 +141,15 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op Some(()) } +fn sad_expr(try_enum: TryEnum, make: &SyntaxFactory, err: impl Fn() -> ast::Expr) -> ast::Expr { + match try_enum { + TryEnum::Option => make.expr_path(make.ident_path("None")), + TryEnum::Result => make + .expr_call(make.expr_path(make.ident_path("Err")), make.arg_list(iter::once(err()))) + .into(), + } +} + #[cfg(test)] mod tests { use super::*; @@ -278,4 +251,46 @@ fn test() { "Replace try expression with let else", ); } + + #[test] + fn test_desugar_try_expr_option_let_else_with_type() { + check_assist_by_label( + desugar_try_expr, + r#" +//- minicore: try, option +fn test() { + let pat: bool = Some(true)$0?; +} + "#, + r#" +fn test() { + let Some(pat): Option<bool> = Some(true) else { + return None; + }; +} + "#, + "Replace try expression with let else", + ); + } + + #[test] + fn test_desugar_try_expr_result_let_else_with_type() { + check_assist_by_label( + desugar_try_expr, + r#" +//- minicore: try, result +fn test() { + let pat: bool = Ok(true)$0?; +} + "#, + r#" +fn test() { + let Ok(pat): Result<bool, _> = Ok(true) else { + return Err(todo!()); + }; +} + "#, + "Replace try expression with let else", + ); + } } diff --git a/crates/ide-assists/src/handlers/expand_glob_import.rs b/crates/ide-assists/src/handlers/expand_glob_import.rs index 7eca4d3f2a..6c5c21bfc9 100644 --- a/crates/ide-assists/src/handlers/expand_glob_import.rs +++ b/crates/ide-assists/src/handlers/expand_glob_import.rs @@ -317,7 +317,7 @@ fn find_refs_in_mod( .into_iter() .map(|v| Ref { visible_name: v.name(ctx.db()), - def: Definition::Variant(v), + def: Definition::EnumVariant(v), is_pub: true, }) .collect(), @@ -379,7 +379,7 @@ fn find_imported_defs(ctx: &AssistContext<'_>, use_item: Use) -> Vec<Definition> | Definition::Module(_) | Definition::Function(_) | Definition::Adt(_) - | Definition::Variant(_) + | Definition::EnumVariant(_) | Definition::Const(_) | Definition::Static(_) | Definition::Trait(_) diff --git a/crates/ide-assists/src/handlers/expand_rest_pattern.rs b/crates/ide-assists/src/handlers/expand_rest_pattern.rs index 867ac48518..a7e78dfc9c 100644 --- a/crates/ide-assists/src/handlers/expand_rest_pattern.rs +++ b/crates/ide-assists/src/handlers/expand_rest_pattern.rs @@ -102,7 +102,7 @@ fn expand_tuple_struct_rest_pattern( let fields = match ctx.sema.type_of_pat(&pat.clone().into())?.original.as_adt()? { hir::Adt::Struct(s) if s.kind(ctx.sema.db) == StructKind::Tuple => s.fields(ctx.sema.db), hir::Adt::Enum(_) => match ctx.sema.resolve_path(&path)? { - PathResolution::Def(hir::ModuleDef::Variant(v)) + PathResolution::Def(hir::ModuleDef::EnumVariant(v)) if v.kind(ctx.sema.db) == StructKind::Tuple => { v.fields(ctx.sema.db) diff --git a/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs b/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs index 61af2de6ec..35e8baa18a 100644 --- a/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs +++ b/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs @@ -8,7 +8,7 @@ use syntax::{ AstNode, AstToken, NodeOrToken, SyntaxKind::WHITESPACE, SyntaxToken, T, - ast::{self, TokenTree, make, syntax_factory::SyntaxFactory}, + ast::{self, TokenTree, syntax_factory::SyntaxFactory}, }; // Assist: extract_expressions_from_format_string @@ -57,6 +57,7 @@ pub(crate) fn extract_expressions_from_format_string( "Extract format expressions", tt.syntax().text_range(), |edit| { + let make = SyntaxFactory::without_mappings(); // Extract existing arguments in macro let mut raw_tokens = tt.token_trees_and_tokens().skip(1).collect_vec(); let format_string_index = format_str_index(&raw_tokens, &fmt_string); @@ -94,14 +95,14 @@ pub(crate) fn extract_expressions_from_format_string( let mut new_tt_bits = raw_tokens; let mut placeholder_indexes = vec![]; - new_tt_bits.push(NodeOrToken::Token(make::tokens::literal(&new_fmt))); + new_tt_bits.push(NodeOrToken::Token(make.expr_literal(&new_fmt).token().clone())); for arg in extracted_args { if matches!(arg, Arg::Expr(_) | Arg::Placeholder) { // insert ", " before each arg new_tt_bits.extend_from_slice(&[ - NodeOrToken::Token(make::token(T![,])), - NodeOrToken::Token(make::tokens::single_space()), + NodeOrToken::Token(make.token(T![,])), + NodeOrToken::Token(make.whitespace(" ")), ]); } @@ -109,7 +110,7 @@ pub(crate) fn extract_expressions_from_format_string( Arg::Expr(s) => { // insert arg let expr = ast::Expr::parse(&s, ctx.edition()).syntax_node(); - let mut expr_tt = utils::tt_from_syntax(expr); + let mut expr_tt = utils::tt_from_syntax(expr, &make); new_tt_bits.append(&mut expr_tt); } Arg::Placeholder => { @@ -120,7 +121,7 @@ pub(crate) fn extract_expressions_from_format_string( } None => { placeholder_indexes.push(new_tt_bits.len()); - new_tt_bits.push(NodeOrToken::Token(make::token(T![_]))); + new_tt_bits.push(NodeOrToken::Token(make.token(T![_]))); } } } @@ -129,7 +130,6 @@ pub(crate) fn extract_expressions_from_format_string( } // Insert new args - let make = SyntaxFactory::with_mappings(); let new_tt = make.token_tree(tt_delimiter, new_tt_bits); let mut editor = edit.make_editor(tt.syntax()); editor.replace(tt.syntax(), new_tt.syntax()); diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index f2363c6f7b..124ef509fb 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -859,7 +859,7 @@ impl FunctionBody { ast::BlockExpr(block_expr) => { let (constness, block) = match block_expr.modifier() { Some(ast::BlockModifier::Const(_)) => (true, block_expr), - Some(ast::BlockModifier::Try(_)) => (false, block_expr), + Some(ast::BlockModifier::Try { .. }) => (false, block_expr), Some(ast::BlockModifier::Label(label)) if label.lifetime().is_some() => (false, block_expr), _ => continue, }; diff --git a/crates/ide-assists/src/handlers/extract_module.rs b/crates/ide-assists/src/handlers/extract_module.rs index a17ae4885e..dcbeaefa21 100644 --- a/crates/ide-assists/src/handlers/extract_module.rs +++ b/crates/ide-assists/src/handlers/extract_module.rs @@ -728,7 +728,7 @@ fn check_def_in_mod_and_out_sel( } Definition::Function(x) => check_item!(x), Definition::Adt(x) => check_item!(x), - Definition::Variant(x) => check_item!(x), + Definition::EnumVariant(x) => check_item!(x), Definition::Const(x) => check_item!(x), Definition::Static(x) => check_item!(x), Definition::Trait(x) => check_item!(x), diff --git a/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs index 386652a422..4c46a51bef 100644 --- a/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs @@ -1,7 +1,7 @@ use std::iter; use either::Either; -use hir::{HasCrate, Module, ModuleDef, Name, Variant}; +use hir::{EnumVariant, HasCrate, Module, ModuleDef, Name}; use ide_db::{ FxHashSet, RootDatabase, defs::Definition, @@ -16,9 +16,7 @@ use syntax::{ SyntaxKind::*, SyntaxNode, T, ast::{ - self, AstNode, HasAttrs, HasGenericParams, HasName, HasVisibility, - edit::{AstNodeEdit, IndentLevel}, - make, + self, AstNode, HasAttrs, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, make, }, match_ast, ted, }; @@ -63,7 +61,7 @@ pub(crate) fn extract_struct_from_enum_variant( let edition = enum_hir.krate(ctx.db()).edition(ctx.db()); let variant_hir_name = variant_hir.name(ctx.db()); let enum_module_def = ModuleDef::from(enum_hir); - let usages = Definition::Variant(variant_hir).usages(&ctx.sema).all(); + let usages = Definition::EnumVariant(variant_hir).usages(&ctx.sema).all(); let mut visited_modules_set = FxHashSet::default(); let current_module = enum_hir.module(ctx.db()); @@ -163,7 +161,7 @@ fn extract_field_list_if_applicable( } } -fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Variant) -> bool { +fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool { variant .parent_enum(db) .module(db) @@ -175,7 +173,7 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va def, ModuleDef::Module(_) | ModuleDef::Adt(_) - | ModuleDef::Variant(_) + | ModuleDef::EnumVariant(_) | ModuleDef::Trait(_) | ModuleDef::TypeAlias(_) | ModuleDef::BuiltinType(_) @@ -290,7 +288,6 @@ fn create_struct_def( field_list.clone().into() } }; - let field_list = field_list.indent(IndentLevel::single()); let strukt = make::struct_(enum_vis, name, generics, field_list).clone_for_update(); diff --git a/crates/ide-assists/src/handlers/extract_type_alias.rs b/crates/ide-assists/src/handlers/extract_type_alias.rs index 769bbd976a..e4fdac27f4 100644 --- a/crates/ide-assists/src/handlers/extract_type_alias.rs +++ b/crates/ide-assists/src/handlers/extract_type_alias.rs @@ -2,7 +2,10 @@ use either::Either; use hir::HirDisplay; use ide_db::syntax_helpers::node_ext::walk_ty; use syntax::{ - ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, make}, + ast::{ + self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, + syntax_factory::SyntaxFactory, + }, syntax_editor, }; @@ -43,10 +46,9 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> let resolved_ty = ctx.sema.resolve_type(&ty)?; let resolved_ty = if !resolved_ty.contains_unknown() { let module = ctx.sema.scope(ty.syntax())?.module(); - let resolved_ty = resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()?; - make::ty(&resolved_ty) + resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()? } else { - ty.clone() + ty.to_string() }; acc.add( @@ -55,6 +57,9 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> target, |builder| { let mut edit = builder.make_editor(node); + let make = SyntaxFactory::without_mappings(); + + let resolved_ty = make.ty(&resolved_ty); let mut known_generics = match item.generic_param_list() { Some(it) => it.generic_params().collect(), @@ -68,22 +73,20 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> } let generics = collect_used_generics(&ty, &known_generics); let generic_params = - generics.map(|it| make::generic_param_list(it.into_iter().cloned())); + generics.map(|it| make.generic_param_list(it.into_iter().cloned())); // Replace original type with the alias let ty_args = generic_params.as_ref().map(|it| it.to_generic_args().generic_args()); let new_ty = if let Some(ty_args) = ty_args { - make::generic_ty_path_segment(make::name_ref("Type"), ty_args) + make.generic_ty_path_segment(make.name_ref("Type"), ty_args) } else { - make::path_segment(make::name_ref("Type")) - } - .clone_for_update(); + make.path_segment(make.name_ref("Type")) + }; edit.replace(ty.syntax(), new_ty.syntax()); // Insert new alias let ty_alias = - make::ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None))) - .clone_for_update(); + make.ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None))); if let Some(cap) = ctx.config.snippet_cap && let Some(name) = ty_alias.name() @@ -96,7 +99,7 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> syntax_editor::Position::before(node), vec![ ty_alias.syntax().clone().into(), - make::tokens::whitespace(&format!("\n\n{indent}")).into(), + make.whitespace(&format!("\n\n{indent}")).into(), ], ); diff --git a/crates/ide-assists/src/handlers/extract_variable.rs b/crates/ide-assists/src/handlers/extract_variable.rs index 7c60184142..e5ce02cf53 100644 --- a/crates/ide-assists/src/handlers/extract_variable.rs +++ b/crates/ide-assists/src/handlers/extract_variable.rs @@ -9,7 +9,6 @@ use syntax::{ ast::{ self, AstNode, edit::{AstNodeEdit, IndentLevel}, - make, syntax_factory::SyntaxFactory, }, syntax_editor::Position, @@ -75,7 +74,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op .next() .and_then(ast::Expr::cast) { - expr.syntax().ancestors().find_map(valid_target_expr)?.syntax().clone() + expr.syntax().ancestors().find_map(valid_target_expr(ctx))?.syntax().clone() } else { return None; } @@ -96,7 +95,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let to_extract = node .descendants() .take_while(|it| range.contains_range(it.text_range())) - .find_map(valid_target_expr)?; + .find_map(valid_target_expr(ctx))?; let ty = ctx.sema.type_of_expr(&to_extract).map(TypeInfo::adjusted); if matches!(&ty, Some(ty_info) if ty_info.is_unit()) { @@ -176,7 +175,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let mut editor = edit.make_editor(&expr_replace); let pat_name = make.name(&var_name); - let name_expr = make.expr_path(make::ext::ident_path(&var_name)); + let name_expr = make.expr_path(make.ident_path(&var_name)); if let Some(cap) = ctx.config.snippet_cap { let tabstop = edit.make_tabstop_before(cap); @@ -233,7 +232,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op Position::before(place), vec![ new_stmt.syntax().clone().into(), - make::tokens::whitespace(&trailing_ws).into(), + make.whitespace(&trailing_ws).into(), ], ); @@ -283,14 +282,19 @@ fn peel_parens(mut expr: ast::Expr) -> ast::Expr { /// Check whether the node is a valid expression which can be extracted to a variable. /// In general that's true for any expression, but in some cases that would produce invalid code. -fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> { - match node.kind() { - SyntaxKind::PATH_EXPR | SyntaxKind::LOOP_EXPR | SyntaxKind::LET_EXPR => None, +fn valid_target_expr(ctx: &AssistContext<'_>) -> impl Fn(SyntaxNode) -> Option<ast::Expr> { + |node| match node.kind() { + SyntaxKind::LOOP_EXPR | SyntaxKind::LET_EXPR => None, SyntaxKind::BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()), SyntaxKind::RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()), SyntaxKind::BLOCK_EXPR => { ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from) } + SyntaxKind::PATH_EXPR => { + let path_expr = ast::PathExpr::cast(node)?; + let path_resolution = ctx.sema.resolve_path(&path_expr.path()?)?; + like_const_value(ctx, path_resolution).then_some(path_expr.into()) + } _ => ast::Expr::cast(node), } } @@ -455,6 +459,31 @@ impl Anchor { } } +fn like_const_value(ctx: &AssistContext<'_>, path_resolution: hir::PathResolution) -> bool { + let db = ctx.db(); + let adt_like_const_value = |adt: Option<hir::Adt>| matches!(adt, Some(hir::Adt::Struct(s)) if s.kind(db) == hir::StructKind::Unit); + match path_resolution { + hir::PathResolution::Def(def) => match def { + hir::ModuleDef::Adt(adt) => adt_like_const_value(Some(adt)), + hir::ModuleDef::EnumVariant(variant) => variant.kind(db) == hir::StructKind::Unit, + hir::ModuleDef::TypeAlias(ty) => adt_like_const_value(ty.ty(db).as_adt()), + hir::ModuleDef::Const(_) | hir::ModuleDef::Static(_) => true, + hir::ModuleDef::Trait(_) + | hir::ModuleDef::BuiltinType(_) + | hir::ModuleDef::Macro(_) + | hir::ModuleDef::Module(_) => false, + hir::ModuleDef::Function(_) => false, // no extract named function + }, + hir::PathResolution::SelfType(ty) => adt_like_const_value(ty.self_ty(db).as_adt()), + hir::PathResolution::ConstParam(_) => true, + hir::PathResolution::Local(_) + | hir::PathResolution::TypeParam(_) + | hir::PathResolution::BuiltinAttr(_) + | hir::PathResolution::ToolModule(_) + | hir::PathResolution::DeriveHelper(_) => false, + } +} + #[cfg(test)] mod tests { // NOTE: We use check_assist_by_label, but not check_assist_not_applicable_by_label @@ -1748,6 +1777,27 @@ fn main() { } #[test] + fn extract_non_local_path_expr() { + check_assist_by_label( + extract_variable, + r#" +struct Foo; +fn foo() -> Foo { + $0Foo$0 +} +"#, + r#" +struct Foo; +fn foo() -> Foo { + let $0foo = Foo; + foo +} +"#, + "Extract into variable", + ); + } + + #[test] fn extract_var_for_return_not_applicable() { check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } "); } diff --git a/crates/ide-assists/src/handlers/fix_visibility.rs b/crates/ide-assists/src/handlers/fix_visibility.rs index 0fd8057a39..440f2d5f17 100644 --- a/crates/ide-assists/src/handlers/fix_visibility.rs +++ b/crates/ide-assists/src/handlers/fix_visibility.rs @@ -2,7 +2,7 @@ use hir::{HasSource, HasVisibility, ModuleDef, PathResolution, ScopeDef, db::Hir use ide_db::FileId; use syntax::{ AstNode, TextRange, - ast::{self, HasVisibility as _, edit_in_place::HasVisibilityEdit, make}, + ast::{self, HasVisibility as _, syntax_factory::SyntaxFactory}, }; use crate::{AssistContext, AssistId, Assists}; @@ -59,10 +59,12 @@ fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext<'_>) let (vis_owner, target, target_file, target_name) = target_data_for_def(ctx.db(), def)?; + let make = SyntaxFactory::without_mappings(); + let missing_visibility = if current_module.krate(ctx.db()) == target_module.krate(ctx.db()) { - make::visibility_pub_crate() + make.visibility_pub_crate() } else { - make::visibility_pub() + make.visibility_pub() }; let assist_label = match target_name { @@ -75,15 +77,36 @@ fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext<'_>) } }; - acc.add(AssistId::quick_fix("fix_visibility"), assist_label, target, |edit| { - edit.edit_file(target_file); - - let vis_owner = edit.make_mut(vis_owner); - vis_owner.set_visibility(Some(missing_visibility.clone_for_update())); + acc.add(AssistId::quick_fix("fix_visibility"), assist_label, target, |builder| { + let mut editor = builder.make_editor(vis_owner.syntax()); + + if let Some(current_visibility) = vis_owner.visibility() { + editor.replace(current_visibility.syntax(), missing_visibility.syntax()); + } else { + let vis_before = vis_owner + .syntax() + .children_with_tokens() + .find(|it| { + !matches!( + it.kind(), + syntax::SyntaxKind::WHITESPACE + | syntax::SyntaxKind::COMMENT + | syntax::SyntaxKind::ATTR + ) + }) + .unwrap_or_else(|| vis_owner.syntax().first_child_or_token().unwrap()); + + editor.insert_all( + syntax::syntax_editor::Position::before(vis_before), + vec![missing_visibility.syntax().clone().into(), make.whitespace(" ").into()], + ); + } - if let Some((cap, vis)) = ctx.config.snippet_cap.zip(vis_owner.visibility()) { - edit.add_tabstop_before(cap, vis); + if let Some(cap) = ctx.config.snippet_cap { + editor.add_annotation(missing_visibility.syntax(), builder.make_tabstop_before(cap)); } + + builder.add_file_edits(target_file, editor); }) } @@ -150,7 +173,7 @@ fn target_data_for_def( // FIXME hir::ModuleDef::Macro(_) => return None, // Enum variants can't be private, we can't modify builtin types - hir::ModuleDef::Variant(_) | hir::ModuleDef::BuiltinType(_) => return None, + hir::ModuleDef::EnumVariant(_) | hir::ModuleDef::BuiltinType(_) => return None, }; Some((offset, target, target_file, target_name)) diff --git a/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs b/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs index b0fa9e6b3e..e022a27e51 100644 --- a/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs +++ b/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs @@ -5,15 +5,15 @@ use crate::{ use hir::{HasCrate, Semantics}; use ide_db::{ RootDatabase, - assists::{AssistId, AssistKind, ExprFillDefaultMode}, + assists::{AssistId, AssistKind}, famous_defs::FamousDefs, syntax_helpers::suggest_name, }; use syntax::{ AstNode, ast::{ - self, AssocItem, BlockExpr, GenericParam, HasAttrs, HasGenericParams, HasName, - HasTypeBounds, HasVisibility, edit::AstNodeEdit, make, + self, AssocItem, GenericParam, HasAttrs, HasGenericParams, HasName, HasTypeBounds, + HasVisibility, edit::AstNodeEdit, make, }, syntax_editor::Position, }; @@ -269,7 +269,7 @@ fn todo_fn(f: &ast::Fn, config: &AssistConfig) -> ast::Fn { f.generic_param_list(), f.where_clause(), params, - default_block(config), + make::block_expr(None, Some(crate::utils::expr_fill_default(config))), f.ret_type(), f.async_token().is_some(), f.const_token().is_some(), @@ -278,15 +278,6 @@ fn todo_fn(f: &ast::Fn, config: &AssistConfig) -> ast::Fn { ) } -fn default_block(config: &AssistConfig) -> BlockExpr { - let expr = match config.expr_fill_default { - ExprFillDefaultMode::Todo => make::ext::expr_todo(), - ExprFillDefaultMode::Underscore => make::ext::expr_underscore(), - ExprFillDefaultMode::Default => make::ext::expr_todo(), - }; - make::block_expr(None, Some(expr)) -} - fn cfg_attrs(node: &impl HasAttrs) -> impl Iterator<Item = ast::Attr> { node.attrs().filter(|attr| attr.as_simple_call().is_some_and(|(name, _arg)| name == "cfg")) } diff --git a/crates/ide-assists/src/handlers/generate_default_from_new.rs b/crates/ide-assists/src/handlers/generate_default_from_new.rs index 48400d436a..2d92bf5146 100644 --- a/crates/ide-assists/src/handlers/generate_default_from_new.rs +++ b/crates/ide-assists/src/handlers/generate_default_from_new.rs @@ -1,8 +1,12 @@ use ide_db::famous_defs::FamousDefs; -use stdx::format_to; use syntax::{ AstNode, - ast::{self, HasGenericParams, HasName, HasTypeBounds, Impl, make}, + ast::{ + self, HasGenericParams, HasName, HasTypeBounds, Impl, + edit::{AstNodeEdit, IndentLevel}, + syntax_factory::SyntaxFactory, + }, + syntax_editor::Position, }; use crate::{ @@ -62,29 +66,32 @@ pub(crate) fn generate_default_from_new(acc: &mut Assists, ctx: &AssistContext<' return None; } - let insert_location = impl_.syntax().text_range(); + let target = impl_.syntax().text_range(); acc.add( AssistId::generate("generate_default_from_new"), "Generate a Default impl from a new fn", - insert_location, + target, move |builder| { - let default_code = " fn default() -> Self { - Self::new() - }"; - let code = generate_trait_impl_text_from_impl(&impl_, self_ty, "Default", default_code); - builder.insert(insert_location.end(), code); + let make = SyntaxFactory::without_mappings(); + let default_impl = generate_default_impl(&make, &impl_, self_ty); + let indent = IndentLevel::from_node(impl_.syntax()); + let default_impl = default_impl.indent(indent); + + let mut editor = builder.make_editor(impl_.syntax()); + editor.insert_all( + Position::after(impl_.syntax()), + vec![ + make.whitespace(&format!("\n\n{indent}")).into(), + default_impl.syntax().clone().into(), + ], + ); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } -// FIXME: based on from utils::generate_impl_text_inner -fn generate_trait_impl_text_from_impl( - impl_: &ast::Impl, - self_ty: ast::Type, - trait_text: &str, - code: &str, -) -> String { +fn generate_default_impl(make: &SyntaxFactory, impl_: &ast::Impl, self_ty: ast::Type) -> ast::Impl { let generic_params = impl_.generic_param_list().map(|generic_params| { let lifetime_params = generic_params.lifetime_params().map(ast::GenericParam::LifetimeParam); @@ -92,40 +99,59 @@ fn generate_trait_impl_text_from_impl( // remove defaults since they can't be specified in impls let param = match param { ast::TypeOrConstParam::Type(param) => { - let param = make::type_param(param.name()?, param.type_bound_list()); + let param = make.type_param(param.name()?, param.type_bound_list()); ast::GenericParam::TypeParam(param) } ast::TypeOrConstParam::Const(param) => { - let param = make::const_param(param.name()?, param.ty()?); + let param = make.const_param(param.name()?, param.ty()?); ast::GenericParam::ConstParam(param) } }; Some(param) }); - make::generic_param_list(itertools::chain(lifetime_params, ty_or_const_params)) + make.generic_param_list(itertools::chain(lifetime_params, ty_or_const_params)) }); - let mut buf = String::with_capacity(code.len()); - buf.push_str("\n\n"); - - // `impl{generic_params} {trait_text} for {impl_.self_ty()}` - buf.push_str("impl"); - if let Some(generic_params) = &generic_params { - format_to!(buf, "{generic_params}") - } - format_to!(buf, " {trait_text} for {self_ty}"); - - match impl_.where_clause() { - Some(where_clause) => { - format_to!(buf, "\n{where_clause}\n{{\n{code}\n}}"); - } - None => { - format_to!(buf, " {{\n{code}\n}}"); - } - } - - buf + let trait_ty: ast::Type = make.ty_path(make.ident_path("Default")).into(); + + let self_new_path = make.path_concat(make.ident_path("Self"), make.ident_path("new")); + let self_new_call = + make.expr_call(make.expr_path(self_new_path), make.arg_list(std::iter::empty())); + let fn_body = make.block_expr(std::iter::empty(), Some(self_new_call.into())); + let self_ty_ret: ast::Type = make.ty_path(make.ident_path("Self")).into(); + let default_fn = make + .fn_( + [], + None, + make.name("default"), + None, + None, + make.param_list(None, std::iter::empty()), + fn_body, + Some(make.ret_type(self_ty_ret)), + false, + false, + false, + false, + ) + .indent(1.into()); + let body = make.assoc_item_list(Some(ast::AssocItem::from(default_fn))); + + make.impl_trait( + [], + false, + None, + None, + generic_params, + None, + false, + trait_ty, + self_ty, + None, + impl_.where_clause(), + Some(body), + ) } fn is_default_implemented(ctx: &AssistContext<'_>, impl_: &Impl) -> bool { @@ -628,12 +654,12 @@ mod test { } } -impl Default for Example { - fn default() -> Self { - Self::new() + impl Default for Example { + fn default() -> Self { + Self::new() + } } } -} "#, ); } diff --git a/crates/ide-assists/src/handlers/generate_delegate_methods.rs b/crates/ide-assists/src/handlers/generate_delegate_methods.rs index c1eb1a74ec..63033c7d5e 100644 --- a/crates/ide-assists/src/handlers/generate_delegate_methods.rs +++ b/crates/ide-assists/src/handlers/generate_delegate_methods.rs @@ -4,7 +4,7 @@ use syntax::{ ast::{ self, AstNode, HasGenericParams, HasName, HasVisibility as _, edit::{AstNodeEdit, IndentLevel}, - make, + syntax_factory::SyntaxFactory, }, syntax_editor::Position, }; @@ -100,7 +100,6 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' let Some(impl_def) = find_struct_impl(ctx, &adt, std::slice::from_ref(&name)) else { continue; }; - let field = make::ext::field_from_idents(["self", &field_name])?; acc.add_group( &GroupLabel("Generate delegate methods…".to_owned()), @@ -108,10 +107,14 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' format!("Generate delegate for `{field_name}.{name}()`",), target, |edit| { + let make = SyntaxFactory::without_mappings(); + let field = make + .field_from_idents(["self", &field_name]) + .expect("always be a valid expression"); // Create the function let method_source = match ctx.sema.source(method) { Some(source) => { - let v = source.value.clone_for_update(); + let v = source.value; let source_scope = ctx.sema.scope(v.syntax()); let target_scope = ctx.sema.scope(strukt.syntax()); if let (Some(s), Some(t)) = (source_scope, target_scope) { @@ -132,42 +135,42 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' let is_unsafe = method_source.unsafe_token().is_some(); let is_gen = method_source.gen_token().is_some(); - let fn_name = make::name(&name); + let fn_name = make.name(&name); let type_params = method_source.generic_param_list(); let where_clause = method_source.where_clause(); let params = - method_source.param_list().unwrap_or_else(|| make::param_list(None, [])); + method_source.param_list().unwrap_or_else(|| make.param_list(None, [])); // compute the `body` let arg_list = method_source .param_list() - .map(convert_param_list_to_arg_list) - .unwrap_or_else(|| make::arg_list([])); + .map(|v| convert_param_list_to_arg_list(v, &make)) + .unwrap_or_else(|| make.arg_list([])); - let tail_expr = - make::expr_method_call(field, make::name_ref(&name), arg_list).into(); + let tail_expr = make.expr_method_call(field, make.name_ref(&name), arg_list).into(); let tail_expr_finished = - if is_async { make::expr_await(tail_expr) } else { tail_expr }; - let body = make::block_expr([], Some(tail_expr_finished)); + if is_async { make.expr_await(tail_expr).into() } else { tail_expr }; + let body = make.block_expr([], Some(tail_expr_finished)); let ret_type = method_source.ret_type(); - let f = make::fn_( - None, - vis, - fn_name, - type_params, - where_clause, - params, - body, - ret_type, - is_async, - is_const, - is_unsafe, - is_gen, - ) - .indent(IndentLevel(1)); + let f = make + .fn_( + None, + vis, + fn_name, + type_params, + where_clause, + params, + body, + ret_type, + is_async, + is_const, + is_unsafe, + is_gen, + ) + .indent(IndentLevel(1)); let item = ast::AssocItem::Fn(f.clone()); let mut editor = edit.make_editor(strukt.syntax()); @@ -179,7 +182,7 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' Some(item) } None => { - let assoc_item_list = make::assoc_item_list(Some(vec![item])); + let assoc_item_list = make.assoc_item_list(vec![item]); editor.insert( Position::last_child_of(impl_def.syntax()), assoc_item_list.syntax(), @@ -192,17 +195,16 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' let ty_params = strukt.generic_param_list(); let ty_args = ty_params.as_ref().map(|it| it.to_generic_args()); let where_clause = strukt.where_clause(); - let assoc_item_list = make::assoc_item_list(Some(vec![item])); + let assoc_item_list = make.assoc_item_list(vec![item]); - let impl_def = make::impl_( + let impl_def = make.impl_( None, ty_params, ty_args, - make::ty_path(make::ext::ident_path(name)), + syntax::ast::Type::PathType(make.ty_path(make.ident_path(name))), where_clause, Some(assoc_item_list), - ) - .clone_for_update(); + ); // Fixup impl_def indentation let indent = strukt.indent_level(); @@ -213,7 +215,7 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' editor.insert_all( Position::after(strukt.syntax()), vec![ - make::tokens::whitespace(&format!("\n\n{indent}")).into(), + make.whitespace(&format!("\n\n{indent}")).into(), impl_def.syntax().clone().into(), ], ); @@ -227,6 +229,7 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' let tabstop = edit.make_tabstop_before(cap); editor.add_annotation(fn_.syntax(), tabstop); } + editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, )?; diff --git a/crates/ide-assists/src/handlers/generate_delegate_trait.rs b/crates/ide-assists/src/handlers/generate_delegate_trait.rs index 921f04f2a5..f703e4dc4a 100644 --- a/crates/ide-assists/src/handlers/generate_delegate_trait.rs +++ b/crates/ide-assists/src/handlers/generate_delegate_trait.rs @@ -782,7 +782,7 @@ fn func_assoc_item( }; // Build argument list with self expression prepended - let other_args = convert_param_list_to_arg_list(l); + let other_args = convert_param_list_to_arg_list(l, &make); let all_args: Vec<ast::Expr> = std::iter::once(tail_expr_self).chain(other_args.args()).collect(); let args = make.arg_list(all_args); @@ -790,13 +790,13 @@ fn func_assoc_item( make.expr_call(make.expr_path(qualified_path), args).into() } None => make - .expr_call(make.expr_path(qualified_path), convert_param_list_to_arg_list(l)) + .expr_call(make.expr_path(qualified_path), convert_param_list_to_arg_list(l, &make)) .into(), }, None => make .expr_call( make.expr_path(qualified_path), - convert_param_list_to_arg_list(make.param_list(None, Vec::new())), + convert_param_list_to_arg_list(make.param_list(None, Vec::new()), &make), ) .into(), }; diff --git a/crates/ide-assists/src/handlers/generate_deref.rs b/crates/ide-assists/src/handlers/generate_deref.rs index 494c87e6d1..5534dc1cd3 100644 --- a/crates/ide-assists/src/handlers/generate_deref.rs +++ b/crates/ide-assists/src/handlers/generate_deref.rs @@ -1,16 +1,15 @@ -use std::fmt::Display; - use hir::{ModPath, ModuleDef}; -use ide_db::{RootDatabase, famous_defs::FamousDefs}; +use ide_db::{FileId, RootDatabase, famous_defs::FamousDefs}; use syntax::{ - AstNode, Edition, SyntaxNode, - ast::{self, HasName}, + Edition, + ast::{self, AstNode, HasName, edit::AstNodeEdit, syntax_factory::SyntaxFactory}, + syntax_editor::Position, }; use crate::{ AssistId, assist_context::{AssistContext, Assists, SourceChangeBuilder}, - utils::generate_trait_impl_text_intransitive, + utils::generate_trait_impl_intransitive_with_item, }; // Assist: generate_deref @@ -64,6 +63,7 @@ fn generate_record_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( let field_type = field.ty()?; let field_name = field.name()?; let target = field.syntax().text_range(); + let file_id = ctx.vfs_file_id(); acc.add( AssistId::generate("generate_deref"), format!("Generate `{deref_type_to_generate:?}` impl using `{field_name}`"), @@ -72,9 +72,10 @@ fn generate_record_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( generate_edit( ctx.db(), edit, + file_id, strukt, - field_type.syntax(), - field_name.syntax(), + field_type, + &field_name.to_string(), deref_type_to_generate, trait_path, module.krate(ctx.db()).edition(ctx.db()), @@ -105,6 +106,7 @@ fn generate_tuple_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<() let field_type = field.ty()?; let target = field.syntax().text_range(); + let file_id = ctx.vfs_file_id(); acc.add( AssistId::generate("generate_deref"), format!("Generate `{deref_type_to_generate:?}` impl using `{field}`"), @@ -113,9 +115,10 @@ fn generate_tuple_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<() generate_edit( ctx.db(), edit, + file_id, strukt, - field_type.syntax(), - field_list_index, + field_type, + &field_list_index.to_string(), deref_type_to_generate, trait_path, module.krate(ctx.db()).edition(ctx.db()), @@ -127,35 +130,81 @@ fn generate_tuple_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<() fn generate_edit( db: &RootDatabase, edit: &mut SourceChangeBuilder, + file_id: FileId, strukt: ast::Struct, - field_type_syntax: &SyntaxNode, - field_name: impl Display, + field_type: ast::Type, + field_name: &str, deref_type: DerefType, trait_path: ModPath, edition: Edition, ) { - let start_offset = strukt.syntax().text_range().end(); - let impl_code = match deref_type { - DerefType::Deref => format!( - r#" type Target = {field_type_syntax}; - - fn deref(&self) -> &Self::Target {{ - &self.{field_name} - }}"#, - ), - DerefType::DerefMut => format!( - r#" fn deref_mut(&mut self) -> &mut Self::Target {{ - &mut self.{field_name} - }}"#, - ), + let make = SyntaxFactory::with_mappings(); + let strukt_adt = ast::Adt::Struct(strukt.clone()); + let trait_ty = make.ty(&trait_path.display(db, edition).to_string()); + + let assoc_items: Vec<ast::AssocItem> = match deref_type { + DerefType::Deref => { + let target_alias = + make.ty_alias([], "Target", None, None, None, Some((field_type, None))); + let ret_ty = + make.ty_ref(make.ty_path(make.path_from_text("Self::Target")).into(), false); + let field_expr = make.expr_field(make.expr_path(make.ident_path("self")), field_name); + let body = make.block_expr([], Some(make.expr_ref(field_expr.into(), false))); + let fn_ = make + .fn_( + [], + None, + make.name("deref"), + None, + None, + make.param_list(Some(make.self_param()), []), + body, + Some(make.ret_type(ret_ty)), + false, + false, + false, + false, + ) + .indent(1.into()); + vec![ast::AssocItem::TypeAlias(target_alias), ast::AssocItem::Fn(fn_)] + } + DerefType::DerefMut => { + let ret_ty = + make.ty_ref(make.ty_path(make.path_from_text("Self::Target")).into(), true); + let field_expr = make.expr_field(make.expr_path(make.ident_path("self")), field_name); + let body = make.block_expr([], Some(make.expr_ref(field_expr.into(), true))); + let fn_ = make + .fn_( + [], + None, + make.name("deref_mut"), + None, + None, + make.param_list(Some(make.mut_self_param()), []), + body, + Some(make.ret_type(ret_ty)), + false, + false, + false, + false, + ) + .indent(1.into()); + vec![ast::AssocItem::Fn(fn_)] + } }; - let strukt_adt = ast::Adt::Struct(strukt); - let deref_impl = generate_trait_impl_text_intransitive( - &strukt_adt, - &trait_path.display(db, edition).to_string(), - &impl_code, + + let body = make.assoc_item_list(assoc_items); + let indent = strukt.indent_level(); + let impl_ = generate_trait_impl_intransitive_with_item(&make, &strukt_adt, trait_ty, body) + .indent(indent); + + let mut editor = edit.make_editor(strukt.syntax()); + editor.insert_all( + Position::after(strukt.syntax()), + vec![make.whitespace(&format!("\n\n{indent}")).into(), impl_.syntax().clone().into()], ); - edit.insert(start_offset, deref_impl); + editor.add_mappings(make.finish_with_mappings()); + edit.add_file_edits(file_id, editor); } fn existing_deref_impl( diff --git a/crates/ide-assists/src/handlers/generate_derive.rs b/crates/ide-assists/src/handlers/generate_derive.rs index 06fef4af22..3ef68f06e4 100644 --- a/crates/ide-assists/src/handlers/generate_derive.rs +++ b/crates/ide-assists/src/handlers/generate_derive.rs @@ -1,7 +1,7 @@ use syntax::{ SyntaxKind::{ATTR, COMMENT, WHITESPACE}, T, - ast::{self, AstNode, HasAttrs, edit::IndentLevel, make}, + ast::{self, AstNode, HasAttrs, edit::IndentLevel, syntax_factory::SyntaxFactory}, syntax_editor::{Element, Position}, }; @@ -42,13 +42,15 @@ pub(crate) fn generate_derive(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt }; acc.add(AssistId::generate("generate_derive"), "Add `#[derive]`", target, |edit| { + let make = SyntaxFactory::without_mappings(); + match derive_attr { None => { - let derive = make::attr_outer(make::meta_token_tree( - make::ext::ident_path("derive"), - make::token_tree(T!['('], vec![]).clone_for_update(), - )) - .clone_for_update(); + let derive = + make.attr_outer(make.meta_token_tree( + make.ident_path("derive"), + make.token_tree(T!['('], vec![]), + )); let mut editor = edit.make_editor(nominal.syntax()); let indent = IndentLevel::from_node(nominal.syntax()); @@ -57,11 +59,12 @@ pub(crate) fn generate_derive(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt .children_with_tokens() .find(|it| !matches!(it.kind(), WHITESPACE | COMMENT | ATTR)) .map_or(Position::first_child_of(nominal.syntax()), Position::before); + editor.insert_all( after_attrs_and_comments, vec![ derive.syntax().syntax_element(), - make::tokens::whitespace(&format!("\n{indent}")).syntax_element(), + make.whitespace(&format!("\n{indent}")).syntax_element(), ], ); @@ -72,7 +75,9 @@ pub(crate) fn generate_derive(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt .expect("failed to get token tree out of Meta") .r_paren_token() .expect("make::attr_outer was expected to have a R_PAREN"); + let tabstop_before = edit.make_tabstop_before(cap); + editor.add_annotation(delimiter, tabstop_before); edit.add_file_edits(ctx.vfs_file_id(), editor); } diff --git a/crates/ide-assists/src/handlers/generate_fn_type_alias.rs b/crates/ide-assists/src/handlers/generate_fn_type_alias.rs index 7fd94b4bed..6bcbd9b0cc 100644 --- a/crates/ide-assists/src/handlers/generate_fn_type_alias.rs +++ b/crates/ide-assists/src/handlers/generate_fn_type_alias.rs @@ -2,7 +2,7 @@ use either::Either; use ide_db::assists::{AssistId, GroupLabel}; use syntax::{ AstNode, - ast::{self, HasGenericParams, HasName, edit::IndentLevel, make}, + ast::{self, HasGenericParams, HasName, edit::IndentLevel, syntax_factory::SyntaxFactory}, syntax_editor, }; @@ -56,6 +56,7 @@ pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) func_node.syntax().text_range(), |builder| { let mut edit = builder.make_editor(func); + let make = SyntaxFactory::without_mappings(); let alias_name = format!("{}Fn", stdx::to_camel_case(&name.to_string())); @@ -68,24 +69,24 @@ pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) let is_mut = self_ty.is_mutable_reference(); if let Some(adt) = self_ty.strip_references().as_adt() { - let inner_type = make::ty(adt.name(ctx.db()).as_str()); + let inner_type = make.ty(adt.name(ctx.db()).as_str()); let ast_self_ty = - if is_ref { make::ty_ref(inner_type, is_mut) } else { inner_type }; + if is_ref { make.ty_ref(inner_type, is_mut) } else { inner_type }; - fn_params_vec.push(make::unnamed_param(ast_self_ty)); + fn_params_vec.push(make.unnamed_param(ast_self_ty)); } } fn_params_vec.extend(param_list.params().filter_map(|p| match style { ParamStyle::Named => Some(p), - ParamStyle::Unnamed => p.ty().map(make::unnamed_param), + ParamStyle::Unnamed => p.ty().map(|ty| make.unnamed_param(ty)), })); let generic_params = func_node.generic_param_list(); let is_unsafe = func_node.unsafe_token().is_some(); - let ty = make::ty_fn_ptr( + let ty = make.ty_fn_ptr( is_unsafe, func_node.abi(), fn_params_vec.into_iter(), @@ -93,22 +94,21 @@ pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) ); // Insert new alias - let ty_alias = make::ty_alias( + let ty_alias = make.ty_alias( None, &alias_name, generic_params, None, None, Some((ast::Type::FnPtrType(ty), None)), - ) - .clone_for_update(); + ); let indent = IndentLevel::from_node(insertion_node); edit.insert_all( syntax_editor::Position::before(insertion_node), vec![ ty_alias.syntax().clone().into(), - make::tokens::whitespace(&format!("\n\n{indent}")).into(), + make.whitespace(&format!("\n\n{indent}")).into(), ], ); diff --git a/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs b/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs index 24f271ded8..1adb3f4fe4 100644 --- a/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs +++ b/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs @@ -1,10 +1,11 @@ use hir::next_solver::{DbInterner, TypingMode}; use ide_db::{RootDatabase, famous_defs::FamousDefs}; -use syntax::ast::{self, AstNode, HasName}; +use syntax::ast::{self, AstNode, HasName, edit::AstNodeEdit, syntax_factory::SyntaxFactory}; +use syntax::syntax_editor::Position; use crate::{ AssistContext, AssistId, Assists, - utils::{generate_trait_impl_text_intransitive, is_selected}, + utils::{generate_trait_impl_intransitive_with_item, is_selected}, }; // Assist: generate_from_impl_for_enum @@ -33,39 +34,72 @@ pub(crate) fn generate_from_impl_for_enum( let variants = selected_variants(ctx, &variant)?; let target = variant.syntax().text_range(); + let file_id = ctx.vfs_file_id(); acc.add( AssistId::generate("generate_from_impl_for_enum"), "Generate `From` impl for this enum variant(s)", target, |edit| { - let start_offset = variant.parent_enum().syntax().text_range().end(); - let from_impl = variants - .into_iter() - .map(|variant_info| { - let from_trait = format!("From<{}>", variant_info.ty); - let impl_code = generate_impl_code(variant_info); - generate_trait_impl_text_intransitive(&adt, &from_trait, &impl_code) - }) - .collect::<String>(); - edit.insert(start_offset, from_impl); + let make = SyntaxFactory::with_mappings(); + let indent = adt.indent_level(); + let mut elements = Vec::new(); + + for variant_info in variants { + let impl_ = build_from_impl(&make, &adt, variant_info).indent(indent); + elements.push(make.whitespace(&format!("\n\n{indent}")).into()); + elements.push(impl_.syntax().clone().into()); + } + + let mut editor = edit.make_editor(adt.syntax()); + editor.insert_all(Position::after(adt.syntax()), elements); + editor.add_mappings(make.finish_with_mappings()); + edit.add_file_edits(file_id, editor); }, ) } -fn generate_impl_code(VariantInfo { name, field_name, ty }: VariantInfo) -> String { - if let Some(field) = field_name { - format!( - r#" fn from({field}: {ty}) -> Self {{ - Self::{name} {{ {field} }} - }}"# - ) +fn build_from_impl(make: &SyntaxFactory, adt: &ast::Adt, variant_info: VariantInfo) -> ast::Impl { + let VariantInfo { name, field_name, ty } = variant_info; + let trait_ty = make.ty(&format!("From<{ty}>")); + let ret_ty = make.ret_type(make.ty_path(make.ident_path("Self")).into()); + + let (params, body_expr) = if let Some(field) = field_name { + let field_str = field.to_string(); + let param = make.param(make.ident_pat(false, false, make.name(&field_str)).into(), ty); + let field_item = make.record_expr_field(make.name_ref(&field_str), None); + let record = make.record_expr( + make.path_from_text(&format!("Self::{name}")), + make.record_expr_field_list([field_item]), + ); + (make.param_list(None, [param]), ast::Expr::from(record)) } else { - format!( - r#" fn from(v: {ty}) -> Self {{ - Self::{name}(v) - }}"# + let param = make.param(make.ident_pat(false, false, make.name("v")).into(), ty); + let call = make.expr_call( + make.expr_path(make.path_from_text(&format!("Self::{name}"))), + make.arg_list([make.expr_path(make.ident_path("v"))]), + ); + (make.param_list(None, [param]), ast::Expr::from(call)) + }; + + let from_fn = make + .fn_( + [], + None, + make.name("from"), + None, + None, + params, + make.block_expr([], Some(body_expr)), + Some(ret_ty), + false, + false, + false, + false, ) - } + .indent(1.into()); + + let body = make.assoc_item_list([ast::AssocItem::Fn(from_fn)]); + generate_trait_impl_intransitive_with_item(make, adt, trait_ty, body) } struct VariantInfo { diff --git a/crates/ide-assists/src/handlers/generate_function.rs b/crates/ide-assists/src/handlers/generate_function.rs index bd66c02b41..fbf6241e43 100644 --- a/crates/ide-assists/src/handlers/generate_function.rs +++ b/crates/ide-assists/src/handlers/generate_function.rs @@ -4,7 +4,6 @@ use hir::{ }; use ide_db::{ FileId, FxHashMap, FxHashSet, RootDatabase, SnippetCap, - assists::ExprFillDefaultMode, defs::{Definition, NameRefClass}, famous_defs::FamousDefs, helpers::is_editable_crate, @@ -24,7 +23,7 @@ use syntax::{ use crate::{ AssistContext, AssistId, Assists, - utils::{convert_reference_type, find_struct_impl}, + utils::{convert_reference_type, expr_fill_default, find_struct_impl}, }; // Assist: generate_function @@ -147,7 +146,16 @@ fn gen_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { return None; } - let (impl_, file) = get_adt_source(ctx, &adt, fn_name.text().as_str())?; + let enclosing_impl = ctx.find_node_at_offset::<ast::Impl>(); + let cursor_impl = enclosing_impl.filter(|impl_| { + ctx.sema.to_def(impl_).map_or(false, |def| def.self_ty(ctx.sema.db).as_adt() == Some(adt)) + }); + + let (impl_, file) = if let Some(impl_) = cursor_impl { + (Some(impl_), ctx.vfs_file_id()) + } else { + get_adt_source(ctx, &adt, fn_name.text().as_str())? + }; let target = get_method_target(ctx, &impl_, &adt)?; let function_builder = FunctionBuilder::from_method_call( @@ -277,11 +285,7 @@ impl FunctionBuilder { target_module, &mut necessary_generic_params, ); - let placeholder_expr = match ctx.config.expr_fill_default { - ExprFillDefaultMode::Todo => make::ext::expr_todo(), - ExprFillDefaultMode::Underscore => make::ext::expr_underscore(), - ExprFillDefaultMode::Default => make::ext::expr_todo(), - }; + let placeholder_expr = expr_fill_default(ctx.config); fn_body = make::block_expr(vec![], Some(placeholder_expr)); }; @@ -336,11 +340,7 @@ impl FunctionBuilder { let (generic_param_list, where_clause) = fn_generic_params(ctx, necessary_generic_params, &target)?; - let placeholder_expr = match ctx.config.expr_fill_default { - ExprFillDefaultMode::Todo => make::ext::expr_todo(), - ExprFillDefaultMode::Underscore => make::ext::expr_underscore(), - ExprFillDefaultMode::Default => make::ext::expr_todo(), - }; + let placeholder_expr = expr_fill_default(ctx.config); let fn_body = make::block_expr(vec![], Some(placeholder_expr)); Some(Self { @@ -456,11 +456,7 @@ fn make_fn_body_as_new_function( let adt_info = adt_info.as_ref()?; let path_self = make::ext::ident_path("Self"); - let placeholder_expr = match ctx.config.expr_fill_default { - ExprFillDefaultMode::Todo => make::ext::expr_todo(), - ExprFillDefaultMode::Underscore => make::ext::expr_underscore(), - ExprFillDefaultMode::Default => make::ext::expr_todo(), - }; + let placeholder_expr = expr_fill_default(ctx.config); let tail_expr = if let Some(strukt) = adt_info.adt.as_struct() { match strukt.kind(ctx.db()) { StructKind::Record => { @@ -1147,14 +1143,7 @@ fn fn_arg_type( if ty.is_reference() || ty.is_mutable_reference() { let famous_defs = &FamousDefs(&ctx.sema, ctx.sema.scope(fn_arg.syntax())?.krate()); convert_reference_type(ty.strip_references(), ctx.db(), famous_defs) - .map(|conversion| { - conversion - .convert_type( - ctx.db(), - target_module.krate(ctx.db()).to_display_target(ctx.db()), - ) - .to_string() - }) + .map(|conversion| conversion.convert_type(ctx.db(), target_module).to_string()) .or_else(|| ty.display_source_code(ctx.db(), target_module.into(), true).ok()) } else { ty.display_source_code(ctx.db(), target_module.into(), true).ok() @@ -3191,4 +3180,66 @@ fn main() { "#, ); } + + #[test] + fn regression_21288() { + check_assist( + generate_function, + r#" +//- minicore: copy +fn foo() { + $0bar(&|x| true) +} + "#, + r#" +fn foo() { + bar(&|x| true) +} + +fn bar(arg: impl Fn(_) -> bool) { + ${0:todo!()} +} + "#, + ); + } + #[test] + fn generate_method_uses_current_impl_block() { + check_assist( + generate_function, + r" +struct Foo; + +impl Foo { + fn new() -> Self { + Foo + } +} + +impl Foo { + fn method1(&self) { + self.method2$0(42) + } +} +", + r" +struct Foo; + +impl Foo { + fn new() -> Self { + Foo + } +} + +impl Foo { + fn method1(&self) { + self.method2(42) + } + + fn method2(&self, arg: i32) { + ${0:todo!()} + } +} +", + ) + } } diff --git a/crates/ide-assists/src/handlers/generate_getter_or_setter.rs b/crates/ide-assists/src/handlers/generate_getter_or_setter.rs index e42d0ed1b0..62ffd3d965 100644 --- a/crates/ide-assists/src/handlers/generate_getter_or_setter.rs +++ b/crates/ide-assists/src/handlers/generate_getter_or_setter.rs @@ -2,13 +2,16 @@ use ide_db::{famous_defs::FamousDefs, source_change::SourceChangeBuilder}; use stdx::{format_to, to_lower_snake_case}; use syntax::{ TextRange, - ast::{self, AstNode, HasName, HasVisibility, edit_in_place::Indent, make}, - ted, + ast::{ + self, AstNode, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, + syntax_factory::SyntaxFactory, + }, + syntax_editor::Position, }; use crate::{ AssistContext, AssistId, Assists, GroupLabel, - utils::{convert_reference_type, find_struct_impl, generate_impl}, + utils::{convert_reference_type, find_struct_impl, is_selected}, }; // Assist: generate_setter @@ -215,35 +218,41 @@ fn generate_getter_from_info( ctx: &AssistContext<'_>, info: &AssistInfo, record_field_info: &RecordFieldInfo, + make: &SyntaxFactory, ) -> ast::Fn { let (ty, body) = if matches!(info.assist_type, AssistType::MutGet) { + let self_expr = make.expr_path(make.ident_path("self")); ( - make::ty_ref(record_field_info.field_ty.clone(), true), - make::expr_ref( - make::expr_field(make::ext::expr_self(), &record_field_info.field_name.text()), + make.ty_ref(record_field_info.field_ty.clone(), true), + make.expr_ref( + make.expr_field(self_expr, &record_field_info.field_name.text()).into(), true, ), ) } else { (|| { - let krate = ctx.sema.scope(record_field_info.field_ty.syntax())?.krate(); - let famous_defs = &FamousDefs(&ctx.sema, krate); + let module = ctx.sema.scope(record_field_info.field_ty.syntax())?.module(); + let famous_defs = &FamousDefs(&ctx.sema, module.krate(ctx.db())); ctx.sema .resolve_type(&record_field_info.field_ty) .and_then(|ty| convert_reference_type(ty, ctx.db(), famous_defs)) .map(|conversion| { cov_mark::hit!(convert_reference_type); ( - conversion.convert_type(ctx.db(), krate.to_display_target(ctx.db())), - conversion.getter(record_field_info.field_name.to_string()), + conversion.convert_type_with_factory(make, ctx.db(), module), + conversion.getter(make, record_field_info.field_name.to_string()), ) }) })() .unwrap_or_else(|| { ( - make::ty_ref(record_field_info.field_ty.clone(), false), - make::expr_ref( - make::expr_field(make::ext::expr_self(), &record_field_info.field_name.text()), + make.ty_ref(record_field_info.field_ty.clone(), false), + make.expr_ref( + make.expr_field( + make.expr_path(make.ident_path("self")), + &record_field_info.field_name.text(), + ) + .into(), false, ), ) @@ -251,18 +260,18 @@ fn generate_getter_from_info( }; let self_param = if matches!(info.assist_type, AssistType::MutGet) { - make::mut_self_param() + make.mut_self_param() } else { - make::self_param() + make.self_param() }; let strukt = &info.strukt; - let fn_name = make::name(&record_field_info.fn_name); - let params = make::param_list(Some(self_param), []); - let ret_type = Some(make::ret_type(ty)); - let body = make::block_expr([], Some(body)); + let fn_name = make.name(&record_field_info.fn_name); + let params = make.param_list(Some(self_param), []); + let ret_type = Some(make.ret_type(ty)); + let body = make.block_expr([], Some(body)); - make::fn_( + make.fn_( None, strukt.visibility(), fn_name, @@ -278,28 +287,32 @@ fn generate_getter_from_info( ) } -fn generate_setter_from_info(info: &AssistInfo, record_field_info: &RecordFieldInfo) -> ast::Fn { +fn generate_setter_from_info( + info: &AssistInfo, + record_field_info: &RecordFieldInfo, + make: &SyntaxFactory, +) -> ast::Fn { let strukt = &info.strukt; let field_name = &record_field_info.fn_name; - let fn_name = make::name(&format!("set_{field_name}")); + let fn_name = make.name(&format!("set_{field_name}")); let field_ty = &record_field_info.field_ty; // Make the param list // `(&mut self, $field_name: $field_ty)` let field_param = - make::param(make::ident_pat(false, false, make::name(field_name)).into(), field_ty.clone()); - let params = make::param_list(Some(make::mut_self_param()), [field_param]); + make.param(make.ident_pat(false, false, make.name(field_name)).into(), field_ty.clone()); + let params = make.param_list(Some(make.mut_self_param()), [field_param]); // Make the assignment body // `self.$field_name = $field_name` - let self_expr = make::ext::expr_self(); - let lhs = make::expr_field(self_expr, field_name); - let rhs = make::expr_path(make::ext::ident_path(field_name)); - let assign_stmt = make::expr_stmt(make::expr_assignment(lhs, rhs).into()); - let body = make::block_expr([assign_stmt.into()], None); + let self_expr = make.expr_path(make.ident_path("self")); + let lhs = make.expr_field(self_expr, field_name); + let rhs = make.expr_path(make.ident_path(field_name)); + let assign_stmt = make.expr_stmt(make.expr_assignment(lhs.into(), rhs).into()); + let body = make.block_expr([assign_stmt.into()], None); // Make the setter fn - make::fn_( + make.fn_( None, strukt.visibility(), fn_name, @@ -360,7 +373,7 @@ fn extract_and_parse_record_fields( let info_of_record_fields_in_selection = ele .fields() .filter_map(|record_field| { - if selection_range.contains_range(record_field.syntax().text_range()) { + if is_selected(&record_field, selection_range, false) { let record_field_info = parse_record_field(record_field, assist_type)?; field_names.push(record_field_info.fn_name.clone()); return Some(record_field_info); @@ -403,47 +416,69 @@ fn build_source_change( info_of_record_fields: Vec<RecordFieldInfo>, assist_info: AssistInfo, ) { - let record_fields_count = info_of_record_fields.len(); - - let impl_def = if let Some(impl_def) = &assist_info.impl_def { - // We have an existing impl to add to - builder.make_mut(impl_def.clone()) - } else { - // Generate a new impl to add the methods to - let impl_def = generate_impl(&ast::Adt::Struct(assist_info.strukt.clone())); + let syntax_factory = SyntaxFactory::without_mappings(); - // Insert it after the adt - let strukt = builder.make_mut(assist_info.strukt.clone()); - - ted::insert_all_raw( - ted::Position::after(strukt.syntax()), - vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()], - ); - - impl_def - }; - - let assoc_item_list = impl_def.get_or_create_assoc_item_list(); + let items: Vec<ast::AssocItem> = info_of_record_fields + .iter() + .map(|record_field_info| { + let method = match assist_info.assist_type { + AssistType::Set => { + generate_setter_from_info(&assist_info, record_field_info, &syntax_factory) + } + _ => { + generate_getter_from_info(ctx, &assist_info, record_field_info, &syntax_factory) + } + }; + let new_fn = method.clone_for_update(); + let new_fn = new_fn.indent(1.into()); + new_fn.into() + }) + .collect(); - for (i, record_field_info) in info_of_record_fields.iter().enumerate() { - // Make the new getter or setter fn - let new_fn = match assist_info.assist_type { - AssistType::Set => generate_setter_from_info(&assist_info, record_field_info), - _ => generate_getter_from_info(ctx, &assist_info, record_field_info), - } - .clone_for_update(); - new_fn.indent(1.into()); + if let Some(impl_def) = &assist_info.impl_def { + // We have an existing impl to add to + let mut editor = builder.make_editor(impl_def.syntax()); + impl_def.assoc_item_list().unwrap().add_items(&mut editor, items.clone()); - // Insert a tabstop only for last method we generate - if i == record_fields_count - 1 - && let Some(cap) = ctx.config.snippet_cap - && let Some(name) = new_fn.name() + if let Some(cap) = ctx.config.snippet_cap + && let Some(ast::AssocItem::Fn(fn_)) = items.last() + && let Some(name) = fn_.name() { - builder.add_tabstop_before(cap, name); + let tabstop = builder.make_tabstop_before(cap); + editor.add_annotation(name.syntax(), tabstop); } - assoc_item_list.add_item(new_fn.clone().into()); + builder.add_file_edits(ctx.vfs_file_id(), editor); + return; + } + let ty_params = assist_info.strukt.generic_param_list(); + let ty_args = ty_params.as_ref().map(|it| it.to_generic_args()); + let impl_def = syntax_factory.impl_( + None, + ty_params, + ty_args, + syntax_factory + .ty_path(syntax_factory.ident_path(&assist_info.strukt.name().unwrap().to_string())) + .into(), + None, + Some(syntax_factory.assoc_item_list(items)), + ); + let mut editor = builder.make_editor(assist_info.strukt.syntax()); + editor.insert_all( + Position::after(assist_info.strukt.syntax()), + vec![syntax_factory.whitespace("\n\n").into(), impl_def.syntax().clone().into()], + ); + + if let Some(cap) = ctx.config.snippet_cap + && let Some(assoc_list) = impl_def.assoc_item_list() + && let Some(ast::AssocItem::Fn(fn_)) = assoc_list.assoc_items().last() + && let Some(name) = fn_.name() + { + let tabstop = builder.make_tabstop_before(cap); + editor.add_annotation(name.syntax().clone(), tabstop); } + + builder.add_file_edits(ctx.vfs_file_id(), editor); } #[cfg(test)] @@ -909,6 +944,37 @@ impl Context { } #[test] + fn test_generate_multiple_getters_from_partial_selection() { + check_assist( + generate_getter, + r#" +struct Context { + data$0: Data, + count$0: usize, + other: usize, +} + "#, + r#" +struct Context { + data: Data, + count: usize, + other: usize, +} + +impl Context { + fn data(&self) -> &Data { + &self.data + } + + fn $0count(&self) -> &usize { + &self.count + } +} + "#, + ); + } + + #[test] fn test_generate_multiple_getters_from_selection_one_already_exists() { // As impl for one of the fields already exist, skip it check_assist_not_applicable( diff --git a/crates/ide-assists/src/handlers/generate_impl.rs b/crates/ide-assists/src/handlers/generate_impl.rs index 77eb8efc6f..2d1235792d 100644 --- a/crates/ide-assists/src/handlers/generate_impl.rs +++ b/crates/ide-assists/src/handlers/generate_impl.rs @@ -1,25 +1,37 @@ use syntax::{ - ast::{self, AstNode, HasGenericParams, HasName, edit_in_place::Indent, make}, + ast::{ + self, AstNode, HasGenericParams, HasName, edit::AstNodeEdit, syntax_factory::SyntaxFactory, + }, syntax_editor::{Position, SyntaxEditor}, }; use crate::{ AssistContext, AssistId, Assists, - utils::{self, DefaultMethods, IgnoreAssocItems}, + utils::{ + self, DefaultMethods, IgnoreAssocItems, generate_impl_with_factory, + generate_trait_impl_intransitive, + }, }; -fn insert_impl(editor: &mut SyntaxEditor, impl_: &ast::Impl, nominal: &impl Indent) { +fn insert_impl( + editor: &mut SyntaxEditor, + make: &SyntaxFactory, + impl_: &ast::Impl, + nominal: &impl AstNodeEdit, +) -> ast::Impl { let indent = nominal.indent_level(); - impl_.indent(indent); + let impl_ = impl_.indent(indent); editor.insert_all( Position::after(nominal.syntax()), vec![ // Add a blank line after the ADT, and indentation for the impl to match the ADT - make::tokens::whitespace(&format!("\n\n{indent}")).into(), + make.whitespace(&format!("\n\n{indent}")).into(), impl_.syntax().clone().into(), ], ); + + impl_ } // Assist: generate_impl @@ -53,10 +65,13 @@ pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio format!("Generate impl for `{name}`"), target, |edit| { + let make = SyntaxFactory::with_mappings(); // Generate the impl - let impl_ = utils::generate_impl(&nominal); + let impl_ = generate_impl_with_factory(&make, &nominal); let mut editor = edit.make_editor(nominal.syntax()); + + let impl_ = insert_impl(&mut editor, &make, &impl_, &nominal); // Add a tabstop after the left curly brace if let Some(cap) = ctx.config.snippet_cap && let Some(l_curly) = impl_.assoc_item_list().and_then(|it| it.l_curly_token()) @@ -65,7 +80,7 @@ pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio editor.add_annotation(l_curly, tabstop); } - insert_impl(&mut editor, &impl_, &nominal); + editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -102,10 +117,13 @@ pub(crate) fn generate_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> format!("Generate trait impl for `{name}`"), target, |edit| { + let make = SyntaxFactory::with_mappings(); // Generate the impl - let impl_ = utils::generate_trait_impl_intransitive(&nominal, make::ty_placeholder()); + let impl_ = generate_trait_impl_intransitive(&make, &nominal, make.ty_placeholder()); let mut editor = edit.make_editor(nominal.syntax()); + + let impl_ = insert_impl(&mut editor, &make, &impl_, &nominal); // Make the trait type a placeholder snippet if let Some(cap) = ctx.config.snippet_cap { if let Some(trait_) = impl_.trait_() { @@ -119,7 +137,7 @@ pub(crate) fn generate_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> } } - insert_impl(&mut editor, &impl_, &nominal); + editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -158,9 +176,10 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> format!("Generate `{name}` impl for type"), target, |edit| { + let make = SyntaxFactory::with_mappings(); let mut editor = edit.make_editor(trait_.syntax()); - let holder_arg = ast::GenericArg::TypeArg(make::type_arg(make::ty_placeholder())); + let holder_arg = ast::GenericArg::TypeArg(make.type_arg(make.ty_placeholder())); let missing_items = utils::filter_assoc_items( &ctx.sema, &hir_trait.items(ctx.db()), @@ -169,11 +188,11 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> ); let trait_gen_args = trait_.generic_param_list().map(|list| { - make::generic_arg_list(list.generic_params().map(|_| holder_arg.clone())) + make.generic_arg_list(list.generic_params().map(|_| holder_arg.clone()), false) }); let make_impl_ = |body| { - make::impl_trait( + make.impl_trait( None, trait_.unsafe_token().is_some(), None, @@ -181,13 +200,12 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> None, None, false, - make::ty(&name.text()), - make::ty_placeholder(), + make.ty(&name.text()), + make.ty_placeholder(), None, None, body, ) - .clone_for_update() }; let impl_ = if missing_items.is_empty() { @@ -202,10 +220,13 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> &impl_, &target_scope, ); - let assoc_item_list = make::assoc_item_list(Some(assoc_items)); + let assoc_item_list = make.assoc_item_list(assoc_items); make_impl_(Some(assoc_item_list)) }; + let impl_ = insert_impl(&mut editor, &make, &impl_, &trait_); + editor.add_mappings(make.finish_with_mappings()); + if let Some(cap) = ctx.config.snippet_cap { if let Some(generics) = impl_.trait_().and_then(|it| it.generic_arg_list()) { for generic in generics.generic_args() { @@ -232,7 +253,6 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> } } - insert_impl(&mut editor, &impl_, &trait_); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs b/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs index 53f6f4883f..3a62a8853e 100644 --- a/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs +++ b/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs @@ -1,7 +1,7 @@ use ide_db::{famous_defs::FamousDefs, traits::resolve_target_trait}; use syntax::{ AstNode, SyntaxElement, SyntaxNode, T, - ast::{self, edit::AstNodeEdit, edit_in_place::Indent, syntax_factory::SyntaxFactory}, + ast::{self, edit::AstNodeEdit, syntax_factory::SyntaxFactory}, syntax_editor::{Element, Position, SyntaxEditor}, }; @@ -46,7 +46,7 @@ use crate::{AssistContext, AssistId, Assists}; // ``` pub(crate) fn generate_mut_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let impl_def = ctx.find_node_at_offset::<ast::Impl>()?; - let indent = Indent::indent_level(&impl_def); + let indent = impl_def.indent_level(); let ast::Type::PathType(path) = impl_def.trait_()? else { return None; @@ -78,7 +78,7 @@ pub(crate) fn generate_mut_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_> let new_impl = ast::Impl::cast(new_root.clone()).unwrap(); - Indent::indent(&new_impl, indent); + let new_impl = new_impl.indent(indent); let mut editor = edit.make_editor(impl_def.syntax()); editor.insert_all( diff --git a/crates/ide-assists/src/handlers/generate_new.rs b/crates/ide-assists/src/handlers/generate_new.rs index 4b923ab556..301d13c095 100644 --- a/crates/ide-assists/src/handlers/generate_new.rs +++ b/crates/ide-assists/src/handlers/generate_new.rs @@ -3,7 +3,10 @@ use ide_db::{ use_trivial_constructor::use_trivial_constructor, }; use syntax::{ - ast::{self, AstNode, HasName, HasVisibility, StructKind, edit_in_place::Indent, make}, + ast::{ + self, AstNode, HasName, HasVisibility, StructKind, edit::AstNodeEdit, + syntax_factory::SyntaxFactory, + }, syntax_editor::Position, }; @@ -36,6 +39,7 @@ use crate::{ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let strukt = ctx.find_node_at_offset::<ast::Struct>()?; + let make = SyntaxFactory::without_mappings(); let field_list = match strukt.kind() { StructKind::Record(named) => { named.fields().filter_map(|f| Some((f.name()?, f.ty()?))).collect::<Vec<_>>() @@ -55,7 +59,7 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option Some(name) => name, None => name_generator.suggest_name(&format!("_{i}")), }; - Some((make::name(name.as_str()), f.ty()?)) + Some((make.name(name.as_str()), f.ty()?)) }) .collect::<Vec<_>>() } @@ -70,6 +74,7 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option let target = strukt.syntax().text_range(); acc.add(AssistId::generate("generate_new"), "Generate `new`", target, |builder| { + let make = SyntaxFactory::with_mappings(); let trivial_constructors = field_list .iter() .map(|(name, ty)| { @@ -95,102 +100,99 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option edition, )?; - Some((make::name_ref(&name.text()), Some(expr))) + Some((make.name_ref(&name.text()), Some(expr))) }) .collect::<Vec<_>>(); let params = field_list.iter().enumerate().filter_map(|(i, (name, ty))| { if trivial_constructors[i].is_none() { - Some(make::param(make::ident_pat(false, false, name.clone()).into(), ty.clone())) + Some(make.param(make.ident_pat(false, false, name.clone()).into(), ty.clone())) } else { None } }); - let params = make::param_list(None, params); + let params = make.param_list(None, params); let fields = field_list.iter().enumerate().map(|(i, (name, _))| { if let Some(constructor) = trivial_constructors[i].clone() { constructor } else { - (make::name_ref(&name.text()), None) + (make.name_ref(&name.text()), None) } }); let tail_expr: ast::Expr = match strukt.kind() { StructKind::Record(_) => { - let fields = fields.map(|(name, expr)| make::record_expr_field(name, expr)); - let fields = make::record_expr_field_list(fields); - make::record_expr(make::ext::ident_path("Self"), fields).into() + let fields = fields.map(|(name, expr)| make.record_expr_field(name, expr)); + let fields = make.record_expr_field_list(fields); + make.record_expr(make.ident_path("Self"), fields).into() } StructKind::Tuple(_) => { let args = fields.map(|(arg, expr)| { - let arg = || make::expr_path(make::path_unqualified(make::path_segment(arg))); + let arg = || make.expr_path(make.path_unqualified(make.path_segment(arg))); expr.unwrap_or_else(arg) }); - let arg_list = make::arg_list(args); - make::expr_call(make::expr_path(make::ext::ident_path("Self")), arg_list).into() + let arg_list = make.arg_list(args); + make.expr_call(make.expr_path(make.ident_path("Self")), arg_list).into() } StructKind::Unit => unreachable!(), }; - let body = make::block_expr(None, tail_expr.into()); - - let ret_type = make::ret_type(make::ty_path(make::ext::ident_path("Self"))); - - let fn_ = make::fn_( - None, - strukt.visibility(), - make::name("new"), - None, - None, - params, - body, - Some(ret_type), - false, - false, - false, - false, - ) - .clone_for_update(); - fn_.indent(1.into()); + let body = make.block_expr(None, tail_expr.into()); + + let ret_type = make.ret_type(make.ty_path(make.ident_path("Self")).into()); + + let fn_ = make + .fn_( + [], + strukt.visibility(), + make.name("new"), + None, + None, + params, + body, + Some(ret_type), + false, + false, + false, + false, + ) + .indent(1.into()); let mut editor = builder.make_editor(strukt.syntax()); // Get the node for set annotation let contain_fn = if let Some(impl_def) = impl_def { - fn_.indent(impl_def.indent_level()); + let fn_ = fn_.indent(impl_def.indent_level()); if let Some(l_curly) = impl_def.assoc_item_list().and_then(|list| list.l_curly_token()) { editor.insert_all( Position::after(l_curly), vec![ - make::tokens::whitespace(&format!("\n{}", impl_def.indent_level() + 1)) - .into(), + make.whitespace(&format!("\n{}", impl_def.indent_level() + 1)).into(), fn_.syntax().clone().into(), - make::tokens::whitespace("\n").into(), + make.whitespace("\n").into(), ], ); fn_.syntax().clone() } else { - let items = vec![ast::AssocItem::Fn(fn_)]; - let list = make::assoc_item_list(Some(items)); + let list = make.assoc_item_list([ast::AssocItem::Fn(fn_)]); editor.insert(Position::after(impl_def.syntax()), list.syntax()); list.syntax().clone() } } else { // Generate a new impl to add the method to let indent_level = strukt.indent_level(); - let body = vec![ast::AssocItem::Fn(fn_)]; - let list = make::assoc_item_list(Some(body)); - let impl_def = generate_impl_with_item(&ast::Adt::Struct(strukt.clone()), Some(list)); - - impl_def.indent(strukt.indent_level()); + let list = make.assoc_item_list([ast::AssocItem::Fn(fn_)]); + let impl_def = + generate_impl_with_item(&make, &ast::Adt::Struct(strukt.clone()), Some(list)) + .indent(strukt.indent_level()); // Insert it after the adt editor.insert_all( Position::after(strukt.syntax()), vec![ - make::tokens::whitespace(&format!("\n\n{indent_level}")).into(), + make.whitespace(&format!("\n\n{indent_level}")).into(), impl_def.syntax().clone().into(), ], ); @@ -234,6 +236,7 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option } } + editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }) } diff --git a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs index 56500cf068..1286abe356 100644 --- a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs +++ b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs @@ -1,8 +1,10 @@ use crate::assist_context::{AssistContext, Assists}; use ide_db::assists::AssistId; use syntax::{ - AstNode, SyntaxKind, T, - ast::{self, HasGenericParams, HasName, HasVisibility, edit_in_place::Indent, make}, + AstNode, AstToken, SyntaxKind, T, + ast::{ + self, HasDocComments, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, make, + }, syntax_editor::{Position, SyntaxEditor}, }; @@ -45,7 +47,7 @@ use syntax::{ // }; // } // -// trait ${0:NewTrait}<const N: usize> { +// trait ${0:Create}<const N: usize> { // // Used as an associated constant. // const CONST_ASSOC: usize = N * 4; // @@ -54,7 +56,7 @@ use syntax::{ // const_maker! {i32, 7} // } // -// impl<const N: usize> ${0:NewTrait}<N> for Foo<N> { +// impl<const N: usize> ${0:Create}<N> for Foo<N> { // // Used as an associated constant. // const CONST_ASSOC: usize = N * 4; // @@ -107,7 +109,7 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ }; let trait_ast = make::trait_( false, - "NewTrait", + &trait_name(&impl_assoc_items).text(), impl_ast.generic_param_list(), impl_ast.where_clause(), trait_items, @@ -133,6 +135,7 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ let mut editor = builder.make_editor(impl_ast.syntax()); impl_assoc_items.assoc_items().for_each(|item| { remove_items_visibility(&mut editor, &item); + remove_doc_comments(&mut editor, &item); }); editor.insert_all(Position::before(impl_name.syntax()), elements); @@ -160,6 +163,18 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ Some(()) } +fn trait_name(items: &ast::AssocItemList) -> ast::Name { + let mut fn_names = items + .assoc_items() + .filter_map(|x| if let ast::AssocItem::Fn(f) = x { f.name() } else { None }); + fn_names + .next() + .and_then(|name| { + fn_names.next().is_none().then(|| make::name(&stdx::to_camel_case(&name.text()))) + }) + .unwrap_or_else(|| make::name("NewTrait")) +} + /// `E0449` Trait items always share the visibility of their trait fn remove_items_visibility(editor: &mut SyntaxEditor, item: &ast::AssocItem) { if let Some(has_vis) = ast::AnyHasVisibility::cast(item.syntax().clone()) { @@ -175,6 +190,17 @@ fn remove_items_visibility(editor: &mut SyntaxEditor, item: &ast::AssocItem) { } } +fn remove_doc_comments(editor: &mut SyntaxEditor, item: &ast::AssocItem) { + for doc in item.doc_comments() { + if let Some(next) = doc.syntax().next_token() + && next.kind() == SyntaxKind::WHITESPACE + { + editor.delete(next); + } + editor.delete(doc.syntax()); + } +} + fn strip_body(editor: &mut SyntaxEditor, item: &ast::AssocItem) { if let ast::AssocItem::Fn(f) = item && let Some(body) = f.body() @@ -226,11 +252,47 @@ impl F$0oo { r#" struct Foo(f64); -trait NewTrait { +trait Add { fn add(&mut self, x: f64); } -impl NewTrait for Foo { +impl Add for Foo { + fn add(&mut self, x: f64) { + self.0 += x; + } +}"#, + ) + } + + #[test] + fn test_remove_doc_comments() { + check_assist_no_snippet_cap( + generate_trait_from_impl, + r#" +struct Foo(f64); + +impl F$0oo { + /// Add `x` + /// + /// # Examples + #[cfg(true)] + fn add(&mut self, x: f64) { + self.0 += x; + } +}"#, + r#" +struct Foo(f64); + +trait Add { + /// Add `x` + /// + /// # Examples + #[cfg(true)] + fn add(&mut self, x: f64); +} + +impl Add for Foo { + #[cfg(true)] fn add(&mut self, x: f64) { self.0 += x; } @@ -339,11 +401,11 @@ impl F$0oo { r#" struct Foo; -trait NewTrait { +trait AFunc { fn a_func() -> Option<()>; } -impl NewTrait for Foo { +impl AFunc for Foo { fn a_func() -> Option<()> { Some(()) } @@ -373,11 +435,11 @@ mod a { }"#, r#" mod a { - trait NewTrait { + trait Foo { fn foo(); } - impl NewTrait for S { + impl Foo for S { fn foo() {} } }"#, @@ -385,6 +447,28 @@ mod a { } #[test] + fn test_multi_fn_impl_not_suggest_trait_name() { + check_assist_no_snippet_cap( + generate_trait_from_impl, + r#" +impl S$0 { + fn foo() {} + fn bar() {} +}"#, + r#" +trait NewTrait { + fn foo(); + fn bar(); +} + +impl NewTrait for S { + fn foo() {} + fn bar() {} +}"#, + ) + } + + #[test] fn test_snippet_cap_is_some() { check_assist( generate_trait_from_impl, diff --git a/crates/ide-assists/src/handlers/inline_call.rs b/crates/ide-assists/src/handlers/inline_call.rs index fa4f2a78c8..21f2249a19 100644 --- a/crates/ide-assists/src/handlers/inline_call.rs +++ b/crates/ide-assists/src/handlers/inline_call.rs @@ -403,6 +403,12 @@ fn inline( .find(|tok| tok.kind() == SyntaxKind::SELF_TYPE_KW) { let replace_with = t.clone_subtree().syntax().clone_for_update(); + if !is_in_type_path(&self_tok) + && let Some(ty) = ast::Type::cast(replace_with.clone()) + && let Some(generic_arg_list) = ty.generic_arg_list() + { + ted::remove(generic_arg_list.syntax()); + } ted::replace(self_tok, replace_with); } } @@ -588,6 +594,17 @@ fn inline( } } +fn is_in_type_path(self_tok: &syntax::SyntaxToken) -> bool { + self_tok + .parent_ancestors() + .skip_while(|it| !ast::Path::can_cast(it.kind())) + .map_while(ast::Path::cast) + .last() + .and_then(|it| it.syntax().parent()) + .and_then(ast::PathType::cast) + .is_some() +} + fn path_expr_as_record_field(usage: &PathExpr) -> Option<ast::RecordExprField> { let path = usage.path()?; let name_ref = path.as_single_name_ref()?; @@ -1695,6 +1712,41 @@ fn main() { } #[test] + fn inline_trait_method_call_with_lifetimes() { + check_assist( + inline_call, + r#" +trait Trait { + fn f() -> Self; +} +struct Foo<'a>(&'a ()); +impl<'a> Trait for Foo<'a> { + fn f() -> Self { Self(&()) } +} +impl Foo<'_> { + fn new() -> Self { + Self::$0f() + } +} +"#, + r#" +trait Trait { + fn f() -> Self; +} +struct Foo<'a>(&'a ()); +impl<'a> Trait for Foo<'a> { + fn f() -> Self { Self(&()) } +} +impl Foo<'_> { + fn new() -> Self { + Foo(&()) + } +} +"#, + ) + } + + #[test] fn method_by_reborrow() { check_assist( inline_call, diff --git a/crates/ide-assists/src/handlers/inline_local_variable.rs b/crates/ide-assists/src/handlers/inline_local_variable.rs index 5d4bdc6ec7..f55ef4229e 100644 --- a/crates/ide-assists/src/handlers/inline_local_variable.rs +++ b/crates/ide-assists/src/handlers/inline_local_variable.rs @@ -1,3 +1,4 @@ +use either::{Either, for_both}; use hir::{PathResolution, Semantics}; use ide_db::{ EditionedFileId, RootDatabase, @@ -5,8 +6,9 @@ use ide_db::{ search::{FileReference, FileReferenceNode, UsageSearchResult}, }; use syntax::{ - SyntaxElement, TextRange, + Direction, TextRange, ast::{self, AstNode, AstToken, HasName, syntax_factory::SyntaxFactory}, + syntax_editor::{Element, SyntaxEditor}, }; use crate::{ @@ -36,12 +38,15 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) let InlineData { let_stmt, delete_let, references, target } = if let Some(path_expr) = ctx.find_node_at_offset::<ast::PathExpr>() { inline_usage(&ctx.sema, path_expr, range, file_id) - } else if let Some(let_stmt) = ctx.find_node_at_offset::<ast::LetStmt>() { + } else if let Some(let_stmt) = ctx.find_node_at_offset() { inline_let(&ctx.sema, let_stmt, range, file_id) } else { None }?; - let initializer_expr = let_stmt.initializer()?; + let initializer_expr = match &let_stmt { + either::Either::Left(it) => it.initializer()?, + either::Either::Right(it) => it.expr()?, + }; let wrap_in_parens = references .into_iter() @@ -81,13 +86,15 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) let mut editor = builder.make_editor(&target); if delete_let { editor.delete(let_stmt.syntax()); - if let Some(whitespace) = let_stmt - .syntax() - .next_sibling_or_token() - .and_then(SyntaxElement::into_token) - .and_then(ast::Whitespace::cast) + + if let Some(bin_expr) = let_stmt.syntax().parent().and_then(ast::BinExpr::cast) + && let Some(op_token) = bin_expr.op_token() { - editor.delete(whitespace.syntax()); + editor.delete(&op_token); + remove_whitespace(op_token, Direction::Prev, &mut editor); + remove_whitespace(let_stmt.syntax(), Direction::Prev, &mut editor); + } else { + remove_whitespace(let_stmt.syntax(), Direction::Next, &mut editor); } } @@ -116,7 +123,7 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) } struct InlineData { - let_stmt: ast::LetStmt, + let_stmt: Either<ast::LetStmt, ast::LetExpr>, delete_let: bool, target: ast::NameOrNameRef, references: Vec<FileReference>, @@ -124,11 +131,11 @@ struct InlineData { fn inline_let( sema: &Semantics<'_, RootDatabase>, - let_stmt: ast::LetStmt, + let_stmt: Either<ast::LetStmt, ast::LetExpr>, range: TextRange, file_id: EditionedFileId, ) -> Option<InlineData> { - let bind_pat = match let_stmt.pat()? { + let bind_pat = match for_both!(&let_stmt, it => it.pat())? { ast::Pat::IdentPat(pat) => pat, _ => return None, }; @@ -187,7 +194,7 @@ fn inline_usage( let bind_pat = source.as_ident_pat()?; - let let_stmt = ast::LetStmt::cast(bind_pat.syntax().parent()?)?; + let let_stmt = AstNode::cast(bind_pat.syntax().parent()?)?; let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all(); let mut references = references.remove(&file_id)?; @@ -197,6 +204,23 @@ fn inline_usage( Some(InlineData { let_stmt, delete_let, target: ast::NameOrNameRef::NameRef(name), references }) } +fn remove_whitespace(elem: impl Element, dir: Direction, editor: &mut SyntaxEditor) { + let token = match elem.syntax_element() { + syntax::NodeOrToken::Node(node) => match dir { + Direction::Next => node.last_token(), + Direction::Prev => node.first_token(), + }, + syntax::NodeOrToken::Token(t) => Some(t), + }; + let next_token = match dir { + Direction::Next => token.and_then(|it| it.next_token()), + Direction::Prev => token.and_then(|it| it.prev_token()), + }; + if let Some(whitespace) = next_token.and_then(ast::Whitespace::cast) { + editor.delete(whitespace.syntax()); + } +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -405,6 +429,38 @@ fn foo() { } #[test] + fn test_inline_let_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize) {} +fn foo() { + if let a$0 = 1 + && true + { + a + 1; + if a > 10 {} + while a > 10 {} + let b = a * 10; + bar(a); + } +}", + r" +fn bar(a: usize) {} +fn foo() { + if true + { + 1 + 1; + if 1 > 10 {} + while 1 > 10 {} + let b = 1 * 10; + bar(1); + } +}", + ); + } + + #[test] fn test_not_inline_mut_variable() { cov_mark::check!(test_not_inline_mut_variable); check_assist_not_applicable( @@ -817,6 +873,70 @@ fn f() { } #[test] + fn let_expr_works_on_local_usage() { + check_assist( + inline_local_variable, + r#" +fn f() { + if let xyz = 0 + && true + { + xyz$0; + } +} +"#, + r#" +fn f() { + if true + { + 0; + } +} +"#, + ); + + check_assist( + inline_local_variable, + r#" +fn f() { + if let xyz = true + && xyz$0 + { + } +} +"#, + r#" +fn f() { + if true + { + } +} +"#, + ); + + check_assist( + inline_local_variable, + r#" +fn f() { + if true + && let xyz = 0 + { + xyz$0; + } +} +"#, + r#" +fn f() { + if true + { + 0; + } +} +"#, + ); + } + + #[test] fn does_not_remove_let_when_multiple_usages() { check_assist( inline_local_variable, diff --git a/crates/ide-assists/src/handlers/inline_type_alias.rs b/crates/ide-assists/src/handlers/inline_type_alias.rs index c7a48f3261..f3ebe61078 100644 --- a/crates/ide-assists/src/handlers/inline_type_alias.rs +++ b/crates/ide-assists/src/handlers/inline_type_alias.rs @@ -12,7 +12,7 @@ use itertools::Itertools; use syntax::ast::syntax_factory::SyntaxFactory; use syntax::syntax_editor::SyntaxEditor; use syntax::{ - AstNode, NodeOrToken, SyntaxNode, + AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T, ast::{self, HasGenericParams, HasName}, }; @@ -322,12 +322,42 @@ fn create_replacement( if let Some(old_lifetime) = ast::Lifetime::cast(syntax.clone()) { if let Some(new_lifetime) = lifetime_map.0.get(&old_lifetime.to_string()) { if new_lifetime.text() == "'_" { + // Check if this lifetime is inside a LifetimeArg (in angle brackets) + if let Some(lifetime_arg) = + old_lifetime.syntax().parent().and_then(ast::LifetimeArg::cast) + { + // Remove LifetimeArg and associated comma/whitespace + let lifetime_arg_syntax = lifetime_arg.syntax(); + removals.push(NodeOrToken::Node(lifetime_arg_syntax.clone())); + + // Remove comma and whitespace (look forward then backward) + let comma_and_ws: Vec<_> = lifetime_arg_syntax + .siblings_with_tokens(syntax::Direction::Next) + .skip(1) + .take_while(|it| it.as_token().is_some()) + .take_while_inclusive(|it| it.kind() == T![,]) + .collect(); + + if comma_and_ws.iter().any(|it| it.kind() == T![,]) { + removals.extend(comma_and_ws); + } else { + // No comma after, try before + let comma_and_ws: Vec<_> = lifetime_arg_syntax + .siblings_with_tokens(syntax::Direction::Prev) + .skip(1) + .take_while(|it| it.as_token().is_some()) + .take_while_inclusive(|it| it.kind() == T![,]) + .collect(); + removals.extend(comma_and_ws); + } + continue; + } removals.push(NodeOrToken::Node(syntax.clone())); - - if let Some(ws) = syntax.next_sibling_or_token() { - removals.push(ws.clone()); + if let Some(ws) = syntax.next_sibling_or_token() + && ws.kind() == SyntaxKind::WHITESPACE + { + removals.push(ws); } - continue; } @@ -349,6 +379,34 @@ fn create_replacement( } } + // Deduplicate removals to avoid intersecting changes + removals.sort_by_key(|n| n.text_range().start()); + removals.dedup(); + + // Remove GenericArgList entirely if all its args are being removed (avoids empty angle brackets) + let generic_arg_lists_to_check: Vec<_> = + updated_concrete_type.descendants().filter_map(ast::GenericArgList::cast).collect(); + + for generic_arg_list in generic_arg_lists_to_check { + let will_be_empty = generic_arg_list.generic_args().all(|arg| match arg { + ast::GenericArg::LifetimeArg(lt_arg) => removals.iter().any(|removal| { + if let NodeOrToken::Node(node) = removal { node == lt_arg.syntax() } else { false } + }), + _ => false, + }); + + if will_be_empty && generic_arg_list.generic_args().next().is_some() { + removals.retain(|removal| { + if let NodeOrToken::Node(node) = removal { + !node.ancestors().any(|anc| anc == *generic_arg_list.syntax()) + } else { + true + } + }); + removals.push(NodeOrToken::Node(generic_arg_list.syntax().clone())); + } + } + for (old, new) in replacements { editor.replace(old, new); } @@ -946,6 +1004,48 @@ trait Tr { ); } + #[test] + fn inline_types_with_lifetime() { + check_assist( + inline_type_alias_uses, + r#" +struct A<'a, 'b>(pub &'a mut &'b mut ()); + +type $0T<'a, 'b> = A<'a, 'b>; + +fn foo(_: T) {} +"#, + r#" +struct A<'a, 'b>(pub &'a mut &'b mut ()); + + + +fn foo(_: A) {} +"#, + ); + } + + #[test] + fn mixed_lifetime_and_type_args() { + check_assist( + inline_type_alias, + r#" +type Foo<'a, T> = Bar<'a, T>; +struct Bar<'a, T>(&'a T); +fn main() { + let a: $0Foo<u32>; +} +"#, + r#" +type Foo<'a, T> = Bar<'a, T>; +struct Bar<'a, T>(&'a T); +fn main() { + let a: Bar<u32>; +} +"#, + ); + } + mod inline_type_alias_uses { use crate::{handlers::inline_type_alias::inline_type_alias_uses, tests::check_assist}; diff --git a/crates/ide-assists/src/handlers/introduce_named_lifetime.rs b/crates/ide-assists/src/handlers/introduce_named_lifetime.rs index 264e3767a2..854e9561d2 100644 --- a/crates/ide-assists/src/handlers/introduce_named_lifetime.rs +++ b/crates/ide-assists/src/handlers/introduce_named_lifetime.rs @@ -1,11 +1,12 @@ -use ide_db::FxHashSet; +use ide_db::{FileId, FxHashSet}; use syntax::{ - AstNode, TextRange, - ast::{self, HasGenericParams, edit_in_place::GenericParamsOwnerEdit, make}, - ted::{self, Position}, + AstNode, SmolStr, T, TextRange, ToSmolStr, + ast::{self, HasGenericParams, HasName, syntax_factory::SyntaxFactory}, + format_smolstr, + syntax_editor::{Element, Position, SyntaxEditor}, }; -use crate::{AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder}; +use crate::{AssistContext, AssistId, Assists}; static ASSIST_NAME: &str = "introduce_named_lifetime"; static ASSIST_LABEL: &str = "Introduce named lifetime"; @@ -38,100 +39,108 @@ pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext<'_ // FIXME: should also add support for the case fun(f: &Foo) -> &$0Foo let lifetime = ctx.find_node_at_offset::<ast::Lifetime>().filter(|lifetime| lifetime.text() == "'_")?; + let file_id = ctx.vfs_file_id(); let lifetime_loc = lifetime.lifetime_ident_token()?.text_range(); if let Some(fn_def) = lifetime.syntax().ancestors().find_map(ast::Fn::cast) { - generate_fn_def_assist(acc, fn_def, lifetime_loc, lifetime) + generate_fn_def_assist(acc, fn_def, lifetime_loc, lifetime, file_id) } else if let Some(impl_def) = lifetime.syntax().ancestors().find_map(ast::Impl::cast) { - generate_impl_def_assist(acc, impl_def, lifetime_loc, lifetime) + generate_impl_def_assist(acc, impl_def, lifetime_loc, lifetime, file_id) } else { None } } -/// Generate the assist for the fn def case +/// Given a type parameter list, generate a unique lifetime parameter name +/// which is not in the list +fn generate_unique_lifetime_param_name( + existing_params: Option<ast::GenericParamList>, +) -> Option<SmolStr> { + let used_lifetime_param: FxHashSet<SmolStr> = existing_params + .iter() + .flat_map(|params| params.lifetime_params()) + .map(|p| p.syntax().text().to_smolstr()) + .collect(); + ('a'..='z').map(|c| format_smolstr!("'{c}")).find(|lt| !used_lifetime_param.contains(lt)) +} + fn generate_fn_def_assist( acc: &mut Assists, fn_def: ast::Fn, lifetime_loc: TextRange, lifetime: ast::Lifetime, + file_id: FileId, ) -> Option<()> { - let param_list: ast::ParamList = fn_def.param_list()?; - let new_lifetime_param = generate_unique_lifetime_param_name(fn_def.generic_param_list())?; + let param_list = fn_def.param_list()?; + let new_lifetime_name = generate_unique_lifetime_param_name(fn_def.generic_param_list())?; let self_param = - // use the self if it's a reference and has no explicit lifetime param_list.self_param().filter(|p| p.lifetime().is_none() && p.amp_token().is_some()); - // compute the location which implicitly has the same lifetime as the anonymous lifetime + let loc_needing_lifetime = if let Some(self_param) = self_param { - // if we have a self reference, use that Some(NeedsLifetime::SelfParam(self_param)) } else { - // otherwise, if there's a single reference parameter without a named lifetime, use that - let fn_params_without_lifetime: Vec<_> = param_list + let unnamed_refs: Vec<_> = param_list .params() .filter_map(|param| match param.ty() { - Some(ast::Type::RefType(ascribed_type)) if ascribed_type.lifetime().is_none() => { - Some(NeedsLifetime::RefType(ascribed_type)) + Some(ast::Type::RefType(ref_type)) if ref_type.lifetime().is_none() => { + Some(NeedsLifetime::RefType(ref_type)) } _ => None, }) .collect(); - match fn_params_without_lifetime.len() { - 1 => Some(fn_params_without_lifetime.into_iter().next()?), + + match unnamed_refs.len() { + 1 => Some(unnamed_refs.into_iter().next()?), 0 => None, - // multiple unnamed is invalid. assist is not applicable _ => return None, } }; - acc.add(AssistId::refactor(ASSIST_NAME), ASSIST_LABEL, lifetime_loc, |builder| { - let fn_def = builder.make_mut(fn_def); - let lifetime = builder.make_mut(lifetime); - let loc_needing_lifetime = - loc_needing_lifetime.and_then(|it| it.make_mut(builder).to_position()); - - fn_def.get_or_create_generic_param_list().add_generic_param( - make::lifetime_param(new_lifetime_param.clone()).clone_for_update().into(), - ); - ted::replace(lifetime.syntax(), new_lifetime_param.clone_for_update().syntax()); - if let Some(position) = loc_needing_lifetime { - ted::insert(position, new_lifetime_param.clone_for_update().syntax()); + + acc.add(AssistId::refactor(ASSIST_NAME), ASSIST_LABEL, lifetime_loc, |edit| { + let root = fn_def.syntax().ancestors().last().unwrap().clone(); + let mut editor = SyntaxEditor::new(root); + let factory = SyntaxFactory::with_mappings(); + + if let Some(generic_list) = fn_def.generic_param_list() { + insert_lifetime_param(&mut editor, &factory, &generic_list, &new_lifetime_name); + } else { + insert_new_generic_param_list_fn(&mut editor, &factory, &fn_def, &new_lifetime_name); } + + editor.replace(lifetime.syntax(), factory.lifetime(&new_lifetime_name).syntax()); + + if let Some(pos) = loc_needing_lifetime.and_then(|l| l.to_position()) { + editor.insert_all( + pos, + vec![ + factory.lifetime(&new_lifetime_name).syntax().clone().into(), + factory.whitespace(" ").into(), + ], + ); + } + + edit.add_file_edits(file_id, editor); }) } -/// Generate the assist for the impl def case -fn generate_impl_def_assist( - acc: &mut Assists, - impl_def: ast::Impl, - lifetime_loc: TextRange, - lifetime: ast::Lifetime, +fn insert_new_generic_param_list_fn( + editor: &mut SyntaxEditor, + factory: &SyntaxFactory, + fn_def: &ast::Fn, + lifetime_name: &str, ) -> Option<()> { - let new_lifetime_param = generate_unique_lifetime_param_name(impl_def.generic_param_list())?; - acc.add(AssistId::refactor(ASSIST_NAME), ASSIST_LABEL, lifetime_loc, |builder| { - let impl_def = builder.make_mut(impl_def); - let lifetime = builder.make_mut(lifetime); + let name = fn_def.name()?; - impl_def.get_or_create_generic_param_list().add_generic_param( - make::lifetime_param(new_lifetime_param.clone()).clone_for_update().into(), - ); - ted::replace(lifetime.syntax(), new_lifetime_param.clone_for_update().syntax()); - }) -} + editor.insert_all( + Position::after(name.syntax()), + vec![ + factory.token(T![<]).syntax_element(), + factory.lifetime(lifetime_name).syntax().syntax_element(), + factory.token(T![>]).syntax_element(), + ], + ); -/// Given a type parameter list, generate a unique lifetime parameter name -/// which is not in the list -fn generate_unique_lifetime_param_name( - existing_type_param_list: Option<ast::GenericParamList>, -) -> Option<ast::Lifetime> { - match existing_type_param_list { - Some(type_params) => { - let used_lifetime_params: FxHashSet<_> = - type_params.lifetime_params().map(|p| p.syntax().text().to_string()).collect(); - ('a'..='z').map(|it| format!("'{it}")).find(|it| !used_lifetime_params.contains(it)) - } - None => Some("'a".to_owned()), - } - .map(|it| make::lifetime(&it)) + Some(()) } enum NeedsLifetime { @@ -140,13 +149,6 @@ enum NeedsLifetime { } impl NeedsLifetime { - fn make_mut(self, builder: &mut SourceChangeBuilder) -> Self { - match self { - Self::SelfParam(it) => Self::SelfParam(builder.make_mut(it)), - Self::RefType(it) => Self::RefType(builder.make_mut(it)), - } - } - fn to_position(self) -> Option<Position> { match self { Self::SelfParam(it) => Some(Position::after(it.amp_token()?)), @@ -155,6 +157,75 @@ impl NeedsLifetime { } } +fn generate_impl_def_assist( + acc: &mut Assists, + impl_def: ast::Impl, + lifetime_loc: TextRange, + lifetime: ast::Lifetime, + file_id: FileId, +) -> Option<()> { + let new_lifetime_name = generate_unique_lifetime_param_name(impl_def.generic_param_list())?; + + acc.add(AssistId::refactor(ASSIST_NAME), ASSIST_LABEL, lifetime_loc, |edit| { + let root = impl_def.syntax().ancestors().last().unwrap().clone(); + let mut editor = SyntaxEditor::new(root); + let factory = SyntaxFactory::without_mappings(); + + if let Some(generic_list) = impl_def.generic_param_list() { + insert_lifetime_param(&mut editor, &factory, &generic_list, &new_lifetime_name); + } else { + insert_new_generic_param_list_imp(&mut editor, &factory, &impl_def, &new_lifetime_name); + } + + editor.replace(lifetime.syntax(), factory.lifetime(&new_lifetime_name).syntax()); + + edit.add_file_edits(file_id, editor); + }) +} + +fn insert_new_generic_param_list_imp( + editor: &mut SyntaxEditor, + factory: &SyntaxFactory, + impl_: &ast::Impl, + lifetime_name: &str, +) -> Option<()> { + let impl_kw = impl_.impl_token()?; + + editor.insert_all( + Position::after(impl_kw), + vec![ + factory.token(T![<]).syntax_element(), + factory.lifetime(lifetime_name).syntax().syntax_element(), + factory.token(T![>]).syntax_element(), + ], + ); + + Some(()) +} + +fn insert_lifetime_param( + editor: &mut SyntaxEditor, + factory: &SyntaxFactory, + generic_list: &ast::GenericParamList, + lifetime_name: &str, +) -> Option<()> { + let r_angle = generic_list.r_angle_token()?; + let needs_comma = generic_list.generic_params().next().is_some(); + + let mut elements = Vec::new(); + + if needs_comma { + elements.push(factory.token(T![,]).syntax_element()); + elements.push(factory.whitespace(" ").syntax_element()); + } + + let lifetime = factory.lifetime(lifetime_name); + elements.push(lifetime.syntax().clone().into()); + + editor.insert_all(Position::before(r_angle), elements); + Some(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/ide-assists/src/handlers/invert_if.rs b/crates/ide-assists/src/handlers/invert_if.rs index bf82d8df9b..c8cb7bb60f 100644 --- a/crates/ide-assists/src/handlers/invert_if.rs +++ b/crates/ide-assists/src/handlers/invert_if.rs @@ -1,13 +1,13 @@ use ide_db::syntax_helpers::node_ext::is_pattern_cond; use syntax::{ T, - ast::{self, AstNode}, + ast::{self, AstNode, syntax_factory::SyntaxFactory}, }; use crate::{ AssistId, assist_context::{AssistContext, Assists}, - utils::invert_boolean_expression_legacy, + utils::invert_boolean_expression, }; // Assist: invert_if @@ -50,7 +50,8 @@ pub(crate) fn invert_if(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<() }; acc.add(AssistId::refactor_rewrite("invert_if"), "Invert if", if_range, |edit| { - let flip_cond = invert_boolean_expression_legacy(cond.clone()); + let make = SyntaxFactory::without_mappings(); + let flip_cond = invert_boolean_expression(&make, cond.clone()); edit.replace_ast(cond, flip_cond); let else_node = else_block.syntax(); diff --git a/crates/ide-assists/src/handlers/merge_imports.rs b/crates/ide-assists/src/handlers/merge_imports.rs index 9ba73d23dd..42bc05811f 100644 --- a/crates/ide-assists/src/handlers/merge_imports.rs +++ b/crates/ide-assists/src/handlers/merge_imports.rs @@ -49,8 +49,9 @@ pub(crate) fn merge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio SyntaxElement::Node(n) => n, SyntaxElement::Token(t) => t.parent()?, }; - let mut selected_nodes = - parent_node.children().filter(|it| selection_range.contains_range(it.text_range())); + let mut selected_nodes = parent_node.children().filter(|it| { + selection_range.intersect(it.text_range()).is_some_and(|it| !it.is_empty()) + }); let first_selected = selected_nodes.next()?; let edits = match_ast! { @@ -678,6 +679,25 @@ use std::fmt::Result; } #[test] + fn merge_partial_selection_uses() { + cov_mark::check!(merge_with_selected_use_item_neighbors); + check_assist( + merge_imports, + r" +use std::fmt::Error; +$0use std::fmt::Display; +use std::fmt::Debug; +use std::fmt::Write; +use$0 std::fmt::Result; +", + r" +use std::fmt::Error; +use std::fmt::{Debug, Display, Result, Write}; +", + ); + } + + #[test] fn merge_selection_use_trees() { cov_mark::check!(merge_with_selected_use_tree_neighbors); check_assist( diff --git a/crates/ide-assists/src/handlers/merge_match_arms.rs b/crates/ide-assists/src/handlers/merge_match_arms.rs index 08170f81b2..6e84af5f91 100644 --- a/crates/ide-assists/src/handlers/merge_match_arms.rs +++ b/crates/ide-assists/src/handlers/merge_match_arms.rs @@ -160,8 +160,11 @@ fn get_arm_types<'db>( } } ast::Pat::IdentPat(ident_pat) => { - if let Some(name) = ident_pat.name() { + if let Some(name) = ident_pat.name() + && ctx.sema.to_def(ident_pat).is_some() + { let pat_type = ctx.sema.type_of_binding_in_pat(ident_pat); + map.insert(name.text().to_string(), pat_type); } } @@ -213,6 +216,40 @@ fn main() { } #[test] + fn merge_match_arms_ambiguous_ident_patterns() { + check_assist( + merge_match_arms, + r#" +#[derive(Debug)] +enum X { A, B, C } +use X::*; + +fn main() { + let x = A; + let y = match x { + A => { 1i32$0 } + B => { 1i32 } + C => { 2i32 } + } +} +"#, + r#" +#[derive(Debug)] +enum X { A, B, C } +use X::*; + +fn main() { + let x = A; + let y = match x { + A | B => { 1i32 }, + C => { 2i32 } + } +} +"#, + ); + } + + #[test] fn merge_match_arms_multiple_patterns() { check_assist( merge_match_arms, diff --git a/crates/ide-assists/src/handlers/move_bounds.rs b/crates/ide-assists/src/handlers/move_bounds.rs index e5425abab0..79b8bd5d3d 100644 --- a/crates/ide-assists/src/handlers/move_bounds.rs +++ b/crates/ide-assists/src/handlers/move_bounds.rs @@ -1,11 +1,8 @@ use either::Either; use syntax::{ - ast::{ - self, AstNode, HasName, HasTypeBounds, - edit_in_place::{GenericParamsOwnerEdit, Removable}, - make, - }, + ast::{self, AstNode, HasName, HasTypeBounds, syntax_factory::SyntaxFactory}, match_ast, + syntax_editor::{GetOrCreateWhereClause, Removable}, }; use crate::{AssistContext, AssistId, Assists}; @@ -47,18 +44,23 @@ pub(crate) fn move_bounds_to_where_clause( AssistId::refactor_rewrite("move_bounds_to_where_clause"), "Move to where clause", target, - |edit| { - let type_param_list = edit.make_mut(type_param_list); - let parent = edit.make_syntax_mut(parent); - - let where_clause: ast::WhereClause = match_ast! { - match parent { - ast::Fn(it) => it.get_or_create_where_clause(), - ast::Trait(it) => it.get_or_create_where_clause(), - ast::Impl(it) => it.get_or_create_where_clause(), - ast::Enum(it) => it.get_or_create_where_clause(), - ast::Struct(it) => it.get_or_create_where_clause(), - ast::TypeAlias(it) => it.get_or_create_where_clause(), + |builder| { + let mut edit = builder.make_editor(&parent); + let make = SyntaxFactory::without_mappings(); + + let new_preds: Vec<ast::WherePred> = type_param_list + .generic_params() + .filter_map(|param| build_predicate(param, &make)) + .collect(); + + match_ast! { + match (&parent) { + ast::Fn(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Trait(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Impl(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Enum(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Struct(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::TypeAlias(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), _ => return, } }; @@ -70,25 +72,22 @@ pub(crate) fn move_bounds_to_where_clause( ast::GenericParam::ConstParam(_) => continue, }; if let Some(tbl) = param.type_bound_list() { - if let Some(predicate) = build_predicate(generic_param) { - where_clause.add_predicate(predicate) - } - tbl.remove() + tbl.remove(&mut edit); } } + + builder.add_file_edits(ctx.vfs_file_id(), edit); }, ) } -fn build_predicate(param: ast::GenericParam) -> Option<ast::WherePred> { +fn build_predicate(param: ast::GenericParam, make: &SyntaxFactory) -> Option<ast::WherePred> { let target = match ¶m { - ast::GenericParam::TypeParam(t) => { - Either::Right(make::ty_path(make::ext::ident_path(&t.name()?.to_string()))) - } + ast::GenericParam::TypeParam(t) => Either::Right(make.ty(&t.name()?.to_string())), ast::GenericParam::LifetimeParam(l) => Either::Left(l.lifetime()?), ast::GenericParam::ConstParam(_) => return None, }; - let predicate = make::where_pred( + let predicate = make.where_pred( target, match param { ast::GenericParam::TypeParam(t) => t.type_bound_list()?, @@ -97,7 +96,7 @@ fn build_predicate(param: ast::GenericParam) -> Option<ast::WherePred> { } .bounds(), ); - Some(predicate.clone_for_update()) + Some(predicate) } #[cfg(test)] diff --git a/crates/ide-assists/src/handlers/move_const_to_impl.rs b/crates/ide-assists/src/handlers/move_const_to_impl.rs index 102d7e6d53..b3e79e4663 100644 --- a/crates/ide-assists/src/handlers/move_const_to_impl.rs +++ b/crates/ide-assists/src/handlers/move_const_to_impl.rs @@ -2,7 +2,10 @@ use hir::{AsAssocItem, AssocItemContainer, FileRange, HasSource}; use ide_db::{assists::AssistId, defs::Definition, search::SearchScope}; use syntax::{ SyntaxKind, - ast::{self, AstNode, edit::IndentLevel, edit_in_place::Indent}, + ast::{ + self, AstNode, + edit::{AstNodeEdit, IndentLevel}, + }, }; use crate::assist_context::{AssistContext, Assists}; @@ -136,7 +139,8 @@ pub(crate) fn move_const_to_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> let indent = IndentLevel::from_node(parent_fn.syntax()); let const_ = const_.clone_for_update(); - const_.reindent_to(indent); + let const_ = const_.reset_indent(); + let const_ = const_.indent(indent); builder.insert(insert_offset, format!("\n{indent}{const_}{fixup}")); }, ) diff --git a/crates/ide-assists/src/handlers/move_guard.rs b/crates/ide-assists/src/handlers/move_guard.rs index 84f02bdfdb..80587372e5 100644 --- a/crates/ide-assists/src/handlers/move_guard.rs +++ b/crates/ide-assists/src/handlers/move_guard.rs @@ -1,8 +1,9 @@ -use itertools::Itertools; +use itertools::{Itertools, chain}; use syntax::{ SyntaxKind::WHITESPACE, + TextRange, ast::{ - AstNode, BlockExpr, ElseBranch, Expr, IfExpr, MatchArm, Pat, edit::AstNodeEdit, make, + AstNode, BlockExpr, ElseBranch, Expr, IfExpr, MatchArm, Pat, edit::AstNodeEdit, prec::ExprPrecedence, syntax_factory::SyntaxFactory, }, syntax_editor::Element, @@ -44,13 +45,27 @@ pub(crate) fn move_guard_to_arm_body(acc: &mut Assists, ctx: &AssistContext<'_>) cov_mark::hit!(move_guard_inapplicable_in_arm_body); return None; } - let space_before_guard = guard.syntax().prev_sibling_or_token(); + let rest_arms = rest_arms(&match_arm, ctx.selection_trimmed())?; + let space_before_delete = chain( + guard.syntax().prev_sibling_or_token(), + rest_arms.iter().filter_map(|it| it.syntax().prev_sibling_or_token()), + ); let space_after_arrow = match_arm.fat_arrow_token()?.next_sibling_or_token(); - let guard_condition = guard.condition()?.reset_indent(); let arm_expr = match_arm.expr()?; - let then_branch = crate::utils::wrap_block(&arm_expr); - let if_expr = make::expr_if(guard_condition, then_branch, None).indent(arm_expr.indent_level()); + let make = SyntaxFactory::without_mappings(); + let if_branch = chain([&match_arm], &rest_arms) + .rfold(None, |else_branch, arm| { + if let Some(guard) = arm.guard() { + let then_branch = crate::utils::wrap_block(&arm.expr()?, &make); + let guard_condition = guard.condition()?.reset_indent(); + Some(make.expr_if(guard_condition, then_branch, else_branch).into()) + } else { + arm.expr().map(|it| crate::utils::wrap_block(&it, &make).into()) + } + })? + .indent(arm_expr.indent_level()); + let ElseBranch::IfExpr(if_expr) = if_branch else { return None }; let target = guard.syntax().text_range(); acc.add( @@ -59,15 +74,18 @@ pub(crate) fn move_guard_to_arm_body(acc: &mut Assists, ctx: &AssistContext<'_>) target, |builder| { let mut edit = builder.make_editor(match_arm.syntax()); - if let Some(element) = space_before_guard - && element.kind() == WHITESPACE - { - edit.delete(element); + for element in space_before_delete { + if element.kind() == WHITESPACE { + edit.delete(element); + } + } + for rest_arm in &rest_arms { + edit.delete(rest_arm.syntax()); } if let Some(element) = space_after_arrow && element.kind() == WHITESPACE { - edit.replace(element, make::tokens::single_space()); + edit.replace(element, make.whitespace(" ")); } edit.delete(guard.syntax()); @@ -221,6 +239,25 @@ pub(crate) fn move_arm_cond_to_match_guard( ) } +fn rest_arms(match_arm: &MatchArm, selection: TextRange) -> Option<Vec<MatchArm>> { + match_arm + .parent_match() + .match_arm_list()? + .arms() + .skip_while(|it| it != match_arm) + .skip(1) + .take_while(move |it| { + selection.is_empty() || crate::utils::is_selected(it, selection, false) + }) + .take_while(move |it| { + it.pat() + .zip(match_arm.pat()) + .is_some_and(|(a, b)| a.syntax().text() == b.syntax().text()) + }) + .collect::<Vec<_>>() + .into() +} + // Parses an if-else-if chain to get the conditions and the then branches until we encounter an else // branch or the end. fn parse_if_chain(if_expr: IfExpr) -> Option<(Vec<(Expr, BlockExpr)>, Option<BlockExpr>)> { @@ -345,6 +382,115 @@ fn main() { } #[test] + fn move_multiple_guard_to_arm_body_works() { + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x @ 0..30 $0if x % 3 == 0 => false, + x @ 0..30 if x % 2 == 0 => true, + _ => false + } +} +"#, + r#" +fn main() { + match 92 { + x @ 0..30 => if x % 3 == 0 { + false + } else if x % 2 == 0 { + true + }, + _ => false + } +} +"#, + ); + + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x @ 0..30 $0if x % 3 == 0 => false, + x @ 0..30 if x % 2 == 0 => true, + x @ 0..30 => false, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x @ 0..30 => if x % 3 == 0 { + false + } else if x % 2 == 0 { + true + } else { + false + }, + _ => true + } +} +"#, + ); + + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x @ 0..30 if x % 3 == 0 => false, + x @ 0..30 $0if x % 2 == 0$0 => true, + x @ 0..30 => false, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x @ 0..30 if x % 3 == 0 => false, + x @ 0..30 => if x % 2 == 0 { + true + }, + x @ 0..30 => false, + _ => true + } +} +"#, + ); + + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x @ 0..30 $0if x % 3 == 0 => false, + x @ 0..30 $0if x % 2 == 0 => true, + x @ 0..30 => false, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x @ 0..30 => if x % 3 == 0 { + false + } else if x % 2 == 0 { + true + }, + x @ 0..30 => false, + _ => true + } +} +"#, + ); + } + + #[test] fn move_guard_to_block_arm_body_works() { check_assist( move_guard_to_arm_body, @@ -422,7 +568,8 @@ fn main() { match 92 { x => if true && true - && true { + && true + { { false } diff --git a/crates/ide-assists/src/handlers/promote_local_to_const.rs b/crates/ide-assists/src/handlers/promote_local_to_const.rs index 547d3686e3..483c90d103 100644 --- a/crates/ide-assists/src/handlers/promote_local_to_const.rs +++ b/crates/ide-assists/src/handlers/promote_local_to_const.rs @@ -8,7 +8,7 @@ use syntax::{ use crate::{ assist_context::{AssistContext, Assists}, - utils::{self}, + utils, }; // Assist: promote_local_to_const diff --git a/crates/ide-assists/src/handlers/qualify_method_call.rs b/crates/ide-assists/src/handlers/qualify_method_call.rs index 495a84d62b..8b9e6570e9 100644 --- a/crates/ide-assists/src/handlers/qualify_method_call.rs +++ b/crates/ide-assists/src/handlers/qualify_method_call.rs @@ -1,6 +1,6 @@ use hir::{AsAssocItem, AssocItem, AssocItemContainer, ItemInNs, ModuleDef, db::HirDatabase}; use ide_db::assists::AssistId; -use syntax::{AstNode, ast}; +use syntax::{AstNode, ast, ast::syntax_factory::SyntaxFactory}; use crate::{ assist_context::{AssistContext, Assists}, @@ -52,19 +52,25 @@ pub(crate) fn qualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) -> cfg, )?; - let qualify_candidate = QualifyCandidate::ImplMethod(ctx.sema.db, call, resolved_call); + let qualify_candidate = QualifyCandidate::ImplMethod(ctx.sema.db, call.clone(), resolved_call); acc.add( AssistId::refactor_rewrite("qualify_method_call"), format!("Qualify `{ident}` method call"), range, |builder| { + let make = SyntaxFactory::with_mappings(); + let mut editor = builder.make_editor(call.syntax()); qualify_candidate.qualify( - |replace_with: String| builder.replace(range, replace_with), + |_| {}, + &mut editor, + &make, &receiver_path, item_in_ns, current_edition, - ) + ); + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); Some(()) diff --git a/crates/ide-assists/src/handlers/qualify_path.rs b/crates/ide-assists/src/handlers/qualify_path.rs index b3cf296965..c059f758c4 100644 --- a/crates/ide-assists/src/handlers/qualify_path.rs +++ b/crates/ide-assists/src/handlers/qualify_path.rs @@ -11,7 +11,8 @@ use syntax::Edition; use syntax::ast::HasGenericArgs; use syntax::{ AstNode, ast, - ast::{HasArgList, make}, + ast::{HasArgList, syntax_factory::SyntaxFactory}, + syntax_editor::SyntaxEditor, }; use crate::{ @@ -54,25 +55,25 @@ pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option let qualify_candidate = match candidate { ImportCandidate::Path(candidate) if !candidate.qualifier.is_empty() => { cov_mark::hit!(qualify_path_qualifier_start); - let path = ast::Path::cast(syntax_under_caret)?; + let path = ast::Path::cast(syntax_under_caret.clone())?; let (prev_segment, segment) = (path.qualifier()?.segment()?, path.segment()?); QualifyCandidate::QualifierStart(segment, prev_segment.generic_arg_list()) } ImportCandidate::Path(_) => { cov_mark::hit!(qualify_path_unqualified_name); - let path = ast::Path::cast(syntax_under_caret)?; + let path = ast::Path::cast(syntax_under_caret.clone())?; let generics = path.segment()?.generic_arg_list(); QualifyCandidate::UnqualifiedName(generics) } ImportCandidate::TraitAssocItem(_) => { cov_mark::hit!(qualify_path_trait_assoc_item); - let path = ast::Path::cast(syntax_under_caret)?; + let path = ast::Path::cast(syntax_under_caret.clone())?; let (qualifier, segment) = (path.qualifier()?, path.segment()?); QualifyCandidate::TraitAssocItem(qualifier, segment) } ImportCandidate::TraitMethod(_) => { cov_mark::hit!(qualify_path_trait_method); - let mcall_expr = ast::MethodCallExpr::cast(syntax_under_caret)?; + let mcall_expr = ast::MethodCallExpr::cast(syntax_under_caret.clone())?; QualifyCandidate::TraitMethod(ctx.sema.db, mcall_expr) } }; @@ -101,12 +102,18 @@ pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option label(ctx.db(), candidate, &import, current_edition), range, |builder| { + let make = SyntaxFactory::with_mappings(); + let mut editor = builder.make_editor(&syntax_under_caret); qualify_candidate.qualify( |replace_with: String| builder.replace(range, replace_with), + &mut editor, + &make, &import.import_path, import.item_to_import, current_edition, - ) + ); + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); } @@ -124,6 +131,8 @@ impl QualifyCandidate<'_> { pub(crate) fn qualify( &self, mut replacer: impl FnMut(String), + editor: &mut SyntaxEditor, + make: &SyntaxFactory, import: &hir::ModPath, item: hir::ItemInNs, edition: Edition, @@ -142,10 +151,10 @@ impl QualifyCandidate<'_> { replacer(format!("<{qualifier} as {import}>::{segment}")); } QualifyCandidate::TraitMethod(db, mcall_expr) => { - Self::qualify_trait_method(db, mcall_expr, replacer, import, item); + Self::qualify_trait_method(db, mcall_expr, editor, make, import, item); } QualifyCandidate::ImplMethod(db, mcall_expr, hir_fn) => { - Self::qualify_fn_call(db, mcall_expr, replacer, import, hir_fn); + Self::qualify_fn_call(db, mcall_expr, editor, make, import, hir_fn); } } } @@ -153,7 +162,8 @@ impl QualifyCandidate<'_> { fn qualify_fn_call( db: &RootDatabase, mcall_expr: &ast::MethodCallExpr, - mut replacer: impl FnMut(String), + editor: &mut SyntaxEditor, + make: &SyntaxFactory, import: ast::Path, hir_fn: &hir::Function, ) -> Option<()> { @@ -165,15 +175,17 @@ impl QualifyCandidate<'_> { if let Some(self_access) = hir_fn.self_param(db).map(|sp| sp.access(db)) { let receiver = match self_access { - hir::Access::Shared => make::expr_ref(receiver, false), - hir::Access::Exclusive => make::expr_ref(receiver, true), + hir::Access::Shared => make.expr_ref(receiver, false), + hir::Access::Exclusive => make.expr_ref(receiver, true), hir::Access::Owned => receiver, }; let arg_list = match arg_list { - Some(args) => make::arg_list(iter::once(receiver).chain(args)), - None => make::arg_list(iter::once(receiver)), + Some(args) => make.arg_list(iter::once(receiver).chain(args)), + None => make.arg_list(iter::once(receiver)), }; - replacer(format!("{import}::{method_name}{generics}{arg_list}")); + let call_path = make.path_from_text(&format!("{import}::{method_name}{generics}")); + let call_expr = make.expr_call(make.expr_path(call_path), arg_list); + editor.replace(mcall_expr.syntax(), call_expr.syntax()); } Some(()) } @@ -181,14 +193,15 @@ impl QualifyCandidate<'_> { fn qualify_trait_method( db: &RootDatabase, mcall_expr: &ast::MethodCallExpr, - replacer: impl FnMut(String), + editor: &mut SyntaxEditor, + make: &SyntaxFactory, import: ast::Path, item: hir::ItemInNs, ) -> Option<()> { let trait_method_name = mcall_expr.name_ref()?; let trait_ = item_as_trait(db, item)?; let method = find_trait_method(db, trait_, &trait_method_name)?; - Self::qualify_fn_call(db, mcall_expr, replacer, import, &method) + Self::qualify_fn_call(db, mcall_expr, editor, make, import, &method) } } diff --git a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs index 11b3fd22fa..f54f7a02d2 100644 --- a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs +++ b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use syntax::{ SyntaxKind::WHITESPACE, T, - ast::{self, AstNode, HasName, make}, + ast::{self, AstNode, HasName, syntax_factory::SyntaxFactory}, syntax_editor::{Position, SyntaxEditor}, }; @@ -12,8 +12,8 @@ use crate::{ AssistConfig, AssistId, assist_context::{AssistContext, Assists}, utils::{ - DefaultMethods, IgnoreAssocItems, add_trait_assoc_items_to_impl, filter_assoc_items, - gen_trait_fn_body, generate_trait_impl, + DefaultMethods, IgnoreAssocItems, add_trait_assoc_items_to_impl_with_factory, + filter_assoc_items, gen_trait_fn_body, generate_trait_impl, }, }; @@ -127,6 +127,7 @@ fn add_assist( let label = format!("Convert to manual `impl {replace_trait_path} for {annotated_name}`"); acc.add(AssistId::refactor("replace_derive_with_manual_impl"), label, target, |builder| { + let make = SyntaxFactory::without_mappings(); let insert_after = Position::after(adt.syntax()); let impl_is_unsafe = trait_.map(|s| s.is_unsafe(ctx.db())).unwrap_or(false); let impl_def = impl_def_from_trait( @@ -142,7 +143,7 @@ fn add_assist( let mut editor = builder.make_editor(attr.syntax()); update_attribute(&mut editor, old_derives, old_tree, old_trait_path, attr); - let trait_path = make::ty_path(replace_trait_path.clone()); + let trait_path = make.ty_path(replace_trait_path.clone()).into(); let (impl_def, first_assoc_item) = if let Some(impl_def) = impl_def { ( @@ -150,7 +151,7 @@ fn add_assist( impl_def.assoc_item_list().and_then(|list| list.assoc_items().next()), ) } else { - (generate_trait_impl(impl_is_unsafe, adt, trait_path), None) + (generate_trait_impl(&make, impl_is_unsafe, adt, trait_path), None) }; if let Some(cap) = ctx.config.snippet_cap { @@ -174,7 +175,7 @@ fn add_assist( editor.insert_all( insert_after, - vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()], + vec![make.whitespace("\n\n").into(), impl_def.syntax().clone().into()], ); builder.add_file_edits(ctx.vfs_file_id(), editor); }) @@ -205,10 +206,19 @@ fn impl_def_from_trait( if trait_items.is_empty() { return None; } - let impl_def = generate_trait_impl(impl_is_unsafe, adt, make::ty_path(trait_path.clone())); - - let assoc_items = - add_trait_assoc_items_to_impl(sema, config, &trait_items, trait_, &impl_def, &target_scope); + let make = SyntaxFactory::without_mappings(); + let trait_ty = make.ty_path(trait_path.clone()).into(); + let impl_def = generate_trait_impl(&make, impl_is_unsafe, adt, trait_ty); + + let assoc_items = add_trait_assoc_items_to_impl_with_factory( + &make, + sema, + config, + &trait_items, + trait_, + &impl_def, + &target_scope, + ); let assoc_item_list = if let Some((first, other)) = assoc_items.split_first().map(|(first, other)| (first.clone_subtree(), other)) { @@ -222,12 +232,12 @@ fn impl_def_from_trait( } else { Some(first.clone()) }; - let items = first_item.into_iter().chain(other.iter().cloned()).collect(); - make::assoc_item_list(Some(items)) + let items: Vec<ast::AssocItem> = + first_item.into_iter().chain(other.iter().cloned()).collect(); + make.assoc_item_list(items) } else { - make::assoc_item_list(None) - } - .clone_for_update(); + make.assoc_item_list_empty() + }; let impl_def = impl_def.clone_subtree(); let mut editor = SyntaxEditor::new(impl_def.syntax().clone()); @@ -243,6 +253,7 @@ fn update_attribute( old_trait_path: &ast::Path, attr: &ast::Attr, ) { + let make = SyntaxFactory::without_mappings(); let new_derives = old_derives .iter() .filter(|t| t.to_string() != old_trait_path.to_string()) @@ -257,13 +268,13 @@ fn update_attribute( .collect::<Vec<_>>() }); // ...which are interspersed with ", " - let tt = Itertools::intersperse(tt, vec![make::token(T![,]), make::tokens::single_space()]); + let tt = Itertools::intersperse(tt, vec![make.token(T![,]), make.whitespace(" ")]); // ...wrap them into the appropriate `NodeOrToken` variant let tt = tt.flatten().map(syntax::NodeOrToken::Token); // ...and make them into a flat list of tokens let tt = tt.collect::<Vec<_>>(); - let new_tree = make::token_tree(T!['('], tt).clone_for_update(); + let new_tree = make.token_tree(T!['('], tt); editor.replace(old_tree.syntax(), new_tree.syntax()); } else { // Remove the attr and any trailing whitespace @@ -1308,6 +1319,29 @@ impl<T: Clone> Clone for Foo<T> { } #[test] + fn add_custom_impl_clone_generic_tuple_struct_with_associated() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive, deref +#[derive(Clo$0ne)] +struct Foo<T: core::ops::Deref>(T::Target); +"#, + r#" +struct Foo<T: core::ops::Deref>(T::Target); + +impl<T: core::ops::Deref + Clone> Clone for Foo<T> +where T::Target: Clone +{ + $0fn clone(&self) -> Self { + Self(self.0.clone()) + } +} +"#, + ) + } + + #[test] fn test_ignore_derive_macro_without_input() { check_assist_not_applicable( replace_derive_with_manual_impl, diff --git a/crates/ide-assists/src/handlers/replace_if_let_with_match.rs b/crates/ide-assists/src/handlers/replace_if_let_with_match.rs index b7e5344712..8ff30fce5b 100644 --- a/crates/ide-assists/src/handlers/replace_if_let_with_match.rs +++ b/crates/ide-assists/src/handlers/replace_if_let_with_match.rs @@ -3,7 +3,11 @@ use std::iter::successors; use ide_db::{RootDatabase, defs::NameClass, ty_filter::TryEnum}; use syntax::{ AstNode, Edition, SyntaxKind, T, TextRange, - ast::{self, HasName, edit::IndentLevel, edit_in_place::Indent, syntax_factory::SyntaxFactory}, + ast::{ + self, HasName, + edit::{AstNodeEdit, IndentLevel}, + syntax_factory::SyntaxFactory, + }, syntax_editor::SyntaxEditor, }; @@ -53,9 +57,8 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? { ast::ElseBranch::IfExpr(expr) => Some(expr), ast::ElseBranch::Block(block) => { - let block = unwrap_trivial_block(block).clone_for_update(); - block.reindent_to(IndentLevel(1)); - else_block = Some(block); + let block = unwrap_trivial_block(block); + else_block = Some(block.reset_indent().indent(IndentLevel(1))); None } }); @@ -82,12 +85,13 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' (Some(pat), guard) } }; - if let Some(guard) = &guard { - guard.dedent(indent); - guard.indent(IndentLevel(1)); - } - let body = if_expr.then_branch()?.clone_for_update(); - body.indent(IndentLevel(1)); + let guard = if let Some(guard) = &guard { + Some(guard.dedent(indent).indent(IndentLevel(1))) + } else { + guard + }; + + let body = if_expr.then_branch()?.indent(IndentLevel(1)); cond_bodies.push((cond, guard, body)); } @@ -109,7 +113,8 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' let else_arm = make_else_arm(ctx, &make, else_block, &cond_bodies); let make_match_arm = |(pat, guard, body): (_, Option<ast::Expr>, ast::BlockExpr)| { - body.reindent_to(IndentLevel::single()); + // Dedent from original position, then indent for match arm + let body = body.dedent(indent); let body = unwrap_trivial_block(body); match (pat, guard.map(|it| make.match_guard(it))) { (Some(pat), guard) => make.match_arm(pat, guard, body), @@ -122,8 +127,8 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' } }; let arms = cond_bodies.into_iter().map(make_match_arm).chain([else_arm]); - let match_expr = make.expr_match(scrutinee_to_be_expr, make.match_arm_list(arms)); - match_expr.indent(indent); + let expr = scrutinee_to_be_expr.reset_indent(); + let match_expr = make.expr_match(expr, make.match_arm_list(arms)).indent(indent); match_expr.into() }; @@ -131,10 +136,9 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' if_expr.syntax().parent().is_some_and(|it| ast::IfExpr::can_cast(it.kind())); let expr = if has_preceding_if_expr { // make sure we replace the `else if let ...` with a block so we don't end up with `else expr` - match_expr.dedent(indent); - match_expr.indent(IndentLevel(1)); - let block_expr = make.block_expr([], Some(match_expr)); - block_expr.indent(indent); + let block_expr = make + .block_expr([], Some(match_expr.dedent(indent).indent(IndentLevel(1)))) + .indent(indent); block_expr.into() } else { match_expr @@ -242,7 +246,7 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext<' first_arm.guard(), second_arm.guard(), )?; - let scrutinee = match_expr.expr()?; + let scrutinee = match_expr.expr()?.reset_indent(); let guard = guard.and_then(|it| it.condition()); let let_ = match &if_let_pat { @@ -267,10 +271,7 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext<' // wrap them in another BlockExpr. match expr { ast::Expr::BlockExpr(block) if block.modifier().is_none() => block, - expr => { - expr.indent(IndentLevel(1)); - make.block_expr([], Some(expr)) - } + expr => make.block_expr([], Some(expr.indent(IndentLevel(1)))), } }; @@ -292,18 +293,17 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext<' } else { condition }; - let then_expr = then_expr.clone_for_update(); - let else_expr = else_expr.clone_for_update(); - then_expr.reindent_to(IndentLevel::single()); - else_expr.reindent_to(IndentLevel::single()); + let then_expr = then_expr.reset_indent(); + let else_expr = else_expr.reset_indent(); let then_block = make_block_expr(then_expr); let else_expr = if is_empty_expr(&else_expr) { None } else { Some(else_expr) }; - let if_let_expr = make.expr_if( - condition, - then_block, - else_expr.map(make_block_expr).map(ast::ElseBranch::Block), - ); - if_let_expr.indent(IndentLevel::from_node(match_expr.syntax())); + let if_let_expr = make + .expr_if( + condition, + then_block, + else_expr.map(make_block_expr).map(ast::ElseBranch::Block), + ) + .indent(IndentLevel::from_node(match_expr.syntax())); let mut editor = builder.make_editor(match_expr.syntax()); editor.replace(match_expr.syntax(), if_let_expr.syntax()); @@ -848,6 +848,31 @@ fn foo(x: Option<i32>) { } #[test] + fn special_case_option_ref() { + check_assist( + replace_if_let_with_match, + r#" +//- minicore: option +fn foo(x: &Option<i32>) { + $0if let Some(x) = x { + println!("{}", x) + } else { + println!("none") + } +} +"#, + r#" +fn foo(x: &Option<i32>) { + match x { + Some(x) => println!("{}", x), + None => println!("none"), + } +} +"#, + ); + } + + #[test] fn special_case_inverted_option() { check_assist( replace_if_let_with_match, @@ -929,7 +954,9 @@ fn foo(x: Result<i32, ()>) { r#" fn main() { if true { - $0if let Ok(rel_path) = path.strip_prefix(root_path) { + $0if let Ok(rel_path) = path.strip_prefix(root_path) + .and(x) + { let rel_path = RelativePathBuf::from_path(rel_path) .ok()?; Some((*id, rel_path)) @@ -944,7 +971,9 @@ fn main() { r#" fn main() { if true { - match path.strip_prefix(root_path) { + match path.strip_prefix(root_path) + .and(x) + { Ok(rel_path) => { let rel_path = RelativePathBuf::from_path(rel_path) .ok()?; @@ -966,7 +995,9 @@ fn main() { r#" fn main() { if true { - $0if let Ok(rel_path) = path.strip_prefix(root_path) { + $0if let Ok(rel_path) = path.strip_prefix(root_path) + .and(x) + { Foo { x: 1 } @@ -981,7 +1012,9 @@ fn main() { r#" fn main() { if true { - match path.strip_prefix(root_path) { + match path.strip_prefix(root_path) + .and(x) + { Ok(rel_path) => { Foo { x: 1 @@ -996,7 +1029,34 @@ fn main() { } } "#, - ) + ); + + check_assist( + replace_if_let_with_match, + r#" +fn main() { + if true { + $0if true + && false + { + foo() + } + } +} +"#, + r#" +fn main() { + if true { + match true + && false + { + true => foo(), + false => (), + } + } +} +"#, + ); } #[test] @@ -1851,7 +1911,9 @@ fn foo(x: Result<i32, ()>) { r#" fn main() { if true { - $0match path.strip_prefix(root_path) { + $0match path.strip_prefix(root_path) + .and(x) + { Ok(rel_path) => Foo { x: 2 } @@ -1865,7 +1927,9 @@ fn main() { r#" fn main() { if true { - if let Ok(rel_path) = path.strip_prefix(root_path) { + if let Ok(rel_path) = path.strip_prefix(root_path) + .and(x) + { Foo { x: 2 } @@ -1884,7 +1948,9 @@ fn main() { r#" fn main() { if true { - $0match path.strip_prefix(root_path) { + $0match path.strip_prefix(root_path) + .and(x) + { Ok(rel_path) => { let rel_path = RelativePathBuf::from_path(rel_path) .ok()?; @@ -1902,7 +1968,9 @@ fn main() { r#" fn main() { if true { - if let Ok(rel_path) = path.strip_prefix(root_path) { + if let Ok(rel_path) = path.strip_prefix(root_path) + .and(x) + { let rel_path = RelativePathBuf::from_path(rel_path) .ok()?; Some((*id, rel_path)) diff --git a/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs b/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs index 5a2307739c..38d8c38ef2 100644 --- a/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs +++ b/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs @@ -1,8 +1,11 @@ use either::Either; use ide_db::syntax_helpers::suggest_name; -use syntax::ast::{self, AstNode, syntax_factory::SyntaxFactory}; +use syntax::ast::{self, AstNode, HasArgList, prec::ExprPrecedence, syntax_factory::SyntaxFactory}; -use crate::{AssistContext, AssistId, Assists, utils::cover_let_chain}; +use crate::{ + AssistContext, AssistId, Assists, + utils::{cover_let_chain, wrap_paren, wrap_paren_in_call}, +}; // Assist: replace_is_some_with_if_let_some // @@ -34,10 +37,12 @@ pub(crate) fn replace_is_method_with_if_let_method( _ => return None, }; - let name_ref = call_expr.name_ref()?; - match name_ref.text().as_str() { + let token = call_expr.name_ref()?.ident_token()?; + let method_kind = token.text().strip_suffix("_and").unwrap_or(token.text()); + match method_kind { "is_some" | "is_ok" => { let receiver = call_expr.receiver()?; + let make = SyntaxFactory::with_mappings(); let mut name_generator = suggest_name::NameGenerator::new_from_scope_locals( ctx.sema.scope(has_cond.syntax()), @@ -47,8 +52,9 @@ pub(crate) fn replace_is_method_with_if_let_method( } else { name_generator.for_variable(&receiver, &ctx.sema) }; + let (pat, predicate) = method_predicate(&call_expr, &var_name, &make); - let (assist_id, message, text) = if name_ref.text() == "is_some" { + let (assist_id, message, text) = if method_kind == "is_some" { ("replace_is_some_with_if_let_some", "Replace `is_some` with `let Some`", "Some") } else { ("replace_is_ok_with_if_let_ok", "Replace `is_ok` with `let Ok`", "Ok") @@ -59,22 +65,29 @@ pub(crate) fn replace_is_method_with_if_let_method( message, call_expr.syntax().text_range(), |edit| { - let make = SyntaxFactory::with_mappings(); let mut editor = edit.make_editor(call_expr.syntax()); - let var_pat = make.ident_pat(false, false, make.name(&var_name)); - let pat = make.tuple_struct_pat(make.ident_path(text), [var_pat.into()]); - let let_expr = make.expr_let(pat.into(), receiver); + let pat = make.tuple_struct_pat(make.ident_path(text), [pat]).into(); + let let_expr = make.expr_let(pat, receiver); if let Some(cap) = ctx.config.snippet_cap && let Some(ast::Pat::TupleStructPat(pat)) = let_expr.pat() && let Some(first_var) = pat.fields().next() + && predicate.is_none() { let placeholder = edit.make_placeholder_snippet(cap); editor.add_annotation(first_var.syntax(), placeholder); } - editor.replace(call_expr.syntax(), let_expr.syntax()); + let new_expr = if let Some(predicate) = predicate { + let op = ast::BinaryOp::LogicOp(ast::LogicOp::And); + let predicate = wrap_paren(predicate, &make, ExprPrecedence::LAnd); + make.expr_bin(let_expr.into(), op, predicate).into() + } else { + ast::Expr::from(let_expr) + }; + editor.replace(call_expr.syntax(), new_expr.syntax()); + editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, @@ -84,6 +97,26 @@ pub(crate) fn replace_is_method_with_if_let_method( } } +fn method_predicate( + call_expr: &ast::MethodCallExpr, + name: &str, + make: &SyntaxFactory, +) -> (ast::Pat, Option<ast::Expr>) { + let argument = call_expr.arg_list().and_then(|it| it.args().next()); + if let Some(ast::Expr::ClosureExpr(it)) = argument.clone() + && let Some(pat) = it.param_list().and_then(|it| it.params().next()?.pat()) + { + (pat, it.body()) + } else { + let pat = make.ident_pat(false, false, make.name(name)); + let expr = argument.map(|expr| { + let arg_list = make.arg_list([make.expr_path(make.ident_path(name))]); + make.expr_call(wrap_paren_in_call(expr, make), arg_list).into() + }); + (pat.into(), expr) + } +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -195,6 +228,73 @@ fn main() { } #[test] + fn replace_is_some_and_with_if_let_chain_some_works() { + check_assist( + replace_is_method_with_if_let_method, + r#" +fn main() { + let x = Some(1); + if x.is_som$0e_and(|it| it != 3) {} +} +"#, + r#" +fn main() { + let x = Some(1); + if let Some(it) = x && it != 3 {} +} +"#, + ); + + check_assist( + replace_is_method_with_if_let_method, + r#" +fn main() { + let x = Some(1); + if x.is_som$0e_and(|it| it != 3 || it > 10) {} +} +"#, + r#" +fn main() { + let x = Some(1); + if let Some(it) = x && (it != 3 || it > 10) {} +} +"#, + ); + + check_assist( + replace_is_method_with_if_let_method, + r#" +fn main() { + let x = Some(1); + if x.is_som$0e_and(predicate) {} +} +"#, + r#" +fn main() { + let x = Some(1); + if let Some(x1) = x && predicate(x1) {} +} +"#, + ); + + check_assist( + replace_is_method_with_if_let_method, + r#" +fn main() { + let x = Some(1); + if x.is_som$0e_and(func.f) {} +} +"#, + r#" +fn main() { + let x = Some(1); + if let Some(x1) = x && (func.f)(x1) {} +} +"#, + ); + } + + #[test] fn replace_is_some_with_if_let_some_in_let_chain() { check_assist( replace_is_method_with_if_let_method, diff --git a/crates/ide-assists/src/handlers/replace_let_with_if_let.rs b/crates/ide-assists/src/handlers/replace_let_with_if_let.rs index b95e9b52b0..6ff5f0bbd3 100644 --- a/crates/ide-assists/src/handlers/replace_let_with_if_let.rs +++ b/crates/ide-assists/src/handlers/replace_let_with_if_let.rs @@ -1,7 +1,11 @@ use ide_db::ty_filter::TryEnum; use syntax::{ AstNode, T, - ast::{self, edit::IndentLevel, edit_in_place::Indent, syntax_factory::SyntaxFactory}, + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + syntax_factory::SyntaxFactory, + }, }; use crate::{AssistContext, AssistId, Assists}; @@ -64,7 +68,7 @@ pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext<'_> if let_expr_needs_paren(&init) { make.expr_paren(init).into() } else { init }; let block = make.block_expr([], None); - block.indent(IndentLevel::from_node(let_stmt.syntax())); + let block = block.indent(IndentLevel::from_node(let_stmt.syntax())); let if_expr = make.expr_if( make.expr_let(pat, init_expr).into(), block, @@ -82,8 +86,8 @@ pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext<'_> } fn let_expr_needs_paren(expr: &ast::Expr) -> bool { - let fake_expr_let = - ast::make::expr_let(ast::make::tuple_pat(None).into(), ast::make::ext::expr_unit()); + let make = SyntaxFactory::without_mappings(); + let fake_expr_let = make.expr_let(make.tuple_pat(None).into(), make.expr_unit()); let Some(fake_expr) = fake_expr_let.expr() else { stdx::never!(); return false; @@ -98,6 +102,29 @@ mod tests { use super::*; #[test] + fn replace_let_try_enum_ref() { + check_assist( + replace_let_with_if_let, + r" +//- minicore: option +fn main(action: Action) { + $0let x = compute(); +} + +fn compute() -> &'static Option<i32> { &None } + ", + r" +fn main(action: Action) { + if let Some(x) = compute() { + } +} + +fn compute() -> &'static Option<i32> { &None } + ", + ) + } + + #[test] fn replace_let_unknown_enum() { check_assist( replace_let_with_if_let, diff --git a/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs b/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs index 6ca3e26ca0..6e4dd8cb73 100644 --- a/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs +++ b/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs @@ -2,10 +2,10 @@ use hir::Semantics; use ide_db::{RootDatabase, assists::AssistId, defs::Definition}; use syntax::{ AstNode, - ast::{self, Expr, HasArgList, make}, + ast::{self, Expr, HasArgList, make, syntax_factory::SyntaxFactory}, }; -use crate::{AssistContext, Assists}; +use crate::{AssistContext, Assists, utils::wrap_paren_in_call}; // Assist: replace_with_lazy_method // @@ -177,11 +177,7 @@ fn into_call(param: &Expr, sema: &Semantics<'_, RootDatabase>) -> Expr { } })() .unwrap_or_else(|| { - let callable = if needs_parens_in_call(param) { - make::expr_paren(param.clone()).into() - } else { - param.clone() - }; + let callable = wrap_paren_in_call(param.clone(), &SyntaxFactory::without_mappings()); make::expr_call(callable, make::arg_list(Vec::new())).into() }) } @@ -200,12 +196,6 @@ fn ends_is(name: &str, end: &str) -> bool { name.strip_suffix(end).is_some_and(|s| s.is_empty() || s.ends_with('_')) } -fn needs_parens_in_call(param: &Expr) -> bool { - let call = make::expr_call(make::ext::expr_unit(), make::arg_list(Vec::new())); - let callable = call.expr().expect("invalid make call"); - param.needs_parens_in_place_of(call.syntax(), callable.syntax()) -} - #[cfg(test)] mod tests { use crate::tests::check_assist; diff --git a/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs b/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs index df7057835c..018642a047 100644 --- a/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs +++ b/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs @@ -5,9 +5,10 @@ use ide_db::{ defs::Definition, search::{SearchScope, UsageSearchResult}, }; +use syntax::ast::syntax_factory::SyntaxFactory; use syntax::{ AstNode, - ast::{self, HasGenericParams, HasName, HasTypeBounds, Name, NameLike, PathType, make}, + ast::{self, HasGenericParams, HasName, HasTypeBounds, Name, NameLike, PathType}, match_ast, }; @@ -72,6 +73,7 @@ pub(crate) fn replace_named_generic_with_impl( target, |edit| { let mut editor = edit.make_editor(type_param.syntax()); + let make = SyntaxFactory::without_mappings(); // remove trait from generic param list if let Some(generic_params) = fn_.generic_param_list() { @@ -83,17 +85,14 @@ pub(crate) fn replace_named_generic_with_impl( if params.is_empty() { editor.delete(generic_params.syntax()); } else { - let new_generic_param_list = make::generic_param_list(params); - editor.replace( - generic_params.syntax(), - new_generic_param_list.syntax().clone_for_update(), - ); + let new_generic_param_list = make.generic_param_list(params); + editor.replace(generic_params.syntax(), new_generic_param_list.syntax()); } } - let new_bounds = make::impl_trait_type(type_bound_list); + let new_bounds = make.impl_trait_type(type_bound_list); for path_type in path_types_to_replace.iter().rev() { - editor.replace(path_type.syntax(), new_bounds.clone_for_update().syntax()); + editor.replace(path_type.syntax(), new_bounds.syntax()); } edit.add_file_edits(ctx.vfs_file_id(), editor); }, diff --git a/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs b/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs index 009fc077ce..cdf20586ef 100644 --- a/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs +++ b/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs @@ -102,7 +102,7 @@ fn target_path(ctx: &AssistContext<'_>, mut original_path: ast::Path) -> Option< } match ctx.sema.resolve_path(&original_path)? { - PathResolution::Def(ModuleDef::Variant(_)) if on_first => original_path.qualifier(), + PathResolution::Def(ModuleDef::EnumVariant(_)) if on_first => original_path.qualifier(), PathResolution::Def(def) if def.as_assoc_item(ctx.db()).is_some() => { on_first.then_some(original_path.qualifier()?) } diff --git a/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs b/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs index 60b0797f02..15143575e7 100644 --- a/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs +++ b/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs @@ -1,6 +1,7 @@ use ide_db::assists::AssistId; use syntax::{ - AstNode, SyntaxToken, T, + AstNode, SyntaxKind, SyntaxToken, T, + algo::{previous_non_trivia_token, skip_trivia_token}, ast::{self, syntax_factory::SyntaxFactory}, }; @@ -36,15 +37,18 @@ pub(crate) fn toggle_macro_delimiter(acc: &mut Assists, ctx: &AssistContext<'_>) RCur, } - let makro = ctx.find_node_at_offset::<ast::MacroCall>()?; + let token_tree = ctx.find_node_at_offset::<ast::TokenTree>()?; let cursor_offset = ctx.offset(); - let semicolon = macro_semicolon(&makro); - let token_tree = makro.token_tree()?; + let semicolon = macro_semicolon(&token_tree); let ltoken = token_tree.left_delimiter_token()?; let rtoken = token_tree.right_delimiter_token()?; + if !is_macro_call(&token_tree)? { + return None; + } + if !ltoken.text_range().contains(cursor_offset) && !rtoken.text_range().contains(cursor_offset) { return None; @@ -70,7 +74,7 @@ pub(crate) fn toggle_macro_delimiter(acc: &mut Assists, ctx: &AssistContext<'_>) token_tree.syntax().text_range(), |builder| { let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(makro.syntax()); + let mut editor = builder.make_editor(token_tree.syntax()); match token { MacroDelims::LPar | MacroDelims::RPar => { @@ -102,12 +106,21 @@ pub(crate) fn toggle_macro_delimiter(acc: &mut Assists, ctx: &AssistContext<'_>) ) } -fn macro_semicolon(makro: &ast::MacroCall) -> Option<SyntaxToken> { - makro.semicolon_token().or_else(|| { - let macro_expr = ast::MacroExpr::cast(makro.syntax().parent()?)?; - let expr_stmt = ast::ExprStmt::cast(macro_expr.syntax().parent()?)?; - expr_stmt.semicolon_token() - }) +fn is_macro_call(token_tree: &ast::TokenTree) -> Option<bool> { + let parent = token_tree.syntax().parent()?; + if ast::MacroCall::can_cast(parent.kind()) { + return Some(true); + } + + let token_tree = ast::TokenTree::cast(parent)?; + let prev = previous_non_trivia_token(token_tree.syntax().clone())?; + let prev_prev = previous_non_trivia_token(prev.clone())?; + Some(prev.kind() == T![!] && prev_prev.kind() == SyntaxKind::IDENT) +} + +fn macro_semicolon(token_tree: &ast::TokenTree) -> Option<SyntaxToken> { + let next_token = token_tree.syntax().last_token()?.next_token()?; + skip_trivia_token(next_token, syntax::Direction::Next).filter(|it| it.kind() == T![;]) } fn needs_semicolon(tt: ast::TokenTree) -> bool { @@ -402,10 +415,9 @@ prt!{(3 + 5)} ) } - // FIXME @alibektas : Inner macro_call is not seen as such. So this doesn't work. #[test] fn test_nested_macros() { - check_assist_not_applicable( + check_assist( toggle_macro_delimiter, r#" macro_rules! prt { @@ -420,7 +432,22 @@ macro_rules! abc { }}; } -prt!{abc!($03 + 5)}; +prt!{abc!$0(3 + 5)}; +"#, + r#" +macro_rules! prt { + ($e:expr) => {{ + println!("{}", stringify!{$e}); + }}; +} + +macro_rules! abc { + ($e:expr) => {{ + println!("{}", stringify!{$e}); + }}; +} + +prt!{abc!{3 + 5}}; "#, ) } diff --git a/crates/ide-assists/src/handlers/unmerge_match_arm.rs b/crates/ide-assists/src/handlers/unmerge_match_arm.rs index 7b0f2dc65a..c4c03d3e35 100644 --- a/crates/ide-assists/src/handlers/unmerge_match_arm.rs +++ b/crates/ide-assists/src/handlers/unmerge_match_arm.rs @@ -38,11 +38,18 @@ pub(crate) fn unmerge_match_arm(acc: &mut Assists, ctx: &AssistContext<'_>) -> O } let match_arm = ast::MatchArm::cast(or_pat.syntax().parent()?)?; let match_arm_body = match_arm.expr()?; + let pats_after = pipe_token + .siblings_with_tokens(Direction::Next) + .filter_map(|it| ast::Pat::cast(it.into_node()?)) + .collect::<Vec<_>>(); // We don't need to check for leading pipe because it is directly under `MatchArm` // without `OrPat`. let new_parent = match_arm.syntax().parent()?; + if pats_after.is_empty() { + return None; + } acc.add( AssistId::refactor_rewrite("unmerge_match_arm"), @@ -51,10 +58,6 @@ pub(crate) fn unmerge_match_arm(acc: &mut Assists, ctx: &AssistContext<'_>) -> O |edit| { let make = SyntaxFactory::with_mappings(); let mut editor = edit.make_editor(&new_parent); - let pats_after = pipe_token - .siblings_with_tokens(Direction::Next) - .filter_map(|it| ast::Pat::cast(it.into_node()?)) - .collect::<Vec<_>>(); // It is guaranteed that `pats_after` has at least one element let new_pat = if pats_after.len() == 1 { pats_after[0].clone() @@ -191,6 +194,21 @@ fn main() { } #[test] + fn unmerge_match_arm_trailing_pipe() { + check_assist_not_applicable( + unmerge_match_arm, + r#" +fn main() { + let y = match 0 { + 0 |$0 => { 1i32 } + 1 => { 2i32 } + }; +} +"#, + ); + } + + #[test] fn unmerge_match_arm_multiple_pipes() { check_assist( unmerge_match_arm, diff --git a/crates/ide-assists/src/handlers/unqualify_method_call.rs b/crates/ide-assists/src/handlers/unqualify_method_call.rs index a58b1da621..ef395791e2 100644 --- a/crates/ide-assists/src/handlers/unqualify_method_call.rs +++ b/crates/ide-assists/src/handlers/unqualify_method_call.rs @@ -1,8 +1,5 @@ use hir::AsAssocItem; -use syntax::{ - TextRange, - ast::{self, AstNode, HasArgList, prec::ExprPrecedence}, -}; +use syntax::ast::{self, AstNode, HasArgList, prec::ExprPrecedence, syntax_factory::SyntaxFactory}; use crate::{AssistContext, AssistId, Assists}; @@ -36,10 +33,7 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) } let args = call.arg_list()?; - let l_paren = args.l_paren_token()?; - let mut args_iter = args.args(); - let first_arg = args_iter.next()?; - let second_arg = args_iter.next(); + let first_arg = args.args().next()?; let qualifier = path.qualifier()?; let method_name = path.segment()?.name_ref()?; @@ -51,43 +45,33 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) return None; } - // `core::ops::Add::add(` -> `` - let delete_path = - TextRange::new(path.syntax().text_range().start(), l_paren.text_range().end()); - - // Parens around `expr` if needed - let parens = first_arg.precedence().needs_parentheses_in(ExprPrecedence::Postfix).then(|| { - let range = first_arg.syntax().text_range(); - (range.start(), range.end()) - }); - - // `, ` -> `.add(` - let replace_comma = TextRange::new( - first_arg.syntax().text_range().end(), - second_arg - .map(|a| a.syntax().text_range().start()) - .unwrap_or_else(|| first_arg.syntax().text_range().end()), - ); - acc.add( AssistId::refactor_rewrite("unqualify_method_call"), "Unqualify method call", call.syntax().text_range(), - |edit| { - edit.delete(delete_path); - if let Some((open, close)) = parens { - edit.insert(open, "("); - edit.insert(close, ")"); - } - edit.replace(replace_comma, format!(".{method_name}(")); + |builder| { + let make = SyntaxFactory::with_mappings(); + let mut editor = builder.make_editor(call.syntax()); + + let new_arg_list = make.arg_list(args.args().skip(1)); + let receiver = if first_arg.precedence().needs_parentheses_in(ExprPrecedence::Postfix) { + ast::Expr::from(make.expr_paren(first_arg.clone())) + } else { + first_arg.clone() + }; + let method_call = make.expr_method_call(receiver, method_name, new_arg_list); + + editor.replace(call.syntax(), method_call.syntax()); if let Some(fun) = fun.as_assoc_item(ctx.db()) && let Some(trait_) = fun.container_or_implemented_trait(ctx.db()) && !scope.can_use_trait_methods(trait_) { - // Only add an import for trait methods that are not already imported. - add_import(qualifier, ctx, edit); + add_import(qualifier, ctx, &make, &mut editor); } + + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } @@ -95,7 +79,8 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) fn add_import( qualifier: ast::Path, ctx: &AssistContext<'_>, - edit: &mut ide_db::source_change::SourceChangeBuilder, + make: &SyntaxFactory, + editor: &mut syntax::syntax_editor::SyntaxEditor, ) { if let Some(path_segment) = qualifier.segment() { // for `<i32 as std::ops::Add>` @@ -122,8 +107,13 @@ fn add_import( ); if let Some(scope) = scope { - let scope = edit.make_import_scope_mut(scope); - ide_db::imports::insert_use::insert_use(&scope, import, &ctx.config.insert_use); + ide_db::imports::insert_use::insert_use_with_editor( + &scope, + import, + &ctx.config.insert_use, + editor, + make, + ); } } } diff --git a/crates/ide-assists/src/handlers/unwrap_block.rs b/crates/ide-assists/src/handlers/unwrap_block.rs index e4f5e3523b..e029d7884f 100644 --- a/crates/ide-assists/src/handlers/unwrap_block.rs +++ b/crates/ide-assists/src/handlers/unwrap_block.rs @@ -45,6 +45,7 @@ pub(crate) fn unwrap_block(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option ast::LoopExpr(it) => it.syntax().clone(), ast::WhileExpr(it) => it.syntax().clone(), ast::MatchArm(it) => it.parent_match().syntax().clone(), + ast::LetElse(it) => it.syntax().parent()?, ast::LetStmt(it) => { replacement = wrap_let(&it, replacement); prefer_container = Some(it.syntax().clone()); @@ -557,6 +558,40 @@ fn main() { } #[test] + fn simple_let_else() { + check_assist( + unwrap_block, + r#" +fn main() { + let Some(2) = None else {$0 + return; + }; +} +"#, + r#" +fn main() { + return; +} +"#, + ); + check_assist( + unwrap_block, + r#" +fn main() { + let Some(2) = None else {$0 + return + }; +} +"#, + r#" +fn main() { + return +} +"#, + ); + } + + #[test] fn unwrap_match_arm() { check_assist( unwrap_block, diff --git a/crates/ide-assists/src/handlers/unwrap_tuple.rs b/crates/ide-assists/src/handlers/unwrap_tuple.rs index 46f3e85e12..e03274bbb3 100644 --- a/crates/ide-assists/src/handlers/unwrap_tuple.rs +++ b/crates/ide-assists/src/handlers/unwrap_tuple.rs @@ -1,3 +1,6 @@ +use std::iter; + +use either::Either; use syntax::{ AstNode, T, ast::{self, edit::AstNodeEdit}, @@ -24,11 +27,16 @@ use crate::{AssistContext, AssistId, Assists}; // ``` pub(crate) fn unwrap_tuple(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let let_kw = ctx.find_token_syntax_at_offset(T![let])?; - let let_stmt = let_kw.parent().and_then(ast::LetStmt::cast)?; - let indent_level = let_stmt.indent_level().0 as usize; - let pat = let_stmt.pat()?; - let ty = let_stmt.ty(); - let init = let_stmt.initializer()?; + let let_stmt = let_kw.parent().and_then(Either::<ast::LetStmt, ast::LetExpr>::cast)?; + let mut indent_level = let_stmt.indent_level(); + let pat = either::for_both!(&let_stmt, it => it.pat())?; + let (ty, init, prefix, suffix) = match &let_stmt { + Either::Left(let_stmt) => (let_stmt.ty(), let_stmt.initializer()?, "", ";"), + Either::Right(let_expr) => { + indent_level += 1; + (None, let_expr.expr()?, "&& ", "") + } + }; // This only applies for tuple patterns, types, and initializers. let tuple_pat = match pat { @@ -60,25 +68,19 @@ pub(crate) fn unwrap_tuple(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option "Unwrap tuple", let_kw.text_range(), |edit| { - let indents = " ".repeat(indent_level); + let mut decls = String::new(); // If there is an ascribed type, insert that type for each declaration, // otherwise, omit that type. - if let Some(tys) = tuple_ty { - let mut zipped_decls = String::new(); - for (pat, ty, expr) in - itertools::izip!(tuple_pat.fields(), tys.fields(), tuple_init.fields()) - { - zipped_decls.push_str(&format!("{indents}let {pat}: {ty} = {expr};\n")) - } - edit.replace(parent.text_range(), zipped_decls.trim()); - } else { - let mut zipped_decls = String::new(); - for (pat, expr) in itertools::izip!(tuple_pat.fields(), tuple_init.fields()) { - zipped_decls.push_str(&format!("{indents}let {pat} = {expr};\n")); - } - edit.replace(parent.text_range(), zipped_decls.trim()); + let tys = + tuple_ty.into_iter().flat_map(|it| it.fields().map(Some)).chain(iter::repeat(None)); + for (pat, ty, expr) in itertools::izip!(tuple_pat.fields(), tys, tuple_init.fields()) { + let ty = ty.map_or_else(String::new, |ty| format!(": {ty}")); + decls.push_str(&format!("{prefix}let {pat}{ty} = {expr}{suffix}\n{indent_level}")) } + + let s = decls.trim(); + edit.replace(parent.text_range(), s.strip_prefix(prefix).unwrap_or(s)); }, ) } @@ -124,6 +126,28 @@ fn main() { } #[test] + fn unwrap_tuples_in_let_expr() { + check_assist( + unwrap_tuple, + r#" +fn main() { + if $0let (foo, bar) = ("Foo", "Bar") { + code(); + } +} +"#, + r#" +fn main() { + if let foo = "Foo" + && let bar = "Bar" { + code(); + } +} +"#, + ); + } + + #[test] fn unwrap_tuple_with_types() { check_assist( unwrap_tuple, diff --git a/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs index 7d5740b748..36df4af31d 100644 --- a/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs +++ b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs @@ -2,7 +2,7 @@ use ide_db::source_change::SourceChangeBuilder; use itertools::Itertools; use syntax::{ NodeOrToken, SyntaxToken, T, TextRange, algo, - ast::{self, AstNode, make, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, edit::AstNodeEdit, make, syntax_factory::SyntaxFactory}, }; use crate::{AssistContext, AssistId, Assists}; @@ -27,7 +27,7 @@ use crate::{AssistContext, AssistId, Assists}; enum WrapUnwrapOption { WrapDerive { derive: TextRange, attr: ast::Attr }, - WrapAttr(ast::Attr), + WrapAttr(Vec<ast::Attr>), } /// Attempts to get the derive attribute from a derive attribute list @@ -102,9 +102,9 @@ fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption { if ident.parent().and_then(ast::TokenTree::cast).is_none() || !attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default() { - WrapUnwrapOption::WrapAttr(attr) + WrapUnwrapOption::WrapAttr(vec![attr]) } else { - attempt_attr().unwrap_or(WrapUnwrapOption::WrapAttr(attr)) + attempt_attr().unwrap_or_else(|| WrapUnwrapOption::WrapAttr(vec![attr])) } } pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { @@ -118,13 +118,27 @@ pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) - Some(attempt_get_derive(attr, ident)) } - (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(attr)), + (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(vec![attr])), _ => None, } } else { let covering_element = ctx.covering_element(); match covering_element { - NodeOrToken::Node(node) => ast::Attr::cast(node).map(WrapUnwrapOption::WrapAttr), + NodeOrToken::Node(node) => { + if let Some(attr) = ast::Attr::cast(node.clone()) { + Some(WrapUnwrapOption::WrapAttr(vec![attr])) + } else { + let attrs = node + .children() + .filter(|it| it.text_range().intersect(ctx.selection_trimmed()).is_some()) + .map(ast::Attr::cast) + .collect::<Option<Vec<_>>>()?; + if attrs.is_empty() { + return None; + } + Some(WrapUnwrapOption::WrapAttr(attrs)) + } + } NodeOrToken::Token(ident) if ident.kind() == syntax::T![ident] => { let attr = ident.parent_ancestors().find_map(ast::Attr::cast)?; Some(attempt_get_derive(attr, ident)) @@ -133,10 +147,12 @@ pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) - } }?; match option { - WrapUnwrapOption::WrapAttr(attr) if attr.simple_name().as_deref() == Some("cfg_attr") => { - unwrap_cfg_attr(acc, attr) - } - WrapUnwrapOption::WrapAttr(attr) => wrap_cfg_attr(acc, ctx, attr), + WrapUnwrapOption::WrapAttr(attrs) => match &attrs[..] { + [attr] if attr.simple_name().as_deref() == Some("cfg_attr") => { + unwrap_cfg_attr(acc, attrs.into_iter().next().unwrap()) + } + _ => wrap_cfg_attrs(acc, ctx, attrs), + }, WrapUnwrapOption::WrapDerive { derive, attr } => wrap_derive(acc, ctx, attr, derive), } } @@ -220,40 +236,51 @@ fn wrap_derive( ); Some(()) } -fn wrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>, attr: ast::Attr) -> Option<()> { - let range = attr.syntax().text_range(); - let path = attr.path()?; +fn wrap_cfg_attrs(acc: &mut Assists, ctx: &AssistContext<'_>, attrs: Vec<ast::Attr>) -> Option<()> { + let (first_attr, last_attr) = (attrs.first()?, attrs.last()?); + let range = first_attr.syntax().text_range().cover(last_attr.syntax().text_range()); + let path_attrs = + attrs.iter().map(|attr| Some((attr.path()?, attr.clone()))).collect::<Option<Vec<_>>>()?; let handle_source_change = |edit: &mut SourceChangeBuilder| { let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(attr.syntax()); - let mut raw_tokens = - vec![NodeOrToken::Token(make.token(T![,])), NodeOrToken::Token(make.whitespace(" "))]; - path.syntax().descendants_with_tokens().for_each(|it| { - if let NodeOrToken::Token(token) = it { - raw_tokens.push(NodeOrToken::Token(token)); - } - }); - if let Some(meta) = attr.meta() { - if let (Some(eq), Some(expr)) = (meta.eq_token(), meta.expr()) { - raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); - raw_tokens.push(NodeOrToken::Token(eq)); - raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); + let mut editor = edit.make_editor(first_attr.syntax()); + let mut raw_tokens = vec![]; + for (path, attr) in path_attrs { + raw_tokens.extend([ + NodeOrToken::Token(make.token(T![,])), + NodeOrToken::Token(make.whitespace(" ")), + ]); + path.syntax().descendants_with_tokens().for_each(|it| { + if let NodeOrToken::Token(token) = it { + raw_tokens.push(NodeOrToken::Token(token)); + } + }); + if let Some(meta) = attr.meta() { + if let (Some(eq), Some(expr)) = (meta.eq_token(), meta.expr()) { + raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); + raw_tokens.push(NodeOrToken::Token(eq)); + raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); - expr.syntax().descendants_with_tokens().for_each(|it| { - if let NodeOrToken::Token(token) = it { - raw_tokens.push(NodeOrToken::Token(token)); - } - }); - } else if let Some(tt) = meta.token_tree() { - raw_tokens.extend(tt.token_trees_and_tokens()); + expr.syntax().descendants_with_tokens().for_each(|it| { + if let NodeOrToken::Token(token) = it { + raw_tokens.push(NodeOrToken::Token(token)); + } + }); + } else if let Some(tt) = meta.token_tree() { + raw_tokens.extend(tt.token_trees_and_tokens()); + } } } let meta = make.meta_token_tree(make.ident_path("cfg_attr"), make.token_tree(T!['('], raw_tokens)); - let cfg_attr = - if attr.excl_token().is_some() { make.attr_inner(meta) } else { make.attr_outer(meta) }; + let cfg_attr = if first_attr.excl_token().is_some() { + make.attr_inner(meta) + } else { + make.attr_outer(meta) + }; - editor.replace(attr.syntax(), cfg_attr.syntax()); + let syntax_range = first_attr.syntax().clone().into()..=last_attr.syntax().clone().into(); + editor.replace_all(syntax_range, vec![cfg_attr.syntax().clone().into()]); if let Some(snippet_cap) = ctx.config.snippet_cap && let Some(first_meta) = @@ -332,7 +359,8 @@ fn unwrap_cfg_attr(acc: &mut Assists, attr: ast::Attr) -> Option<()> { return None; } let handle_source_change = |f: &mut SourceChangeBuilder| { - let inner_attrs = inner_attrs.iter().map(|it| it.to_string()).join("\n"); + let inner_attrs = + inner_attrs.iter().map(|it| it.to_string()).join(&format!("\n{}", attr.indent_level())); f.replace(range, inner_attrs); }; acc.add( @@ -414,6 +442,42 @@ mod tests { } "#, ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[other_attr] + $0#[foo] + #[bar]$0 + #[other_attr] + test: u32, + } + "#, + r#" + pub struct Test { + #[other_attr] + #[cfg_attr($0, foo, bar)] + #[other_attr] + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[cfg_attr(debug_assertions$0, foo, bar)] + test: u32, + } + "#, + r#" + pub struct Test { + #[foo] + #[bar] + test: u32, + } + "#, + ); } #[test] fn to_from_eq_attr() { diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 3040509000..66d5cf834f 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -507,8 +507,8 @@ fn main() { r#####" fn main() { let x = vec![1, 2, 3]; - let mut tmp = x.into_iter(); - while let Some(v) = tmp.next() { + let mut iter = x.into_iter(); + while let Some(v) = iter.next() { let y = v * 2; }; } @@ -2282,7 +2282,7 @@ macro_rules! const_maker { }; } -trait ${0:NewTrait}<const N: usize> { +trait ${0:Create}<const N: usize> { // Used as an associated constant. const CONST_ASSOC: usize = N * 4; @@ -2291,7 +2291,7 @@ trait ${0:NewTrait}<const N: usize> { const_maker! {i32, 7} } -impl<const N: usize> ${0:NewTrait}<N> for Foo<N> { +impl<const N: usize> ${0:Create}<N> for Foo<N> { // Used as an associated constant. const CONST_ASSOC: usize = N * 4; diff --git a/crates/ide-assists/src/utils.rs b/crates/ide-assists/src/utils.rs index 4b8c193057..10057f8681 100644 --- a/crates/ide-assists/src/utils.rs +++ b/crates/ide-assists/src/utils.rs @@ -4,8 +4,7 @@ use std::slice; pub(crate) use gen_trait_fn_body::gen_trait_fn_body; use hir::{ - DisplayTarget, HasAttrs as HirHasAttrs, HirDisplay, InFile, ModuleDef, PathResolution, - Semantics, + HasAttrs as HirHasAttrs, HirDisplay, InFile, ModuleDef, PathResolution, Semantics, db::{ExpandDatabase, HirDatabase}, }; use ide_db::{ @@ -15,6 +14,7 @@ use ide_db::{ path_transform::PathTransform, syntax_helpers::{node_ext::preorder_expr, prettify_macro_expansion}, }; +use itertools::Itertools; use stdx::format_to; use syntax::{ AstNode, AstToken, Direction, NodeOrToken, SourceFile, @@ -25,6 +25,7 @@ use syntax::{ edit::{AstNodeEdit, IndentLevel}, edit_in_place::AttrsOwnerEdit, make, + prec::ExprPrecedence, syntax_factory::SyntaxFactory, }, syntax_editor::{Element, Removable, SyntaxEditor}, @@ -86,17 +87,31 @@ pub fn extract_trivial_expression(block_expr: &ast::BlockExpr) -> Option<ast::Ex None } -pub(crate) fn wrap_block(expr: &ast::Expr) -> ast::BlockExpr { +pub(crate) fn wrap_block(expr: &ast::Expr, make: &SyntaxFactory) -> ast::BlockExpr { if let ast::Expr::BlockExpr(block) = expr && let Some(first) = block.syntax().first_token() && first.kind() == T!['{'] { block.reset_indent() } else { - make::block_expr(None, Some(expr.reset_indent().indent(1.into()))) + make.block_expr(None, Some(expr.reset_indent().indent(1.into()))) } } +pub(crate) fn wrap_paren(expr: ast::Expr, make: &SyntaxFactory, prec: ExprPrecedence) -> ast::Expr { + if expr.precedence().needs_parentheses_in(prec) { make.expr_paren(expr).into() } else { expr } +} + +pub(crate) fn wrap_paren_in_call(expr: ast::Expr, make: &SyntaxFactory) -> ast::Expr { + if needs_parens_in_call(make, &expr) { make.expr_paren(expr).into() } else { expr } +} + +fn needs_parens_in_call(make: &SyntaxFactory, param: &ast::Expr) -> bool { + let call = make.expr_call(make.expr_unit(), make.arg_list(Vec::new())); + let callable = call.expr().expect("invalid make call"); + param.needs_parens_in_place_of(call.syntax(), callable.syntax()) +} + /// This is a method with a heuristics to support test methods annotated with custom test annotations, such as /// `#[test_case(...)]`, `#[tokio::test]` and similar. /// Also a regular `#[test]` annotation is supported. @@ -188,6 +203,9 @@ pub fn filter_assoc_items( /// [`filter_assoc_items()`]), clones each item for update and applies path transformation to it, /// then inserts into `impl_`. Returns the modified `impl_` and the first associated item that got /// inserted. +/// +/// Legacy: prefer [`add_trait_assoc_items_to_impl_with_factory`] when a [`SyntaxFactory`] is +/// available. #[must_use] pub fn add_trait_assoc_items_to_impl( sema: &Semantics<'_, RootDatabase>, @@ -233,15 +251,79 @@ pub fn add_trait_assoc_items_to_impl( .filter_map(|item| match item { ast::AssocItem::Fn(fn_) if fn_.body().is_none() => { let fn_ = fn_.clone_subtree(); - let new_body = &make::block_expr( - None, - Some(match config.expr_fill_default { - ExprFillDefaultMode::Todo => make::ext::expr_todo(), - ExprFillDefaultMode::Underscore => make::ext::expr_underscore(), - ExprFillDefaultMode::Default => make::ext::expr_todo(), - }), - ); - let new_body = AstNodeEdit::indent(new_body, IndentLevel::single()); + let new_body = make::block_expr(None, Some(expr_fill_default(config))); + let mut fn_editor = SyntaxEditor::new(fn_.syntax().clone()); + fn_.replace_or_insert_body(&mut fn_editor, new_body.clone_for_update()); + let new_fn_ = fn_editor.finish().new_root().clone(); + ast::AssocItem::cast(new_fn_) + } + ast::AssocItem::TypeAlias(type_alias) => { + let type_alias = type_alias.clone_subtree(); + if let Some(type_bound_list) = type_alias.type_bound_list() { + let mut type_alias_editor = SyntaxEditor::new(type_alias.syntax().clone()); + type_bound_list.remove(&mut type_alias_editor); + let type_alias = type_alias_editor.finish().new_root().clone(); + ast::AssocItem::cast(type_alias) + } else { + Some(ast::AssocItem::TypeAlias(type_alias)) + } + } + item => Some(item), + }) + .map(|item| AstNodeEdit::indent(&item, new_indent_level)) + .collect() +} + +/// [`SyntaxFactory`]-based variant of [`add_trait_assoc_items_to_impl`]. +#[must_use] +pub fn add_trait_assoc_items_to_impl_with_factory( + make: &SyntaxFactory, + sema: &Semantics<'_, RootDatabase>, + config: &AssistConfig, + original_items: &[InFile<ast::AssocItem>], + trait_: hir::Trait, + impl_: &ast::Impl, + target_scope: &hir::SemanticsScope<'_>, +) -> Vec<ast::AssocItem> { + let new_indent_level = IndentLevel::from_node(impl_.syntax()) + 1; + original_items + .iter() + .map(|InFile { file_id, value: original_item }| { + let mut cloned_item = { + if let Some(macro_file) = file_id.macro_file() { + let span_map = sema.db.expansion_span_map(macro_file); + let item_prettified = prettify_macro_expansion( + sema.db, + original_item.syntax().clone(), + &span_map, + target_scope.krate().into(), + ); + if let Some(formatted) = ast::AssocItem::cast(item_prettified) { + return formatted; + } else { + stdx::never!("formatted `AssocItem` could not be cast back to `AssocItem`"); + } + } + original_item + } + .reset_indent(); + + if let Some(source_scope) = sema.scope(original_item.syntax()) { + let transform = + PathTransform::trait_impl(target_scope, &source_scope, trait_, impl_.clone()); + cloned_item = ast::AssocItem::cast(transform.apply(cloned_item.syntax())).unwrap(); + } + cloned_item.remove_attrs_and_docs(); + cloned_item + }) + .filter_map(|item| match item { + ast::AssocItem::Fn(fn_) if fn_.body().is_none() => { + let fn_ = fn_.clone_subtree(); + let fill_expr: ast::Expr = match config.expr_fill_default { + ExprFillDefaultMode::Todo | ExprFillDefaultMode::Default => make.expr_todo(), + ExprFillDefaultMode::Underscore => make.expr_underscore().into(), + }; + let new_body = make.block_expr(None::<ast::Stmt>, Some(fill_expr)); let mut fn_editor = SyntaxEditor::new(fn_.syntax().clone()); fn_.replace_or_insert_body(&mut fn_editor, new_body); let new_fn_ = fn_editor.finish().new_root().clone(); @@ -275,11 +357,6 @@ pub(crate) fn invert_boolean_expression(make: &SyntaxFactory, expr: ast::Expr) - invert_special_case(make, &expr).unwrap_or_else(|| make.expr_prefix(T![!], expr).into()) } -// FIXME: Migrate usages of this function to the above function and remove this. -pub(crate) fn invert_boolean_expression_legacy(expr: ast::Expr) -> ast::Expr { - invert_special_case_legacy(&expr).unwrap_or_else(|| make::expr_prefix(T![!], expr).into()) -} - fn invert_special_case(make: &SyntaxFactory, expr: &ast::Expr) -> Option<ast::Expr> { match expr { ast::Expr::BinExpr(bin) => { @@ -343,62 +420,11 @@ fn invert_special_case(make: &SyntaxFactory, expr: &ast::Expr) -> Option<ast::Ex } } -fn invert_special_case_legacy(expr: &ast::Expr) -> Option<ast::Expr> { - match expr { - ast::Expr::BinExpr(bin) => { - let bin = bin.clone_subtree(); - let op_token = bin.op_token()?; - let rev_token = match op_token.kind() { - T![==] => T![!=], - T![!=] => T![==], - T![<] => T![>=], - T![<=] => T![>], - T![>] => T![<=], - T![>=] => T![<], - // Parenthesize other expressions before prefixing `!` - _ => { - return Some( - make::expr_prefix(T![!], make::expr_paren(expr.clone()).into()).into(), - ); - } - }; - let mut bin_editor = SyntaxEditor::new(bin.syntax().clone()); - bin_editor.replace(op_token, make::token(rev_token)); - ast::Expr::cast(bin_editor.finish().new_root().clone()) - } - ast::Expr::MethodCallExpr(mce) => { - let receiver = mce.receiver()?; - let method = mce.name_ref()?; - let arg_list = mce.arg_list()?; - - let method = match method.text().as_str() { - "is_some" => "is_none", - "is_none" => "is_some", - "is_ok" => "is_err", - "is_err" => "is_ok", - _ => return None, - }; - Some(make::expr_method_call(receiver, make::name_ref(method), arg_list).into()) - } - ast::Expr::PrefixExpr(pe) if pe.op_kind()? == ast::UnaryOp::Not => match pe.expr()? { - ast::Expr::ParenExpr(parexpr) => parexpr.expr(), - _ => pe.expr(), - }, - ast::Expr::Literal(lit) => match lit.kind() { - ast::LiteralKind::Bool(b) => match b { - true => Some(ast::Expr::Literal(make::expr_literal("false"))), - false => Some(ast::Expr::Literal(make::expr_literal("true"))), - }, - _ => None, - }, - _ => None, - } -} - pub(crate) fn insert_attributes( before: impl Element, edit: &mut SyntaxEditor, attrs: impl IntoIterator<Item = ast::Attr>, + make: &SyntaxFactory, ) { let mut attrs = attrs.into_iter().peekable(); if attrs.peek().is_none() { @@ -410,9 +436,7 @@ pub(crate) fn insert_attributes( edit.insert_all( syntax::syntax_editor::Position::before(elem), attrs - .flat_map(|attr| { - [attr.syntax().clone().into(), make::tokens::whitespace(&whitespace).into()] - }) + .flat_map(|attr| [attr.syntax().clone().into(), make.whitespace(&whitespace).into()]) .collect(), ); } @@ -508,6 +532,15 @@ fn check_pat_variant_nested_or_literal_with_depth( } } +pub(crate) fn expr_fill_default(config: &AssistConfig) -> ast::Expr { + let make = SyntaxFactory::without_mappings(); + match config.expr_fill_default { + ExprFillDefaultMode::Todo => make.expr_todo(), + ExprFillDefaultMode::Underscore => make.expr_underscore().into(), + ExprFillDefaultMode::Default => make.expr_todo(), + } +} + // Uses a syntax-driven approach to find any impl blocks for the struct that // exist within the module/file // @@ -596,29 +629,6 @@ pub(crate) fn generate_impl_text(adt: &ast::Adt, code: &str) -> String { generate_impl_text_inner(adt, None, true, code) } -/// Generates the surrounding `impl <trait> for Type { <code> }` including type -/// and lifetime parameters, with `<trait>` appended to `impl`'s generic parameters' bounds. -/// -/// This is useful for traits like `PartialEq`, since `impl<T> PartialEq for U<T>` often requires `T: PartialEq`. -// FIXME: migrate remaining uses to `generate_trait_impl` -#[allow(dead_code)] -pub(crate) fn generate_trait_impl_text(adt: &ast::Adt, trait_text: &str, code: &str) -> String { - generate_impl_text_inner(adt, Some(trait_text), true, code) -} - -/// Generates the surrounding `impl <trait> for Type { <code> }` including type -/// and lifetime parameters, with `impl`'s generic parameters' bounds kept as-is. -/// -/// This is useful for traits like `From<T>`, since `impl<T> From<T> for U<T>` doesn't require `T: From<T>`. -// FIXME: migrate remaining uses to `generate_trait_impl_intransitive` -pub(crate) fn generate_trait_impl_text_intransitive( - adt: &ast::Adt, - trait_text: &str, - code: &str, -) -> String { - generate_impl_text_inner(adt, Some(trait_text), false, code) -} - fn generate_impl_text_inner( adt: &ast::Adt, trait_text: Option<&str>, @@ -699,10 +709,15 @@ fn generate_impl_text_inner( /// Generates the corresponding `impl Type {}` including type and lifetime /// parameters. pub(crate) fn generate_impl_with_item( + make: &SyntaxFactory, adt: &ast::Adt, body: Option<ast::AssocItemList>, ) -> ast::Impl { - generate_impl_inner(false, adt, None, true, body) + generate_impl_inner_with_factory(make, false, adt, None, true, body) +} + +pub(crate) fn generate_impl_with_factory(make: &SyntaxFactory, adt: &ast::Adt) -> ast::Impl { + generate_impl_inner_with_factory(make, false, adt, None, true, None) } pub(crate) fn generate_impl(adt: &ast::Adt) -> ast::Impl { @@ -713,16 +728,34 @@ pub(crate) fn generate_impl(adt: &ast::Adt) -> ast::Impl { /// and lifetime parameters, with `<trait>` appended to `impl`'s generic parameters' bounds. /// /// This is useful for traits like `PartialEq`, since `impl<T> PartialEq for U<T>` often requires `T: PartialEq`. -pub(crate) fn generate_trait_impl(is_unsafe: bool, adt: &ast::Adt, trait_: ast::Type) -> ast::Impl { - generate_impl_inner(is_unsafe, adt, Some(trait_), true, None) +pub(crate) fn generate_trait_impl( + make: &SyntaxFactory, + is_unsafe: bool, + adt: &ast::Adt, + trait_: ast::Type, +) -> ast::Impl { + generate_impl_inner_with_factory(make, is_unsafe, adt, Some(trait_), true, None) } /// Generates the corresponding `impl <trait> for Type {}` including type /// and lifetime parameters, with `impl`'s generic parameters' bounds kept as-is. /// /// This is useful for traits like `From<T>`, since `impl<T> From<T> for U<T>` doesn't require `T: From<T>`. -pub(crate) fn generate_trait_impl_intransitive(adt: &ast::Adt, trait_: ast::Type) -> ast::Impl { - generate_impl_inner(false, adt, Some(trait_), false, None) +pub(crate) fn generate_trait_impl_intransitive( + make: &SyntaxFactory, + adt: &ast::Adt, + trait_: ast::Type, +) -> ast::Impl { + generate_impl_inner_with_factory(make, false, adt, Some(trait_), false, None) +} + +pub(crate) fn generate_trait_impl_intransitive_with_item( + make: &SyntaxFactory, + adt: &ast::Adt, + trait_: ast::Type, + body: ast::AssocItemList, +) -> ast::Impl { + generate_impl_inner_with_factory(make, false, adt, Some(trait_), false, Some(body)) } fn generate_impl_inner( @@ -766,6 +799,11 @@ fn generate_impl_inner( }); let generic_args = generic_params.as_ref().map(|params| params.to_generic_args().clone_for_update()); + let adt_assoc_bounds = trait_ + .as_ref() + .zip(generic_params.as_ref()) + .and_then(|(trait_, params)| generic_param_associated_bounds(adt, trait_, params)); + let ty = make::ty_path(make::ext::ident_path(&adt.name().unwrap().text())); let cfg_attrs = @@ -781,7 +819,7 @@ fn generate_impl_inner( false, trait_, ty, - None, + adt_assoc_bounds, adt.where_clause(), body, ), @@ -790,6 +828,167 @@ fn generate_impl_inner( .clone_for_update() } +fn generate_impl_inner_with_factory( + make: &SyntaxFactory, + is_unsafe: bool, + adt: &ast::Adt, + trait_: Option<ast::Type>, + trait_is_transitive: bool, + body: Option<ast::AssocItemList>, +) -> ast::Impl { + // Ensure lifetime params are before type & const params + let generic_params = adt.generic_param_list().map(|generic_params| { + let lifetime_params = + generic_params.lifetime_params().map(ast::GenericParam::LifetimeParam); + let ty_or_const_params = generic_params.type_or_const_params().filter_map(|param| { + let param = match param { + ast::TypeOrConstParam::Type(param) => { + // remove defaults since they can't be specified in impls + let mut bounds = + param.type_bound_list().map_or_else(Vec::new, |it| it.bounds().collect()); + if let Some(trait_) = &trait_ { + // Add the current trait to `bounds` if the trait is transitive, + // meaning `impl<T> Trait for U<T>` requires `T: Trait`. + if trait_is_transitive { + bounds.push(make.type_bound(trait_.clone())); + } + }; + // `{ty_param}: {bounds}` + let param = make.type_param(param.name()?, make.type_bound_list(bounds)); + ast::GenericParam::TypeParam(param) + } + ast::TypeOrConstParam::Const(param) => { + // remove defaults since they can't be specified in impls + let param = make.const_param(param.name()?, param.ty()?); + ast::GenericParam::ConstParam(param) + } + }; + Some(param) + }); + + make.generic_param_list(itertools::chain(lifetime_params, ty_or_const_params)) + }); + let generic_args = + generic_params.as_ref().map(|params| params.to_generic_args().clone_for_update()); + let adt_assoc_bounds = + trait_.as_ref().zip(generic_params.as_ref()).and_then(|(trait_, params)| { + generic_param_associated_bounds_with_factory(make, adt, trait_, params) + }); + + let ty: ast::Type = make.ty_path(make.ident_path(&adt.name().unwrap().text())).into(); + + let cfg_attrs = + adt.attrs().filter(|attr| attr.as_simple_call().is_some_and(|(name, _arg)| name == "cfg")); + match trait_ { + Some(trait_) => make.impl_trait( + cfg_attrs, + is_unsafe, + None, + None, + generic_params, + generic_args, + false, + trait_, + ty, + adt_assoc_bounds, + adt.where_clause(), + body, + ), + None => make.impl_(cfg_attrs, generic_params, generic_args, ty, adt.where_clause(), body), + } +} + +fn generic_param_associated_bounds( + adt: &ast::Adt, + trait_: &ast::Type, + generic_params: &ast::GenericParamList, +) -> Option<ast::WhereClause> { + let in_type_params = |name: &ast::NameRef| { + generic_params + .generic_params() + .filter_map(|param| match param { + ast::GenericParam::TypeParam(type_param) => type_param.name(), + _ => None, + }) + .any(|param| param.text() == name.text()) + }; + let adt_body = match adt { + ast::Adt::Enum(e) => e.variant_list().map(|it| it.syntax().clone()), + ast::Adt::Struct(s) => s.field_list().map(|it| it.syntax().clone()), + ast::Adt::Union(u) => u.record_field_list().map(|it| it.syntax().clone()), + }; + let mut trait_where_clause = adt_body + .into_iter() + .flat_map(|it| it.descendants()) + .filter_map(ast::Path::cast) + .filter_map(|path| { + let qualifier = path.qualifier()?.as_single_segment()?; + let qualifier = qualifier + .name_ref() + .or_else(|| match qualifier.type_anchor()?.ty()? { + ast::Type::PathType(path_type) => path_type.path()?.as_single_name_ref(), + _ => None, + }) + .filter(in_type_params)?; + Some((qualifier, path.segment()?.name_ref()?)) + }) + .map(|(qualifier, assoc_name)| { + let segments = [qualifier, assoc_name].map(make::path_segment); + let path = make::path_from_segments(segments, false); + let bounds = Some(make::type_bound(trait_.clone())); + make::where_pred(either::Either::Right(make::ty_path(path)), bounds) + }) + .unique_by(|it| it.syntax().to_string()) + .peekable(); + trait_where_clause.peek().is_some().then(|| make::where_clause(trait_where_clause)) +} + +fn generic_param_associated_bounds_with_factory( + make: &SyntaxFactory, + adt: &ast::Adt, + trait_: &ast::Type, + generic_params: &ast::GenericParamList, +) -> Option<ast::WhereClause> { + let in_type_params = |name: &ast::NameRef| { + generic_params + .generic_params() + .filter_map(|param| match param { + ast::GenericParam::TypeParam(type_param) => type_param.name(), + _ => None, + }) + .any(|param| param.text() == name.text()) + }; + let adt_body = match adt { + ast::Adt::Enum(e) => e.variant_list().map(|it| it.syntax().clone()), + ast::Adt::Struct(s) => s.field_list().map(|it| it.syntax().clone()), + ast::Adt::Union(u) => u.record_field_list().map(|it| it.syntax().clone()), + }; + let mut trait_where_clause = adt_body + .into_iter() + .flat_map(|it| it.descendants()) + .filter_map(ast::Path::cast) + .filter_map(|path| { + let qualifier = path.qualifier()?.as_single_segment()?; + let qualifier = qualifier + .name_ref() + .or_else(|| match qualifier.type_anchor()?.ty()? { + ast::Type::PathType(path_type) => path_type.path()?.as_single_name_ref(), + _ => None, + }) + .filter(in_type_params)?; + Some((qualifier, path.segment()?.name_ref()?)) + }) + .map(|(qualifier, assoc_name)| { + let segments = [qualifier, assoc_name].map(|nr| make.path_segment(nr)); + let path = make.path_from_segments(segments, false); + let bounds = [make.type_bound(trait_.clone())]; + make.where_pred(either::Either::Right(make.ty_path(path).into()), bounds) + }) + .unique_by(|it| it.syntax().to_string()) + .peekable(); + trait_where_clause.peek().is_some().then(|| make.where_clause(trait_where_clause)) +} + pub(crate) fn add_method_to_adt( builder: &mut SourceChangeBuilder, adt: &ast::Adt, @@ -836,13 +1035,12 @@ enum ReferenceConversionType { } impl<'db> ReferenceConversion<'db> { - pub(crate) fn convert_type( - &self, - db: &'db dyn HirDatabase, - display_target: DisplayTarget, - ) -> ast::Type { - let ty = match self.conversion { - ReferenceConversionType::Copy => self.ty.display(db, display_target).to_string(), + fn type_to_string(&self, db: &'db dyn HirDatabase, module: hir::Module) -> String { + match self.conversion { + ReferenceConversionType::Copy => self + .ty + .display_source_code(db, module.into(), true) + .unwrap_or_else(|_| "_".to_owned()), ReferenceConversionType::AsRefStr => "&str".to_owned(), ReferenceConversionType::AsRefSlice => { let type_argument_name = self @@ -850,8 +1048,8 @@ impl<'db> ReferenceConversion<'db> { .type_arguments() .next() .unwrap() - .display(db, display_target) - .to_string(); + .display_source_code(db, module.into(), true) + .unwrap_or_else(|_| "_".to_owned()); format!("&[{type_argument_name}]") } ReferenceConversionType::Dereferenced => { @@ -860,8 +1058,8 @@ impl<'db> ReferenceConversion<'db> { .type_arguments() .next() .unwrap() - .display(db, display_target) - .to_string(); + .display_source_code(db, module.into(), true) + .unwrap_or_else(|_| "_".to_owned()); format!("&{type_argument_name}") } ReferenceConversionType::Option => { @@ -870,37 +1068,56 @@ impl<'db> ReferenceConversion<'db> { .type_arguments() .next() .unwrap() - .display(db, display_target) - .to_string(); + .display_source_code(db, module.into(), true) + .unwrap_or_else(|_| "_".to_owned()); format!("Option<&{type_argument_name}>") } ReferenceConversionType::Result => { let mut type_arguments = self.ty.type_arguments(); - let first_type_argument_name = - type_arguments.next().unwrap().display(db, display_target).to_string(); - let second_type_argument_name = - type_arguments.next().unwrap().display(db, display_target).to_string(); + let first_type_argument_name = type_arguments + .next() + .unwrap() + .display_source_code(db, module.into(), true) + .unwrap_or_else(|_| "_".to_owned()); + let second_type_argument_name = type_arguments + .next() + .unwrap() + .display_source_code(db, module.into(), true) + .unwrap_or_else(|_| "_".to_owned()); format!("Result<&{first_type_argument_name}, &{second_type_argument_name}>") } - }; + } + } + pub(crate) fn convert_type(&self, db: &'db dyn HirDatabase, module: hir::Module) -> ast::Type { + let ty = self.type_to_string(db, module); make::ty(&ty) } - pub(crate) fn getter(&self, field_name: String) -> ast::Expr { - let expr = make::expr_field(make::ext::expr_self(), &field_name); + pub(crate) fn convert_type_with_factory( + &self, + make: &SyntaxFactory, + db: &'db dyn HirDatabase, + module: hir::Module, + ) -> ast::Type { + let ty = self.type_to_string(db, module); + make.ty(&ty) + } + + pub(crate) fn getter(&self, make: &SyntaxFactory, field_name: String) -> ast::Expr { + let expr = make.expr_field(make.expr_self(), &field_name); match self.conversion { - ReferenceConversionType::Copy => expr, + ReferenceConversionType::Copy => expr.into(), ReferenceConversionType::AsRefStr | ReferenceConversionType::AsRefSlice | ReferenceConversionType::Dereferenced | ReferenceConversionType::Option | ReferenceConversionType::Result => { if self.impls_deref { - make::expr_ref(expr, false) + make.expr_ref(expr.into(), false) } else { - make::expr_method_call(expr, make::name_ref("as_ref"), make::arg_list([])) + make.expr_method_call(expr.into(), make.name_ref("as_ref"), make.arg_list([])) .into() } } @@ -1040,18 +1257,21 @@ pub(crate) fn trimmed_text_range(source_file: &SourceFile, initial_range: TextRa /// Convert a list of function params to a list of arguments that can be passed /// into a function call. -pub(crate) fn convert_param_list_to_arg_list(list: ast::ParamList) -> ast::ArgList { +pub(crate) fn convert_param_list_to_arg_list( + list: ast::ParamList, + make: &SyntaxFactory, +) -> ast::ArgList { let mut args = vec![]; for param in list.params() { if let Some(ast::Pat::IdentPat(pat)) = param.pat() && let Some(name) = pat.name() { let name = name.to_string(); - let expr = make::expr_path(make::ext::ident_path(&name)); + let expr = make.expr_path(make.ident_path(&name)); args.push(expr); } } - make::arg_list(args) + make.arg_list(args) } /// Calculate the number of hashes required for a raw string containing `s` @@ -1136,7 +1356,10 @@ pub(crate) fn replace_record_field_expr( /// Creates a token tree list from a syntax node, creating the needed delimited sub token trees. /// Assumes that the input syntax node is a valid syntax tree. -pub(crate) fn tt_from_syntax(node: SyntaxNode) -> Vec<NodeOrToken<ast::TokenTree, SyntaxToken>> { +pub(crate) fn tt_from_syntax( + node: SyntaxNode, + make: &SyntaxFactory, +) -> Vec<NodeOrToken<ast::TokenTree, SyntaxToken>> { let mut tt_stack = vec![(None, vec![])]; for element in node.descendants_with_tokens() { @@ -1164,7 +1387,7 @@ pub(crate) fn tt_from_syntax(node: SyntaxNode) -> Vec<NodeOrToken<ast::TokenTree "mismatched opening and closing delimiters" ); - let sub_tt = make::token_tree(delimiter.expect("unbalanced delimiters"), tt); + let sub_tt = make.token_tree(delimiter.expect("unbalanced delimiters"), tt); parent_tt.push(NodeOrToken::Node(sub_tt)); } _ => { @@ -1199,6 +1422,20 @@ pub(crate) fn cover_let_chain(mut expr: ast::Expr, range: TextRange) -> Option<a } } +pub(crate) fn cover_edit_range( + source: &SyntaxNode, + range: TextRange, +) -> std::ops::RangeInclusive<syntax::SyntaxElement> { + let node = match source.covering_element(range) { + NodeOrToken::Node(node) => node, + NodeOrToken::Token(t) => t.parent().unwrap(), + }; + let mut iter = node.children_with_tokens().filter(|it| range.contains_range(it.text_range())); + let first = iter.next().unwrap_or(node.into()); + let last = iter.last().unwrap_or_else(|| first.clone()); + first..=last +} + pub(crate) fn is_selected( it: &impl AstNode, selection: syntax::TextRange, diff --git a/crates/ide-assists/src/utils/ref_field_expr.rs b/crates/ide-assists/src/utils/ref_field_expr.rs index 840b26a7ad..fc9bf210e4 100644 --- a/crates/ide-assists/src/utils/ref_field_expr.rs +++ b/crates/ide-assists/src/utils/ref_field_expr.rs @@ -5,7 +5,7 @@ //! based on the parent of the existing expression. use syntax::{ AstNode, T, - ast::{self, FieldExpr, MethodCallExpr, make}, + ast::{self, FieldExpr, MethodCallExpr, syntax_factory::SyntaxFactory}, }; use crate::AssistContext; @@ -119,13 +119,29 @@ pub(crate) struct RefData { impl RefData { /// Derefs `expr` and wraps it in parens if necessary - pub(crate) fn wrap_expr(&self, mut expr: ast::Expr) -> ast::Expr { + pub(crate) fn wrap_expr(&self, mut expr: ast::Expr, make: &SyntaxFactory) -> ast::Expr { if self.needs_deref { - expr = make::expr_prefix(T![*], expr).into(); + expr = make.expr_prefix(T![*], expr).into(); } if self.needs_parentheses { - expr = make::expr_paren(expr).into(); + expr = make.expr_paren(expr).into(); + } + + expr + } + + pub(crate) fn wrap_expr_with_factory( + &self, + mut expr: ast::Expr, + syntax_factory: &SyntaxFactory, + ) -> ast::Expr { + if self.needs_deref { + expr = syntax_factory.expr_prefix(T![*], expr).into(); + } + + if self.needs_parentheses { + expr = syntax_factory.expr_paren(expr).into(); } expr diff --git a/crates/ide-completion/src/completions.rs b/crates/ide-completion/src/completions.rs index 355687b203..1fb1fd4e57 100644 --- a/crates/ide-completion/src/completions.rs +++ b/crates/ide-completion/src/completions.rs @@ -26,7 +26,7 @@ pub(crate) mod vis; use std::iter; -use hir::{HasAttrs, Name, ScopeDef, Variant, sym}; +use hir::{EnumVariant, HasAttrs, Name, ScopeDef, sym}; use ide_db::{RootDatabase, SymbolKind, imports::import_assets::LocatedImport}; use syntax::{SmolStr, ToSmolStr, ast}; @@ -426,7 +426,7 @@ impl Completions { &mut self, ctx: &CompletionContext<'_>, path_ctx: &PathCompletionCtx<'_>, - variant: hir::Variant, + variant: hir::EnumVariant, path: hir::ModPath, ) { if !ctx.check_stability_and_hidden(variant) { @@ -443,7 +443,7 @@ impl Completions { &mut self, ctx: &CompletionContext<'_>, path_ctx: &PathCompletionCtx<'_>, - variant: hir::Variant, + variant: hir::EnumVariant, local_name: Option<hir::Name>, ) { if !ctx.check_stability_and_hidden(variant) { @@ -569,7 +569,7 @@ impl Completions { ctx: &CompletionContext<'_>, pattern_ctx: &PatternContext, path_ctx: Option<&PathCompletionCtx<'_>>, - variant: hir::Variant, + variant: hir::EnumVariant, local_name: Option<hir::Name>, ) { if !ctx.check_stability_and_hidden(variant) { @@ -589,7 +589,7 @@ impl Completions { &mut self, ctx: &CompletionContext<'_>, pattern_ctx: &PatternContext, - variant: hir::Variant, + variant: hir::EnumVariant, path: hir::ModPath, ) { if !ctx.check_stability_and_hidden(variant) { @@ -644,9 +644,9 @@ fn enum_variants_with_paths( ctx: &CompletionContext<'_>, enum_: hir::Enum, impl_: Option<&ast::Impl>, - cb: impl Fn(&mut Completions, &CompletionContext<'_>, hir::Variant, hir::ModPath), + cb: impl Fn(&mut Completions, &CompletionContext<'_>, hir::EnumVariant, hir::ModPath), ) { - let mut process_variant = |variant: Variant| { + let mut process_variant = |variant: EnumVariant| { let self_path = hir::ModPath::from_segments( hir::PathKind::Plain, iter::once(Name::new_symbol_root(sym::Self_)).chain(iter::once(variant.name(ctx.db))), diff --git a/crates/ide-completion/src/completions/fn_param.rs b/crates/ide-completion/src/completions/fn_param.rs index 34d25c9c67..bd0b69215c 100644 --- a/crates/ide-completion/src/completions/fn_param.rs +++ b/crates/ide-completion/src/completions/fn_param.rs @@ -4,9 +4,9 @@ use hir::HirDisplay; use ide_db::FxHashMap; use itertools::Either; use syntax::{ - AstNode, Direction, SyntaxKind, TextRange, TextSize, algo, + AstNode, Direction, SmolStr, SyntaxKind, TextRange, TextSize, ToSmolStr, algo, ast::{self, HasModuleItem}, - match_ast, + format_smolstr, match_ast, }; use crate::{ @@ -25,19 +25,32 @@ pub(crate) fn complete_fn_param( ctx: &CompletionContext<'_>, pattern_ctx: &PatternContext, ) -> Option<()> { - let (ParamContext { param_list, kind, .. }, impl_or_trait) = match pattern_ctx { + let (ParamContext { param_list, kind, param, .. }, impl_or_trait) = match pattern_ctx { PatternContext { param_ctx: Some(kind), impl_or_trait, .. } => (kind, impl_or_trait), _ => return None, }; + let qualifier = param_qualifier(param); let comma_wrapper = comma_wrapper(ctx); let mut add_new_item_to_acc = |label: &str| { - let mk_item = |label: &str, range: TextRange| { - CompletionItem::new(CompletionItemKind::Binding, range, label, ctx.edition) + let label = label.strip_prefix(qualifier.as_str()).unwrap_or(label); + let insert = if label.starts_with('#') { + // FIXME: `#[attr] it: i32` -> `#[attr] mut it: i32` + label.to_smolstr() + } else { + format_smolstr!("{qualifier}{label}") + }; + let mk_item = |insert_text: &str, range: TextRange| { + let mut item = + CompletionItem::new(CompletionItemKind::Binding, range, label, ctx.edition); + if insert_text != label { + item.insert_text(insert_text); + } + item }; let item = match &comma_wrapper { - Some((fmt, range)) => mk_item(&fmt(label), *range), - None => mk_item(label, ctx.source_range()), + Some((fmt, range)) => mk_item(&fmt(&insert), *range), + None => mk_item(&insert, ctx.source_range()), }; // Completion lookup is omitted intentionally here. // See the full discussion: https://github.com/rust-lang/rust-analyzer/issues/12073 @@ -46,13 +59,18 @@ pub(crate) fn complete_fn_param( match kind { ParamKind::Function(function) => { - fill_fn_params(ctx, function, param_list, impl_or_trait, add_new_item_to_acc); + fill_fn_params(ctx, function, param_list, param, impl_or_trait, add_new_item_to_acc); } ParamKind::Closure(closure) => { - let stmt_list = closure.syntax().ancestors().find_map(ast::StmtList::cast)?; - params_from_stmt_list_scope(ctx, stmt_list, |name, ty| { - add_new_item_to_acc(&format!("{}: {ty}", name.display(ctx.db, ctx.edition))); - }); + if is_simple_param(param) { + let stmt_list = closure.syntax().ancestors().find_map(ast::StmtList::cast)?; + params_from_stmt_list_scope(ctx, stmt_list, |name, ty| { + add_new_item_to_acc(&format_smolstr!( + "{}: {ty}", + name.display(ctx.db, ctx.edition) + )); + }); + } } } @@ -63,6 +81,7 @@ fn fill_fn_params( ctx: &CompletionContext<'_>, function: &ast::Fn, param_list: &ast::ParamList, + current_param: &ast::Param, impl_or_trait: &Option<Either<ast::Impl, ast::Trait>>, mut add_new_item_to_acc: impl FnMut(&str), ) { @@ -71,15 +90,17 @@ fn fill_fn_params( let mut extract_params = |f: ast::Fn| { f.param_list().into_iter().flat_map(|it| it.params()).for_each(|param| { if let Some(pat) = param.pat() { - // FIXME: We should be able to turn these into SmolStr without having to allocate a String - let whole_param = param.syntax().text().to_string(); - let binding = pat.syntax().text().to_string(); + let whole_param = param.to_smolstr(); + let binding = pat.to_smolstr(); file_params.entry(whole_param).or_insert(binding); } }); }; for node in ctx.token.parent_ancestors() { + if !is_simple_param(current_param) { + break; + } match_ast! { match node { ast::SourceFile(it) => it.items().filter_map(|item| match item { @@ -99,11 +120,13 @@ fn fill_fn_params( }; } - if let Some(stmt_list) = function.syntax().parent().and_then(ast::StmtList::cast) { + if let Some(stmt_list) = function.syntax().parent().and_then(ast::StmtList::cast) + && is_simple_param(current_param) + { params_from_stmt_list_scope(ctx, stmt_list, |name, ty| { file_params - .entry(format!("{}: {ty}", name.display(ctx.db, ctx.edition))) - .or_insert(name.display(ctx.db, ctx.edition).to_string()); + .entry(format_smolstr!("{}: {ty}", name.display(ctx.db, ctx.edition))) + .or_insert(name.display(ctx.db, ctx.edition).to_smolstr()); }); } remove_duplicated(&mut file_params, param_list.params()); @@ -139,11 +162,11 @@ fn params_from_stmt_list_scope( } fn remove_duplicated( - file_params: &mut FxHashMap<String, String>, + file_params: &mut FxHashMap<SmolStr, SmolStr>, fn_params: ast::AstChildren<ast::Param>, ) { fn_params.for_each(|param| { - let whole_param = param.syntax().text().to_string(); + let whole_param = param.to_smolstr(); file_params.remove(&whole_param); match param.pat() { @@ -151,7 +174,7 @@ fn remove_duplicated( // if the type is missing we are checking the current param to be completed // in which case this would find itself removing the suggestions due to itself Some(pattern) if param.ty().is_some() => { - let binding = pattern.syntax().text().to_string(); + let binding = pattern.to_smolstr(); file_params.retain(|_, v| v != &binding); } _ => (), @@ -173,7 +196,7 @@ fn should_add_self_completions( } } -fn comma_wrapper(ctx: &CompletionContext<'_>) -> Option<(impl Fn(&str) -> String, TextRange)> { +fn comma_wrapper(ctx: &CompletionContext<'_>) -> Option<(impl Fn(&str) -> SmolStr, TextRange)> { let param = ctx.original_token.parent_ancestors().find(|node| node.kind() == SyntaxKind::PARAM)?; @@ -196,5 +219,24 @@ fn comma_wrapper(ctx: &CompletionContext<'_>) -> Option<(impl Fn(&str) -> String matches!(prev_token_kind, SyntaxKind::COMMA | SyntaxKind::L_PAREN | SyntaxKind::PIPE); let leading = if has_leading_comma { "" } else { ", " }; - Some((move |label: &_| format!("{leading}{label}{trailing}"), param.text_range())) + Some((move |label: &_| format_smolstr!("{leading}{label}{trailing}"), param.text_range())) +} + +fn is_simple_param(param: &ast::Param) -> bool { + param + .pat() + .is_none_or(|pat| matches!(pat, ast::Pat::IdentPat(ident_pat) if ident_pat.pat().is_none())) +} + +fn param_qualifier(param: &ast::Param) -> SmolStr { + let mut b = syntax::SmolStrBuilder::new(); + if let Some(ast::Pat::IdentPat(pat)) = param.pat() { + if pat.ref_token().is_some() { + b.push_str("ref "); + } + if pat.mut_token().is_some() { + b.push_str("mut "); + } + } + b.finish() } diff --git a/crates/ide-completion/src/completions/pattern.rs b/crates/ide-completion/src/completions/pattern.rs index eeb2c65e48..e7597bf95c 100644 --- a/crates/ide-completion/src/completions/pattern.rs +++ b/crates/ide-completion/src/completions/pattern.rs @@ -91,11 +91,11 @@ pub(crate) fn complete_pattern( acc.add_struct_pat(ctx, pattern_ctx, strukt, Some(name.clone())); true } - hir::ModuleDef::Variant(variant) + hir::ModuleDef::EnumVariant(variant) if refutable || single_variant_enum(variant.parent_enum(ctx.db)) => { acc.add_variant_pat(ctx, pattern_ctx, None, variant, Some(name.clone())); - true + false } hir::ModuleDef::Adt(hir::Adt::Enum(e)) => refutable || single_variant_enum(e), hir::ModuleDef::Const(..) => refutable, @@ -190,7 +190,7 @@ pub(crate) fn complete_pattern_path( let add_completion = match res { ScopeDef::ModuleDef(hir::ModuleDef::Macro(mac)) => mac.is_fn_like(ctx.db), ScopeDef::ModuleDef(hir::ModuleDef::Adt(_)) => true, - ScopeDef::ModuleDef(hir::ModuleDef::Variant(_)) => true, + ScopeDef::ModuleDef(hir::ModuleDef::EnumVariant(_)) => true, ScopeDef::ModuleDef(hir::ModuleDef::Module(_)) => true, ScopeDef::ImplSelfType(_) => true, _ => false, diff --git a/crates/ide-completion/src/completions/postfix.rs b/crates/ide-completion/src/completions/postfix.rs index 7f67ef848e..5b91e7c456 100644 --- a/crates/ide-completion/src/completions/postfix.rs +++ b/crates/ide-completion/src/completions/postfix.rs @@ -16,7 +16,7 @@ use itertools::Itertools; use stdx::never; use syntax::{ SmolStr, - SyntaxKind::{EXPR_STMT, STMT_LIST}, + SyntaxKind::{CLOSURE_EXPR, EXPR_STMT, MATCH_ARM, STMT_LIST}, T, TextRange, TextSize, ToSmolStr, ast::{self, AstNode, AstToken}, format_smolstr, match_ast, @@ -52,6 +52,7 @@ pub(crate) fn complete_postfix( _ => return, }; let expr_ctx = &dot_access.ctx; + let receiver_accessor = receiver_accessor(dot_receiver); let receiver_text = get_receiver_text(&ctx.sema, dot_receiver, receiver_is_ambiguous_float_literal); @@ -65,6 +66,12 @@ pub(crate) fn complete_postfix( Some(it) => it, None => return, }; + let semi = + if expr_ctx.in_block_expr && ctx.token.next_token().is_none_or(|it| it.kind() != T![;]) { + ";" + } else { + "" + }; let cfg = ctx.config.find_path_config(ctx.is_nightly); @@ -90,9 +97,8 @@ pub(crate) fn complete_postfix( // The rest of the postfix completions create an expression that moves an argument, // so it's better to consider references now to avoid breaking the compilation - let (dot_receiver_including_refs, prefix) = include_references(dot_receiver); - let mut receiver_text = - get_receiver_text(&ctx.sema, dot_receiver, receiver_is_ambiguous_float_literal); + let (dot_receiver_including_refs, prefix) = include_references(&receiver_accessor); + let mut receiver_text = receiver_text; receiver_text.insert_str(0, &prefix); let postfix_snippet = match build_postfix_snippet_builder(ctx, cap, &dot_receiver_including_refs) { @@ -111,14 +117,9 @@ pub(crate) fn complete_postfix( postfix_snippet("call", "function(expr)", &format!("${{1}}({receiver_text})")) .add_to(acc, ctx.db); - let try_enum = TryEnum::from_ty(&ctx.sema, &receiver_ty.strip_references()); - let mut is_in_cond = false; - if let Some(parent) = dot_receiver_including_refs.syntax().parent() - && let Some(second_ancestor) = parent.parent() - { - if let Some(parent_expr) = ast::Expr::cast(parent) { - is_in_cond = is_in_condition(&parent_expr); - } + let try_enum = TryEnum::from_ty(&ctx.sema, receiver_ty); + let is_in_cond = is_in_condition(&dot_receiver_including_refs); + if let Some(parent) = dot_receiver_including_refs.syntax().parent() { let placeholder = suggest_receiver_name(dot_receiver, "0", &ctx.sema); match &try_enum { Some(try_enum) if is_in_cond => match try_enum { @@ -151,12 +152,30 @@ pub(crate) fn complete_postfix( .add_to(acc, ctx.db); } }, - _ if matches!(second_ancestor.kind(), STMT_LIST | EXPR_STMT) => { - postfix_snippet("let", "let", &format!("let $0 = {receiver_text};")) + _ if is_in_cond => { + postfix_snippet("let", "let", &format!("let $1 = {receiver_text}")) .add_to(acc, ctx.db); - postfix_snippet("letm", "let mut", &format!("let mut $0 = {receiver_text};")) + } + _ if matches!(parent.kind(), STMT_LIST | EXPR_STMT) => { + postfix_snippet("let", "let", &format!("let $0 = {receiver_text}{semi}")) + .add_to(acc, ctx.db); + postfix_snippet("letm", "let mut", &format!("let mut $0 = {receiver_text}{semi}")) .add_to(acc, ctx.db); } + _ if matches!(parent.kind(), MATCH_ARM | CLOSURE_EXPR) => { + postfix_snippet( + "let", + "let", + &format!("{{\n let $1 = {receiver_text};\n $0\n}}"), + ) + .add_to(acc, ctx.db); + postfix_snippet( + "letm", + "let mut", + &format!("{{\n let mut $1 = {receiver_text};\n $0\n}}"), + ) + .add_to(acc, ctx.db); + } _ => (), } } @@ -253,7 +272,6 @@ pub(crate) fn complete_postfix( &format!("while {receiver_text} {{\n $0\n}}"), ) .add_to(acc, ctx.db); - postfix_snippet("not", "!expr", &format!("!{receiver_text}")).add_to(acc, ctx.db); } else if let Some(trait_) = ctx.famous_defs().core_iter_IntoIterator() && receiver_ty.impls_trait(ctx.db, trait_, &[]) { @@ -266,6 +284,10 @@ pub(crate) fn complete_postfix( } } + if receiver_ty.is_bool() || receiver_ty.is_unknown() { + postfix_snippet("not", "!expr", &format!("!{receiver_text}")).add_to(acc, ctx.db); + } + let block_should_be_wrapped = if let ast::Expr::BlockExpr(block) = dot_receiver { block.modifier().is_some() || !block.is_standalone() } else { @@ -285,32 +307,18 @@ pub(crate) fn complete_postfix( postfix_snippet("const", "const {}", &const_completion_string).add_to(acc, ctx.db); } - if let ast::Expr::Literal(literal) = dot_receiver_including_refs.clone() + if let ast::Expr::Literal(literal) = dot_receiver.clone() && let Some(literal_text) = ast::String::cast(literal.token()) { add_format_like_completions(acc, ctx, &dot_receiver_including_refs, cap, &literal_text); } - postfix_snippet( - "return", - "return expr", - &format!( - "return {receiver_text}{semi}", - semi = if expr_ctx.in_block_expr { ";" } else { "" } - ), - ) - .add_to(acc, ctx.db); + postfix_snippet("return", "return expr", &format!("return {receiver_text}{semi}")) + .add_to(acc, ctx.db); if let Some(BreakableKind::Block | BreakableKind::Loop) = expr_ctx.in_breakable { - postfix_snippet( - "break", - "break expr", - &format!( - "break {receiver_text}{semi}", - semi = if expr_ctx.in_block_expr { ";" } else { "" } - ), - ) - .add_to(acc, ctx.db); + postfix_snippet("break", "break expr", &format!("break {receiver_text}{semi}")) + .add_to(acc, ctx.db); } } @@ -355,12 +363,20 @@ fn get_receiver_text( range.range = TextRange::at(range.range.start(), range.range.len() - TextSize::of('.')) } let file_text = sema.db.file_text(range.file_id.file_id(sema.db)); - let mut text = file_text.text(sema.db)[range.range].to_owned(); + let text = file_text.text(sema.db); + let indent_spaces = indent_of_tail_line(&text[TextRange::up_to(range.range.start())]); + let mut text = stdx::dedent_by(indent_spaces, &text[range.range]); // The receiver texts should be interpreted as-is, as they are expected to be // normal Rust expressions. escape_snippet_bits(&mut text); - text + return text; + + fn indent_of_tail_line(text: &str) -> usize { + let tail_line = text.rsplit_once('\n').map_or(text, |(_, s)| s); + let trimmed = tail_line.trim_start_matches(' '); + tail_line.len() - trimmed.len() + } } /// Escapes `\` and `$` so that they don't get interpreted as snippet-specific constructs. @@ -372,25 +388,34 @@ fn escape_snippet_bits(text: &mut String) { stdx::replace(text, '$', "\\$"); } +fn receiver_accessor(receiver: &ast::Expr) -> ast::Expr { + receiver + .syntax() + .parent() + .and_then(ast::Expr::cast) + .filter(|it| { + matches!( + it, + ast::Expr::FieldExpr(_) | ast::Expr::MethodCallExpr(_) | ast::Expr::CallExpr(_) + ) + }) + .unwrap_or_else(|| receiver.clone()) +} + +/// Given an `initial_element`, tries to expand it to include deref(s), and then references. +/// Returns the expanded expressions, and the added prefix as a string +/// +/// For example, if called with the `42` in `&&mut *42`, would return `(&&mut *42, "&&mut *")`. fn include_references(initial_element: &ast::Expr) -> (ast::Expr, String) { let mut resulting_element = initial_element.clone(); - - while let Some(field_expr) = resulting_element.syntax().parent().and_then(ast::FieldExpr::cast) - { - resulting_element = ast::Expr::from(field_expr); - } - let mut prefix = String::new(); let mut found_ref_or_deref = false; while let Some(parent_deref_element) = resulting_element.syntax().parent().and_then(ast::PrefixExpr::cast) + && parent_deref_element.op_kind() == Some(ast::UnaryOp::Deref) { - if parent_deref_element.op_kind() != Some(ast::UnaryOp::Deref) { - break; - } - found_ref_or_deref = true; resulting_element = ast::Expr::from(parent_deref_element); @@ -586,6 +611,31 @@ fn main() { } #[test] + fn postfix_completion_works_in_if_condition() { + check( + r#" +fn foo(cond: bool) { + if cond.$0 +} +"#, + expect![[r#" + sn box Box::new(expr) + sn call function(expr) + sn const const {} + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn let let + sn not !expr + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + "#]], + ); + } + + #[test] fn postfix_type_filtering() { check( r#" @@ -614,6 +664,22 @@ fn main() { #[test] fn let_middle_block() { + check_edit( + "let", + r#" +fn main() { + baz.l$0 + res +} +"#, + r#" +fn main() { + let $0 = baz; + res +} +"#, + ); + check( r#" fn main() { @@ -640,6 +706,118 @@ fn main() { sn while while expr {} "#]], ); + check( + r#" +fn main() { + &baz.l$0 + res +} +"#, + expect![[r#" + sn box Box::new(expr) + sn call function(expr) + sn const const {} + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn if if expr {} + sn let let + sn letm let mut + sn match match expr {} + sn not !expr + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + sn while while expr {} + "#]], + ); + } + + #[test] + fn let_tail_block() { + check_edit( + "let", + r#" +fn main() { + baz.l$0 +} +"#, + r#" +fn main() { + let $0 = baz; +} +"#, + ); + + check( + r#" +fn main() { + baz.l$0 +} +"#, + expect![[r#" + sn box Box::new(expr) + sn call function(expr) + sn const const {} + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn if if expr {} + sn let let + sn letm let mut + sn match match expr {} + sn not !expr + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + sn while while expr {} + "#]], + ); + + check( + r#" +fn main() { + &baz.l$0 +} +"#, + expect![[r#" + sn box Box::new(expr) + sn call function(expr) + sn const const {} + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn if if expr {} + sn let let + sn letm let mut + sn match match expr {} + sn not !expr + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + sn while while expr {} + "#]], + ); + } + + #[test] + fn let_before_semicolon() { + check_edit( + "let", + r#" +fn main() { + baz.l$0; +} +"#, + r#" +fn main() { + let $0 = baz; +} +"#, + ); } #[test] @@ -745,6 +923,119 @@ fn main() { } #[test] + fn iflet_fallback_cond() { + check_edit( + "let", + r#" +fn main() { + let bar = 2; + if bar.$0 +} +"#, + r#" +fn main() { + let bar = 2; + if let $1 = bar +} +"#, + ); + } + + #[test] + fn match_arm_let_block() { + check( + r#" +fn main() { + match 2 { + bar => bar.$0 + } +} +"#, + expect![[r#" + sn box Box::new(expr) + sn call function(expr) + sn const const {} + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn let let + sn letm let mut + sn match match expr {} + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + "#]], + ); + check( + r#" +fn main() { + match 2 { + bar => &bar.l$0 + } +} +"#, + expect![[r#" + sn box Box::new(expr) + sn call function(expr) + sn const const {} + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn let let + sn letm let mut + sn match match expr {} + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + "#]], + ); + check_edit( + "let", + r#" +fn main() { + match 2 { + bar => bar.$0 + } +} +"#, + r#" +fn main() { + match 2 { + bar => { + let $1 = bar; + $0 +} + } +} +"#, + ); + } + + #[test] + fn closure_let_block() { + check_edit( + "let", + r#" +fn main() { + let bar = 2; + let f = || bar.$0; +} +"#, + r#" +fn main() { + let bar = 2; + let f = || { + let $1 = bar; + $0 +}; +} +"#, + ); + } + + #[test] fn option_letelse() { check_edit( "lete", @@ -819,6 +1110,7 @@ fn main() { #[test] fn postfix_completion_for_references() { check_edit("dbg", r#"fn main() { &&42.$0 }"#, r#"fn main() { dbg!(&&42) }"#); + check_edit("dbg", r#"fn main() { &&*"hello".$0 }"#, r#"fn main() { dbg!(&&*"hello") }"#); check_edit("refm", r#"fn main() { &&42.$0 }"#, r#"fn main() { &&&mut 42 }"#); check_edit( "ifl", @@ -977,9 +1269,9 @@ use core::ops::ControlFlow; fn main() { ControlFlow::Break(match true { - true => "\${1:placeholder}", - false => "\\\$", - }) + true => "\${1:placeholder}", + false => "\\\$", +}) } "#, ); @@ -1219,4 +1511,31 @@ fn foo() { "#, ); } + + #[test] + fn snippet_dedent() { + check_edit( + "let", + r#" +//- minicore: option +fn foo(x: Option<i32>, y: Option<i32>) { + let _f = || { + x + .and(y) + .map(|it| it+2) + .$0 + }; +} +"#, + r#" +fn foo(x: Option<i32>, y: Option<i32>) { + let _f = || { + let $0 = x + .and(y) + .map(|it| it+2); + }; +} +"#, + ); + } } diff --git a/crates/ide-completion/src/completions/type.rs b/crates/ide-completion/src/completions/type.rs index abcf9fca6f..8ff9c3258e 100644 --- a/crates/ide-completion/src/completions/type.rs +++ b/crates/ide-completion/src/completions/type.rs @@ -23,7 +23,9 @@ pub(crate) fn complete_type_path( ScopeDef::GenericParam(LifetimeParam(_)) => location.complete_lifetimes(), ScopeDef::Label(_) => false, // no values in type places - ScopeDef::ModuleDef(Function(_) | Variant(_) | Static(_)) | ScopeDef::Local(_) => false, + ScopeDef::ModuleDef(Function(_) | EnumVariant(_) | Static(_)) | ScopeDef::Local(_) => { + false + } // unless its a constant in a generic arg list position ScopeDef::ModuleDef(Const(_)) | ScopeDef::GenericParam(ConstParam(_)) => { location.complete_consts() diff --git a/crates/ide-completion/src/context.rs b/crates/ide-completion/src/context.rs index cab8bced88..4fd0348156 100644 --- a/crates/ide-completion/src/context.rs +++ b/crates/ide-completion/src/context.rs @@ -288,7 +288,7 @@ pub(crate) struct PatternContext { pub(crate) record_pat: Option<ast::RecordPat>, pub(crate) impl_or_trait: Option<Either<ast::Impl, ast::Trait>>, /// List of missing variants in a match expr - pub(crate) missing_variants: Vec<hir::Variant>, + pub(crate) missing_variants: Vec<hir::EnumVariant>, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -527,7 +527,7 @@ impl CompletionContext<'_> { hir::ModuleDef::Module(it) => self.is_visible(it), hir::ModuleDef::Function(it) => self.is_visible(it), hir::ModuleDef::Adt(it) => self.is_visible(it), - hir::ModuleDef::Variant(it) => self.is_visible(it), + hir::ModuleDef::EnumVariant(it) => self.is_visible(it), hir::ModuleDef::Const(it) => self.is_visible(it), hir::ModuleDef::Static(it) => self.is_visible(it), hir::ModuleDef::Trait(it) => self.is_visible(it), @@ -821,7 +821,10 @@ impl<'db> CompletionContext<'db> { CompleteSemicolon::DoNotComplete } else if let Some(term_node) = sema.token_ancestors_with_macros(token.clone()).find(|node| { - matches!(node.kind(), BLOCK_EXPR | MATCH_ARM | CLOSURE_EXPR | ARG_LIST | PAREN_EXPR) + matches!( + node.kind(), + BLOCK_EXPR | MATCH_ARM | CLOSURE_EXPR | ARG_LIST | PAREN_EXPR | ARRAY_EXPR + ) }) { let next_token = iter::successors(token.next_token(), |it| it.next_token()) diff --git a/crates/ide-completion/src/context/analysis.rs b/crates/ide-completion/src/context/analysis.rs index 1c8bc656ca..a3494b964f 100644 --- a/crates/ide-completion/src/context/analysis.rs +++ b/crates/ide-completion/src/context/analysis.rs @@ -1,7 +1,7 @@ //! Module responsible for analyzing the code surrounding the cursor for completion. use std::iter; -use hir::{ExpandResult, InFile, Semantics, Type, TypeInfo, Variant}; +use hir::{EnumVariant, ExpandResult, InFile, Semantics, Type, TypeInfo}; use ide_db::{ RootDatabase, active_parameter::ActiveParameter, syntax_helpers::node_ext::find_loops, }; @@ -778,6 +778,16 @@ fn expected_type_and_name<'db>( let ty = sema.type_of_pat(&ast::Pat::from(it)).map(TypeInfo::original); (ty, None) }, + ast::TupleStructPat(it) => { + let fields = it.path().and_then(|path| match sema.resolve_path(&path)? { + hir::PathResolution::Def(hir::ModuleDef::Adt(adt)) => Some(adt.as_struct()?.fields(sema.db)), + hir::PathResolution::Def(hir::ModuleDef::EnumVariant(variant)) => Some(variant.fields(sema.db)), + _ => None, + }); + let nr = it.fields().take_while(|it| it.syntax().text_range().end() <= token.text_range().start()).count(); + let ty = fields.and_then(|fields| Some(fields.get(nr)?.ty(sema.db).to_type(sema.db))); + (ty, None) + }, ast::Fn(it) => { cov_mark::hit!(expected_type_fn_ret_with_leading_char); cov_mark::hit!(expected_type_fn_ret_without_leading_char); @@ -944,10 +954,10 @@ fn classify_name_ref<'db>( let field_expr_handle = |receiver, node| { let receiver = find_opt_node_in_file(original_file, receiver); let receiver_is_ambiguous_float_literal = match &receiver { - Some(ast::Expr::Literal(l)) => matches! { - l.kind(), - ast::LiteralKind::FloatNumber { .. } if l.syntax().last_token().is_some_and(|it| it.text().ends_with('.')) - }, + Some(ast::Expr::Literal(l)) => { + matches!(l.kind(), ast::LiteralKind::FloatNumber { .. }) + && l.syntax().last_token().is_some_and(|it| it.text().ends_with('.')) + } _ => false, }; @@ -1139,7 +1149,7 @@ fn classify_name_ref<'db>( hir::ModuleDef::Adt(adt) => { sema.source(adt)?.value.generic_param_list() } - hir::ModuleDef::Variant(variant) => { + hir::ModuleDef::EnumVariant(variant) => { sema.source(variant.parent_enum(sema.db))?.value.generic_param_list() } hir::ModuleDef::Trait(trait_) => { @@ -1501,7 +1511,7 @@ fn classify_name_ref<'db>( | SyntaxKind::RECORD_FIELD ) }) - .and_then(|_| nameref.as_ref()?.syntax().ancestors().find_map(ast::Adt::cast)) + .and_then(|_| find_node_at_offset::<ast::Adt>(original_file, original_offset)) .and_then(|adt| sema.derive_helpers_in_scope(&adt)) .unwrap_or_default(); Some(PathKind::Attr { attr_ctx: AttrCtx { kind, annotated_item_kind, derive_helpers } }) @@ -1815,7 +1825,7 @@ fn pattern_context_for( }); (!variant_already_present).then_some(*variant) - }).collect::<Vec<Variant>>()) + }).collect::<Vec<EnumVariant>>()) }); if let Some(missing_variants_) = missing_variants_opt { diff --git a/crates/ide-completion/src/context/tests.rs b/crates/ide-completion/src/context/tests.rs index e97d9720e3..94d904932a 100644 --- a/crates/ide-completion/src/context/tests.rs +++ b/crates/ide-completion/src/context/tests.rs @@ -288,6 +288,50 @@ fn foo() -> Foo { } #[test] +fn expected_type_tuple_struct_pat() { + check_expected_type_and_name( + r#" +//- minicore: option +struct Foo(Option<i32>); +fn foo(x: Foo) -> Foo { + match x { Foo($0) => () } +} +"#, + expect![[r#"ty: Option<i32>, name: ?"#]], + ); + + check_expected_type_and_name( + r#" +struct Foo(i32, u32, f32); +fn foo(x: Foo) -> Foo { + match x { Foo($0) => () } +} +"#, + expect![[r#"ty: i32, name: ?"#]], + ); + + check_expected_type_and_name( + r#" +struct Foo(i32, u32, f32); +fn foo(x: Foo) -> Foo { + match x { Foo(num,$0) => () } +} +"#, + expect![[r#"ty: u32, name: ?"#]], + ); + + check_expected_type_and_name( + r#" +struct Foo(i32, u32, f32); +fn foo(x: Foo) -> Foo { + match x { Foo(num,$0,float) => () } +} +"#, + expect![[r#"ty: u32, name: ?"#]], + ); +} + +#[test] fn expected_type_if_let_without_leading_char() { cov_mark::check!(expected_type_if_let_without_leading_char); check_expected_type_and_name( diff --git a/crates/ide-completion/src/render.rs b/crates/ide-completion/src/render.rs index 765304d818..d77e793295 100644 --- a/crates/ide-completion/src/render.rs +++ b/crates/ide-completion/src/render.rs @@ -408,7 +408,7 @@ fn render_resolution_path( let ctx = ctx.import_to_add(import_to_add); return render_fn(ctx, path_ctx, Some(local_name), func); } - ScopeDef::ModuleDef(Variant(var)) => { + ScopeDef::ModuleDef(EnumVariant(var)) => { let ctx = ctx.clone().import_to_add(import_to_add.clone()); if let Some(item) = render_variant_lit(ctx, path_ctx, Some(local_name.clone()), var, None) @@ -476,7 +476,7 @@ fn render_resolution_path( } // Filtered out above ScopeDef::ModuleDef( - ModuleDef::Function(_) | ModuleDef::Variant(_) | ModuleDef::Macro(_), + ModuleDef::Function(_) | ModuleDef::EnumVariant(_) | ModuleDef::Macro(_), ) => (), ScopeDef::ModuleDef(ModuleDef::Const(konst)) => set_item_relevance(konst.ty(db)), ScopeDef::ModuleDef(ModuleDef::Static(stat)) => set_item_relevance(stat.ty(db)), @@ -528,7 +528,7 @@ fn res_to_kind(resolution: ScopeDef) -> CompletionItemKind { match resolution { ScopeDef::Unknown => CompletionItemKind::UnresolvedReference, ScopeDef::ModuleDef(Function(_)) => CompletionItemKind::SymbolKind(SymbolKind::Function), - ScopeDef::ModuleDef(Variant(_)) => CompletionItemKind::SymbolKind(SymbolKind::Variant), + ScopeDef::ModuleDef(EnumVariant(_)) => CompletionItemKind::SymbolKind(SymbolKind::Variant), ScopeDef::ModuleDef(Macro(_)) => CompletionItemKind::SymbolKind(SymbolKind::Macro), ScopeDef::ModuleDef(Module(..)) => CompletionItemKind::SymbolKind(SymbolKind::Module), ScopeDef::ModuleDef(Adt(adt)) => CompletionItemKind::SymbolKind(match adt { @@ -559,7 +559,7 @@ fn scope_def_docs(db: &RootDatabase, resolution: ScopeDef) -> Option<Documentati match resolution { ScopeDef::ModuleDef(Module(it)) => it.docs(db), ScopeDef::ModuleDef(Adt(it)) => it.docs(db), - ScopeDef::ModuleDef(Variant(it)) => it.docs(db), + ScopeDef::ModuleDef(EnumVariant(it)) => it.docs(db), ScopeDef::ModuleDef(Const(it)) => it.docs(db), ScopeDef::ModuleDef(Static(it)) => it.docs(db), ScopeDef::ModuleDef(Trait(it)) => it.docs(db), diff --git a/crates/ide-completion/src/render/function.rs b/crates/ide-completion/src/render/function.rs index 4713b1f1af..dfa30841e7 100644 --- a/crates/ide-completion/src/render/function.rs +++ b/crates/ide-completion/src/render/function.rs @@ -678,7 +678,7 @@ fn main() { fn complete_fn_param() { // has mut kw check_edit( - "mut bar: u32", + "bar: u32", r#" fn f(foo: (), mut bar: u32) {} fn g(foo: (), mut ba$0) @@ -689,10 +689,35 @@ fn g(foo: (), mut bar: u32) "#, ); - // has type param + // has unmatched mut kw + check_edit( + "bar: u32", + r#" +fn f(foo: (), bar: u32) {} +fn g(foo: (), mut ba$0) +"#, + r#" +fn f(foo: (), bar: u32) {} +fn g(foo: (), mut bar: u32) +"#, + ); + check_edit( "mut bar: u32", r#" +fn f(foo: (), mut bar: u32) {} +fn g(foo: (), ba$0) +"#, + r#" +fn f(foo: (), mut bar: u32) {} +fn g(foo: (), mut bar: u32) +"#, + ); + + // has type param + check_edit( + "bar: u32", + r#" fn g(foo: (), mut ba$0: u32) fn f(foo: (), mut bar: u32) {} "#, @@ -707,7 +732,7 @@ fn f(foo: (), mut bar: u32) {} fn complete_fn_mut_param_add_comma() { // add leading and trailing comma check_edit( - ", mut bar: u32,", + "bar: u32", r#" fn f(foo: (), mut bar: u32) {} fn g(foo: ()mut ba$0 baz: ()) @@ -746,7 +771,7 @@ fn g(foo: (), #[baz = "qux"] mut bar: u32) ); check_edit( - r#", #[baz = "qux"] mut bar: u32"#, + r#"#[baz = "qux"] mut bar: u32"#, r#" fn f(foo: (), #[baz = "qux"] mut bar: u32) {} fn g(foo: ()#[baz = "qux"] mut ba$0) @@ -908,4 +933,23 @@ fn bar() { "#, ); } + + #[test] + fn no_semicolon_in_array() { + check_edit( + r#"foo"#, + r#" +fn foo() {} +fn bar() { + let _ = [fo$0]; +} +"#, + r#" +fn foo() {} +fn bar() { + let _ = [foo()$0]; +} +"#, + ); + } } diff --git a/crates/ide-completion/src/render/literal.rs b/crates/ide-completion/src/render/literal.rs index 8b14f05b72..6e49af980a 100644 --- a/crates/ide-completion/src/render/literal.rs +++ b/crates/ide-completion/src/render/literal.rs @@ -23,7 +23,7 @@ pub(crate) fn render_variant_lit( ctx: RenderContext<'_>, path_ctx: &PathCompletionCtx<'_>, local_name: Option<hir::Name>, - variant: hir::Variant, + variant: hir::EnumVariant, path: Option<hir::ModPath>, ) -> Option<Builder> { let _p = tracing::info_span!("render_variant_lit").entered(); @@ -150,7 +150,7 @@ fn render( #[derive(Clone, Copy)] enum Variant { Struct(hir::Struct), - EnumVariant(hir::Variant), + EnumVariant(hir::EnumVariant), } impl Variant { diff --git a/crates/ide-completion/src/render/pattern.rs b/crates/ide-completion/src/render/pattern.rs index 60474a31b4..fb35d7b9b6 100644 --- a/crates/ide-completion/src/render/pattern.rs +++ b/crates/ide-completion/src/render/pattern.rs @@ -47,7 +47,7 @@ pub(crate) fn render_variant_pat( ctx: RenderContext<'_>, pattern_ctx: &PatternContext, path_ctx: Option<&PathCompletionCtx<'_>>, - variant: hir::Variant, + variant: hir::EnumVariant, local_name: Option<Name>, path: Option<&hir::ModPath>, ) -> Option<CompletionItem> { diff --git a/crates/ide-completion/src/tests/attribute.rs b/crates/ide-completion/src/tests/attribute.rs index 3701416dfc..131911be91 100644 --- a/crates/ide-completion/src/tests/attribute.rs +++ b/crates/ide-completion/src/tests/attribute.rs @@ -68,7 +68,71 @@ pub struct Foo(#[m$0] i32); kw crate:: kw self:: "#]], - ) + ); + check( + r#" +//- /mac.rs crate:mac +#![crate_type = "proc-macro"] + +#[proc_macro_derive(MyDerive, attributes(my_cool_helper_attribute))] +pub fn my_derive() {} + +//- /lib.rs crate:lib deps:mac +#[rustc_builtin_macro] +pub macro derive($item:item) {} + +#[derive(mac::MyDerive)] +pub struct Foo(#[$0] i32); +"#, + expect![[r#" + at allow(…) + at automatically_derived + at cfg(…) + at cfg_attr(…) + at cold + at deny(…) + at deprecated + at derive macro derive + at derive(…) + at diagnostic::do_not_recommend + at diagnostic::on_unimplemented + at doc = "…" + at doc = include_str!("…") + at doc(alias = "…") + at doc(hidden) + at expect(…) + at export_name = "…" + at forbid(…) + at global_allocator + at ignore = "…" + at inline + at link + at link_name = "…" + at link_section = "…" + at macro_export + at macro_use + at must_use + at my_cool_helper_attribute derive helper of `MyDerive` + at no_mangle + at non_exhaustive + at panic_handler + at path = "…" + at proc_macro + at proc_macro_attribute + at proc_macro_derive(…) + at repr(…) + at should_panic + at target_feature(enable = "…") + at test + at track_caller + at unsafe(…) + at used + at warn(…) + md mac + kw crate:: + kw self:: + "#]], + ); } #[test] diff --git a/crates/ide-completion/src/tests/expression.rs b/crates/ide-completion/src/tests/expression.rs index df39591a33..8e50ef10ec 100644 --- a/crates/ide-completion/src/tests/expression.rs +++ b/crates/ide-completion/src/tests/expression.rs @@ -3268,6 +3268,8 @@ fn foo() { sn dbg dbg!(expr) sn dbgr dbg!(&expr) sn deref *expr + sn let let + sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr @@ -3657,3 +3659,38 @@ fn main() { "#]], ); } + +#[test] +fn rpitit_with_reference() { + check( + r#" +trait Foo { + fn foo(&self); +} + +trait Bar { + fn bar(&self) -> &impl Foo; +} + +fn baz(v: impl Bar) { + v.bar().$0 +} + "#, + expect![[r#" + me foo() (as Foo) fn(&self) + sn box Box::new(expr) + sn call function(expr) + sn const const {} + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn let let + sn letm let mut + sn match match expr {} + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + "#]], + ); +} diff --git a/crates/ide-completion/src/tests/fn_param.rs b/crates/ide-completion/src/tests/fn_param.rs index 02cba6b646..aaa225642c 100644 --- a/crates/ide-completion/src/tests/fn_param.rs +++ b/crates/ide-completion/src/tests/fn_param.rs @@ -43,7 +43,7 @@ fn bar(file_id: usize) {} fn baz(file$0 id: u32) {} "#, expect![[r#" - bn file_id: usize, + bn file_id: usize kw mut kw ref "#]], @@ -293,6 +293,60 @@ fn bar(bar$0) {} } #[test] +fn not_shows_fully_equal_inside_pattern_params() { + check( + r#" +fn foo(bar: u32) {} +fn bar((a, bar$0)) {} +"#, + expect![[r#" + kw mut + kw ref + "#]], + ) +} + +#[test] +fn not_shows_locals_inside_pattern_params() { + check( + r#" +fn outer() { + let foo = 3; + { + let bar = 3; + |($0)| {}; + let baz = 3; + let qux = 3; + } + let fez = 3; +} +"#, + expect![[r#" + kw mut + kw ref + "#]], + ); + check( + r#" +fn outer() { + let foo = 3; + { + let bar = 3; + fn inner(($0)) {} + let baz = 3; + let qux = 3; + } + let fez = 3; +} +"#, + expect![[r#" + kw mut + kw ref + "#]], + ); +} + +#[test] fn completes_for_params_with_attributes() { check( r#" diff --git a/crates/ide-completion/src/tests/pattern.rs b/crates/ide-completion/src/tests/pattern.rs index b8728028bb..0d85f2e9ad 100644 --- a/crates/ide-completion/src/tests/pattern.rs +++ b/crates/ide-completion/src/tests/pattern.rs @@ -122,7 +122,6 @@ fn foo() { st Record st Tuple st Unit - ev TupleV bn Record {…} Record { field$1 }$0 bn Tuple(…) Tuple($1)$0 bn TupleV(…) TupleV($1)$0 @@ -159,8 +158,6 @@ fn foo(foo: Foo) { match foo { Foo { x: $0 } } } expect![[r#" en Bar st Foo - ev Nil - ev Value bn Foo {…} Foo { x$1 }$0 bn Nil Nil$0 bn Value Value$0 @@ -189,7 +186,6 @@ fn foo() { st Record st Tuple st Unit - ev Variant bn Record {…} Record { field$1 }$0 bn Tuple(…) Tuple($1)$0 bn Variant Variant$0 @@ -355,6 +351,34 @@ fn func() { } #[test] +fn enum_unqualified() { + check_with_base_items( + r#" +use Enum::*; +fn func() { + if let $0 = unknown {} +} +"#, + expect![[r#" + ct CONST + en Enum + ma makro!(…) macro_rules! makro + md module + st Record + st Tuple + st Unit + bn Record {…} Record { field$1 }$0 + bn RecordV {…} RecordV { field$1 }$0 + bn Tuple(…) Tuple($1)$0 + bn TupleV(…) TupleV($1)$0 + bn UnitV UnitV$0 + kw mut + kw ref + "#]], + ); +} + +#[test] fn completes_in_record_field_pat() { check( r#" diff --git a/crates/ide-completion/src/tests/record.rs b/crates/ide-completion/src/tests/record.rs index 045b2d03b0..c1274f6640 100644 --- a/crates/ide-completion/src/tests/record.rs +++ b/crates/ide-completion/src/tests/record.rs @@ -61,8 +61,6 @@ fn foo(baz: Baz) { en Baz en Result md core - ev Err - ev Ok bn Baz::Bar Baz::Bar$0 bn Baz::Foo Baz::Foo$0 bn Err(…) Err($1)$0 @@ -89,10 +87,6 @@ fn foo(baz: Baz) { en Baz en Result md core - ev Bar - ev Err - ev Foo - ev Ok bn Bar Bar$0 bn Err(…) Err($1)$0 bn Foo Foo$0 diff --git a/crates/ide-db/src/active_parameter.rs b/crates/ide-db/src/active_parameter.rs index f5a5b76c33..8bd4c6c46b 100644 --- a/crates/ide-db/src/active_parameter.rs +++ b/crates/ide-db/src/active_parameter.rs @@ -113,7 +113,7 @@ pub fn generic_def_for_node( sema: &Semantics<'_, RootDatabase>, generic_arg_list: &ast::GenericArgList, token: &SyntaxToken, -) -> Option<(hir::GenericDef, usize, bool, Option<hir::Variant>)> { +) -> Option<(hir::GenericDef, usize, bool, Option<hir::EnumVariant>)> { let parent = generic_arg_list.syntax().parent()?; let mut variant = None; let def = match_ast! { @@ -125,7 +125,7 @@ pub fn generic_def_for_node( hir::PathResolution::Def(hir::ModuleDef::Function(it)) => it.into(), hir::PathResolution::Def(hir::ModuleDef::Trait(it)) => it.into(), hir::PathResolution::Def(hir::ModuleDef::TypeAlias(it)) => it.into(), - hir::PathResolution::Def(hir::ModuleDef::Variant(it)) => { + hir::PathResolution::Def(hir::ModuleDef::EnumVariant(it)) => { variant = Some(it); it.parent_enum(sema.db).into() }, diff --git a/crates/ide-db/src/defs.rs b/crates/ide-db/src/defs.rs index 788f9b73fa..82cff37296 100644 --- a/crates/ide-db/src/defs.rs +++ b/crates/ide-db/src/defs.rs @@ -14,11 +14,12 @@ use arrayvec::ArrayVec; use either::Either; use hir::{ Adt, AsAssocItem, AsExternAssocItem, AssocItem, AttributeTemplate, BuiltinAttr, BuiltinType, - Const, Crate, DefWithBody, DeriveHelper, DisplayTarget, DocLinkDef, ExternAssocItem, - ExternCrateDecl, Field, Function, GenericDef, GenericParam, GenericSubstitution, HasContainer, - HasVisibility, HirDisplay, Impl, InlineAsmOperand, ItemContainer, Label, Local, Macro, Module, - ModuleDef, Name, PathResolution, Semantics, Static, StaticLifetime, Struct, ToolModule, Trait, - TupleField, TypeAlias, Variant, VariantDef, Visibility, + Const, Crate, DefWithBody, DeriveHelper, DisplayTarget, DocLinkDef, EnumVariant, + ExpressionStoreOwner, ExternAssocItem, ExternCrateDecl, Field, Function, GenericDef, + GenericParam, GenericSubstitution, HasContainer, HasVisibility, HirDisplay, Impl, + InlineAsmOperand, ItemContainer, Label, Local, Macro, Module, ModuleDef, Name, PathResolution, + Semantics, Static, StaticLifetime, Struct, ToolModule, Trait, TupleField, TypeAlias, Variant, + Visibility, }; use span::Edition; use stdx::{format_to, impl_from}; @@ -38,7 +39,7 @@ pub enum Definition { Crate(Crate), Function(Function), Adt(Adt), - Variant(Variant), + EnumVariant(EnumVariant), Const(Const), Static(Static), Trait(Trait), @@ -85,7 +86,7 @@ impl Definition { Definition::Static(it) => it.module(db), Definition::Trait(it) => it.module(db), Definition::TypeAlias(it) => it.module(db), - Definition::Variant(it) => it.module(db), + Definition::EnumVariant(it) => it.module(db), Definition::SelfType(it) => it.module(db), Definition::Local(it) => it.module(db), Definition::GenericParam(it) => it.module(db), @@ -123,7 +124,7 @@ impl Definition { Definition::Static(it) => container_to_definition(it.container(db)), Definition::Trait(it) => container_to_definition(it.container(db)), Definition::TypeAlias(it) => container_to_definition(it.container(db)), - Definition::Variant(it) => Some(Adt::Enum(it.parent_enum(db)).into()), + Definition::EnumVariant(it) => Some(Adt::Enum(it.parent_enum(db)).into()), Definition::SelfType(it) => Some(it.module(db).into()), Definition::Local(it) => it.parent(db).try_into().ok(), Definition::GenericParam(it) => Some(it.parent().into()), @@ -151,7 +152,7 @@ impl Definition { Definition::Static(it) => it.visibility(db), Definition::Trait(it) => it.visibility(db), Definition::TypeAlias(it) => it.visibility(db), - Definition::Variant(it) => it.visibility(db), + Definition::EnumVariant(it) => it.visibility(db), Definition::ExternCrateDecl(it) => it.visibility(db), Definition::Macro(it) => it.visibility(db), Definition::BuiltinType(_) | Definition::TupleField(_) => Visibility::Public, @@ -179,7 +180,7 @@ impl Definition { } Definition::Function(it) => it.name(db), Definition::Adt(it) => it.name(db), - Definition::Variant(it) => it.name(db), + Definition::EnumVariant(it) => it.name(db), Definition::Const(it) => it.name(db)?, Definition::Static(it) => it.name(db), Definition::Trait(it) => it.name(db), @@ -227,7 +228,7 @@ impl Definition { Definition::Crate(it) => it.docs_with_rangemap(db), Definition::Function(it) => it.docs_with_rangemap(db), Definition::Adt(it) => it.docs_with_rangemap(db), - Definition::Variant(it) => it.docs_with_rangemap(db), + Definition::EnumVariant(it) => it.docs_with_rangemap(db), Definition::Const(it) => it.docs_with_rangemap(db), Definition::Static(it) => it.docs_with_rangemap(db), Definition::Trait(it) => it.docs_with_rangemap(db), @@ -315,7 +316,7 @@ impl Definition { Definition::Crate(it) => it.display(db, display_target).to_string(), Definition::Function(it) => it.display(db, display_target).to_string(), Definition::Adt(it) => it.display(db, display_target).to_string(), - Definition::Variant(it) => it.display(db, display_target).to_string(), + Definition::EnumVariant(it) => it.display(db, display_target).to_string(), Definition::Const(it) => it.display(db, display_target).to_string(), Definition::Static(it) => it.display(db, display_target).to_string(), Definition::Trait(it) => it.display(db, display_target).to_string(), @@ -556,7 +557,7 @@ impl<'db> NameClass<'db> { ast::Rename(it) => classify_rename(sema, it)?, ast::SelfParam(it) => Definition::Local(sema.to_def(&it)?), ast::RecordField(it) => Definition::Field(sema.to_def(&it)?), - ast::Variant(it) => Definition::Variant(sema.to_def(&it)?), + ast::Variant(it) => Definition::EnumVariant(sema.to_def(&it)?), ast::TypeParam(it) => Definition::GenericParam(sema.to_def(&it)?.into()), ast::ConstParam(it) => Definition::GenericParam(sema.to_def(&it)?.into()), ast::AsmOperandNamed(it) => Definition::InlineAsmOperand(sema.to_def(&it)?), @@ -848,7 +849,7 @@ impl<'db> NameRefClass<'db> { ast::OffsetOfExpr(_) => { let (def, subst) = sema.resolve_offset_of_field(name_ref)?; let def = match def { - Either::Left(variant) => Definition::Variant(variant), + Either::Left(variant) => Definition::EnumVariant(variant), Either::Right(field) => Definition::Field(field), }; Some(NameRefClass::Definition(def, Some(subst))) @@ -891,7 +892,7 @@ impl<'db> NameRefClass<'db> { } impl_from!( - Field, Module, Function, Adt, Variant, Const, Static, Trait, TypeAlias, BuiltinType, Local, + Field, Module, Function, Adt, EnumVariant, Const, Static, Trait, TypeAlias, BuiltinType, Local, GenericParam, Label, Macro, ExternCrateDecl for Definition ); @@ -967,7 +968,7 @@ impl From<ModuleDef> for Definition { ModuleDef::Module(it) => Definition::Module(it), ModuleDef::Function(it) => Definition::Function(it), ModuleDef::Adt(it) => Definition::Adt(it), - ModuleDef::Variant(it) => Definition::Variant(it), + ModuleDef::EnumVariant(it) => Definition::EnumVariant(it), ModuleDef::Const(it) => Definition::Const(it), ModuleDef::Static(it) => Definition::Static(it), ModuleDef::Trait(it) => Definition::Trait(it), @@ -988,8 +989,8 @@ impl From<DocLinkDef> for Definition { } } -impl From<VariantDef> for Definition { - fn from(def: VariantDef) -> Self { +impl From<Variant> for Definition { + fn from(def: Variant) -> Self { ModuleDef::from(def).into() } } @@ -1001,7 +1002,7 @@ impl TryFrom<DefWithBody> for Definition { DefWithBody::Function(it) => Ok(it.into()), DefWithBody::Static(it) => Ok(it.into()), DefWithBody::Const(it) => Ok(it.into()), - DefWithBody::Variant(it) => Ok(it.into()), + DefWithBody::EnumVariant(it) => Ok(it.into()), } } } @@ -1020,6 +1021,17 @@ impl From<GenericDef> for Definition { } } +impl TryFrom<ExpressionStoreOwner> for Definition { + type Error = (); + fn try_from(def: ExpressionStoreOwner) -> Result<Self, Self::Error> { + match def { + ExpressionStoreOwner::Body(def_with_body) => def_with_body.try_into(), + ExpressionStoreOwner::Signature(generic_def) => Ok(generic_def.into()), + ExpressionStoreOwner::VariantFields(it) => Ok(it.into()), + } + } +} + impl TryFrom<Definition> for GenericDef { type Error = (); fn try_from(def: Definition) -> Result<Self, Self::Error> { diff --git a/crates/ide-db/src/documentation.rs b/crates/ide-db/src/documentation.rs index 4c4691cca2..407049f4b3 100644 --- a/crates/ide-db/src/documentation.rs +++ b/crates/ide-db/src/documentation.rs @@ -58,8 +58,22 @@ macro_rules! impl_has_docs { } impl_has_docs![ - Variant, Field, Static, Const, Trait, TypeAlias, Macro, Function, Adt, Module, Impl, Crate, - AssocItem, Struct, Union, Enum, + EnumVariant, + Field, + Static, + Const, + Trait, + TypeAlias, + Macro, + Function, + Adt, + Module, + Impl, + Crate, + AssocItem, + Struct, + Union, + Enum, ]; impl HasDocs for hir::ExternCrateDecl { diff --git a/crates/ide-db/src/generated/lints.rs b/crates/ide-db/src/generated/lints.rs index dedc12aa65..9e6d586008 100644 --- a/crates/ide-db/src/generated/lints.rs +++ b/crates/ide-db/src/generated/lints.rs @@ -11576,9 +11576,9 @@ The tracking issue for this feature is: [#85731] label: "try_blocks", description: r##"# `try_blocks` -The tracking issue for this feature is: [#31436] +The tracking issue for this feature is: [#154391] -[#31436]: https://github.com/rust-lang/rust/issues/31436 +[#154391]: https://github.com/rust-lang/rust/issues/154391 ------------------------ @@ -11590,14 +11590,14 @@ block creates a new scope one can use the `?` operator in. use std::num::ParseIntError; -let result: Result<i32, ParseIntError> = try { +let result = try { "1".parse::<i32>()? + "2".parse::<i32>()? + "3".parse::<i32>()? }; assert_eq!(result, Ok(6)); -let result: Result<i32, ParseIntError> = try { +let result = try { "1".parse::<i32>()? + "foo".parse::<i32>()? + "3".parse::<i32>()? diff --git a/crates/ide-db/src/imports/import_assets.rs b/crates/ide-db/src/imports/import_assets.rs index 35579eb259..1c48527027 100644 --- a/crates/ide-db/src/imports/import_assets.rs +++ b/crates/ide-db/src/imports/import_assets.rs @@ -9,7 +9,7 @@ use hir::{ }; use itertools::Itertools; use rustc_hash::{FxHashMap, FxHashSet}; -use smallvec::SmallVec; +use smallvec::{SmallVec, smallvec}; use syntax::{ AstNode, SyntaxNode, ast::{self, HasName, make}, @@ -68,6 +68,8 @@ pub struct PathImportCandidate { pub qualifier: Vec<Name>, /// The name the item (struct, trait, enum, etc.) should have. pub name: NameToImport, + /// Potentially more segments that should resolve in the candidate. + pub after: Vec<Name>, } /// A name that will be used during item lookups. @@ -376,7 +378,7 @@ fn path_applicable_imports( ) -> FxIndexSet<LocatedImport> { let _p = tracing::info_span!("ImportAssets::path_applicable_imports").entered(); - match &*path_candidate.qualifier { + let mut result = match &*path_candidate.qualifier { [] => { items_locator::items_with_name( db, @@ -433,6 +435,75 @@ fn path_applicable_imports( }) .take(DEFAULT_QUERY_SEARCH_LIMIT) .collect(), + }; + + filter_candidates_by_after_path(db, scope, path_candidate, &mut result); + + result +} + +fn filter_candidates_by_after_path( + db: &RootDatabase, + scope: &SemanticsScope<'_>, + path_candidate: &PathImportCandidate, + imports: &mut FxIndexSet<LocatedImport>, +) { + if imports.len() <= 1 { + // Short-circuit, as even if it doesn't match fully we want it. + return; + } + + let Some((last_after, after_except_last)) = path_candidate.after.split_last() else { + return; + }; + + let original_imports = imports.clone(); + + let traits_in_scope = scope.visible_traits(); + imports.retain(|import| { + let items = if after_except_last.is_empty() { + smallvec![import.original_item] + } else { + let ItemInNs::Types(ModuleDef::Module(item)) = import.original_item else { + return false; + }; + // FIXME: This doesn't consider visibilities. + item.resolve_mod_path(db, after_except_last.iter().cloned()) + .into_iter() + .flatten() + .collect::<SmallVec<[_; 3]>>() + }; + items.into_iter().any(|item| { + let has_last_method = |ty: hir::Type<'_>| { + ty.iterate_path_candidates(db, scope, &traits_in_scope, Some(last_after), |_| { + Some(()) + }) + .is_some() + }; + // FIXME: A trait can have an assoc type that has a function/const, that's two segments before last. + match item { + // A module? Can we resolve one more segment? + ItemInNs::Types(ModuleDef::Module(module)) => module + .resolve_mod_path(db, [last_after.clone()]) + .is_some_and(|mut it| it.any(|_| true)), + // And ADT/Type Alias? That might be a method. + ItemInNs::Types(ModuleDef::Adt(it)) => has_last_method(it.ty(db)), + ItemInNs::Types(ModuleDef::BuiltinType(it)) => has_last_method(it.ty(db)), + ItemInNs::Types(ModuleDef::TypeAlias(it)) => has_last_method(it.ty(db)), + // A trait? Might have an associated item. + ItemInNs::Types(ModuleDef::Trait(it)) => it + .items(db) + .into_iter() + .any(|assoc_item| assoc_item.name(db) == Some(last_after.clone())), + // Other items? can't resolve one more segment. + _ => false, + } + }) + }); + + if imports.is_empty() { + // Better one half-match than zero full matches. + *imports = original_imports; } } @@ -759,10 +830,14 @@ impl<'db> ImportCandidate<'db> { if sema.resolve_path(path).is_some() { return None; } + let after = std::iter::successors(path.parent_path(), |it| it.parent_path()) + .map(|seg| seg.segment()?.name_ref().map(|name| Name::new_root(&name.text()))) + .collect::<Option<_>>()?; path_import_candidate( sema, path.qualifier(), NameToImport::exact_case_sensitive(path.segment()?.name_ref()?.to_string()), + after, ) } @@ -777,6 +852,7 @@ impl<'db> ImportCandidate<'db> { Some(ImportCandidate::Path(PathImportCandidate { qualifier: vec![], name: NameToImport::exact_case_sensitive(name.to_string()), + after: vec![], })) } @@ -785,7 +861,8 @@ impl<'db> ImportCandidate<'db> { fuzzy_name: String, sema: &Semantics<'db, RootDatabase>, ) -> Option<Self> { - path_import_candidate(sema, qualifier, NameToImport::fuzzy(fuzzy_name)) + // Assume a fuzzy match does not want the segments after. Because... I guess why not? + path_import_candidate(sema, qualifier, NameToImport::fuzzy(fuzzy_name), Vec::new()) } } @@ -793,6 +870,7 @@ fn path_import_candidate<'db>( sema: &Semantics<'db, RootDatabase>, qualifier: Option<ast::Path>, name: NameToImport, + after: Vec<Name>, ) -> Option<ImportCandidate<'db>> { Some(match qualifier { Some(qualifier) => match sema.resolve_path(&qualifier) { @@ -802,7 +880,7 @@ fn path_import_candidate<'db>( .segments() .map(|seg| seg.name_ref().map(|name| Name::new_root(&name.text()))) .collect::<Option<Vec<_>>>()?; - ImportCandidate::Path(PathImportCandidate { qualifier, name }) + ImportCandidate::Path(PathImportCandidate { qualifier, name, after }) } else { return None; } @@ -826,7 +904,7 @@ fn path_import_candidate<'db>( } Some(_) => return None, }, - None => ImportCandidate::Path(PathImportCandidate { qualifier: vec![], name }), + None => ImportCandidate::Path(PathImportCandidate { qualifier: vec![], name, after }), }) } diff --git a/crates/ide-db/src/imports/insert_use.rs b/crates/ide-db/src/imports/insert_use.rs index 4444ef5d81..da8525d1fb 100644 --- a/crates/ide-db/src/imports/insert_use.rs +++ b/crates/ide-db/src/imports/insert_use.rs @@ -9,8 +9,9 @@ use syntax::{ Direction, NodeOrToken, SyntaxKind, SyntaxNode, algo, ast::{ self, AstNode, HasAttrs, HasModuleItem, HasVisibility, PathSegmentKind, - edit_in_place::Removable, make, + edit_in_place::Removable, make, syntax_factory::SyntaxFactory, }, + syntax_editor::{Position, SyntaxEditor}, ted, }; @@ -93,7 +94,7 @@ impl ImportScope { .item_list() .map(ImportScopeKind::Module) .map(|kind| ImportScope { kind, required_cfgs }); - } else if let Some(has_attrs) = ast::AnyHasAttrs::cast(syntax) { + } else if let Some(has_attrs) = ast::AnyHasAttrs::cast(syntax.clone()) { if block.is_none() && let Some(b) = ast::BlockExpr::cast(has_attrs.syntax().clone()) && let Some(b) = sema.original_ast_node(b) @@ -104,11 +105,34 @@ impl ImportScope { .attrs() .any(|attr| attr.as_simple_call().is_some_and(|(ident, _)| ident == "cfg")) { - if let Some(b) = block { - return Some(ImportScope { - kind: ImportScopeKind::Block(b), - required_cfgs, + if let Some(b) = block.clone() { + let current_cfgs = has_attrs.attrs().filter(|attr| { + attr.as_simple_call().is_some_and(|(ident, _)| ident == "cfg") }); + + let total_cfgs: Vec<_> = + required_cfgs.iter().cloned().chain(current_cfgs).collect(); + + let parent = syntax.parent(); + let mut can_merge = false; + if let Some(parent) = parent { + can_merge = parent.children().filter_map(ast::Use::cast).any(|u| { + let u_attrs = u.attrs().filter(|attr| { + attr.as_simple_call().is_some_and(|(ident, _)| ident == "cfg") + }); + crate::imports::merge_imports::eq_attrs( + u_attrs, + total_cfgs.iter().cloned(), + ) + }); + } + + if !can_merge { + return Some(ImportScope { + kind: ImportScopeKind::Block(b), + required_cfgs, + }); + } } required_cfgs.extend(has_attrs.attrs().filter(|attr| { attr.as_simple_call().is_some_and(|(ident, _)| ident == "cfg") @@ -146,9 +170,25 @@ pub fn insert_use(scope: &ImportScope, path: ast::Path, cfg: &InsertUseConfig) { insert_use_with_alias_option(scope, path, cfg, None); } -pub fn insert_use_as_alias(scope: &ImportScope, path: ast::Path, cfg: &InsertUseConfig) { +/// Insert an import path into the given file/node. A `merge` value of none indicates that no import merging is allowed to occur. +pub fn insert_use_with_editor( + scope: &ImportScope, + path: ast::Path, + cfg: &InsertUseConfig, + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, +) { + insert_use_with_alias_option_with_editor(scope, path, cfg, None, syntax_editor, syntax_factory); +} + +pub fn insert_use_as_alias( + scope: &ImportScope, + path: ast::Path, + cfg: &InsertUseConfig, + edition: span::Edition, +) { let text: &str = "use foo as _"; - let parse = syntax::SourceFile::parse(text, span::Edition::CURRENT_FIXME); + let parse = syntax::SourceFile::parse(text, edition); let node = parse .tree() .syntax() @@ -224,6 +264,71 @@ fn insert_use_with_alias_option( insert_use_(scope, use_item, cfg.group); } +fn insert_use_with_alias_option_with_editor( + scope: &ImportScope, + path: ast::Path, + cfg: &InsertUseConfig, + alias: Option<ast::Rename>, + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, +) { + let _p = tracing::info_span!("insert_use_with_alias_option").entered(); + let mut mb = match cfg.granularity { + ImportGranularity::Crate => Some(MergeBehavior::Crate), + ImportGranularity::Module => Some(MergeBehavior::Module), + ImportGranularity::One => Some(MergeBehavior::One), + ImportGranularity::Item => None, + }; + if !cfg.enforce_granularity { + let file_granularity = guess_granularity_from_scope(scope); + mb = match file_granularity { + ImportGranularityGuess::Unknown => mb, + ImportGranularityGuess::Item => None, + ImportGranularityGuess::Module => Some(MergeBehavior::Module), + // We use the user's setting to infer if this is module or item. + ImportGranularityGuess::ModuleOrItem => match mb { + Some(MergeBehavior::Module) | None => mb, + // There isn't really a way to decide between module or item here, so we just pick one. + // FIXME: Maybe it is possible to infer based on semantic analysis? + Some(MergeBehavior::One | MergeBehavior::Crate) => Some(MergeBehavior::Module), + }, + ImportGranularityGuess::Crate => Some(MergeBehavior::Crate), + ImportGranularityGuess::CrateOrModule => match mb { + Some(MergeBehavior::Crate | MergeBehavior::Module) => mb, + Some(MergeBehavior::One) | None => Some(MergeBehavior::Crate), + }, + ImportGranularityGuess::One => Some(MergeBehavior::One), + }; + } + + let use_tree = syntax_factory.use_tree(path, None, alias, false); + if mb == Some(MergeBehavior::One) && use_tree.path().is_some() { + use_tree.wrap_in_tree_list(); + } + let use_item = make::use_(None, None, use_tree).clone_for_update(); + for attr in + scope.required_cfgs.iter().map(|attr| attr.syntax().clone_subtree().clone_for_update()) + { + syntax_editor.insert(Position::first_child_of(use_item.syntax()), attr); + } + + // merge into existing imports if possible + if let Some(mb) = mb { + let filter = |it: &_| !(cfg.skip_glob_imports && ast::Use::is_simple_glob(it)); + for existing_use in + scope.as_syntax_node().children().filter_map(ast::Use::cast).filter(filter) + { + if let Some(merged) = try_merge_imports(&existing_use, &use_item, mb) { + syntax_editor.replace(existing_use.syntax(), merged.syntax()); + return; + } + } + } + // either we weren't allowed to merge or there is no import that fits the merge conditions + // so look for the place we have to insert to + insert_use_with_editor_(scope, use_item, cfg.group, syntax_editor, syntax_factory); +} + pub fn ast_to_remove_for_path_in_use_stmt(path: &ast::Path) -> Option<Box<dyn Removable>> { // FIXME: improve this if path.parent_path().is_some() { @@ -464,7 +569,9 @@ fn insert_use_(scope: &ImportScope, use_item: ast::Use, group_imports: bool) { // skip the curly brace .skip(l_curly.is_some() as usize) .take_while(|child| match child { - NodeOrToken::Node(node) => is_inner_attribute(node.clone()), + NodeOrToken::Node(node) => { + is_inner_attribute(node.clone()) && ast::Item::cast(node.clone()).is_none() + } NodeOrToken::Token(token) => { [SyntaxKind::WHITESPACE, SyntaxKind::COMMENT, SyntaxKind::SHEBANG] .contains(&token.kind()) @@ -495,6 +602,129 @@ fn insert_use_(scope: &ImportScope, use_item: ast::Use, group_imports: bool) { } } +fn insert_use_with_editor_( + scope: &ImportScope, + use_item: ast::Use, + group_imports: bool, + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, +) { + let scope_syntax = scope.as_syntax_node(); + let insert_use_tree = + use_item.use_tree().expect("`use_item` should have a use tree for `insert_path`"); + let group = ImportGroup::new(&insert_use_tree); + let path_node_iter = scope_syntax + .children() + .filter_map(|node| ast::Use::cast(node.clone()).zip(Some(node))) + .flat_map(|(use_, node)| { + let tree = use_.use_tree()?; + Some((tree, node)) + }); + + if group_imports { + // Iterator that discards anything that's not in the required grouping + // This implementation allows the user to rearrange their import groups as this only takes the first group that fits + let group_iter = path_node_iter + .clone() + .skip_while(|(use_tree, ..)| ImportGroup::new(use_tree) != group) + .take_while(|(use_tree, ..)| ImportGroup::new(use_tree) == group); + + // track the last element we iterated over, if this is still None after the iteration then that means we never iterated in the first place + let mut last = None; + // find the element that would come directly after our new import + let post_insert: Option<(_, SyntaxNode)> = group_iter + .inspect(|(.., node)| last = Some(node.clone())) + .find(|(use_tree, _)| use_tree_cmp(&insert_use_tree, use_tree) != Ordering::Greater); + + if let Some((.., node)) = post_insert { + cov_mark::hit!(insert_group); + // insert our import before that element + return syntax_editor.insert(Position::before(node), use_item.syntax()); + } + if let Some(node) = last { + cov_mark::hit!(insert_group_last); + // there is no element after our new import, so append it to the end of the group + return syntax_editor.insert(Position::after(node), use_item.syntax()); + } + + // the group we were looking for actually doesn't exist, so insert + + let mut last = None; + // find the group that comes after where we want to insert + let post_group = path_node_iter + .inspect(|(.., node)| last = Some(node.clone())) + .find(|(use_tree, ..)| ImportGroup::new(use_tree) > group); + if let Some((.., node)) = post_group { + cov_mark::hit!(insert_group_new_group); + syntax_editor.insert(Position::before(&node), use_item.syntax()); + if let Some(node) = algo::non_trivia_sibling(node.into(), Direction::Prev) { + syntax_editor.insert(Position::after(node), syntax_factory.whitespace("\n")); + } + return; + } + // there is no such group, so append after the last one + if let Some(node) = last { + cov_mark::hit!(insert_group_no_group); + syntax_editor.insert(Position::after(&node), use_item.syntax()); + syntax_editor.insert(Position::after(node), syntax_factory.whitespace("\n")); + return; + } + } else { + // There exists a group, so append to the end of it + if let Some((_, node)) = path_node_iter.last() { + cov_mark::hit!(insert_no_grouping_last); + syntax_editor.insert(Position::after(node), use_item.syntax()); + return; + } + } + + let l_curly = match &scope.kind { + ImportScopeKind::File(_) => None, + // don't insert the imports before the item list/block expr's opening curly brace + ImportScopeKind::Module(item_list) => item_list.l_curly_token(), + // don't insert the imports before the item list's opening curly brace + ImportScopeKind::Block(block) => block.l_curly_token(), + }; + // there are no imports in this file at all + // so put the import after all inner module attributes and possible license header comments + if let Some(last_inner_element) = scope_syntax + .children_with_tokens() + // skip the curly brace + .skip(l_curly.is_some() as usize) + .take_while(|child| match child { + NodeOrToken::Node(node) => { + is_inner_attribute(node.clone()) && ast::Item::cast(node.clone()).is_none() + } + NodeOrToken::Token(token) => { + [SyntaxKind::WHITESPACE, SyntaxKind::COMMENT, SyntaxKind::SHEBANG] + .contains(&token.kind()) + } + }) + .filter(|child| child.as_token().is_none_or(|t| t.kind() != SyntaxKind::WHITESPACE)) + .last() + { + cov_mark::hit!(insert_empty_inner_attr); + syntax_editor.insert(Position::after(&last_inner_element), use_item.syntax()); + syntax_editor.insert(Position::after(last_inner_element), syntax_factory.whitespace("\n")); + } else { + match l_curly { + Some(b) => { + cov_mark::hit!(insert_empty_module); + syntax_editor.insert(Position::after(&b), syntax_factory.whitespace("\n")); + syntax_editor.insert(Position::after(&b), use_item.syntax()); + } + None => { + cov_mark::hit!(insert_empty_file); + syntax_editor.insert( + Position::first_child_of(scope_syntax), + syntax_factory.whitespace("\n\n"), + ); + syntax_editor.insert(Position::first_child_of(scope_syntax), use_item.syntax()); + } + } + } +} + fn is_inner_attribute(node: SyntaxNode) -> bool { ast::Attr::cast(node).map(|attr| attr.kind()) == Some(ast::AttrKind::Inner) } diff --git a/crates/ide-db/src/imports/insert_use/tests.rs b/crates/ide-db/src/imports/insert_use/tests.rs index 3350e1c3d2..6c7b97458d 100644 --- a/crates/ide-db/src/imports/insert_use/tests.rs +++ b/crates/ide-db/src/imports/insert_use/tests.rs @@ -1438,3 +1438,156 @@ fn check_guess(#[rust_analyzer::rust_fixture] ra_fixture: &str, expected: Import let file = ImportScope { kind: ImportScopeKind::File(syntax), required_cfgs: vec![] }; assert_eq!(super::guess_granularity_from_scope(&file), expected); } + +#[test] +fn insert_with_existing_imports_and_cfg_module() { + check( + "std::fmt", + r#" +use foo::bar; + +#[cfg(target_arch = "x86_64")] +pub mod api; +"#, + r#" +use std::fmt; + +use foo::bar; + +#[cfg(target_arch = "x86_64")] +pub mod api; +"#, + ImportGranularity::Crate, + ); +} + +#[test] +fn insert_before_cfg_module() { + check( + "std::fmt", + r#" +#[cfg(target_arch = "x86_64")] +pub mod api; +"#, + r#" +use std::fmt; + +#[cfg(target_arch = "x86_64")] +pub mod api; +"#, + ImportGranularity::Crate, + ); +} + +fn check_merge(ra_fixture0: &str, ra_fixture1: &str, last: &str, mb: MergeBehavior) { + let use0 = ast::SourceFile::parse(ra_fixture0, span::Edition::CURRENT) + .tree() + .syntax() + .descendants() + .find_map(ast::Use::cast) + .unwrap(); + + let use1 = ast::SourceFile::parse(ra_fixture1, span::Edition::CURRENT) + .tree() + .syntax() + .descendants() + .find_map(ast::Use::cast) + .unwrap(); + + let result = try_merge_imports(&use0, &use1, mb); + assert_eq!(result.map(|u| u.to_string().trim().to_owned()), Some(last.trim().to_owned())); +} + +#[test] +fn merge_gated_imports() { + check_merge( + r#"#[cfg(test)] use foo::bar;"#, + r#"#[cfg(test)] use foo::baz;"#, + r#"#[cfg(test)] use foo::{bar, baz};"#, + MergeBehavior::Crate, + ); +} + +#[test] +fn merge_gated_imports_with_different_values() { + let use0 = ast::SourceFile::parse(r#"#[cfg(a)] use foo::bar;"#, span::Edition::CURRENT) + .tree() + .syntax() + .descendants() + .find_map(ast::Use::cast) + .unwrap(); + + let use1 = ast::SourceFile::parse(r#"#[cfg(b)] use foo::baz;"#, span::Edition::CURRENT) + .tree() + .syntax() + .descendants() + .find_map(ast::Use::cast) + .unwrap(); + + let result = try_merge_imports(&use0, &use1, MergeBehavior::Crate); + assert_eq!(result, None); +} + +#[test] +fn merge_gated_imports_different_order() { + check_merge( + r#"#[cfg(a)] #[cfg(b)] use foo::bar;"#, + r#"#[cfg(b)] #[cfg(a)] use foo::baz;"#, + r#"#[cfg(a)] #[cfg(b)] use foo::{bar, baz};"#, + MergeBehavior::Crate, + ); +} + +#[test] +fn merge_into_existing_cfg_import() { + check( + r#"foo::Foo"#, + r#" +#[cfg(target_os = "windows")] +use bar::Baz; + +#[cfg(target_os = "windows")] +fn buzz() { + Foo$0; +} +"#, + r#" +#[cfg(target_os = "windows")] +use bar::Baz; +#[cfg(target_os = "windows")] +use foo::Foo; + +#[cfg(target_os = "windows")] +fn buzz() { + Foo; +} +"#, + ImportGranularity::Crate, + ); +} + +#[test] +fn reproduce_user_issue_missing_semicolon() { + check( + "std::fmt", + r#" +use { + foo +} + +#[cfg(target_arch = "x86_64")] +pub mod api; +"#, + r#" +use std::fmt; + +use { + foo +} + +#[cfg(target_arch = "x86_64")] +pub mod api; +"#, + ImportGranularity::Crate, + ); +} diff --git a/crates/ide-db/src/imports/merge_imports.rs b/crates/ide-db/src/imports/merge_imports.rs index 635ed7368c..3301719f5c 100644 --- a/crates/ide-db/src/imports/merge_imports.rs +++ b/crates/ide-db/src/imports/merge_imports.rs @@ -4,7 +4,7 @@ use std::cmp::Ordering; use itertools::{EitherOrBoth, Itertools}; use parser::T; use syntax::{ - Direction, SyntaxElement, algo, + Direction, SyntaxElement, ToSmolStr, algo, ast::{ self, AstNode, HasAttrs, HasName, HasVisibility, PathSegmentKind, edit_in_place::Removable, make, @@ -691,14 +691,12 @@ pub fn eq_attrs( attrs0: impl Iterator<Item = ast::Attr>, attrs1: impl Iterator<Item = ast::Attr>, ) -> bool { - // FIXME order of attributes should not matter - let attrs0 = attrs0 - .flat_map(|attr| attr.syntax().descendants_with_tokens()) - .flat_map(|it| it.into_token()); - let attrs1 = attrs1 - .flat_map(|attr| attr.syntax().descendants_with_tokens()) - .flat_map(|it| it.into_token()); - stdx::iter_eq_by(attrs0, attrs1, |tok, tok2| tok.text() == tok2.text()) + let mut attrs0: Vec<_> = attrs0.map(|attr| attr.syntax().text().to_smolstr()).collect(); + let mut attrs1: Vec<_> = attrs1.map(|attr| attr.syntax().text().to_smolstr()).collect(); + attrs0.sort_unstable(); + attrs1.sort_unstable(); + + attrs0 == attrs1 } fn path_is_self(path: &ast::Path) -> bool { diff --git a/crates/ide-db/src/lib.rs b/crates/ide-db/src/lib.rs index 023b32b361..cde0705d8a 100644 --- a/crates/ide-db/src/lib.rs +++ b/crates/ide-db/src/lib.rs @@ -312,7 +312,7 @@ impl SymbolKind { pub fn from_module_def(db: &dyn HirDatabase, it: hir::ModuleDef) -> Self { match it { hir::ModuleDef::Const(..) => SymbolKind::Const, - hir::ModuleDef::Variant(..) => SymbolKind::Variant, + hir::ModuleDef::EnumVariant(..) => SymbolKind::Variant, hir::ModuleDef::Function(..) => SymbolKind::Function, hir::ModuleDef::Macro(mac) if mac.is_proc_macro() => SymbolKind::ProcMacro, hir::ModuleDef::Macro(..) => SymbolKind::Macro, diff --git a/crates/ide-db/src/path_transform.rs b/crates/ide-db/src/path_transform.rs index 48305c2082..508f841340 100644 --- a/crates/ide-db/src/path_transform.rs +++ b/crates/ide-db/src/path_transform.rs @@ -553,6 +553,39 @@ impl Ctx<'_> { return None; } + // Similarly, modules cannot be used in pattern position. + if matches!(def, hir::ModuleDef::Module(_)) { + return None; + } + + if matches!( + def, + hir::ModuleDef::Function(_) + | hir::ModuleDef::Trait(_) + | hir::ModuleDef::TypeAlias(_) + ) { + return None; + } + + if let hir::ModuleDef::Adt(adt) = def { + match adt { + hir::Adt::Struct(s) + if s.kind(self.source_scope.db) != hir::StructKind::Unit => + { + return None; + } + hir::Adt::Union(_) => return None, + hir::Adt::Enum(_) => return None, + _ => (), + } + } + + if let hir::ModuleDef::EnumVariant(v) = def + && v.kind(self.source_scope.db) != hir::StructKind::Unit + { + return None; + } + let cfg = FindPathConfig { prefer_no_std: false, prefer_prelude: true, @@ -632,3 +665,87 @@ fn find_trait_for_assoc_item( None } + +#[cfg(test)] +mod tests { + use crate::RootDatabase; + use crate::path_transform::PathTransform; + use hir::Semantics; + use syntax::{AstNode, ast::HasName}; + use test_fixture::WithFixture; + use test_utils::assert_eq_text; + + #[test] + fn test_transform_ident_pat() { + let (db, file_id) = RootDatabase::with_single_file( + r#" +mod foo { + pub struct UnitStruct; + pub struct RecordStruct {} + pub enum Enum { UnitVariant, RecordVariant {} } + pub fn function() {} + pub const CONST: i32 = 0; + pub static STATIC: i32 = 0; + pub type Alias = i32; + pub union Union { f: i32 } +} + +mod bar { + fn anchor() {} +} + +fn main() { + use foo::*; + use foo::Enum::*; + let UnitStruct = (); + let RecordStruct = (); + let Enum = (); + let UnitVariant = (); + let RecordVariant = (); + let function = (); + let CONST = (); + let STATIC = (); + let Alias = (); + let Union = (); +} +"#, + ); + let sema = Semantics::new(&db); + let source_file = sema.parse(file_id); + + let function = source_file + .syntax() + .descendants() + .filter_map(syntax::ast::Fn::cast) + .find(|it| it.name().unwrap().text() == "main") + .unwrap(); + let source_scope = sema.scope(function.body().unwrap().syntax()).unwrap(); + + let anchor = source_file + .syntax() + .descendants() + .filter_map(syntax::ast::Fn::cast) + .find(|it| it.name().unwrap().text() == "anchor") + .unwrap(); + let target_scope = sema.scope(anchor.body().unwrap().syntax()).unwrap(); + + let transform = PathTransform::generic_transformation(&target_scope, &source_scope); + let transformed = transform.apply(function.body().unwrap().syntax()); + + let expected = r#"{ + use crate::foo::*; + use crate::foo::Enum::*; + let crate::foo::UnitStruct = (); + let RecordStruct = (); + let Enum = (); + let crate::foo::Enum::UnitVariant = (); + let RecordVariant = (); + let function = (); + let crate::foo::CONST = (); + let crate::foo::STATIC = (); + let Alias = (); + let Union = (); +}"#; + assert_eq_text!(expected, &transformed.to_string()); + } +} diff --git a/crates/ide-db/src/prime_caches.rs b/crates/ide-db/src/prime_caches.rs index 015b06e8e0..d264428212 100644 --- a/crates/ide-db/src/prime_caches.rs +++ b/crates/ide-db/src/prime_caches.rs @@ -4,7 +4,7 @@ //! various caches, it's not really advanced at the moment. use std::panic::AssertUnwindSafe; -use hir::{Symbol, db::DefDatabase}; +use hir::{Symbol, import_map::ImportMap}; use rustc_hash::FxHashMap; use salsa::{Cancelled, Database}; @@ -123,7 +123,7 @@ pub fn parallel_prime_caches( Ok::<_, crossbeam_channel::SendError<_>>(()) }; let handle_import_map = |crate_id| { - let cancelled = Cancelled::catch(|| _ = db.import_map(crate_id)); + let cancelled = Cancelled::catch(|| _ = ImportMap::of(&db, crate_id)); match cancelled { Ok(()) => { diff --git a/crates/ide-db/src/rename.rs b/crates/ide-db/src/rename.rs index b03a5b6efb..b18ed69d80 100644 --- a/crates/ide-db/src/rename.rs +++ b/crates/ide-db/src/rename.rs @@ -170,7 +170,7 @@ impl Definition { hir::Adt::Union(it) => name_range(it, sema).and_then(syn_ctx_is_root), hir::Adt::Enum(it) => name_range(it, sema).and_then(syn_ctx_is_root), }, - Definition::Variant(it) => name_range(it, sema).and_then(syn_ctx_is_root), + Definition::EnumVariant(it) => name_range(it, sema).and_then(syn_ctx_is_root), Definition::Const(it) => name_range(it, sema).and_then(syn_ctx_is_root), Definition::Static(it) => name_range(it, sema).and_then(syn_ctx_is_root), Definition::Trait(it) => name_range(it, sema).and_then(syn_ctx_is_root), diff --git a/crates/ide-db/src/search.rs b/crates/ide-db/src/search.rs index 1d865892a2..25acb47f7b 100644 --- a/crates/ide-db/src/search.rs +++ b/crates/ide-db/src/search.rs @@ -10,9 +10,9 @@ use std::{cell::LazyCell, cmp::Reverse}; use base_db::{RootQueryDb, SourceDatabase}; use either::Either; use hir::{ - Adt, AsAssocItem, DefWithBody, EditionedFileId, FileRange, FileRangeWrapper, HasAttrs, - HasContainer, HasSource, InFile, InFileWrapper, InRealFile, InlineAsmOperand, ItemContainer, - ModuleSource, PathResolution, Semantics, Visibility, + Adt, AsAssocItem, DefWithBody, EditionedFileId, ExpressionStoreOwner, FileRange, + FileRangeWrapper, HasAttrs, HasContainer, HasSource, InFile, InFileWrapper, InRealFile, + InlineAsmOperand, ItemContainer, ModuleSource, PathResolution, Semantics, Visibility, }; use memchr::memmem::Finder; use parser::SyntaxKind; @@ -169,7 +169,7 @@ impl SearchScope { entries.extend( source_root .iter() - .map(|id| (EditionedFileId::new(db, id, crate_data.edition, krate), None)), + .map(|id| (EditionedFileId::new(db, id, crate_data.edition), None)), ); } SearchScope { entries } @@ -183,9 +183,11 @@ impl SearchScope { let source_root = db.file_source_root(root_file).source_root_id(db); let source_root = db.source_root(source_root).source_root(db); - entries.extend(source_root.iter().map(|id| { - (EditionedFileId::new(db, id, rev_dep.edition(db), rev_dep.into()), None) - })); + entries.extend( + source_root + .iter() + .map(|id| (EditionedFileId::new(db, id, rev_dep.edition(db)), None)), + ); } SearchScope { entries } } @@ -199,7 +201,7 @@ impl SearchScope { SearchScope { entries: source_root .iter() - .map(|id| (EditionedFileId::new(db, id, of.edition(db), of.into()), None)) + .map(|id| (EditionedFileId::new(db, id, of.edition(db)), None)) .collect(), } } @@ -308,10 +310,26 @@ impl Definition { if let Definition::Local(var) = self { let def = match var.parent(db) { - DefWithBody::Function(f) => f.source(db).map(|src| src.syntax().cloned()), - DefWithBody::Const(c) => c.source(db).map(|src| src.syntax().cloned()), - DefWithBody::Static(s) => s.source(db).map(|src| src.syntax().cloned()), - DefWithBody::Variant(v) => v.source(db).map(|src| src.syntax().cloned()), + ExpressionStoreOwner::Body(def) => match def { + DefWithBody::Function(f) => f.source(db).map(|src| src.syntax().cloned()), + DefWithBody::Const(c) => c.source(db).map(|src| src.syntax().cloned()), + DefWithBody::Static(s) => s.source(db).map(|src| src.syntax().cloned()), + DefWithBody::EnumVariant(v) => v.source(db).map(|src| src.syntax().cloned()), + }, + ExpressionStoreOwner::Signature(def) => match def { + hir::GenericDef::Function(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Adt(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Trait(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::TypeAlias(it) => { + it.source(db).map(|src| src.syntax().cloned()) + } + hir::GenericDef::Impl(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Const(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Static(it) => it.source(db).map(|src| src.syntax().cloned()), + }, + ExpressionStoreOwner::VariantFields(it) => { + it.source(db).map(|src| src.syntax().cloned()) + } }; return match def { Some(def) => SearchScope::file_range( @@ -323,10 +341,26 @@ impl Definition { if let Definition::InlineAsmOperand(op) = self { let def = match op.parent(db) { - DefWithBody::Function(f) => f.source(db).map(|src| src.syntax().cloned()), - DefWithBody::Const(c) => c.source(db).map(|src| src.syntax().cloned()), - DefWithBody::Static(s) => s.source(db).map(|src| src.syntax().cloned()), - DefWithBody::Variant(v) => v.source(db).map(|src| src.syntax().cloned()), + ExpressionStoreOwner::Body(def) => match def { + DefWithBody::Function(f) => f.source(db).map(|src| src.syntax().cloned()), + DefWithBody::Const(c) => c.source(db).map(|src| src.syntax().cloned()), + DefWithBody::Static(s) => s.source(db).map(|src| src.syntax().cloned()), + DefWithBody::EnumVariant(v) => v.source(db).map(|src| src.syntax().cloned()), + }, + ExpressionStoreOwner::Signature(def) => match def { + hir::GenericDef::Function(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Adt(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Trait(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::TypeAlias(it) => { + it.source(db).map(|src| src.syntax().cloned()) + } + hir::GenericDef::Impl(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Const(it) => it.source(db).map(|src| src.syntax().cloned()), + hir::GenericDef::Static(it) => it.source(db).map(|src| src.syntax().cloned()), + }, + ExpressionStoreOwner::VariantFields(it) => { + it.source(db).map(|src| src.syntax().cloned()) + } }; return match def { Some(def) => SearchScope::file_range( @@ -1370,7 +1404,7 @@ fn is_name_ref_in_import(name_ref: &ast::NameRef) -> bool { } fn is_name_ref_in_test(sema: &Semantics<'_, RootDatabase>, name_ref: &ast::NameRef) -> bool { - name_ref.syntax().ancestors().any(|node| match ast::Fn::cast(node) { + sema.ancestors_with_macros(name_ref.syntax().clone()).any(|node| match ast::Fn::cast(node) { Some(it) => sema.to_def(&it).is_some_and(|func| func.is_test(sema.db)), None => false, }) diff --git a/crates/ide-db/src/syntax_helpers/node_ext.rs b/crates/ide-db/src/syntax_helpers/node_ext.rs index 94ecf6a02d..e30b21c139 100644 --- a/crates/ide-db/src/syntax_helpers/node_ext.rs +++ b/crates/ide-db/src/syntax_helpers/node_ext.rs @@ -49,7 +49,7 @@ pub fn is_closure_or_blk_with_modif(expr: &ast::Expr) -> bool { block_expr.modifier(), Some( ast::BlockModifier::Async(_) - | ast::BlockModifier::Try(_) + | ast::BlockModifier::Try { .. } | ast::BlockModifier::Const(_) ) ) @@ -148,7 +148,7 @@ pub fn walk_patterns_in_expr(start: &ast::Expr, cb: &mut dyn FnMut(ast::Pat)) { block_expr.modifier(), Some( ast::BlockModifier::Async(_) - | ast::BlockModifier::Try(_) + | ast::BlockModifier::Try { .. } | ast::BlockModifier::Const(_) ) ) @@ -291,7 +291,7 @@ pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) { match b.modifier() { Some( ast::BlockModifier::Async(_) - | ast::BlockModifier::Try(_) + | ast::BlockModifier::Try { .. } | ast::BlockModifier::Const(_), ) => return cb(expr), diff --git a/crates/ide-db/src/syntax_helpers/suggest_name.rs b/crates/ide-db/src/syntax_helpers/suggest_name.rs index b8b9a7a768..3a785fbe80 100644 --- a/crates/ide-db/src/syntax_helpers/suggest_name.rs +++ b/crates/ide-db/src/syntax_helpers/suggest_name.rs @@ -89,6 +89,12 @@ const USELESS_METHODS: &[&str] = &[ /// /// assert_eq!(generator.suggest_name("b2"), "b2"); /// assert_eq!(generator.suggest_name("b"), "b3"); +/// +/// // Multi-byte UTF-8 identifiers (e.g. CJK) are handled correctly +/// assert_eq!(generator.suggest_name("日本語"), "日本語"); +/// assert_eq!(generator.suggest_name("日本語"), "日本語1"); +/// assert_eq!(generator.suggest_name("données3"), "données3"); +/// assert_eq!(generator.suggest_name("données"), "données4"); /// ``` #[derive(Debug, Default)] pub struct NameGenerator { @@ -206,17 +212,18 @@ impl NameGenerator { expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>, ) -> Option<SmolStr> { + let edition = sema.scope(expr.syntax())?.krate().edition(sema.db); // `from_param` does not benefit from stripping it need the largest // context possible so we check firstmost - if let Some(name) = from_param(expr, sema) { + if let Some(name) = from_param(expr, sema, edition) { return Some(self.suggest_name(&name)); } let mut next_expr = Some(expr.clone()); while let Some(expr) = next_expr { - let name = from_call(&expr) - .or_else(|| from_type(&expr, sema)) - .or_else(|| from_field_name(&expr)); + let name = from_call(&expr, edition) + .or_else(|| from_type(&expr, sema, edition)) + .or_else(|| from_field_name(&expr, edition)); if let Some(name) = name { return Some(self.suggest_name(&name)); } @@ -261,16 +268,20 @@ impl NameGenerator { /// Remove the numeric suffix from the name /// /// # Examples - /// `a1b2c3` -> `a1b2c` + /// `a1b2c3` -> (`a1b2c`, Some(3)) fn split_numeric_suffix(name: &str) -> (&str, Option<usize>) { let pos = name.rfind(|c: char| !c.is_numeric()).expect("Name cannot be empty or all-numeric"); - let (prefix, suffix) = name.split_at(pos + 1); + // `rfind` returns the byte offset of the matched character, which may be + // multi-byte (e.g. CJK identifiers). Use `ceil_char_boundary` to advance + // past the full character to the next valid split point. + let split = name.ceil_char_boundary(pos + 1); + let (prefix, suffix) = name.split_at(split); (prefix, suffix.parse().ok()) } } -fn normalize(name: &str) -> Option<SmolStr> { +fn normalize(name: &str, edition: syntax::Edition) -> Option<SmolStr> { let name = to_lower_snake_case(name).to_smolstr(); if USELESS_NAMES.contains(&name.as_str()) { @@ -281,16 +292,16 @@ fn normalize(name: &str) -> Option<SmolStr> { return None; } - if !is_valid_name(&name) { + if !is_valid_name(&name, edition) { return None; } Some(name) } -fn is_valid_name(name: &str) -> bool { +fn is_valid_name(name: &str, edition: syntax::Edition) -> bool { matches!( - super::LexedStr::single_token(syntax::Edition::CURRENT_FIXME, name), + super::LexedStr::single_token(edition, name), Some((syntax::SyntaxKind::IDENT, _error)) ) } @@ -304,11 +315,11 @@ fn is_useless_method(method: &ast::MethodCallExpr) -> bool { } } -fn from_call(expr: &ast::Expr) -> Option<SmolStr> { - from_func_call(expr).or_else(|| from_method_call(expr)) +fn from_call(expr: &ast::Expr, edition: syntax::Edition) -> Option<SmolStr> { + from_func_call(expr, edition).or_else(|| from_method_call(expr, edition)) } -fn from_func_call(expr: &ast::Expr) -> Option<SmolStr> { +fn from_func_call(expr: &ast::Expr, edition: syntax::Edition) -> Option<SmolStr> { let call = match expr { ast::Expr::CallExpr(call) => call, _ => return None, @@ -318,10 +329,10 @@ fn from_func_call(expr: &ast::Expr) -> Option<SmolStr> { _ => return None, }; let ident = func.path()?.segment()?.name_ref()?.ident_token()?; - normalize(ident.text()) + normalize(ident.text(), edition) } -fn from_method_call(expr: &ast::Expr) -> Option<SmolStr> { +fn from_method_call(expr: &ast::Expr, edition: syntax::Edition) -> Option<SmolStr> { let method = match expr { ast::Expr::MethodCallExpr(call) => call, _ => return None, @@ -340,10 +351,14 @@ fn from_method_call(expr: &ast::Expr) -> Option<SmolStr> { } } - normalize(name) + normalize(name, edition) } -fn from_param(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) -> Option<SmolStr> { +fn from_param( + expr: &ast::Expr, + sema: &Semantics<'_, RootDatabase>, + edition: Edition, +) -> Option<SmolStr> { let arg_list = expr.syntax().parent().and_then(ast::ArgList::cast)?; let args_parent = arg_list.syntax().parent()?; let func = match_ast! { @@ -362,7 +377,7 @@ fn from_param(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) -> Option<Sm let param = func.params().into_iter().nth(idx)?; let pat = sema.source(param)?.value.right()?.pat()?; let name = var_name_from_pat(&pat)?; - normalize(&name.to_smolstr()) + normalize(&name.to_smolstr(), edition) } fn var_name_from_pat(pat: &ast::Pat) -> Option<ast::Name> { @@ -374,10 +389,13 @@ fn var_name_from_pat(pat: &ast::Pat) -> Option<ast::Name> { } } -fn from_type(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) -> Option<SmolStr> { +fn from_type( + expr: &ast::Expr, + sema: &Semantics<'_, RootDatabase>, + edition: Edition, +) -> Option<SmolStr> { let ty = sema.type_of_expr(expr)?.adjusted(); let ty = ty.remove_ref().unwrap_or(ty); - let edition = sema.scope(expr.syntax())?.krate().edition(sema.db); name_of_type(&ty, sema.db, edition) } @@ -417,7 +435,7 @@ fn name_of_type<'db>( } else { return None; }; - normalize(&name) + normalize(&name, edition) } fn sequence_name<'db>( @@ -450,13 +468,13 @@ fn trait_name(trait_: &hir::Trait, db: &RootDatabase, edition: Edition) -> Optio Some(name) } -fn from_field_name(expr: &ast::Expr) -> Option<SmolStr> { +fn from_field_name(expr: &ast::Expr, edition: syntax::Edition) -> Option<SmolStr> { let field = match expr { ast::Expr::FieldExpr(field) => field, _ => return None, }; let ident = field.name_ref()?.ident_token()?; - normalize(ident.text()) + normalize(ident.text(), edition) } #[cfg(test)] diff --git a/crates/ide-db/src/test_data/test_doc_alias.txt b/crates/ide-db/src/test_data/test_doc_alias.txt index 0c28c312f8..fc98ebb069 100644 --- a/crates/ide-db/src/test_data/test_doc_alias.txt +++ b/crates/ide-db/src/test_data/test_doc_alias.txt @@ -2,7 +2,7 @@ ( Module { id: ModuleIdLt { - [salsa id]: Id(3800), + [salsa id]: Id(3400), }, }, [ diff --git a/crates/ide-db/src/test_data/test_symbol_index_collection.txt b/crates/ide-db/src/test_data/test_symbol_index_collection.txt index 4b588572d3..02a023038a 100644 --- a/crates/ide-db/src/test_data/test_symbol_index_collection.txt +++ b/crates/ide-db/src/test_data/test_symbol_index_collection.txt @@ -2,14 +2,14 @@ ( Module { id: ModuleIdLt { - [salsa id]: Id(3800), + [salsa id]: Id(3400), }, }, [ FileSymbol { name: "A", - def: Variant( - Variant { + def: EnumVariant( + EnumVariant { id: EnumVariantId( 7c00, ), @@ -80,8 +80,8 @@ }, FileSymbol { name: "B", - def: Variant( - Variant { + def: EnumVariant( + EnumVariant { id: EnumVariantId( 7c01, ), @@ -671,7 +671,7 @@ def: Module( Module { id: ModuleIdLt { - [salsa id]: Id(3801), + [salsa id]: Id(3401), }, }, ), @@ -706,7 +706,7 @@ def: Module( Module { id: ModuleIdLt { - [salsa id]: Id(3802), + [salsa id]: Id(3402), }, }, ), @@ -998,7 +998,7 @@ ( Module { id: ModuleIdLt { - [salsa id]: Id(3801), + [salsa id]: Id(3401), }, }, [ @@ -1044,7 +1044,7 @@ ( Module { id: ModuleIdLt { - [salsa id]: Id(3802), + [salsa id]: Id(3402), }, }, [ diff --git a/crates/ide-db/src/test_data/test_symbols_exclude_imports.txt b/crates/ide-db/src/test_data/test_symbols_exclude_imports.txt index 87f0c7d9a8..aff1d56c56 100644 --- a/crates/ide-db/src/test_data/test_symbols_exclude_imports.txt +++ b/crates/ide-db/src/test_data/test_symbols_exclude_imports.txt @@ -5,7 +5,7 @@ Struct( Struct { id: StructId( - 3c00, + 4000, ), }, ), diff --git a/crates/ide-db/src/test_data/test_symbols_with_imports.txt b/crates/ide-db/src/test_data/test_symbols_with_imports.txt index e96aa889ba..bf5d81cfb1 100644 --- a/crates/ide-db/src/test_data/test_symbols_with_imports.txt +++ b/crates/ide-db/src/test_data/test_symbols_with_imports.txt @@ -5,7 +5,7 @@ Struct( Struct { id: StructId( - 3c00, + 4000, ), }, ), @@ -42,7 +42,7 @@ Struct( Struct { id: StructId( - 3c00, + 4000, ), }, ), diff --git a/crates/ide-db/src/traits.rs b/crates/ide-db/src/traits.rs index 7200e7fbe5..60bdc2d82c 100644 --- a/crates/ide-db/src/traits.rs +++ b/crates/ide-db/src/traits.rs @@ -130,7 +130,8 @@ mod tests { database.apply_change(change_fixture.change); let (file_id, range_or_offset) = change_fixture.file_position.expect("expected a marker ($0)"); - let file_id = EditionedFileId::from_span_guess_origin(&database, file_id); + + let file_id = EditionedFileId::from_span_file_id(&database, file_id); let offset = range_or_offset.expect_offset(); (database, FilePosition { file_id, offset }) } diff --git a/crates/ide-db/src/ty_filter.rs b/crates/ide-db/src/ty_filter.rs index 095256d829..b5c43b3b36 100644 --- a/crates/ide-db/src/ty_filter.rs +++ b/crates/ide-db/src/ty_filter.rs @@ -19,8 +19,16 @@ pub enum TryEnum { impl TryEnum { const ALL: [TryEnum; 2] = [TryEnum::Option, TryEnum::Result]; - /// Returns `Some(..)` if the provided type is an enum that implements `std::ops::Try`. + /// Returns `Some(..)` if the provided `ty.strip_references()` is an enum that implements `std::ops::Try`. pub fn from_ty(sema: &Semantics<'_, RootDatabase>, ty: &hir::Type<'_>) -> Option<TryEnum> { + Self::from_ty_without_strip(sema, &ty.strip_references()) + } + + /// Returns `Some(..)` if the provided type is an enum that implements `std::ops::Try`. + pub fn from_ty_without_strip( + sema: &Semantics<'_, RootDatabase>, + ty: &hir::Type<'_>, + ) -> Option<TryEnum> { let enum_ = match ty.as_adt() { Some(hir::Adt::Enum(it)) => it, _ => return None, diff --git a/crates/ide-diagnostics/src/handlers/incorrect_case.rs b/crates/ide-diagnostics/src/handlers/incorrect_case.rs index c47449f259..5410f8b58a 100644 --- a/crates/ide-diagnostics/src/handlers/incorrect_case.rs +++ b/crates/ide-diagnostics/src/handlers/incorrect_case.rs @@ -263,6 +263,48 @@ struct SomeStruct { SomeField: u8 } } #[test] + fn incorrect_union_names() { + check_diagnostics( + r#" +union non_camel_case_name { field: u8 } + // ^^^^^^^^^^^^^^^^^^^ 💡 warn: Union `non_camel_case_name` should have UpperCamelCase name, e.g. `NonCamelCaseName` + +union SCREAMING_CASE { field: u8 } + // ^^^^^^^^^^^^^^ 💡 warn: Union `SCREAMING_CASE` should have UpperCamelCase name, e.g. `ScreamingCase` +"#, + ); + } + + #[test] + fn no_diagnostic_for_camel_cased_acronyms_in_union_name() { + check_diagnostics( + r#" +union AABB { field: u8 } +"#, + ); + } + + #[test] + fn no_diagnostic_for_repr_c_union() { + check_diagnostics( + r#" +#[repr(C)] +union my_union { field: u8 } +"#, + ); + } + + #[test] + fn incorrect_union_field() { + check_diagnostics( + r#" +union SomeUnion { SomeField: u8 } + // ^^^^^^^^^ 💡 warn: Field `SomeField` should have snake_case name, e.g. `some_field` +"#, + ); + } + + #[test] fn incorrect_enum_names() { check_diagnostics( r#" diff --git a/crates/ide-diagnostics/src/handlers/incorrect_generics_len.rs b/crates/ide-diagnostics/src/handlers/incorrect_generics_len.rs index 894e044642..25220704e0 100644 --- a/crates/ide-diagnostics/src/handlers/incorrect_generics_len.rs +++ b/crates/ide-diagnostics/src/handlers/incorrect_generics_len.rs @@ -224,4 +224,21 @@ fn main() { "#, ); } + + #[test] + fn type_as_trait_does_not_count() { + check_diagnostics( + r#" +pub trait Lock<T> { + fn new(b: T) -> Self; +} +pub trait LockChoice { + type Lock<T>: Lock<T>; +} +fn f<L: LockChoice>() { + <L as LockChoice>::Lock::new(()); +} + "#, + ); + } } diff --git a/crates/ide-diagnostics/src/handlers/invalid_cast.rs b/crates/ide-diagnostics/src/handlers/invalid_cast.rs index 7479f8147d..405d8df685 100644 --- a/crates/ide-diagnostics/src/handlers/invalid_cast.rs +++ b/crates/ide-diagnostics/src/handlers/invalid_cast.rs @@ -517,11 +517,13 @@ trait Trait<'a> {} fn add_auto<'a>(x: *mut dyn Trait<'a>) -> *mut (dyn Trait<'a> + Send) { x as _ + //^^^^^^ error: cannot add auto trait to dyn bound via pointer cast } // (to test diagnostic list formatting) fn add_multiple_auto<'a>(x: *mut dyn Trait<'a>) -> *mut (dyn Trait<'a> + Send + Sync + Unpin) { x as _ + //^^^^^^ error: cannot add auto trait to dyn bound via pointer cast } "#, ); diff --git a/crates/ide-diagnostics/src/handlers/missing_fields.rs b/crates/ide-diagnostics/src/handlers/missing_fields.rs index d5f25dfaf2..050d5477f6 100644 --- a/crates/ide-diagnostics/src/handlers/missing_fields.rs +++ b/crates/ide-diagnostics/src/handlers/missing_fields.rs @@ -588,14 +588,14 @@ fn test_fn() { fn test_fill_struct_fields_default() { check_fix( r#" -//- minicore: default, option +//- minicore: default, option, slice struct TestWithDefault(usize); impl Default for TestWithDefault { pub fn default() -> Self { Self(0) } } -struct TestStruct { one: i32, two: TestWithDefault } +struct TestStruct { one: i32, two: TestWithDefault, r: &'static [i32] } fn test_fn() { let s = TestStruct{ $0 }; @@ -608,10 +608,10 @@ impl Default for TestWithDefault { Self(0) } } -struct TestStruct { one: i32, two: TestWithDefault } +struct TestStruct { one: i32, two: TestWithDefault, r: &'static [i32] } fn test_fn() { - let s = TestStruct{ one: 0, two: TestWithDefault::default() }; + let s = TestStruct{ one: 0, two: TestWithDefault::default(), r: <&'static [i32]>::default() }; } ", ); diff --git a/crates/ide-diagnostics/src/handlers/missing_lifetime.rs b/crates/ide-diagnostics/src/handlers/missing_lifetime.rs index 5cb710b66b..b10cdaa14e 100644 --- a/crates/ide-diagnostics/src/handlers/missing_lifetime.rs +++ b/crates/ide-diagnostics/src/handlers/missing_lifetime.rs @@ -115,4 +115,20 @@ struct A<'a, T> { "#, ); } + + // FIXME: Ideally, should emit generic default forbidden as well + #[test] + fn regression_16280() { + check_diagnostics( + r#" +trait Traitor<'a, const M: Traitor = Traitor> { + fn crash<const Traitor: Traitor = Traitor, const M: Traitor = Traitor>(&self) -> Traitor { + // ^^^^^^^ error: missing lifetime specifier + // ^^^^^^^ error: missing lifetime specifier + Traitor + } +} +"#, + ); + } } diff --git a/crates/ide-diagnostics/src/handlers/no_such_field.rs b/crates/ide-diagnostics/src/handlers/no_such_field.rs index bcfe3a8aa5..944622bb1d 100644 --- a/crates/ide-diagnostics/src/handlers/no_such_field.rs +++ b/crates/ide-diagnostics/src/handlers/no_such_field.rs @@ -64,20 +64,20 @@ fn missing_record_expr_field_fixes( let module; let def_file_id; let record_fields = match def_id { - hir::VariantDef::Struct(s) => { + hir::Variant::Struct(s) => { module = s.module(sema.db); let source = s.source(sema.db)?; def_file_id = source.file_id; let fields = source.value.field_list()?; record_field_list(fields)? } - hir::VariantDef::Union(u) => { + hir::Variant::Union(u) => { module = u.module(sema.db); let source = u.source(sema.db)?; def_file_id = source.file_id; source.value.record_field_list()? } - hir::VariantDef::Variant(e) => { + hir::Variant::EnumVariant(e) => { module = e.module(sema.db); let source = e.source(sema.db)?; def_file_id = source.file_id; @@ -97,25 +97,37 @@ fn missing_record_expr_field_fixes( make::ty(&new_field_type.display_source_code(sema.db, module.into(), true).ok()?), ); - let last_field = record_fields.fields().last()?; - let last_field_syntax = last_field.syntax(); - let indent = IndentLevel::from_node(last_field_syntax); + let (indent, offset, postfix, needs_comma) = + if let Some(last_field) = record_fields.fields().last() { + let indent = IndentLevel::from_node(last_field.syntax()); + let offset = last_field.syntax().text_range().end(); + let needs_comma = !last_field.to_string().ends_with(','); + (indent, offset, String::new(), needs_comma) + } else { + let indent = IndentLevel::from_node(record_fields.syntax()); + let offset = record_fields.l_curly_token()?.text_range().end(); + let postfix = if record_fields.syntax().text().contains_char('\n') { + ",".into() + } else { + format!(",\n{indent}") + }; + (indent + 1, offset, postfix, false) + }; let mut new_field = new_field.to_string(); // FIXME: check submodule instead of FileId - if usage_file_id != def_file_id && !matches!(def_id, hir::VariantDef::Variant(_)) { + if usage_file_id != def_file_id && !matches!(def_id, hir::Variant::EnumVariant(_)) { new_field = format!("pub(crate) {new_field}"); } - new_field = format!("\n{indent}{new_field}"); + new_field = format!("\n{indent}{new_field}{postfix}"); - let needs_comma = !last_field_syntax.to_string().ends_with(','); if needs_comma { new_field = format!(",{new_field}"); } let source_change = SourceChange::from_text_edit( def_file_id.file_id(sema.db), - TextEdit::insert(last_field_syntax.text_range().end(), new_field), + TextEdit::insert(offset, new_field), ); return Some(vec![fix( @@ -335,6 +347,44 @@ struct Foo { } #[test] + fn test_add_field_from_usage_with_empty_struct() { + check_fix( + r" +fn main() { + Foo { bar$0: false }; +} +struct Foo {} +", + r" +fn main() { + Foo { bar: false }; +} +struct Foo { + bar: bool, +} +", + ); + + check_fix( + r" +fn main() { + Foo { bar$0: false }; +} +struct Foo { +} +", + r" +fn main() { + Foo { bar: false }; +} +struct Foo { + bar: bool, +} +", + ); + } + + #[test] fn test_add_field_in_other_file_from_usage() { check_fix( r#" diff --git a/crates/ide-diagnostics/src/handlers/non_exhaustive_let.rs b/crates/ide-diagnostics/src/handlers/non_exhaustive_let.rs index c86ecd2f03..bc10e82854 100644 --- a/crates/ide-diagnostics/src/handlers/non_exhaustive_let.rs +++ b/crates/ide-diagnostics/src/handlers/non_exhaustive_let.rs @@ -1,4 +1,11 @@ -use crate::{Diagnostic, DiagnosticCode, DiagnosticsContext}; +use either::Either; +use hir::Semantics; +use ide_db::text_edit::TextEdit; +use ide_db::ty_filter::TryEnum; +use ide_db::{RootDatabase, source_change::SourceChange}; +use syntax::{AstNode, ast}; + +use crate::{Assist, Diagnostic, DiagnosticCode, DiagnosticsContext, fix}; // Diagnostic: non-exhaustive-let // @@ -15,11 +22,74 @@ pub(crate) fn non_exhaustive_let( d.pat.map(Into::into), ) .stable() + .with_fixes(fixes(&ctx.sema, d)) +} + +fn fixes(sema: &Semantics<'_, RootDatabase>, d: &hir::NonExhaustiveLet) -> Option<Vec<Assist>> { + let root = sema.parse_or_expand(d.pat.file_id); + let pat = d.pat.value.to_node(&root); + let let_stmt = ast::LetStmt::cast(pat.syntax().parent()?)?; + let early_node = + sema.ancestors_with_macros(let_stmt.syntax().clone()).find_map(AstNode::cast)?; + let early_text = early_text(sema, &early_node); + + if let_stmt.let_else().is_some() { + return None; + } + let hir::FileRangeWrapper { file_id, range } = sema.original_range_opt(let_stmt.syntax())?; + let insert_offset = if let Some(semicolon) = let_stmt.semicolon_token() + && let Some(token) = sema.parse(file_id).syntax().token_at_offset(range.end()).left_biased() + && token.kind() == semicolon.kind() + { + token.text_range().start() + } else { + range.end() + }; + let semicolon = if let_stmt.semicolon_token().is_none() { ";" } else { "" }; + let else_block = format!(" else {{ {early_text} }}{semicolon}"); + let file_id = file_id.file_id(sema.db); + + let source_change = + SourceChange::from_text_edit(file_id, TextEdit::insert(insert_offset, else_block)); + let target = sema.original_range(let_stmt.syntax()).range; + Some(vec![fix("add_let_else_block", "Add let-else block", source_change, target)]) +} + +fn early_text( + sema: &Semantics<'_, RootDatabase>, + early_node: &Either<ast::AnyHasLoopBody, Either<ast::Fn, ast::ClosureExpr>>, +) -> &'static str { + match early_node { + Either::Left(_any_loop) => "continue", + Either::Right(Either::Left(fn_)) => sema + .to_def(fn_) + .map(|fn_def| fn_def.ret_type(sema.db)) + .map(|ty| return_text(&ty, sema)) + .unwrap_or("return"), + Either::Right(Either::Right(closure)) => closure + .body() + .and_then(|expr| sema.type_of_expr(&expr)) + .map(|ty| return_text(&ty.adjusted(), sema)) + .unwrap_or("return"), + } +} + +fn return_text(ty: &hir::Type<'_>, sema: &Semantics<'_, RootDatabase>) -> &'static str { + if ty.is_unit() { + "return" + } else if let Some(try_enum) = TryEnum::from_ty(sema, ty) { + match try_enum { + TryEnum::Option => "return None", + TryEnum::Result => "return Err($0)", + } + } else { + "return $0" + } } #[cfg(test)] mod tests { - use crate::tests::check_diagnostics; + use crate::tests::{check_diagnostics, check_fix}; #[test] fn option_nonexhaustive() { @@ -28,7 +98,7 @@ mod tests { //- minicore: option fn main() { let None = Some(5); - //^^^^ error: non-exhaustive pattern: `Some(_)` not covered + //^^^^ 💡 error: non-exhaustive pattern: `Some(_)` not covered } "#, ); @@ -54,7 +124,7 @@ fn main() { fn main() { '_a: { let None = Some(5); - //^^^^ error: non-exhaustive pattern: `Some(_)` not covered + //^^^^ 💡 error: non-exhaustive pattern: `Some(_)` not covered } } "#, @@ -66,7 +136,7 @@ fn main() { fn main() { let _ = async { let None = Some(5); - //^^^^ error: non-exhaustive pattern: `Some(_)` not covered + //^^^^ 💡 error: non-exhaustive pattern: `Some(_)` not covered }; } "#, @@ -78,7 +148,7 @@ fn main() { fn main() { unsafe { let None = Some(5); - //^^^^ error: non-exhaustive pattern: `Some(_)` not covered + //^^^^ 💡 error: non-exhaustive pattern: `Some(_)` not covered } } "#, @@ -101,7 +171,7 @@ fn test(x: Result<i32, !>) { //- minicore: result fn test(x: Result<i32, &'static !>) { let Ok(_y) = x; - //^^^^^^ error: non-exhaustive pattern: `Err(_)` not covered + //^^^^^^ 💡 error: non-exhaustive pattern: `Err(_)` not covered } "#, ); @@ -133,6 +203,136 @@ fn foo(v: Enum<()>) { } #[test] + fn fix_return_in_loop() { + check_fix( + r#" +//- minicore: option +fn foo() { + while cond { + let None$0 = Some(5); + } +} +"#, + r#" +fn foo() { + while cond { + let None = Some(5) else { continue }; + } +} +"#, + ); + } + + #[test] + fn fix_return_in_fn() { + check_fix( + r#" +//- minicore: option +fn foo() { + let None$0 = Some(5); +} +"#, + r#" +fn foo() { + let None = Some(5) else { return }; +} +"#, + ); + } + + #[test] + fn fix_return_in_macro_expanded() { + check_fix( + r#" +//- minicore: option +macro_rules! identity { ($($t:tt)*) => { $($t)* }; } +fn foo() { + identity! { + let None$0 = Some(5); + } +} +"#, + r#" +macro_rules! identity { ($($t:tt)*) => { $($t)* }; } +fn foo() { + identity! { + let None = Some(5) else { return }; + } +} +"#, + ); + } + + #[test] + fn fix_return_in_incomplete_let() { + check_fix( + r#" +//- minicore: option +fn foo() { + let None$0 = Some(5) +} +"#, + r#" +fn foo() { + let None = Some(5) else { return }; +} +"#, + ); + } + + #[test] + fn fix_return_in_closure() { + check_fix( + r#" +//- minicore: option +fn foo() -> Option<()> { + let _f = || { + let None$0 = Some(5); + }; +} +"#, + r#" +fn foo() -> Option<()> { + let _f = || { + let None = Some(5) else { return }; + }; +} +"#, + ); + } + + #[test] + fn fix_return_try_in_fn() { + check_fix( + r#" +//- minicore: option +fn foo() -> Option<()> { + let None$0 = Some(5); +} +"#, + r#" +fn foo() -> Option<()> { + let None = Some(5) else { return None }; +} +"#, + ); + + check_fix( + r#" +//- minicore: option, result +fn foo() -> Result<(), i32> { + let None$0 = Some(5); +} +"#, + r#" +fn foo() -> Result<(), i32> { + let None = Some(5) else { return Err($0) }; +} +"#, + ); + } + + #[test] fn regression_20259() { check_diagnostics( r#" diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs index 7dc5b5b45e..04f48ae3db 100644 --- a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs +++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs @@ -48,7 +48,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveUnnecessaryElse) -> Option<Vec< let mut indent = IndentLevel::from_node(if_expr.syntax()); let has_parent_if_expr = if_expr.syntax().parent().and_then(ast::IfExpr::cast).is_some(); if has_parent_if_expr { - indent = indent + 1; + indent += 1; } let else_replacement = match if_expr.else_branch()? { ast::ElseBranch::Block(block) => block diff --git a/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs b/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs index cb3aac3717..f4054610f2 100644 --- a/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs +++ b/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs @@ -6,6 +6,7 @@ use ide_db::{ source_change::SourceChangeBuilder, }; use syntax::ToSmolStr; +use syntax::ast::edit::AstNodeEdit; use crate::{Diagnostic, DiagnosticCode, DiagnosticsContext}; @@ -23,6 +24,7 @@ pub(crate) fn trait_impl_redundant_assoc_item( let default_range = d.impl_.syntax_node_ptr().text_range(); let trait_name = d.trait_.name(db).display_no_db(ctx.edition).to_smolstr(); + let indent_level = d.trait_.source(db).map_or(0, |it| it.value.indent_level().0) + 1; let (redundant_item_name, diagnostic_range, redundant_item_def) = match assoc_item { hir::AssocItem::Function(id) => { @@ -30,7 +32,7 @@ pub(crate) fn trait_impl_redundant_assoc_item( ( format!("`fn {redundant_assoc_item_name}`"), function.source(db).map(|it| it.syntax().text_range()).unwrap_or(default_range), - format!("\n {};", function.display(db, ctx.display_target)), + format!("\n{};", function.display(db, ctx.display_target)), ) } hir::AssocItem::Const(id) => { @@ -38,7 +40,7 @@ pub(crate) fn trait_impl_redundant_assoc_item( ( format!("`const {redundant_assoc_item_name}`"), constant.source(db).map(|it| it.syntax().text_range()).unwrap_or(default_range), - format!("\n {};", constant.display(db, ctx.display_target)), + format!("\n{};", constant.display(db, ctx.display_target)), ) } hir::AssocItem::TypeAlias(id) => { @@ -46,10 +48,8 @@ pub(crate) fn trait_impl_redundant_assoc_item( ( format!("`type {redundant_assoc_item_name}`"), type_alias.source(db).map(|it| it.syntax().text_range()).unwrap_or(default_range), - format!( - "\n type {};", - type_alias.name(ctx.sema.db).display_no_db(ctx.edition).to_smolstr() - ), + // FIXME cannot generate generic parameter and bounds + format!("\ntype {};", type_alias.name(ctx.sema.db).display_no_db(ctx.edition)), ) } }; @@ -65,7 +65,7 @@ pub(crate) fn trait_impl_redundant_assoc_item( .with_fixes(quickfix_for_redundant_assoc_item( ctx, d, - redundant_item_def, + stdx::indent_string(&redundant_item_def, indent_level), diagnostic_range, )) } @@ -192,6 +192,89 @@ impl Marker for Foo { } #[test] + fn quickfix_indentations() { + check_fix( + r#" +mod indent { + trait Marker { + fn boo(); + } + struct Foo; + impl Marker for Foo { + fn$0 bar<T: Copy>(_a: i32, _b: T) -> String {} + fn boo() {} + } +} + "#, + r#" +mod indent { + trait Marker { + fn bar<T>(_a: i32, _b: T) -> String + where + T: Copy,; + fn boo(); + } + struct Foo; + impl Marker for Foo { + fn bar<T: Copy>(_a: i32, _b: T) -> String {} + fn boo() {} + } +} + "#, + ); + + check_fix( + r#" +mod indent { + trait Marker { + fn foo () {} + } + struct Foo; + impl Marker for Foo { + const FLAG: bool$0 = false; + } +} + "#, + r#" +mod indent { + trait Marker { + const FLAG: bool; + fn foo () {} + } + struct Foo; + impl Marker for Foo { + const FLAG: bool = false; + } +} + "#, + ); + + check_fix( + r#" +mod indent { + trait Marker { + } + struct Foo; + impl Marker for Foo { + type T = i32;$0 + } +} + "#, + r#" +mod indent { + trait Marker { + type T; + } + struct Foo; + impl Marker for Foo { + type T = i32; + } +} + "#, + ); + } + + #[test] fn quickfix_dont_work() { check_no_fix( r#" diff --git a/crates/ide-diagnostics/src/handlers/typed_hole.rs b/crates/ide-diagnostics/src/handlers/typed_hole.rs index 577c582a20..fd1674e2a4 100644 --- a/crates/ide-diagnostics/src/handlers/typed_hole.rs +++ b/crates/ide-diagnostics/src/handlers/typed_hole.rs @@ -442,4 +442,30 @@ fn rdtscp() -> u64 { }"#, ); } + + #[test] + fn asm_sym_with_macro_expr_fragment() { + // Regression test for issue #21582 + // When `$e:expr` captures a path and is used in `sym $e`, the path gets + // wrapped in parentheses during macro expansion due to invisible delimiters. + // This should not cause false positive typed-hole errors. + check_diagnostics( + r#" +//- minicore: asm +macro_rules! m { + ($e:expr) => { + core::arch::asm!("/*{f}*/", f = sym $e, out("ax") _) + }; +} + +fn generic<T>() {} + +fn main() { + unsafe { + m!(generic::<i32>); + } +} +"#, + ); + } } diff --git a/crates/ide-ssr/src/from_comment.rs b/crates/ide-ssr/src/from_comment.rs index de26879c29..181cc74a51 100644 --- a/crates/ide-ssr/src/from_comment.rs +++ b/crates/ide-ssr/src/from_comment.rs @@ -17,7 +17,7 @@ pub fn ssr_from_comment( frange: FileRange, ) -> Option<(MatchFinder<'_>, TextRange)> { let comment = { - let file_id = EditionedFileId::current_edition_guess_origin(db, frange.file_id); + let file_id = EditionedFileId::current_edition(db, frange.file_id); let file = db.parse(file_id); file.tree().syntax().token_at_offset(frange.range.start()).find_map(ast::Comment::cast) diff --git a/crates/ide/src/doc_links.rs b/crates/ide/src/doc_links.rs index d854c1c450..33bed9501a 100644 --- a/crates/ide/src/doc_links.rs +++ b/crates/ide/src/doc_links.rs @@ -219,7 +219,7 @@ pub(crate) fn resolve_doc_path_for_def( Definition::Crate(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), Definition::Function(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), Definition::Adt(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), - Definition::Variant(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), + Definition::EnumVariant(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), Definition::Const(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), Definition::Static(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), Definition::Trait(it) => it.resolve_doc_path(db, link, ns, is_inner_doc), @@ -678,7 +678,7 @@ fn filename_and_frag_for_def( Definition::Function(f) => { format!("fn.{}.html", f.name(db).as_str()) } - Definition::Variant(ev) => { + Definition::EnumVariant(ev) => { let def = Definition::Adt(ev.parent_enum(db).into()); let (_, file, _) = filename_and_frag_for_def(db, def)?; return Some((def, file, Some(format!("variant.{}", ev.name(db).as_str())))); @@ -703,9 +703,9 @@ fn filename_and_frag_for_def( }, Definition::Field(field) => { let def = match field.parent_def(db) { - hir::VariantDef::Struct(it) => Definition::Adt(it.into()), - hir::VariantDef::Union(it) => Definition::Adt(it.into()), - hir::VariantDef::Variant(it) => Definition::Variant(it), + hir::Variant::Struct(it) => Definition::Adt(it.into()), + hir::Variant::Union(it) => Definition::Adt(it.into()), + hir::Variant::EnumVariant(it) => Definition::EnumVariant(it), }; let (_, file, _) = filename_and_frag_for_def(db, def)?; return Some((def, file, Some(format!("structfield.{}", field.name(db).as_str())))); diff --git a/crates/ide/src/doc_links/tests.rs b/crates/ide/src/doc_links/tests.rs index a61a6c677f..509c55a31e 100644 --- a/crates/ide/src/doc_links/tests.rs +++ b/crates/ide/src/doc_links/tests.rs @@ -113,7 +113,7 @@ fn node_to_def<'db>( ast::Struct(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::Adt(hir::Adt::Struct(def)))), ast::Union(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::Adt(hir::Adt::Union(def)))), ast::Enum(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::Adt(hir::Adt::Enum(def)))), - ast::Variant(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::Variant(def))), + ast::Variant(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::EnumVariant(def))), ast::Trait(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::Trait(def))), ast::Static(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::Static(def))), ast::Const(it) => sema.to_def(&it).map(|def| (def.docs_with_rangemap(sema.db), Definition::Const(def))), diff --git a/crates/ide/src/goto_definition.rs b/crates/ide/src/goto_definition.rs index c0a7438081..3890bcad7f 100644 --- a/crates/ide/src/goto_definition.rs +++ b/crates/ide/src/goto_definition.rs @@ -95,6 +95,13 @@ pub(crate) fn goto_definition( continue; } + let parent = token.value.parent()?; + + if let Some(question_mark_conversion) = goto_question_mark_conversions(sema, &parent) { + navs.extend(def_to_nav(sema, question_mark_conversion.into())); + continue; + } + if let Some(token) = ast::String::cast(token.value.clone()) && let Some(original_token) = ast::String::cast(original_token.clone()) && let Some((analysis, fixture_analysis)) = @@ -113,8 +120,6 @@ pub(crate) fn goto_definition( }); } - let parent = token.value.parent()?; - let token_file_id = token.file_id; if let Some(token) = ast::String::cast(token.value.clone()) && let Some(x) = @@ -149,6 +154,45 @@ pub(crate) fn goto_definition( Some(RangeInfo::new(original_token.text_range(), navs)) } +/// When the `?` operator is used on `Result`, go to the `From` impl if it exists as this provides more value. +fn goto_question_mark_conversions( + sema: &Semantics<'_, RootDatabase>, + node: &SyntaxNode, +) -> Option<hir::Function> { + let node = ast::TryExpr::cast(node.clone())?; + let try_expr_ty = sema.type_of_expr(&node.expr()?)?.adjusted(); + + let fd = FamousDefs(sema, try_expr_ty.krate(sema.db)); + let result_enum = fd.core_result_Result()?.into(); + + let (try_expr_ty_adt, try_expr_ty_args) = try_expr_ty.as_adt_with_args()?; + if try_expr_ty_adt != result_enum { + // FIXME: Support `Poll<Result>`. + return None; + } + let original_err_ty = try_expr_ty_args.get(1)?.clone()?; + + let returned_ty = sema.try_expr_returned_type(&node)?; + let (returned_adt, returned_ty_args) = returned_ty.as_adt_with_args()?; + if returned_adt != result_enum { + return None; + } + let returned_err_ty = returned_ty_args.get(1)?.clone()?; + + if returned_err_ty.could_unify_with_deeply(sema.db, &original_err_ty) { + return None; + } + + let from_trait = fd.core_convert_From()?; + let from_fn = from_trait.function(sema.db, sym::from)?; + sema.resolve_trait_impl_method( + returned_err_ty.clone(), + from_trait, + from_fn, + [returned_err_ty, original_err_ty], + ) +} + // If the token is into(), try_into(), search the definition of From, TryFrom. fn find_definition_for_known_blanket_dual_impls( sema: &Semantics<'_, RootDatabase>, @@ -332,7 +376,7 @@ pub(crate) fn find_fn_or_blocks( ast::BlockExpr(blk) => { match blk.modifier() { Some(ast::BlockModifier::Async(_)) => blk.syntax().clone(), - Some(ast::BlockModifier::Try(_)) if token_kind != T![return] => blk.syntax().clone(), + Some(ast::BlockModifier::Try { .. }) if token_kind != T![return] => blk.syntax().clone(), _ => continue, } }, @@ -404,8 +448,8 @@ fn nav_for_exit_points( let blk_in_file = InFile::new(file_id, blk.into()); Some(expr_to_nav(db, blk_in_file, Some(async_tok))) }, - Some(ast::BlockModifier::Try(_)) if token_kind != T![return] => { - let try_tok = blk.try_token()?.text_range(); + Some(ast::BlockModifier::Try { .. }) if token_kind != T![return] => { + let try_tok = blk.try_block_modifier()?.try_token()?.text_range(); let blk_in_file = InFile::new(file_id, blk.into()); Some(expr_to_nav(db, blk_in_file, Some(try_tok))) }, @@ -4034,4 +4078,25 @@ where "#, ) } + + #[test] + fn question_mark_on_result_goes_to_conversion() { + check( + r#" +//- minicore: try, result, from + +struct Foo; +struct Bar; +impl From<Foo> for Bar { + fn from(_: Foo) -> Bar { Bar } + // ^^^^ +} + +fn foo() -> Result<(), Bar> { + Err(Foo)?$0; + Ok(()) +} + "#, + ); + } } diff --git a/crates/ide/src/highlight_related.rs b/crates/ide/src/highlight_related.rs index fce033382b..c8e01e21ec 100644 --- a/crates/ide/src/highlight_related.rs +++ b/crates/ide/src/highlight_related.rs @@ -473,7 +473,7 @@ pub(crate) fn highlight_exit_points( }, ast::BlockExpr(blk) => match blk.modifier() { Some(ast::BlockModifier::Async(t)) => hl_exit_points(sema, Some(t), blk.into()), - Some(ast::BlockModifier::Try(t)) if token.kind() != T![return] => { + Some(ast::BlockModifier::Try { try_token: t, .. }) if token.kind() != T![return] => { hl_exit_points(sema, Some(t), blk.into()) }, _ => continue, diff --git a/crates/ide/src/hover/render.rs b/crates/ide/src/hover/render.rs index 15ea92d1c6..af78e9a40c 100644 --- a/crates/ide/src/hover/render.rs +++ b/crates/ide/src/hover/render.rs @@ -5,7 +5,7 @@ use either::Either; use hir::{ Adt, AsAssocItem, AsExternAssocItem, CaptureKind, DisplayTarget, DropGlue, DynCompatibilityViolation, HasCrate, HasSource, HirDisplay, Layout, LayoutError, - MethodViolationCode, Name, Semantics, Symbol, Trait, Type, TypeInfo, VariantDef, + MethodViolationCode, Name, Semantics, Symbol, Trait, Type, TypeInfo, Variant, db::ExpandDatabase, }; use ide_db::{ @@ -74,7 +74,7 @@ pub(super) fn try_expr( ast::Fn(fn_) => sema.to_def(&fn_)?.ret_type(sema.db), ast::Item(__) => return None, ast::ClosureExpr(closure) => sema.type_of_expr(&closure.body()?)?.original, - ast::BlockExpr(block_expr) => if matches!(block_expr.modifier(), Some(ast::BlockModifier::Async(_) | ast::BlockModifier::Try(_)| ast::BlockModifier::Const(_))) { + ast::BlockExpr(block_expr) => if matches!(block_expr.modifier(), Some(ast::BlockModifier::Async(_) | ast::BlockModifier::Try { .. } | ast::BlockModifier::Const(_))) { sema.type_of_expr(&block_expr.into())?.original } else { continue; @@ -366,14 +366,14 @@ fn definition_owner_name(db: &RootDatabase, def: Definition, edition: Edition) - let parent_name = parent.name(db); let parent_name = parent_name.display(db, edition).to_string(); return match parent { - VariantDef::Variant(variant) => { + Variant::EnumVariant(variant) => { let enum_name = variant.parent_enum(db).name(db); Some(format!("{}::{parent_name}", enum_name.display(db, edition))) } _ => Some(parent_name), }; } - Definition::Variant(e) => Some(e.parent_enum(db).name(db)), + Definition::EnumVariant(e) => Some(e.parent_enum(db).name(db)), Definition::GenericParam(generic_param) => match generic_param.parent() { hir::GenericDef::Adt(it) => Some(it.name(db)), hir::GenericDef::Trait(it) => Some(it.name(db)), @@ -470,7 +470,7 @@ pub(super) fn definition( Definition::Adt(adt @ (Adt::Struct(_) | Adt::Union(_))) => { adt.display_limited(db, config.max_fields_count, display_target).to_string() } - Definition::Variant(variant) => { + Definition::EnumVariant(variant) => { variant.display_limited(db, config.max_fields_count, display_target).to_string() } Definition::Adt(adt @ Adt::Enum(_)) => { @@ -499,7 +499,7 @@ pub(super) fn definition( }; let docs = def.docs_with_rangemap(db, famous_defs, display_target); let value = || match def { - Definition::Variant(it) => { + Definition::EnumVariant(it) => { if !it.parent_enum(db).is_data_carrying(db) { match it.eval(db) { Ok(it) => { @@ -596,7 +596,7 @@ pub(super) fn definition( |_| { let var_def = it.parent_def(db); match var_def { - hir::VariantDef::Struct(s) => { + hir::Variant::Struct(s) => { Adt::from(s).layout(db).ok().and_then(|layout| layout.field_offset(it)) } _ => None, @@ -627,7 +627,7 @@ pub(super) fn definition( |_| None, |_| None, ), - Definition::Variant(it) => render_memory_layout( + Definition::EnumVariant(it) => render_memory_layout( config.memory_layout, || it.layout(db), |_| None, @@ -710,7 +710,7 @@ pub(super) fn definition( has_dtor: Some(enum_drop_glue > fields_drop_glue), } } - Definition::Variant(variant) => { + Definition::EnumVariant(variant) => { let fields_drop_glue = variant .fields(db) .iter() diff --git a/crates/ide/src/hover/tests.rs b/crates/ide/src/hover/tests.rs index 7900a0dc99..7fbbc576dd 100644 --- a/crates/ide/src/hover/tests.rs +++ b/crates/ide/src/hover/tests.rs @@ -9720,6 +9720,99 @@ fn test_hover_function_with_pat_param() { } #[test] +fn test_hover_function_with_too_long_param() { + check( + r#" +fn fn_$0( + attrs: impl IntoIterator<Item = ast::Attr>, + visibility: Option<ast::Visibility>, + fn_name: ast::Name, + type_params: Option<ast::GenericParamList>, + where_clause: Option<ast::WhereClause>, + params: ast::ParamList, + body: ast::BlockExpr, + ret_type: Option<ast::RetType>, + is_async: bool, + is_const: bool, + is_unsafe: bool, + is_gen: bool, +) -> ast::Fn {} + "#, + expect![[r#" + *fn_* + + ```rust + ra_test_fixture + ``` + + ```rust + fn fn_( + attrs: impl IntoIterator<Item = ast::Attr>, + visibility: Option<ast::Visibility>, + fn_name: ast::Name, + type_params: Option<ast::GenericParamList>, + where_clause: Option<ast::WhereClause>, + params: ast::ParamList, + body: ast::BlockExpr, + ret_type: Option<ast::RetType>, + is_async: bool, + is_const: bool, + is_unsafe: bool, + is_gen: bool + ) -> ast::Fn + ``` + "#]], + ); + + check( + r#" +fn fn_$0( + &self, + attrs: impl IntoIterator<Item = ast::Attr>, + visibility: Option<ast::Visibility>, + fn_name: ast::Name, + type_params: Option<ast::GenericParamList>, + where_clause: Option<ast::WhereClause>, + params: ast::ParamList, + body: ast::BlockExpr, + ret_type: Option<ast::RetType>, + is_async: bool, + is_const: bool, + is_unsafe: bool, + is_gen: bool, + ... +) -> ast::Fn {} + "#, + expect![[r#" + *fn_* + + ```rust + ra_test_fixture + ``` + + ```rust + fn fn_( + &self, + attrs: impl IntoIterator<Item = ast::Attr>, + visibility: Option<ast::Visibility>, + fn_name: ast::Name, + type_params: Option<ast::GenericParamList>, + where_clause: Option<ast::WhereClause>, + params: ast::ParamList, + body: ast::BlockExpr, + ret_type: Option<ast::RetType>, + is_async: bool, + is_const: bool, + is_unsafe: bool, + is_gen: bool, + ... + ) -> ast::Fn + ``` + "#]], + ); +} + +#[test] fn hover_path_inside_block_scope() { check( r#" diff --git a/crates/ide/src/inlay_hints/bind_pat.rs b/crates/ide/src/inlay_hints/bind_pat.rs index c74e3104c1..caf7cc714d 100644 --- a/crates/ide/src/inlay_hints/bind_pat.rs +++ b/crates/ide/src/inlay_hints/bind_pat.rs @@ -1382,4 +1382,21 @@ fn f<'a>() { "#]], ); } + + #[test] + fn ref_multi_trait_impl_trait() { + check_with_config( + InlayHintsConfig { type_hints: true, ..DISABLED_CONFIG }, + r#" +//- minicore: sized +trait Eq {} +trait Ord {} + +fn foo(argument: &(impl Eq + Ord)) { + let x = argument; + // ^ &(impl Eq + Ord) +} + "#, + ); + } } diff --git a/crates/ide/src/inlay_hints/discriminant.rs b/crates/ide/src/inlay_hints/discriminant.rs index 5b9267126f..e845faec56 100644 --- a/crates/ide/src/inlay_hints/discriminant.rs +++ b/crates/ide/src/inlay_hints/discriminant.rs @@ -46,7 +46,7 @@ fn variant_hints( enum_: &ast::Enum, variant: &ast::Variant, ) -> Option<()> { - if variant.expr().is_some() { + if variant.const_arg().is_some() { return None; } diff --git a/crates/ide/src/inlay_hints/implicit_drop.rs b/crates/ide/src/inlay_hints/implicit_drop.rs index e5e4c899ec..3af529e8c5 100644 --- a/crates/ide/src/inlay_hints/implicit_drop.rs +++ b/crates/ide/src/inlay_hints/implicit_drop.rs @@ -7,7 +7,7 @@ //! ``` use hir::{ DefWithBody, - db::{DefDatabase as _, HirDatabase as _}, + db::HirDatabase as _, mir::{MirSpan, TerminatorKind}, }; use ide_db::{FileRange, famous_defs::FamousDefs}; @@ -35,7 +35,7 @@ pub(super) fn hints( let def: DefWithBody = def.into(); let def = def.try_into().ok()?; - let (hir, source_map) = sema.db.body_with_source_map(def); + let (hir, source_map) = hir::Body::with_source_map(sema.db, def); let mir = sema.db.mir_body(def).ok()?; diff --git a/crates/ide/src/inlay_hints/param_name.rs b/crates/ide/src/inlay_hints/param_name.rs index f1e62a5ab8..08588bbed0 100644 --- a/crates/ide/src/inlay_hints/param_name.rs +++ b/crates/ide/src/inlay_hints/param_name.rs @@ -374,7 +374,7 @@ fn is_adt_constructor_similar_to_param_name( hir::PathResolution::Def(hir::ModuleDef::Adt(_)) => { Some(to_lower_snake_case(&path.segment()?.name_ref()?.text()) == param_name) } - hir::PathResolution::Def(hir::ModuleDef::Function(_) | hir::ModuleDef::Variant(_)) => { + hir::PathResolution::Def(hir::ModuleDef::Function(_) | hir::ModuleDef::EnumVariant(_)) => { if to_lower_snake_case(&path.segment()?.name_ref()?.text()) == param_name { return Some(true); } diff --git a/crates/ide/src/lib.rs b/crates/ide/src/lib.rs index 930eaf2262..81a771fec8 100644 --- a/crates/ide/src/lib.rs +++ b/crates/ide/src/lib.rs @@ -339,8 +339,7 @@ impl Analysis { pub fn parse(&self, file_id: FileId) -> Cancellable<SourceFile> { // FIXME edition self.with_db(|db| { - let editioned_file_id_wrapper = - EditionedFileId::current_edition_guess_origin(&self.db, file_id); + let editioned_file_id_wrapper = EditionedFileId::current_edition(&self.db, file_id); db.parse(editioned_file_id_wrapper).tree() }) @@ -369,7 +368,7 @@ impl Analysis { /// supported). pub fn matching_brace(&self, position: FilePosition) -> Cancellable<Option<TextSize>> { self.with_db(|db| { - let file_id = EditionedFileId::current_edition_guess_origin(&self.db, position.file_id); + let file_id = EditionedFileId::current_edition(&self.db, position.file_id); let parse = db.parse(file_id); let file = parse.tree(); matching_brace::matching_brace(&file, position.offset) @@ -430,7 +429,7 @@ impl Analysis { pub fn join_lines(&self, config: &JoinLinesConfig, frange: FileRange) -> Cancellable<TextEdit> { self.with_db(|db| { let editioned_file_id_wrapper = - EditionedFileId::current_edition_guess_origin(&self.db, frange.file_id); + EditionedFileId::current_edition(&self.db, frange.file_id); let parse = db.parse(editioned_file_id_wrapper); join_lines::join_lines(config, &parse.tree(), frange.range) }) @@ -471,8 +470,7 @@ impl Analysis { ) -> Cancellable<Vec<StructureNode>> { // FIXME: Edition self.with_db(|db| { - let editioned_file_id_wrapper = - EditionedFileId::current_edition_guess_origin(&self.db, file_id); + let editioned_file_id_wrapper = EditionedFileId::current_edition(&self.db, file_id); let source_file = db.parse(editioned_file_id_wrapper).tree(); file_structure::file_structure(&source_file, config) }) @@ -503,8 +501,7 @@ impl Analysis { /// Returns the set of folding ranges. pub fn folding_ranges(&self, file_id: FileId) -> Cancellable<Vec<Fold>> { self.with_db(|db| { - let editioned_file_id_wrapper = - EditionedFileId::current_edition_guess_origin(&self.db, file_id); + let editioned_file_id_wrapper = EditionedFileId::current_edition(&self.db, file_id); folding_ranges::folding_ranges(&db.parse(editioned_file_id_wrapper).tree()) }) diff --git a/crates/ide/src/moniker.rs b/crates/ide/src/moniker.rs index 1c1389ca7a..335e1b5b13 100644 --- a/crates/ide/src/moniker.rs +++ b/crates/ide/src/moniker.rs @@ -205,7 +205,7 @@ pub(crate) fn def_to_kind(db: &RootDatabase, def: Definition) -> SymbolInformati Definition::Adt(Adt::Struct(..)) => Struct, Definition::Adt(Adt::Union(..)) => Union, Definition::Adt(Adt::Enum(..)) => Enum, - Definition::Variant(..) => EnumMember, + Definition::EnumVariant(..) => EnumMember, Definition::Const(..) => Constant, Definition::Static(..) => StaticVariable, Definition::Trait(..) => Trait, diff --git a/crates/ide/src/navigation_target.rs b/crates/ide/src/navigation_target.rs index 185df92e2d..92020321f4 100644 --- a/crates/ide/src/navigation_target.rs +++ b/crates/ide/src/navigation_target.rs @@ -276,7 +276,7 @@ impl<'db> TryToNav for FileSymbol<'db> { Some(it.display(db, display_target).to_string()) } hir::ModuleDef::Adt(it) => Some(it.display(db, display_target).to_string()), - hir::ModuleDef::Variant(it) => { + hir::ModuleDef::EnumVariant(it) => { Some(it.display(db, display_target).to_string()) } hir::ModuleDef::Const(it) => { @@ -319,7 +319,7 @@ impl TryToNav for Definition { Definition::GenericParam(it) => it.try_to_nav(sema), Definition::Function(it) => it.try_to_nav(sema), Definition::Adt(it) => it.try_to_nav(sema), - Definition::Variant(it) => it.try_to_nav(sema), + Definition::EnumVariant(it) => it.try_to_nav(sema), Definition::Const(it) => it.try_to_nav(sema), Definition::Static(it) => it.try_to_nav(sema), Definition::Trait(it) => it.try_to_nav(sema), @@ -347,7 +347,7 @@ impl TryToNav for hir::ModuleDef { hir::ModuleDef::Module(it) => Some(it.to_nav(sema.db)), hir::ModuleDef::Function(it) => it.try_to_nav(sema), hir::ModuleDef::Adt(it) => it.try_to_nav(sema), - hir::ModuleDef::Variant(it) => it.try_to_nav(sema), + hir::ModuleDef::EnumVariant(it) => it.try_to_nav(sema), hir::ModuleDef::Const(it) => it.try_to_nav(sema), hir::ModuleDef::Static(it) => it.try_to_nav(sema), hir::ModuleDef::Trait(it) => it.try_to_nav(sema), @@ -406,7 +406,7 @@ impl ToNavFromAst for hir::Enum { container_name(db, self) } } -impl ToNavFromAst for hir::Variant { +impl ToNavFromAst for hir::EnumVariant { const KIND: SymbolKind = SymbolKind::Variant; } impl ToNavFromAst for hir::Union { diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs index 5443021988..9392651c17 100644 --- a/crates/ide/src/references.rs +++ b/crates/ide/src/references.rs @@ -299,7 +299,7 @@ fn retain_adt_literal_usages( }); usages.references.retain(|_, it| !it.is_empty()); } - Definition::Adt(_) | Definition::Variant(_) => { + Definition::Adt(_) | Definition::EnumVariant(_) => { refs.for_each(|it| { it.retain(|reference| reference.name.as_name_ref().is_some_and(is_lit_name_ref)) }); @@ -377,7 +377,7 @@ fn is_enum_lit_name_ref( let path_is_variant_of_enum = |path: ast::Path| { matches!( sema.resolve_path(&path), - Some(PathResolution::Def(hir::ModuleDef::Variant(variant))) + Some(PathResolution::Def(hir::ModuleDef::EnumVariant(variant))) if variant.parent_enum(sema.db) == enum_ ) }; @@ -513,26 +513,32 @@ fn test() { } #[test] - fn test_access() { + fn exclude_tests_macro_refs() { check( r#" -struct S { f$0: u32 } +macro_rules! my_macro { + ($e:expr) => { $e }; +} + +fn foo$0() -> i32 { 42 } + +fn bar() { + foo(); +} #[test] -fn test() { - let mut x = S { f: 92 }; - x.f = 92; +fn t2() { + my_macro!(foo()); } "#, expect![[r#" - f Field FileId(0) 11..17 11..12 + foo Function FileId(0) 52..74 55..58 - FileId(0) 61..62 read test - FileId(0) 76..77 write test + FileId(0) 91..94 + FileId(0) 133..136 test "#]], ); } - #[test] fn test_struct_literal_after_space() { check( @@ -1145,10 +1151,7 @@ pub(super) struct Foo$0 { check_with_scope( code, Some(&mut |db| { - SearchScope::single_file(EditionedFileId::current_edition_guess_origin( - db, - FileId::from_raw(2), - )) + SearchScope::single_file(EditionedFileId::current_edition(db, FileId::from_raw(2))) }), expect![[r#" quux Function FileId(0) 19..35 26..30 @@ -2073,6 +2076,7 @@ fn func() {} expect![[r#" identity Attribute FileId(1) 1..107 32..40 + FileId(0) 17..25 import FileId(0) 43..51 "#]], ); @@ -2103,6 +2107,7 @@ mirror$0! {} expect![[r#" mirror ProcMacro FileId(1) 1..77 22..28 + FileId(0) 17..23 import FileId(0) 26..32 "#]], ) diff --git a/crates/ide/src/rename.rs b/crates/ide/src/rename.rs index ae19e77509..900a885a64 100644 --- a/crates/ide/src/rename.rs +++ b/crates/ide/src/rename.rs @@ -480,7 +480,7 @@ fn rename_to_self( } let fn_def = match local.parent(sema.db) { - hir::DefWithBody::Function(func) => func, + hir::ExpressionStoreOwner::Body(hir::DefWithBody::Function(func)) => func, _ => bail!("Cannot rename local to self outside of function"), }; @@ -743,7 +743,7 @@ fn rename_self_to_param( } let fn_def = match local.parent(sema.db) { - hir::DefWithBody::Function(func) => func, + hir::ExpressionStoreOwner::Body(hir::DefWithBody::Function(func)) => func, _ => bail!("Cannot rename local to self outside of function"), }; diff --git a/crates/ide/src/runnables.rs b/crates/ide/src/runnables.rs index 42efa7142b..a0a6a24559 100644 --- a/crates/ide/src/runnables.rs +++ b/crates/ide/src/runnables.rs @@ -494,7 +494,7 @@ fn module_def_doctest(sema: &Semantics<'_, RootDatabase>, def: Definition) -> Op Definition::Module(it) => it.attrs(db), Definition::Function(it) => it.attrs(db), Definition::Adt(it) => it.attrs(db), - Definition::Variant(it) => it.attrs(db), + Definition::EnumVariant(it) => it.attrs(db), Definition::Const(it) => it.attrs(db), Definition::Static(it) => it.attrs(db), Definition::Trait(it) => it.attrs(db), diff --git a/crates/ide/src/signature_help.rs b/crates/ide/src/signature_help.rs index 9ab07565e9..9eb01b12f2 100644 --- a/crates/ide/src/signature_help.rs +++ b/crates/ide/src/signature_help.rs @@ -497,7 +497,7 @@ fn signature_help_for_tuple_struct_pat( }; let db = sema.db; - let fields: Vec<_> = if let PathResolution::Def(ModuleDef::Variant(variant)) = path_res { + let fields: Vec<_> = if let PathResolution::Def(ModuleDef::EnumVariant(variant)) = path_res { let en = variant.parent_enum(db); res.doc = en.docs(db).map(Documentation::into_owned); @@ -623,7 +623,7 @@ fn signature_help_for_record_<'db>( let db = sema.db; let path_res = sema.resolve_path(path)?; - if let PathResolution::Def(ModuleDef::Variant(variant)) = path_res { + if let PathResolution::Def(ModuleDef::EnumVariant(variant)) = path_res { fields = variant.fields(db); let en = variant.parent_enum(db); @@ -1975,8 +1975,8 @@ trait Sub: Super + Super { fn f() -> impl Sub<$0 "#, expect![[r#" - trait Sub<SubTy = …, SuperTy = …> - ^^^^^^^^^ ----------- + trait Sub<SuperTy = …, SubTy = …> + ^^^^^^^^^^^ --------- "#]], ); } diff --git a/crates/ide/src/syntax_highlighting.rs b/crates/ide/src/syntax_highlighting.rs index e64fd6488f..217b13b4ef 100644 --- a/crates/ide/src/syntax_highlighting.rs +++ b/crates/ide/src/syntax_highlighting.rs @@ -14,7 +14,9 @@ mod tests; use std::ops::ControlFlow; use either::Either; -use hir::{DefWithBody, EditionedFileId, InFile, InRealFile, MacroKind, Name, Semantics}; +use hir::{ + DefWithBody, EditionedFileId, ExpressionStoreOwner, InFile, InRealFile, MacroKind, Semantics, +}; use ide_db::{FxHashMap, FxHashSet, MiniCore, Ranker, RootDatabase, SymbolKind}; use syntax::{ AstNode, AstToken, NodeOrToken, @@ -256,9 +258,8 @@ fn traverse( let mut inside_attribute = false; // FIXME: accommodate range highlighting - let mut body_stack: Vec<Option<DefWithBody>> = vec![]; - let mut per_body_cache: FxHashMap<DefWithBody, (FxHashSet<_>, FxHashMap<Name, u32>)> = - FxHashMap::default(); + let mut body_stack: Vec<Option<ExpressionStoreOwner>> = vec![]; + let mut per_body_cache: FxHashMap<ExpressionStoreOwner, FxHashSet<_>> = FxHashMap::default(); // Walk all nodes, keeping track of whether we are inside a macro or not. // If in macro, expand it first and highlight the expanded code. @@ -289,19 +290,18 @@ fn traverse( inside_attribute = false } Enter(NodeOrToken::Node(node)) => { + // FIXME: ExpressionStore signatures and variant fields + // Maybe we can re-use child container stuff here if let Some(item) = <Either<ast::Item, ast::Variant>>::cast(node.clone()) { match item { Either::Left(item) => { match &item { - ast::Item::Fn(it) => { - body_stack.push(sema.to_def(it).map(Into::into)) - } - ast::Item::Const(it) => { - body_stack.push(sema.to_def(it).map(Into::into)) - } - ast::Item::Static(it) => { - body_stack.push(sema.to_def(it).map(Into::into)) - } + ast::Item::Fn(it) => body_stack + .push(sema.to_def(it).map(DefWithBody::from).map(Into::into)), + ast::Item::Const(it) => body_stack + .push(sema.to_def(it).map(DefWithBody::from).map(Into::into)), + ast::Item::Static(it) => body_stack + .push(sema.to_def(it).map(DefWithBody::from).map(Into::into)), _ => (), } @@ -330,7 +330,9 @@ fn traverse( } } } - Either::Right(it) => body_stack.push(sema.to_def(&it).map(Into::into)), + Either::Right(it) => { + body_stack.push(sema.to_def(&it).map(DefWithBody::from).map(Into::into)) + } } } } @@ -393,11 +395,11 @@ fn traverse( let descended = descend_token(sema, InRealFile::new(file_id, token)); let body = match &descended.value { NodeOrToken::Node(n) => { - sema.body_for(InFile::new(descended.file_id, n.syntax())) - } - NodeOrToken::Token(t) => { - t.parent().and_then(|it| sema.body_for(InFile::new(descended.file_id, &it))) + sema.store_owner_for(InFile::new(descended.file_id, n.syntax())) } + NodeOrToken::Token(t) => t + .parent() + .and_then(|it| sema.store_owner_for(InFile::new(descended.file_id, &it))), }; (descended, body) } @@ -422,14 +424,11 @@ fn traverse( } let edition = descended_element.file_id.edition(sema.db); - let (unsafe_ops, bindings_shadow_count) = match current_body { - Some(current_body) => { - let (ops, bindings) = per_body_cache - .entry(current_body) - .or_insert_with(|| (sema.get_unsafe_ops(current_body), Default::default())); - (&*ops, Some(bindings)) - } - None => (&empty, None), + let unsafe_ops = match current_body { + Some(current_body) => per_body_cache + .entry(current_body) + .or_insert_with(|| sema.get_unsafe_ops(current_body)), + None => &empty, }; let is_unsafe_node = |node| unsafe_ops.contains(&InFile::new(descended_element.file_id, node)); @@ -438,7 +437,6 @@ fn traverse( let hl = highlight::name_like( sema, krate, - bindings_shadow_count, &is_unsafe_node, config.syntactic_name_ref_highlighting, name_like, diff --git a/crates/ide/src/syntax_highlighting/highlight.rs b/crates/ide/src/syntax_highlighting/highlight.rs index dcc9a8c0d5..0e101ab235 100644 --- a/crates/ide/src/syntax_highlighting/highlight.rs +++ b/crates/ide/src/syntax_highlighting/highlight.rs @@ -5,12 +5,11 @@ use std::ops::ControlFlow; use either::Either; use hir::{AsAssocItem, HasAttrs, HasVisibility, Semantics}; use ide_db::{ - FxHashMap, RootDatabase, SymbolKind, + RootDatabase, SymbolKind, defs::{Definition, IdentClass, NameClass, NameRefClass}, syntax_helpers::node_ext::walk_pat, }; use span::Edition; -use stdx::hash_once; use syntax::{ AstNode, AstPtr, AstToken, NodeOrToken, SyntaxKind::{self, *}, @@ -64,7 +63,6 @@ pub(super) fn token( pub(super) fn name_like( sema: &Semantics<'_, RootDatabase>, krate: Option<hir::Crate>, - bindings_shadow_count: Option<&mut FxHashMap<hir::Name, u32>>, is_unsafe_node: &impl Fn(AstPtr<Either<ast::Expr, ast::Pat>>) -> bool, syntactic_name_ref_highlighting: bool, name_like: ast::NameLike, @@ -75,22 +73,15 @@ pub(super) fn name_like( ast::NameLike::NameRef(name_ref) => highlight_name_ref( sema, krate, - bindings_shadow_count, &mut binding_hash, is_unsafe_node, syntactic_name_ref_highlighting, name_ref, edition, ), - ast::NameLike::Name(name) => highlight_name( - sema, - bindings_shadow_count, - &mut binding_hash, - is_unsafe_node, - krate, - name, - edition, - ), + ast::NameLike::Name(name) => { + highlight_name(sema, &mut binding_hash, is_unsafe_node, krate, name, edition) + } ast::NameLike::Lifetime(lifetime) => match IdentClass::classify_lifetime(sema, &lifetime) { Some(IdentClass::NameClass(NameClass::Definition(def))) => { highlight_def(sema, krate, def, edition, false) | HlMod::Definition @@ -273,7 +264,6 @@ fn keyword(token: SyntaxToken, kind: SyntaxKind) -> Highlight { fn highlight_name_ref( sema: &Semantics<'_, RootDatabase>, krate: Option<hir::Crate>, - bindings_shadow_count: Option<&mut FxHashMap<hir::Name, u32>>, binding_hash: &mut Option<u64>, is_unsafe_node: &impl Fn(AstPtr<Either<ast::Expr, ast::Pat>>) -> bool, syntactic_name_ref_highlighting: bool, @@ -306,12 +296,8 @@ fn highlight_name_ref( }; let mut h = match name_class { NameRefClass::Definition(def, _) => { - if let Definition::Local(local) = &def - && let Some(bindings_shadow_count) = bindings_shadow_count - { - let name = local.name(sema.db); - let shadow_count = bindings_shadow_count.entry(name.clone()).or_default(); - *binding_hash = Some(calc_binding_hash(&name, *shadow_count)) + if let Definition::Local(local) = &def { + *binding_hash = Some(local.as_id() as u64); }; let mut h = highlight_def(sema, krate, def, edition, true); @@ -432,7 +418,6 @@ fn highlight_name_ref( fn highlight_name( sema: &Semantics<'_, RootDatabase>, - bindings_shadow_count: Option<&mut FxHashMap<hir::Name, u32>>, binding_hash: &mut Option<u64>, is_unsafe_node: &impl Fn(AstPtr<Either<ast::Expr, ast::Pat>>) -> bool, krate: Option<hir::Crate>, @@ -440,13 +425,8 @@ fn highlight_name( edition: Edition, ) -> Highlight { let name_kind = NameClass::classify(sema, &name); - if let Some(NameClass::Definition(Definition::Local(local))) = &name_kind - && let Some(bindings_shadow_count) = bindings_shadow_count - { - let name = local.name(sema.db); - let shadow_count = bindings_shadow_count.entry(name.clone()).or_default(); - *shadow_count += 1; - *binding_hash = Some(calc_binding_hash(&name, *shadow_count)) + if let Some(NameClass::Definition(Definition::Local(local))) = &name_kind { + *binding_hash = Some(local.as_id() as u64); }; match name_kind { Some(NameClass::Definition(def)) => { @@ -474,10 +454,6 @@ fn highlight_name( } } -fn calc_binding_hash(name: &hir::Name, shadow_count: u32) -> u64 { - hash_once::<ide_db::FxHasher>((name.as_str(), shadow_count)) -} - pub(super) fn highlight_def( sema: &Semantics<'_, RootDatabase>, krate: Option<hir::Crate>, @@ -562,7 +538,7 @@ pub(super) fn highlight_def( (Highlight::new(h), Some(adt.attrs(sema.db))) } - Definition::Variant(variant) => { + Definition::EnumVariant(variant) => { (Highlight::new(HlTag::Symbol(SymbolKind::Variant)), Some(variant.attrs(sema.db))) } Definition::Const(konst) => { diff --git a/crates/ide/src/syntax_highlighting/inject.rs b/crates/ide/src/syntax_highlighting/inject.rs index 291333f09c..74a8d93dfe 100644 --- a/crates/ide/src/syntax_highlighting/inject.rs +++ b/crates/ide/src/syntax_highlighting/inject.rs @@ -209,7 +209,7 @@ fn module_def_to_hl_tag(db: &dyn HirDatabase, def: Definition) -> HlTag { Definition::Adt(hir::Adt::Struct(_)) => SymbolKind::Struct, Definition::Adt(hir::Adt::Enum(_)) => SymbolKind::Enum, Definition::Adt(hir::Adt::Union(_)) => SymbolKind::Union, - Definition::Variant(_) => SymbolKind::Variant, + Definition::EnumVariant(_) => SymbolKind::Variant, Definition::Const(_) => SymbolKind::Const, Definition::Static(_) => SymbolKind::Static, Definition::Trait(_) => SymbolKind::Trait, diff --git a/crates/ide/src/syntax_highlighting/test_data/highlight_general.html b/crates/ide/src/syntax_highlighting/test_data/highlight_general.html index c6dbc435c0..1184739cc2 100644 --- a/crates/ide/src/syntax_highlighting/test_data/highlight_general.html +++ b/crates/ide/src/syntax_highlighting/test_data/highlight_general.html @@ -105,7 +105,7 @@ pre { color: #DCDCCC; background: #3F3F3F; font-size: 22px; padd <span class="keyword control">loop</span> <span class="brace">{</span><span class="brace">}</span> <span class="brace">}</span> -<span class="keyword">fn</span> <span class="function declaration">const_param</span><span class="angle"><</span><span class="keyword const">const</span> <span class="const_param const declaration">FOO</span><span class="colon">:</span> <span class="builtin_type">usize</span><span class="angle">></span><span class="parenthesis">(</span><span class="parenthesis">)</span> <span class="operator">-></span> <span class="builtin_type">usize</span> <span class="brace">{</span> +<span class="keyword">fn</span> <span class="function declaration">const_param</span><span class="angle"><</span><span class="keyword const">const</span> <span class="const_param const declaration">FOO</span><span class="colon">:</span> <span class="builtin_type">usize</span><span class="angle">></span><span class="parenthesis">(</span><span class="parenthesis">)</span> <span class="operator">-></span> <span class="builtin_type">usize</span> <span class="keyword">where</span> <span class="bracket">[</span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> <span class="const_param const">FOO</span><span class="bracket">]</span><span class="colon">:</span> <span class="trait default_library library">Sized</span> <span class="brace">{</span> <span class="function">const_param</span><span class="operator">::</span><span class="angle"><</span><span class="brace">{</span> <span class="const_param const">FOO</span> <span class="brace">}</span><span class="angle">></span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> <span class="const_param const">FOO</span> <span class="brace">}</span> diff --git a/crates/ide/src/syntax_highlighting/test_data/highlight_macros.html b/crates/ide/src/syntax_highlighting/test_data/highlight_macros.html index 59612634fd..740a6272a7 100644 --- a/crates/ide/src/syntax_highlighting/test_data/highlight_macros.html +++ b/crates/ide/src/syntax_highlighting/test_data/highlight_macros.html @@ -41,7 +41,8 @@ pre { color: #DCDCCC; background: #3F3F3F; font-size: 22px; padd .invalid_escape_sequence { color: #FC5555; text-decoration: wavy underline; } .unresolved_reference { color: #FC5555; text-decoration: wavy underline; } </style> -<pre><code><span class="keyword">use</span> <span class="crate_root library">proc_macros</span><span class="operator">::</span><span class="brace">{</span><span class="function library">mirror</span><span class="comma">,</span> <span class="function library">identity</span><span class="comma">,</span> <span class="derive library">DeriveIdentity</span><span class="brace">}</span><span class="semicolon">;</span> +<pre><code><span class="keyword">use</span> <span class="crate_root library">proc_macros</span><span class="operator">::</span><span class="brace">{</span><span class="proc_macro library">mirror</span><span class="comma">,</span> <span class="attribute library">identity</span><span class="comma">,</span> <span class="derive library">DeriveIdentity</span><span class="brace">}</span><span class="semicolon">;</span> +<span class="keyword">use</span> <span class="crate_root library">pm</span><span class="operator">::</span><span class="attribute library">proc_macro</span><span class="semicolon">;</span> <span class="proc_macro library">mirror</span><span class="macro_bang">!</span> <span class="brace">{</span> <span class="brace macro proc_macro">{</span> diff --git a/crates/ide/src/syntax_highlighting/test_data/highlight_rainbow.html b/crates/ide/src/syntax_highlighting/test_data/highlight_rainbow.html index d5401e7aec..7c64707ac1 100644 --- a/crates/ide/src/syntax_highlighting/test_data/highlight_rainbow.html +++ b/crates/ide/src/syntax_highlighting/test_data/highlight_rainbow.html @@ -42,14 +42,14 @@ pre { color: #DCDCCC; background: #3F3F3F; font-size: 22px; padd .unresolved_reference { color: #FC5555; text-decoration: wavy underline; } </style> <pre><code><span class="keyword">fn</span> <span class="function declaration">main</span><span class="parenthesis">(</span><span class="parenthesis">)</span> <span class="brace">{</span> - <span class="keyword">let</span> <span class="variable declaration reference" data-binding-hash="18084384843626695225" style="color: hsl(154,95%,53%);">hello</span> <span class="operator">=</span> <span class="string_literal">"hello"</span><span class="semicolon">;</span> - <span class="keyword">let</span> <span class="variable declaration" data-binding-hash="5697120079570210533" style="color: hsl(268,86%,80%);">x</span> <span class="operator">=</span> <span class="variable reference" data-binding-hash="18084384843626695225" style="color: hsl(154,95%,53%);">hello</span><span class="operator">.</span><span class="unresolved_reference">to_string</span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> - <span class="keyword">let</span> <span class="variable declaration" data-binding-hash="4222724691718692706" style="color: hsl(156,71%,51%);">y</span> <span class="operator">=</span> <span class="variable reference" data-binding-hash="18084384843626695225" style="color: hsl(154,95%,53%);">hello</span><span class="operator">.</span><span class="unresolved_reference">to_string</span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> + <span class="keyword">let</span> <span class="variable declaration reference" data-binding-hash="0" style="color: hsl(74,59%,48%);">hello</span> <span class="operator">=</span> <span class="string_literal">"hello"</span><span class="semicolon">;</span> + <span class="keyword">let</span> <span class="variable declaration" data-binding-hash="1" style="color: hsl(152,51%,64%);">x</span> <span class="operator">=</span> <span class="variable reference" data-binding-hash="0" style="color: hsl(74,59%,48%);">hello</span><span class="operator">.</span><span class="unresolved_reference">to_string</span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> + <span class="keyword">let</span> <span class="variable declaration" data-binding-hash="2" style="color: hsl(272,82%,82%);">y</span> <span class="operator">=</span> <span class="variable reference" data-binding-hash="0" style="color: hsl(74,59%,48%);">hello</span><span class="operator">.</span><span class="unresolved_reference">to_string</span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> - <span class="keyword">let</span> <span class="variable declaration reference" data-binding-hash="17855021198829413584" style="color: hsl(230,76%,79%);">x</span> <span class="operator">=</span> <span class="string_literal">"other color please!"</span><span class="semicolon">;</span> - <span class="keyword">let</span> <span class="variable declaration" data-binding-hash="16380625810977895757" style="color: hsl(262,75%,75%);">y</span> <span class="operator">=</span> <span class="variable reference" data-binding-hash="17855021198829413584" style="color: hsl(230,76%,79%);">x</span><span class="operator">.</span><span class="unresolved_reference">to_string</span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> + <span class="keyword">let</span> <span class="variable declaration reference" data-binding-hash="3" style="color: hsl(107,98%,81%);">x</span> <span class="operator">=</span> <span class="string_literal">"other color please!"</span><span class="semicolon">;</span> + <span class="keyword">let</span> <span class="variable declaration" data-binding-hash="4" style="color: hsl(241,93%,64%);">y</span> <span class="operator">=</span> <span class="variable reference" data-binding-hash="3" style="color: hsl(107,98%,81%);">x</span><span class="operator">.</span><span class="unresolved_reference">to_string</span><span class="parenthesis">(</span><span class="parenthesis">)</span><span class="semicolon">;</span> <span class="brace">}</span> <span class="keyword">fn</span> <span class="function declaration">bar</span><span class="parenthesis">(</span><span class="parenthesis">)</span> <span class="brace">{</span> - <span class="keyword">let</span> <span class="keyword">mut</span> <span class="variable declaration mutable reference" data-binding-hash="18084384843626695225" style="color: hsl(154,95%,53%);">hello</span> <span class="operator">=</span> <span class="string_literal">"hello"</span><span class="semicolon">;</span> + <span class="keyword">let</span> <span class="keyword">mut</span> <span class="variable declaration mutable reference" data-binding-hash="0" style="color: hsl(74,59%,48%);">hello</span> <span class="operator">=</span> <span class="string_literal">"hello"</span><span class="semicolon">;</span> <span class="brace">}</span></code></pre>
\ No newline at end of file diff --git a/crates/ide/src/syntax_highlighting/tests.rs b/crates/ide/src/syntax_highlighting/tests.rs index 8b529cf10f..aecd1d3fdb 100644 --- a/crates/ide/src/syntax_highlighting/tests.rs +++ b/crates/ide/src/syntax_highlighting/tests.rs @@ -55,8 +55,9 @@ fn macros() { r#" //- proc_macros: mirror, identity, derive_identity //- minicore: fmt, include, concat -//- /lib.rs crate:lib +//- /lib.rs crate:lib deps:pm use proc_macros::{mirror, identity, DeriveIdentity}; +use pm::proc_macro; mirror! { { @@ -126,6 +127,11 @@ fn main() { //- /foo/foo.rs crate:foo mod foo {} use self::foo as bar; +//- /pm.rs crate:pm +#![crate_type = "proc-macro"] + +#[proc_macro_attribute] +pub fn proc_macro() {} "#, expect_file!["./test_data/highlight_macros.html"], false, @@ -204,7 +210,7 @@ fn never() -> ! { loop {} } -fn const_param<const FOO: usize>() -> usize { +fn const_param<const FOO: usize>() -> usize where [(); FOO]: Sized { const_param::<{ FOO }>(); FOO } diff --git a/crates/ide/src/typing.rs b/crates/ide/src/typing.rs index 0381865fed..e8b0c92dcb 100644 --- a/crates/ide/src/typing.rs +++ b/crates/ide/src/typing.rs @@ -70,15 +70,12 @@ pub(crate) fn on_char_typed( if !TRIGGER_CHARS.contains(&char_typed) { return None; } - // FIXME: We need to figure out the edition of the file here, but that means hitting the - // database for more than just parsing the file which is bad. - // FIXME: We are hitting the database here, if we are unlucky this call might block momentarily - // causing the editor to feel sluggish! - let edition = Edition::CURRENT_FIXME; - let editioned_file_id_wrapper = EditionedFileId::from_span_guess_origin( - db, - span::EditionedFileId::new(position.file_id, edition), - ); + let edition = db + .relevant_crates(position.file_id) + .first() + .copied() + .map_or(Edition::CURRENT, |krate| krate.data(db).edition); + let editioned_file_id_wrapper = EditionedFileId::new(db, position.file_id, edition); let file = &db.parse(editioned_file_id_wrapper); let char_matches_position = file.tree().syntax().text().char_at(position.offset) == Some(char_typed); @@ -457,8 +454,8 @@ mod tests { let (offset, mut before) = extract_offset(before); let edit = TextEdit::insert(offset, char_typed.to_string()); edit.apply(&mut before); - let parse = SourceFile::parse(&before, span::Edition::CURRENT_FIXME); - on_char_typed_(&parse, offset, char_typed, span::Edition::CURRENT_FIXME).map(|it| { + let parse = SourceFile::parse(&before, span::Edition::CURRENT); + on_char_typed_(&parse, offset, char_typed, span::Edition::CURRENT).map(|it| { it.apply(&mut before); before.to_string() }) diff --git a/crates/ide/src/typing/on_enter.rs b/crates/ide/src/typing/on_enter.rs index 76a2802d29..fdc583a15c 100644 --- a/crates/ide/src/typing/on_enter.rs +++ b/crates/ide/src/typing/on_enter.rs @@ -51,7 +51,7 @@ use ide_db::text_edit::TextEdit; //  pub(crate) fn on_enter(db: &RootDatabase, position: FilePosition) -> Option<TextEdit> { let editioned_file_id_wrapper = - ide_db::base_db::EditionedFileId::current_edition_guess_origin(db, position.file_id); + ide_db::base_db::EditionedFileId::current_edition(db, position.file_id); let parse = db.parse(editioned_file_id_wrapper); let file = parse.tree(); let token = file.syntax().token_at_offset(position.offset).left_biased()?; diff --git a/crates/ide/src/view_item_tree.rs b/crates/ide/src/view_item_tree.rs index e1a7e4e6ab..8d84eba7ab 100644 --- a/crates/ide/src/view_item_tree.rs +++ b/crates/ide/src/view_item_tree.rs @@ -10,6 +10,9 @@ use ide_db::{FileId, RootDatabase}; // | VS Code | **rust-analyzer: Debug ItemTree** | pub(crate) fn view_item_tree(db: &RootDatabase, file_id: FileId) -> String { let sema = Semantics::new(db); + let Some(krate) = sema.first_crate(file_id) else { + return String::new(); + }; let file_id = sema.attach_first_edition(file_id); - db.file_item_tree(file_id.into()).pretty_print(db, file_id.edition(db)) + db.file_item_tree(file_id.into(), krate.into()).pretty_print(db, file_id.edition(db)) } diff --git a/crates/intern/Cargo.toml b/crates/intern/Cargo.toml index ad73c191c0..39320ebd1c 100644 --- a/crates/intern/Cargo.toml +++ b/crates/intern/Cargo.toml @@ -22,3 +22,6 @@ rayon.workspace = true [lints] workspace = true + +[features] +prevent-gc = [] diff --git a/crates/intern/src/gc.rs b/crates/intern/src/gc.rs index 937de26831..f4e8f75e71 100644 --- a/crates/intern/src/gc.rs +++ b/crates/intern/src/gc.rs @@ -110,6 +110,10 @@ impl GarbageCollector { /// the added storages must form a DAG. /// - [`GcInternedVisit`] and [`GcInternedSliceVisit`] must mark all values reachable from the node. pub unsafe fn collect(mut self) { + if cfg!(feature = "prevent-gc") { + return; + } + let total_nodes = self.storages.iter().map(|storage| storage.len()).sum(); self.alive.clear(); self.alive.reserve(total_nodes); diff --git a/crates/intern/src/symbol/symbols.rs b/crates/intern/src/symbol/symbols.rs index 2be4e41f4f..cc09a1aae7 100644 --- a/crates/intern/src/symbol/symbols.rs +++ b/crates/intern/src/symbol/symbols.rs @@ -110,6 +110,7 @@ define_symbols! { win64_dash_unwind = "win64-unwind", x86_dash_interrupt = "x86-interrupt", rust_dash_preserve_dash_none = "preserve-none", + _0_u8 = "0_u8", @PLAIN: __ra_fixup, @@ -285,6 +286,7 @@ define_symbols! { Into, into_future, into_iter, + into_try_type, IntoFuture, IntoIter, IntoIterator, diff --git a/crates/load-cargo/src/lib.rs b/crates/load-cargo/src/lib.rs index c2935d94a8..8753eab43a 100644 --- a/crates/load-cargo/src/lib.rs +++ b/crates/load-cargo/src/lib.rs @@ -26,10 +26,7 @@ use ide_db::{ use itertools::Itertools; use proc_macro_api::{ MacroDylib, ProcMacroClient, - bidirectional_protocol::{ - msg::{SubRequest, SubResponse}, - reject_subrequests, - }, + bidirectional_protocol::msg::{ParentSpan, SubRequest, SubResponse}, }; use project_model::{CargoConfig, PackageRoot, ProjectManifest, ProjectWorkspace}; use span::{Span, SpanAnchor, SyntaxContext}; @@ -45,6 +42,7 @@ pub struct LoadCargoConfig { pub load_out_dirs_from_check: bool, pub with_proc_macro_server: ProcMacroServerChoice, pub prefill_caches: bool, + pub num_worker_threads: usize, pub proc_macro_processes: usize, } @@ -200,7 +198,7 @@ pub fn load_workspace_into_db( ); if load_config.prefill_caches { - prime_caches::parallel_prime_caches(db, 1, &|_| ()); + prime_caches::parallel_prime_caches(db, load_config.num_worker_threads, &|_| ()); } Ok((vfs, proc_macro_server.and_then(Result::ok))) @@ -446,7 +444,7 @@ pub fn load_proc_macro( ) -> ProcMacroLoadResult { let res: Result<Vec<_>, _> = (|| { let dylib = MacroDylib::new(path.to_path_buf()); - let vec = server.load_dylib(dylib, Some(&reject_subrequests)).map_err(|e| { + let vec = server.load_dylib(dylib).map_err(|e| { ProcMacroLoadingError::ProcMacroSrvError(format!("{e}").into_boxed_str()) })?; if vec.is_empty() { @@ -615,6 +613,91 @@ impl ProcMacroExpander for Expander { Ok(SubResponse::ByteRangeResult { range: range.range.into() }) } + SubRequest::SpanSource { file_id, ast_id, start, end, ctx } => { + let span = Span { + range: TextRange::new(TextSize::from(start), TextSize::from(end)), + anchor: SpanAnchor { + file_id: span::EditionedFileId::from_raw(file_id), + ast_id: span::ErasedFileAstId::from_raw(ast_id), + }, + // SAFETY: We only receive spans from the server. If someone mess up the communication UB can happen, + // but that will be their problem. + ctx: unsafe { SyntaxContext::from_u32(ctx) }, + }; + + let mut current_span = span; + let mut current_ctx = span.ctx; + + while let Some(macro_call_id) = current_ctx.outer_expn(db) { + let macro_call_loc = db.lookup_intern_macro_call(macro_call_id.into()); + + let call_site_file = macro_call_loc.kind.file_id(); + + let resolved = db.resolve_span(current_span); + + current_ctx = macro_call_loc.ctxt; + current_span = Span { + range: resolved.range, + anchor: SpanAnchor { + file_id: resolved.file_id.span_file_id(db), + ast_id: span::ROOT_ERASED_FILE_AST_ID, + }, + ctx: current_ctx, + }; + + if call_site_file.file_id().is_some() { + break; + } + } + + let resolved = db.resolve_span(current_span); + + Ok(SubResponse::SpanSourceResult { + file_id: resolved.file_id.span_file_id(db).as_u32(), + ast_id: span::ROOT_ERASED_FILE_AST_ID.into_raw(), + start: u32::from(resolved.range.start()), + end: u32::from(resolved.range.end()), + ctx: current_span.ctx.into_u32(), + }) + } + SubRequest::SpanParent { file_id, ast_id, start, end, ctx } => { + let span = Span { + range: TextRange::new(TextSize::from(start), TextSize::from(end)), + anchor: SpanAnchor { + file_id: span::EditionedFileId::from_raw(file_id), + ast_id: span::ErasedFileAstId::from_raw(ast_id), + }, + // SAFETY: We only receive spans from the server. If someone mess up the communication UB can happen, + // but that will be their problem. + ctx: unsafe { SyntaxContext::from_u32(ctx) }, + }; + + if let Some(macro_call_id) = span.ctx.outer_expn(db) { + let macro_call_loc = db.lookup_intern_macro_call(macro_call_id.into()); + + let call_site_file = macro_call_loc.kind.file_id(); + let call_site_ast_id = macro_call_loc.kind.erased_ast_id(); + + if let Some(editioned_file_id) = call_site_file.file_id() { + let range = db + .ast_id_map(editioned_file_id.into()) + .get_erased(call_site_ast_id) + .text_range(); + + let parent_span = Some(ParentSpan { + file_id: editioned_file_id.span_file_id(db).as_u32(), + ast_id: span::ROOT_ERASED_FILE_AST_ID.into_raw(), + start: u32::from(range.start()), + end: u32::from(range.end()), + ctx: macro_call_loc.ctxt.into_u32(), + }); + + return Ok(SubResponse::SpanParentResult { parent_span }); + } + } + + Ok(SubResponse::SpanParentResult { parent_span: None }) + } }; match self.0.expand( subtree.view(), @@ -662,16 +745,26 @@ mod tests { #[test] fn test_loading_rust_analyzer() { - let path = Path::new(env!("CARGO_MANIFEST_DIR")).parent().unwrap().parent().unwrap(); + let cargo_toml_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .parent() + .unwrap() + .join("Cargo.toml"); + let cargo_toml_path = AbsPathBuf::assert_utf8(cargo_toml_path); + let manifest = ProjectManifest::from_manifest_file(cargo_toml_path).unwrap(); + let cargo_config = CargoConfig { set_test: true, ..CargoConfig::default() }; let load_cargo_config = LoadCargoConfig { load_out_dirs_from_check: false, with_proc_macro_server: ProcMacroServerChoice::None, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; + let workspace = ProjectWorkspace::load(manifest, &cargo_config, &|_| {}).unwrap(); let (db, _vfs, _proc_macro) = - load_workspace_at(path, &cargo_config, &load_cargo_config, &|_| {}).unwrap(); + load_workspace(workspace, &cargo_config.extra_env, &load_cargo_config).unwrap(); let n_crates = db.all_crates().len(); // RA has quite a few crates, but the exact count doesn't matter diff --git a/crates/parser/src/grammar/expressions/atom.rs b/crates/parser/src/grammar/expressions/atom.rs index d83e2eb2b4..3214fd90f2 100644 --- a/crates/parser/src/grammar/expressions/atom.rs +++ b/crates/parser/src/grammar/expressions/atom.rs @@ -407,8 +407,18 @@ pub(crate) fn parse_asm_expr(p: &mut Parser<'_>, m: Marker) -> Option<CompletedM op.complete(p, ASM_CONST); op_n.complete(p, ASM_OPERAND_NAMED); } else if p.eat_contextual_kw(T![sym]) { + // test asm_sym_paren + // fn foo() { + // builtin#asm("", f = sym (foo::bar)); + // } dir_spec.abandon(p); - paths::type_path(p); + if p.at(T!['(']) { + p.bump(T!['(']); + paths::type_path(p); + p.expect(T![')']); + } else { + paths::type_path(p); + } op.complete(p, ASM_SYM); op_n.complete(p, ASM_OPERAND_NAMED); } else if allow_templates { @@ -976,11 +986,17 @@ fn break_expr(p: &mut Parser<'_>, r: Restrictions) -> CompletedMarker { // test try_block_expr // fn foo() { // let _ = try {}; +// let _ = try bikeshed T<U> {}; // } fn try_block_expr(p: &mut Parser<'_>, m: Option<Marker>) -> CompletedMarker { assert!(p.at(T![try])); let m = m.unwrap_or_else(|| p.start()); + let try_modifier = p.start(); p.bump(T![try]); + if p.eat_contextual_kw(T![bikeshed]) { + type_(p); + } + try_modifier.complete(p, TRY_BLOCK_MODIFIER); if p.at(T!['{']) { stmt_list(p); } else { diff --git a/crates/parser/src/grammar/items/adt.rs b/crates/parser/src/grammar/items/adt.rs index a375696140..cfba4c3a77 100644 --- a/crates/parser/src/grammar/items/adt.rs +++ b/crates/parser/src/grammar/items/adt.rs @@ -96,7 +96,9 @@ pub(crate) fn variant_list(p: &mut Parser<'_>) { // test variant_discriminant // enum E { X(i32) = 10 } if p.eat(T![=]) { + let m = p.start(); expressions::expr(p); + m.complete(p, CONST_ARG); } m.complete(p, VARIANT); } else { @@ -139,7 +141,9 @@ pub(crate) fn record_field_list(p: &mut Parser<'_>) { // test record_field_default_values // struct S { f: f32 = 0.0 } if p.eat(T![=]) { + let m = p.start(); expressions::expr(p); + m.complete(p, CONST_ARG); } m.complete(p, RECORD_FIELD); } else { diff --git a/crates/parser/src/syntax_kind/generated.rs b/crates/parser/src/syntax_kind/generated.rs index 5d22d966b2..a2295e4495 100644 --- a/crates/parser/src/syntax_kind/generated.rs +++ b/crates/parser/src/syntax_kind/generated.rs @@ -114,6 +114,7 @@ pub enum SyntaxKind { ATT_SYNTAX_KW, AUTO_KW, AWAIT_KW, + BIKESHED_KW, BUILTIN_KW, CLOBBER_ABI_KW, DEFAULT_KW, @@ -285,6 +286,7 @@ pub enum SyntaxKind { STRUCT, TOKEN_TREE, TRAIT, + TRY_BLOCK_MODIFIER, TRY_EXPR, TUPLE_EXPR, TUPLE_FIELD, @@ -458,6 +460,7 @@ impl SyntaxKind { | STRUCT | TOKEN_TREE | TRAIT + | TRY_BLOCK_MODIFIER | TRY_EXPR | TUPLE_EXPR | TUPLE_FIELD @@ -596,6 +599,7 @@ impl SyntaxKind { ASM_KW => "asm", ATT_SYNTAX_KW => "att_syntax", AUTO_KW => "auto", + BIKESHED_KW => "bikeshed", BUILTIN_KW => "builtin", CLOBBER_ABI_KW => "clobber_abi", DEFAULT_KW => "default", @@ -698,6 +702,7 @@ impl SyntaxKind { ASM_KW => true, ATT_SYNTAX_KW => true, AUTO_KW => true, + BIKESHED_KW => true, BUILTIN_KW => true, CLOBBER_ABI_KW => true, DEFAULT_KW => true, @@ -788,6 +793,7 @@ impl SyntaxKind { ASM_KW => true, ATT_SYNTAX_KW => true, AUTO_KW => true, + BIKESHED_KW => true, BUILTIN_KW => true, CLOBBER_ABI_KW => true, DEFAULT_KW => true, @@ -941,6 +947,7 @@ impl SyntaxKind { "asm" => ASM_KW, "att_syntax" => ATT_SYNTAX_KW, "auto" => AUTO_KW, + "bikeshed" => BIKESHED_KW, "builtin" => BUILTIN_KW, "clobber_abi" => CLOBBER_ABI_KW, "default" => DEFAULT_KW, @@ -1112,6 +1119,7 @@ macro_rules ! T_ { [asm] => { $ crate :: SyntaxKind :: ASM_KW }; [att_syntax] => { $ crate :: SyntaxKind :: ATT_SYNTAX_KW }; [auto] => { $ crate :: SyntaxKind :: AUTO_KW }; + [bikeshed] => { $ crate :: SyntaxKind :: BIKESHED_KW }; [builtin] => { $ crate :: SyntaxKind :: BUILTIN_KW }; [clobber_abi] => { $ crate :: SyntaxKind :: CLOBBER_ABI_KW }; [default] => { $ crate :: SyntaxKind :: DEFAULT_KW }; diff --git a/crates/parser/test_data/generated/runner.rs b/crates/parser/test_data/generated/runner.rs index 9f919f6cea..4c001104fe 100644 --- a/crates/parser/test_data/generated/runner.rs +++ b/crates/parser/test_data/generated/runner.rs @@ -25,6 +25,8 @@ mod ok { #[test] fn asm_label() { run_and_expect_no_errors("test_data/parser/inline/ok/asm_label.rs"); } #[test] + fn asm_sym_paren() { run_and_expect_no_errors("test_data/parser/inline/ok/asm_sym_paren.rs"); } + #[test] fn assoc_const_eq() { run_and_expect_no_errors("test_data/parser/inline/ok/assoc_const_eq.rs"); } diff --git a/crates/parser/test_data/parser/err/0042_weird_blocks.rast b/crates/parser/test_data/parser/err/0042_weird_blocks.rast index d6d2e75cca..9e4e9dbf9d 100644 --- a/crates/parser/test_data/parser/err/0042_weird_blocks.rast +++ b/crates/parser/test_data/parser/err/0042_weird_blocks.rast @@ -45,7 +45,8 @@ SOURCE_FILE WHITESPACE " " EXPR_STMT BLOCK_EXPR - TRY_KW "try" + TRY_BLOCK_MODIFIER + TRY_KW "try" WHITESPACE " " LITERAL INT_NUMBER "92" diff --git a/crates/parser/test_data/parser/inline/ok/asm_sym_paren.rast b/crates/parser/test_data/parser/inline/ok/asm_sym_paren.rast new file mode 100644 index 0000000000..d189f63f2a --- /dev/null +++ b/crates/parser/test_data/parser/inline/ok/asm_sym_paren.rast @@ -0,0 +1,49 @@ +SOURCE_FILE + FN + FN_KW "fn" + WHITESPACE " " + NAME + IDENT "foo" + PARAM_LIST + L_PAREN "(" + R_PAREN ")" + WHITESPACE " " + BLOCK_EXPR + STMT_LIST + L_CURLY "{" + WHITESPACE "\n " + EXPR_STMT + ASM_EXPR + BUILTIN_KW "builtin" + POUND "#" + ASM_KW "asm" + L_PAREN "(" + LITERAL + STRING "\"\"" + COMMA "," + WHITESPACE " " + ASM_OPERAND_NAMED + NAME + IDENT "f" + WHITESPACE " " + EQ "=" + WHITESPACE " " + ASM_SYM + SYM_KW "sym" + WHITESPACE " " + L_PAREN "(" + PATH + PATH + PATH_SEGMENT + NAME_REF + IDENT "foo" + COLON2 "::" + PATH_SEGMENT + NAME_REF + IDENT "bar" + R_PAREN ")" + R_PAREN ")" + SEMICOLON ";" + WHITESPACE "\n" + R_CURLY "}" + WHITESPACE "\n" diff --git a/crates/parser/test_data/parser/inline/ok/asm_sym_paren.rs b/crates/parser/test_data/parser/inline/ok/asm_sym_paren.rs new file mode 100644 index 0000000000..7b2f80704c --- /dev/null +++ b/crates/parser/test_data/parser/inline/ok/asm_sym_paren.rs @@ -0,0 +1,3 @@ +fn foo() { + builtin#asm("", f = sym (foo::bar)); +} diff --git a/crates/parser/test_data/parser/inline/ok/record_field_default_values.rast b/crates/parser/test_data/parser/inline/ok/record_field_default_values.rast index 33088f2cab..e53b886bbf 100644 --- a/crates/parser/test_data/parser/inline/ok/record_field_default_values.rast +++ b/crates/parser/test_data/parser/inline/ok/record_field_default_values.rast @@ -21,8 +21,9 @@ SOURCE_FILE WHITESPACE " " EQ "=" WHITESPACE " " - LITERAL - FLOAT_NUMBER "0.0" + CONST_ARG + LITERAL + FLOAT_NUMBER "0.0" WHITESPACE " " R_CURLY "}" WHITESPACE "\n" diff --git a/crates/parser/test_data/parser/inline/ok/try_block_expr.rast b/crates/parser/test_data/parser/inline/ok/try_block_expr.rast index aec8fbf477..472ce711c5 100644 --- a/crates/parser/test_data/parser/inline/ok/try_block_expr.rast +++ b/crates/parser/test_data/parser/inline/ok/try_block_expr.rast @@ -21,7 +21,42 @@ SOURCE_FILE EQ "=" WHITESPACE " " BLOCK_EXPR - TRY_KW "try" + TRY_BLOCK_MODIFIER + TRY_KW "try" + WHITESPACE " " + STMT_LIST + L_CURLY "{" + R_CURLY "}" + SEMICOLON ";" + WHITESPACE "\n " + LET_STMT + LET_KW "let" + WHITESPACE " " + WILDCARD_PAT + UNDERSCORE "_" + WHITESPACE " " + EQ "=" + WHITESPACE " " + BLOCK_EXPR + TRY_BLOCK_MODIFIER + TRY_KW "try" + WHITESPACE " " + BIKESHED_KW "bikeshed" + WHITESPACE " " + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "T" + GENERIC_ARG_LIST + L_ANGLE "<" + TYPE_ARG + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "U" + R_ANGLE ">" WHITESPACE " " STMT_LIST L_CURLY "{" diff --git a/crates/parser/test_data/parser/inline/ok/try_block_expr.rs b/crates/parser/test_data/parser/inline/ok/try_block_expr.rs index 0f1b41eb64..719980473c 100644 --- a/crates/parser/test_data/parser/inline/ok/try_block_expr.rs +++ b/crates/parser/test_data/parser/inline/ok/try_block_expr.rs @@ -1,3 +1,4 @@ fn foo() { let _ = try {}; + let _ = try bikeshed T<U> {}; } diff --git a/crates/parser/test_data/parser/inline/ok/variant_discriminant.rast b/crates/parser/test_data/parser/inline/ok/variant_discriminant.rast index 9f0c5a7610..3494085e88 100644 --- a/crates/parser/test_data/parser/inline/ok/variant_discriminant.rast +++ b/crates/parser/test_data/parser/inline/ok/variant_discriminant.rast @@ -23,8 +23,9 @@ SOURCE_FILE WHITESPACE " " EQ "=" WHITESPACE " " - LITERAL - INT_NUMBER "10" + CONST_ARG + LITERAL + INT_NUMBER "10" WHITESPACE " " R_CURLY "}" WHITESPACE "\n" diff --git a/crates/parser/test_data/parser/ok/0019_enums.rast b/crates/parser/test_data/parser/ok/0019_enums.rast index dd47e3aa47..51837e5372 100644 --- a/crates/parser/test_data/parser/ok/0019_enums.rast +++ b/crates/parser/test_data/parser/ok/0019_enums.rast @@ -78,8 +78,9 @@ SOURCE_FILE WHITESPACE " " EQ "=" WHITESPACE " " - LITERAL - INT_NUMBER "92" + CONST_ARG + LITERAL + INT_NUMBER "92" COMMA "," WHITESPACE "\n " VARIANT diff --git a/crates/proc-macro-api/Cargo.toml b/crates/proc-macro-api/Cargo.toml index 4de1a3e5dd..a135a469e8 100644 --- a/crates/proc-macro-api/Cargo.toml +++ b/crates/proc-macro-api/Cargo.toml @@ -31,6 +31,7 @@ span = { path = "../span", version = "0.0.0", default-features = false} intern.workspace = true postcard.workspace = true semver.workspace = true +rayon.workspace = true [features] sysroot-abi = ["proc-macro-srv", "proc-macro-srv/sysroot-abi"] diff --git a/crates/proc-macro-api/src/bidirectional_protocol/msg.rs b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs index 3f0422dc5b..ab4bed81e6 100644 --- a/crates/proc-macro-api/src/bidirectional_protocol/msg.rs +++ b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs @@ -21,6 +21,8 @@ pub enum SubRequest { LocalFilePath { file_id: u32 }, LineColumn { file_id: u32, ast_id: u32, offset: u32 }, ByteRange { file_id: u32, ast_id: u32, start: u32, end: u32 }, + SpanSource { file_id: u32, ast_id: u32, start: u32, end: u32, ctx: u32 }, + SpanParent { file_id: u32, ast_id: u32, start: u32, end: u32, ctx: u32 }, } #[derive(Debug, Serialize, Deserialize)] @@ -42,12 +44,31 @@ pub enum SubResponse { ByteRangeResult { range: Range<usize>, }, + SpanSourceResult { + file_id: u32, + ast_id: u32, + start: u32, + end: u32, + ctx: u32, + }, + SpanParentResult { + parent_span: Option<ParentSpan>, + }, Cancel { reason: String, }, } #[derive(Debug, Serialize, Deserialize)] +pub struct ParentSpan { + pub file_id: u32, + pub ast_id: u32, + pub start: u32, + pub end: u32, + pub ctx: u32, +} + +#[derive(Debug, Serialize, Deserialize)] pub enum BidirectionalMessage { Request(Request), Response(Response), diff --git a/crates/proc-macro-api/src/lib.rs b/crates/proc-macro-api/src/lib.rs index e4b121b033..e83ddb8594 100644 --- a/crates/proc-macro-api/src/lib.rs +++ b/crates/proc-macro-api/src/lib.rs @@ -10,7 +10,7 @@ feature = "sysroot-abi", feature(proc_macro_internals, proc_macro_diagnostic, proc_macro_span) )] -#![allow(internal_features)] +#![allow(internal_features, unused_features)] #![cfg_attr(feature = "in-rust-tree", feature(rustc_private))] #[cfg(feature = "in-rust-tree")] @@ -198,12 +198,8 @@ impl ProcMacroClient { } /// Loads a proc-macro dylib into the server process returning a list of `ProcMacro`s loaded. - pub fn load_dylib( - &self, - dylib: MacroDylib, - callback: Option<SubCallback<'_>>, - ) -> Result<Vec<ProcMacro>, ServerError> { - self.pool.load_dylib(&dylib, callback) + pub fn load_dylib(&self, dylib: MacroDylib) -> Result<Vec<ProcMacro>, ServerError> { + self.pool.load_dylib(&dylib) } /// Checks if the proc-macro server has exited. diff --git a/crates/proc-macro-api/src/pool.rs b/crates/proc-macro-api/src/pool.rs index a637bc0e48..e6541823da 100644 --- a/crates/proc-macro-api/src/pool.rs +++ b/crates/proc-macro-api/src/pool.rs @@ -1,10 +1,9 @@ //! A pool of proc-macro server processes use std::sync::Arc; -use crate::{ - MacroDylib, ProcMacro, ServerError, bidirectional_protocol::SubCallback, - process::ProcMacroServerProcess, -}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +use crate::{MacroDylib, ProcMacro, ServerError, process::ProcMacroServerProcess}; #[derive(Debug, Clone)] pub(crate) struct ProcMacroServerPool { @@ -50,11 +49,7 @@ impl ProcMacroServerPool { }) } - pub(crate) fn load_dylib( - &self, - dylib: &MacroDylib, - callback: Option<SubCallback<'_>>, - ) -> Result<Vec<ProcMacro>, ServerError> { + pub(crate) fn load_dylib(&self, dylib: &MacroDylib) -> Result<Vec<ProcMacro>, ServerError> { let _span = tracing::info_span!("ProcMacroServer::load_dylib").entered(); let dylib_path = Arc::new(dylib.path.clone()); @@ -64,14 +59,17 @@ impl ProcMacroServerPool { let (first, rest) = self.workers.split_first().expect("worker pool must not be empty"); let macros = first - .find_proc_macros(&dylib.path, callback)? + .find_proc_macros(&dylib.path)? .map_err(|e| ServerError { message: e, io: None })?; - for worker in rest { - worker - .find_proc_macros(&dylib.path, callback)? - .map_err(|e| ServerError { message: e, io: None })?; - } + rest.into_par_iter() + .map(|worker| { + worker + .find_proc_macros(&dylib.path)? + .map(|_| ()) + .map_err(|e| ServerError { message: e, io: None }) + }) + .collect::<Result<(), _>>()?; Ok(macros .into_iter() diff --git a/crates/proc-macro-api/src/process.rs b/crates/proc-macro-api/src/process.rs index 9f80880965..80e4ed05c3 100644 --- a/crates/proc-macro-api/src/process.rs +++ b/crates/proc-macro-api/src/process.rs @@ -18,7 +18,11 @@ use stdx::JodChild; use crate::{ ProcMacro, ProcMacroKind, ProtocolFormat, ServerError, - bidirectional_protocol::{self, SubCallback, msg::BidirectionalMessage, reject_subrequests}, + bidirectional_protocol::{ + self, SubCallback, + msg::{BidirectionalMessage, SubResponse}, + reject_subrequests, + }, legacy_protocol::{self, SpanMode}, version, }; @@ -207,14 +211,18 @@ impl ProcMacroServerProcess { pub(crate) fn find_proc_macros( &self, dylib_path: &AbsPath, - callback: Option<SubCallback<'_>>, ) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> { match self.protocol { Protocol::LegacyJson { .. } => legacy_protocol::find_proc_macros(self, dylib_path), Protocol::BidirectionalPostcardPrototype { .. } => { - let cb = callback.expect("callback required for bidirectional protocol"); - bidirectional_protocol::find_proc_macros(self, dylib_path, cb) + bidirectional_protocol::find_proc_macros(self, dylib_path, &|_| { + Ok(SubResponse::Cancel { + reason: String::from( + "Server should not do a sub request when loading proc-macros", + ), + }) + }) } } } diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs index 9be3199a38..c525ed848b 100644 --- a/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/crates/proc-macro-srv-cli/src/main_loop.rs @@ -273,6 +273,76 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> { other => handle_failure(other), } } + + fn span_source( + &mut self, + proc_macro_srv::span::Span { range, anchor, ctx }: proc_macro_srv::span::Span, + ) -> proc_macro_srv::span::Span { + match self.roundtrip(bidirectional::SubRequest::SpanSource { + file_id: anchor.file_id.as_u32(), + ast_id: anchor.ast_id.into_raw(), + start: range.start().into(), + end: range.end().into(), + ctx: ctx.into_u32(), + }) { + Ok(bidirectional::SubResponse::SpanSourceResult { + file_id, + ast_id, + start, + end, + ctx, + }) => { + proc_macro_srv::span::Span { + range: proc_macro_srv::span::TextRange::new( + proc_macro_srv::span::TextSize::new(start), + proc_macro_srv::span::TextSize::new(end), + ), + anchor: proc_macro_srv::span::SpanAnchor { + file_id: proc_macro_srv::span::EditionedFileId::from_raw(file_id), + ast_id: proc_macro_srv::span::ErasedFileAstId::from_raw(ast_id), + }, + // SAFETY: We only receive spans from the server. If someone mess up the communication UB can happen, + // but that will be their problem. + ctx: unsafe { proc_macro_srv::span::SyntaxContext::from_u32(ctx) }, + } + } + other => handle_failure(other), + } + } + + fn span_parent( + &mut self, + proc_macro_srv::span::Span { range, anchor, ctx }: proc_macro_srv::span::Span, + ) -> Option<proc_macro_srv::span::Span> { + let response = self.roundtrip(bidirectional::SubRequest::SpanParent { + file_id: anchor.file_id.as_u32(), + ast_id: anchor.ast_id.into_raw(), + start: range.start().into(), + end: range.end().into(), + ctx: ctx.into_u32(), + }); + + match response { + Ok(bidirectional::SubResponse::SpanParentResult { parent_span }) => { + parent_span.map(|bidirectional::ParentSpan { file_id, ast_id, start, end, ctx }| { + proc_macro_srv::span::Span { + range: proc_macro_srv::span::TextRange::new( + proc_macro_srv::span::TextSize::new(start), + proc_macro_srv::span::TextSize::new(end), + ), + anchor: proc_macro_srv::span::SpanAnchor { + file_id: proc_macro_srv::span::EditionedFileId::from_raw(file_id), + ast_id: proc_macro_srv::span::ErasedFileAstId::from_raw(ast_id), + }, + // SAFETY: spans originate from the server. If the protocol is violated, + // undefined behavior is the caller’s responsibility. + ctx: unsafe { proc_macro_srv::span::SyntaxContext::from_u32(ctx) }, + } + }) + } + other => handle_failure(other), + } + } } fn handle_expand_ra( diff --git a/crates/proc-macro-srv/src/dylib/proc_macros.rs b/crates/proc-macro-srv/src/dylib/proc_macros.rs index 76c5097101..4065dbd0b4 100644 --- a/crates/proc-macro-srv/src/dylib/proc_macros.rs +++ b/crates/proc-macro-srv/src/dylib/proc_macros.rs @@ -30,7 +30,7 @@ impl ProcMacros { if *trait_name == macro_name => { let res = client.run( - &bridge::server::SameThread, + &bridge::server::SAME_THREAD, S::make_server(call_site, def_site, mixed_site, callback), macro_body, cfg!(debug_assertions), @@ -39,7 +39,7 @@ impl ProcMacros { } bridge::client::ProcMacro::Bang { name, client } if *name == macro_name => { let res = client.run( - &bridge::server::SameThread, + &bridge::server::SAME_THREAD, S::make_server(call_site, def_site, mixed_site, callback), macro_body, cfg!(debug_assertions), @@ -48,7 +48,7 @@ impl ProcMacros { } bridge::client::ProcMacro::Attr { name, client } if *name == macro_name => { let res = client.run( - &bridge::server::SameThread, + &bridge::server::SAME_THREAD, S::make_server(call_site, def_site, mixed_site, callback), parsed_attributes, macro_body, diff --git a/crates/proc-macro-srv/src/lib.rs b/crates/proc-macro-srv/src/lib.rs index c548dc620a..0bdc379cb6 100644 --- a/crates/proc-macro-srv/src/lib.rs +++ b/crates/proc-macro-srv/src/lib.rs @@ -18,7 +18,8 @@ internal_features, clippy::disallowed_types, clippy::print_stderr, - unused_crate_dependencies + unused_crate_dependencies, + unused_features )] #![deny(deprecated_safe, clippy::undocumented_unsafe_blocks)] @@ -120,6 +121,8 @@ pub trait ProcMacroClientInterface { fn line_column(&mut self, span: Span) -> Option<(u32, u32)>; fn byte_range(&mut self, span: Span) -> Range<usize>; + fn span_source(&mut self, span: Span) -> Span; + fn span_parent(&mut self, span: Span) -> Option<Span>; } const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024; @@ -325,7 +328,7 @@ impl<'snap> EnvChange<'snap> { let prev_working_dir = std::env::current_dir().ok(); if let Err(err) = std::env::set_current_dir(dir) { eprintln!( - "Failed to set the current working dir to {}. Error: {err:?}", + "Failed to change the current working dir to {}. Error: {err:?}", dir.display() ) } @@ -367,7 +370,7 @@ impl Drop for EnvChange<'_> { && let Err(err) = std::env::set_current_dir(dir) { eprintln!( - "Failed to set the current working dir to {}. Error: {:?}", + "Failed to change the current working dir back to {}. Error: {:?}", dir.display(), err ) diff --git a/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs b/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs index c114d52ec3..6b6bfcc934 100644 --- a/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs +++ b/crates/proc-macro-srv/src/server_impl/rust_analyzer_span.rs @@ -62,8 +62,9 @@ impl server::Server for RaSpanServer<'_> { self.tracked_paths.insert(path.into()); } - fn literal_from_str(&mut self, s: &str) -> Result<Literal<Self::Span>, ()> { + fn literal_from_str(&mut self, s: &str) -> Result<Literal<Self::Span>, String> { literal_from_str(s, self.call_site) + .map_err(|()| "cannot parse string into literal".to_string()) } fn emit_diagnostic(&mut self, _: Diagnostic<Self::Span>) { @@ -81,14 +82,9 @@ impl server::Server for RaSpanServer<'_> { fn ts_is_empty(&mut self, stream: &Self::TokenStream) -> bool { stream.is_empty() } - fn ts_from_str(&mut self, src: &str) -> Self::TokenStream { - Self::TokenStream::from_str(src, self.call_site).unwrap_or_else(|e| { - Self::TokenStream::from_str( - &format!("compile_error!(\"failed to parse str to token stream: {e}\")"), - self.call_site, - ) - .unwrap() - }) + fn ts_from_str(&mut self, src: &str) -> Result<Self::TokenStream, String> { + Self::TokenStream::from_str(src, self.call_site) + .map_err(|e| format!("failed to parse str to token stream: {e}")) } fn ts_to_string(&mut self, stream: &Self::TokenStream) -> String { stream.to_string() @@ -168,12 +164,16 @@ impl server::Server for RaSpanServer<'_> { self.callback.as_mut()?.source_text(span) } - fn span_parent(&mut self, _span: Self::Span) -> Option<Self::Span> { - // FIXME requires db, looks up the parent call site + fn span_parent(&mut self, span: Self::Span) -> Option<Self::Span> { + if let Some(ref mut callback) = self.callback { + return callback.span_parent(span); + } None } fn span_source(&mut self, span: Self::Span) -> Self::Span { - // FIXME requires db, returns the top level call site + if let Some(ref mut callback) = self.callback { + return callback.span_source(span); + } span } fn span_byte_range(&mut self, span: Self::Span) -> Range<usize> { diff --git a/crates/proc-macro-srv/src/server_impl/token_id.rs b/crates/proc-macro-srv/src/server_impl/token_id.rs index 70484c4dc2..e1c96095c8 100644 --- a/crates/proc-macro-srv/src/server_impl/token_id.rs +++ b/crates/proc-macro-srv/src/server_impl/token_id.rs @@ -67,8 +67,9 @@ impl server::Server for SpanIdServer<'_> { self.tracked_paths.insert(path.into()); } - fn literal_from_str(&mut self, s: &str) -> Result<Literal<Self::Span>, ()> { + fn literal_from_str(&mut self, s: &str) -> Result<Literal<Self::Span>, String> { literal_from_str(s, self.call_site) + .map_err(|()| "cannot parse string into literal".to_string()) } fn emit_diagnostic(&mut self, _: Diagnostic<Self::Span>) {} @@ -84,14 +85,9 @@ impl server::Server for SpanIdServer<'_> { fn ts_is_empty(&mut self, stream: &Self::TokenStream) -> bool { stream.is_empty() } - fn ts_from_str(&mut self, src: &str) -> Self::TokenStream { - Self::TokenStream::from_str(src, self.call_site).unwrap_or_else(|e| { - Self::TokenStream::from_str( - &format!("compile_error!(\"failed to parse str to token stream: {e}\")"), - self.call_site, - ) - .unwrap() - }) + fn ts_from_str(&mut self, src: &str) -> Result<Self::TokenStream, String> { + Self::TokenStream::from_str(src, self.call_site) + .map_err(|e| format!("failed to parse str to token stream: {e}")) } fn ts_to_string(&mut self, stream: &Self::TokenStream) -> String { stream.to_string() diff --git a/crates/proc-macro-srv/src/tests/utils.rs b/crates/proc-macro-srv/src/tests/utils.rs index b7c5c4fdd2..31beca20d6 100644 --- a/crates/proc-macro-srv/src/tests/utils.rs +++ b/crates/proc-macro-srv/src/tests/utils.rs @@ -142,6 +142,14 @@ impl ProcMacroClientInterface for MockCallback<'_> { fn byte_range(&mut self, span: Span) -> Range<usize> { Range { start: span.range.start().into(), end: span.range.end().into() } } + + fn span_source(&mut self, span: Span) -> Span { + span + } + + fn span_parent(&mut self, _span: Span) -> Option<Span> { + None + } } pub fn assert_expand_with_callback( diff --git a/crates/profile/src/google_cpu_profiler.rs b/crates/profile/src/google_cpu_profiler.rs index cae6caeaa6..d77c945f26 100644 --- a/crates/profile/src/google_cpu_profiler.rs +++ b/crates/profile/src/google_cpu_profiler.rs @@ -9,7 +9,7 @@ use std::{ #[link(name = "profiler")] #[allow(non_snake_case)] -extern "C" { +unsafe extern "C" { fn ProfilerStart(fname: *const c_char) -> i32; fn ProfilerStop(); } diff --git a/crates/project-model/src/build_dependencies.rs b/crates/project-model/src/build_dependencies.rs index fedc6944f5..aff5391697 100644 --- a/crates/project-model/src/build_dependencies.rs +++ b/crates/project-model/src/build_dependencies.rs @@ -22,8 +22,9 @@ use triomphe::Arc; use crate::{ CargoConfig, CargoFeatures, CargoWorkspace, InvocationStrategy, ManifestPath, Package, Sysroot, - TargetKind, cargo_config_file::make_lockfile_copy, - cargo_workspace::MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH, utf8_stdout, + TargetKind, + cargo_config_file::{LockfileCopy, LockfileUsage, make_lockfile_copy}, + utf8_stdout, }; /// Output of the build script and proc-macro building steps for a workspace. @@ -436,7 +437,7 @@ impl WorkspaceBuildScripts { current_dir: &AbsPath, sysroot: &Sysroot, toolchain: Option<&semver::Version>, - ) -> io::Result<(Option<temp_dir::TempDir>, Command)> { + ) -> io::Result<(Option<LockfileCopy>, Command)> { match config.run_build_script_command.as_deref() { Some([program, args @ ..]) => { let mut cmd = toolchain::command(program, current_dir, &config.extra_env); @@ -461,17 +462,26 @@ impl WorkspaceBuildScripts { if let Some(target) = &config.target { cmd.args(["--target", target]); } - let mut temp_dir_guard = None; - if toolchain - .is_some_and(|v| *v >= MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH) - { + let mut lockfile_copy = None; + if let Some(toolchain) = toolchain { let lockfile_path = <_ as AsRef<Utf8Path>>::as_ref(manifest_path).with_extension("lock"); - if let Some((temp_dir, target_lockfile)) = make_lockfile_copy(&lockfile_path) { + lockfile_copy = make_lockfile_copy(toolchain, &lockfile_path); + if let Some(lockfile_copy) = &lockfile_copy { requires_unstable_options = true; - temp_dir_guard = Some(temp_dir); - cmd.arg("--lockfile-path"); - cmd.arg(target_lockfile.as_str()); + match lockfile_copy.usage { + LockfileUsage::WithFlag => { + cmd.arg("--lockfile-path"); + cmd.arg(lockfile_copy.path.as_str()); + } + LockfileUsage::WithEnvVar => { + cmd.arg("-Zlockfile-path"); + cmd.env( + "CARGO_RESOLVER_LOCKFILE_PATH", + lockfile_copy.path.as_os_str(), + ); + } + } } } match &config.features { @@ -542,7 +552,7 @@ impl WorkspaceBuildScripts { cmd.env("__CARGO_TEST_CHANNEL_OVERRIDE_DO_NOT_USE_THIS", "nightly"); cmd.arg("-Zunstable-options"); } - Ok((temp_dir_guard, cmd)) + Ok((lockfile_copy, cmd)) } } } diff --git a/crates/project-model/src/cargo_config_file.rs b/crates/project-model/src/cargo_config_file.rs index 5d6e5fd648..9ce88484f7 100644 --- a/crates/project-model/src/cargo_config_file.rs +++ b/crates/project-model/src/cargo_config_file.rs @@ -132,25 +132,66 @@ impl<'a> CargoConfigFileReader<'a> { } } +pub(crate) struct LockfileCopy { + pub(crate) path: Utf8PathBuf, + pub(crate) usage: LockfileUsage, + _temp_dir: temp_dir::TempDir, +} + +pub(crate) enum LockfileUsage { + /// Rust [1.82.0, 1.95.0). `cargo <subcmd> --lockfile-path <lockfile path>` + WithFlag, + /// Rust >= 1.95.0. `CARGO_RESOLVER_LOCKFILE_PATH=<lockfile path> cargo <subcmd>` + WithEnvVar, +} + pub(crate) fn make_lockfile_copy( + toolchain_version: &semver::Version, lockfile_path: &Utf8Path, -) -> Option<(temp_dir::TempDir, Utf8PathBuf)> { +) -> Option<LockfileCopy> { + const MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH_FLAG: semver::Version = + semver::Version { + major: 1, + minor: 82, + patch: 0, + pre: semver::Prerelease::EMPTY, + build: semver::BuildMetadata::EMPTY, + }; + + // TODO: turn this into a const and remove pre once 1.95 is stable + #[allow(non_snake_case)] + let MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH_ENV: semver::Version = semver::Version { + major: 1, + minor: 95, + patch: 0, + pre: semver::Prerelease::new("beta").unwrap(), + build: semver::BuildMetadata::EMPTY, + }; + + let usage = if *toolchain_version >= MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH_ENV { + LockfileUsage::WithEnvVar + } else if *toolchain_version >= MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH_FLAG { + LockfileUsage::WithFlag + } else { + return None; + }; + let temp_dir = temp_dir::TempDir::with_prefix("rust-analyzer").ok()?; - let target_lockfile = temp_dir.path().join("Cargo.lock").try_into().ok()?; - match std::fs::copy(lockfile_path, &target_lockfile) { + let path: Utf8PathBuf = temp_dir.path().join("Cargo.lock").try_into().ok()?; + let path = match std::fs::copy(lockfile_path, &path) { Ok(_) => { - tracing::debug!("Copied lock file from `{}` to `{}`", lockfile_path, target_lockfile); - Some((temp_dir, target_lockfile)) + tracing::debug!("Copied lock file from `{}` to `{}`", lockfile_path, path); + path } // lockfile does not yet exist, so we can just create a new one in the temp dir - Err(e) if e.kind() == std::io::ErrorKind::NotFound => Some((temp_dir, target_lockfile)), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => path, Err(e) => { - tracing::warn!( - "Failed to copy lock file from `{lockfile_path}` to `{target_lockfile}`: {e}", - ); - None + tracing::warn!("Failed to copy lock file from `{lockfile_path}` to `{path}`: {e}",); + return None; } - } + }; + + Some(LockfileCopy { path, usage, _temp_dir: temp_dir }) } #[test] diff --git a/crates/project-model/src/cargo_workspace.rs b/crates/project-model/src/cargo_workspace.rs index 483ab28450..792206b74f 100644 --- a/crates/project-model/src/cargo_workspace.rs +++ b/crates/project-model/src/cargo_workspace.rs @@ -16,18 +16,10 @@ use toolchain::{NO_RUSTUP_AUTO_INSTALL_ENV, Tool}; use triomphe::Arc; use crate::{ - CfgOverrides, InvocationStrategy, ManifestPath, Sysroot, cargo_config_file::make_lockfile_copy, + CfgOverrides, InvocationStrategy, ManifestPath, Sysroot, + cargo_config_file::{LockfileCopy, LockfileUsage, make_lockfile_copy}, }; -pub(crate) const MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH: semver::Version = - semver::Version { - major: 1, - minor: 82, - patch: 0, - pre: semver::Prerelease::EMPTY, - build: semver::BuildMetadata::EMPTY, - }; - /// [`CargoWorkspace`] represents the logical structure of, well, a Cargo /// workspace. It pretty closely mirrors `cargo metadata` output. /// @@ -628,7 +620,7 @@ pub(crate) struct FetchMetadata { command: cargo_metadata::MetadataCommand, #[expect(dead_code)] manifest_path: ManifestPath, - lockfile_path: Option<Utf8PathBuf>, + lockfile_copy: Option<LockfileCopy>, #[expect(dead_code)] kind: &'static str, no_deps: bool, @@ -688,15 +680,14 @@ impl FetchMetadata { } } - let mut lockfile_path = None; + let mut lockfile_copy = None; if cargo_toml.is_rust_manifest() { other_options.push("-Zscript".to_owned()); - } else if config - .toolchain_version - .as_ref() - .is_some_and(|v| *v >= MINIMUM_TOOLCHAIN_VERSION_SUPPORTING_LOCKFILE_PATH) - { - lockfile_path = Some(<_ as AsRef<Utf8Path>>::as_ref(cargo_toml).with_extension("lock")); + } else if let Some(v) = config.toolchain_version.as_ref() { + lockfile_copy = make_lockfile_copy( + v, + &<_ as AsRef<Utf8Path>>::as_ref(cargo_toml).with_extension("lock"), + ); } if !config.targets.is_empty() { @@ -729,7 +720,7 @@ impl FetchMetadata { Self { manifest_path: cargo_toml.clone(), command, - lockfile_path, + lockfile_copy, kind: config.kind, no_deps, no_deps_result, @@ -749,7 +740,7 @@ impl FetchMetadata { let Self { mut command, manifest_path: _, - lockfile_path, + lockfile_copy, kind: _, no_deps, no_deps_result, @@ -761,13 +752,17 @@ impl FetchMetadata { } let mut using_lockfile_copy = false; - let mut _temp_dir_guard; - if let Some(lockfile) = lockfile_path - && let Some((temp_dir, target_lockfile)) = make_lockfile_copy(&lockfile) - { - _temp_dir_guard = temp_dir; - other_options.push("--lockfile-path".to_owned()); - other_options.push(target_lockfile.to_string()); + if let Some(lockfile_copy) = &lockfile_copy { + match lockfile_copy.usage { + LockfileUsage::WithFlag => { + other_options.push("--lockfile-path".to_owned()); + other_options.push(lockfile_copy.path.to_string()); + } + LockfileUsage::WithEnvVar => { + other_options.push("-Zlockfile-path".to_owned()); + command.env("CARGO_RESOLVER_LOCKFILE_PATH", lockfile_copy.path.as_os_str()); + } + } using_lockfile_copy = true; } if using_lockfile_copy || other_options.iter().any(|it| it.starts_with("-Z")) { diff --git a/crates/project-model/src/project_json.rs b/crates/project-model/src/project_json.rs index 6938010cbd..4ea136afbb 100644 --- a/crates/project-model/src/project_json.rs +++ b/crates/project-model/src/project_json.rs @@ -365,9 +365,27 @@ pub enum RunnableKind { /// May include {test_id} which will get the test clicked on by the user. TestOne, + /// Run tests matching a pattern (in RA, usually a path::to::module::of::tests) + /// May include {label} which will get the label from the `build` section of a crate. + /// May include {test_pattern} which will get the test module clicked on by the user. + TestMod, + + /// Run a single doctest + /// May include {label} which will get the label from the `build` section of a crate. + /// May include {test_id} which will get the doctest clicked on by the user. + DocTestOne, + + /// Run a single benchmark + /// May include {label} which will get the label from the `build` section of a crate. + /// May include {bench_id} which will get the benchmark clicked on by the user. + BenchOne, + /// Template for checking a target, emitting rustc JSON diagnostics. /// May include {label} which will get the label from the `build` section of a crate. Flycheck, + + /// For forwards-compatibility, i.e. old rust-analyzer binary with newer workspace discovery tools + Unknown, } #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] @@ -380,6 +398,8 @@ pub struct ProjectJsonData { crates: Vec<CrateData>, #[serde(default)] runnables: Vec<RunnableData>, + // + // New fields should be Option or #[serde(default)]. This applies to most of this datastructure. } #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq, Default)] @@ -391,7 +411,6 @@ struct CrateData { display_name: Option<String>, root_module: Utf8PathBuf, edition: EditionData, - #[serde(default)] version: Option<semver::Version>, deps: Vec<Dep>, #[serde(default)] @@ -408,11 +427,8 @@ struct CrateData { source: Option<CrateSource>, #[serde(default)] is_proc_macro: bool, - #[serde(default)] repository: Option<String>, - #[serde(default)] build: Option<BuildData>, - #[serde(default)] proc_macro_cwd: Option<Utf8PathBuf>, } @@ -457,32 +473,40 @@ enum EditionData { } #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -pub struct BuildData { +struct BuildData { label: String, build_file: Utf8PathBuf, target_kind: TargetKindData, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct RunnableData { - pub program: String, - pub args: Vec<String>, - pub cwd: Utf8PathBuf, - pub kind: RunnableKindData, +struct RunnableData { + program: String, + args: Vec<String>, + cwd: Utf8PathBuf, + kind: RunnableKindData, } #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -pub enum RunnableKindData { +enum RunnableKindData { Flycheck, Check, Run, TestOne, + TestMod, + DocTestOne, + BenchOne, + + /// For forwards-compatibility, i.e. old rust-analyzer binary with newer workspace discovery tools + #[allow(unused)] + #[serde(other)] + Unknown, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -pub enum TargetKindData { +enum TargetKindData { Bin, /// Any kind of Cargo lib crate-type (dylib, rlib, proc-macro, ...). Lib, @@ -545,7 +569,11 @@ impl From<RunnableKindData> for RunnableKind { RunnableKindData::Check => RunnableKind::Check, RunnableKindData::Run => RunnableKind::Run, RunnableKindData::TestOne => RunnableKind::TestOne, + RunnableKindData::TestMod => RunnableKind::TestMod, + RunnableKindData::DocTestOne => RunnableKind::DocTestOne, + RunnableKindData::BenchOne => RunnableKind::BenchOne, RunnableKindData::Flycheck => RunnableKind::Flycheck, + RunnableKindData::Unknown => RunnableKind::Unknown, } } } diff --git a/crates/project-model/src/sysroot.rs b/crates/project-model/src/sysroot.rs index f244c9736c..546a1e05a0 100644 --- a/crates/project-model/src/sysroot.rs +++ b/crates/project-model/src/sysroot.rs @@ -275,7 +275,10 @@ impl Sysroot { } tracing::debug!("Stitching sysroot library: {src_root}"); - let mut stitched = stitched::Stitched { crates: Default::default() }; + let mut stitched = stitched::Stitched { + crates: Default::default(), + edition: span::Edition::Edition2024, + }; for path in stitched::SYSROOT_CRATES.trim().lines() { let name = path.split('/').next_back().unwrap(); @@ -511,6 +514,7 @@ pub(crate) mod stitched { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Stitched { pub(super) crates: Arena<RustLibSrcCrateData>, + pub(crate) edition: span::Edition, } impl ops::Index<RustLibSrcCrate> for Stitched { diff --git a/crates/project-model/src/tests.rs b/crates/project-model/src/tests.rs index a03ed562e1..395cea6f76 100644 --- a/crates/project-model/src/tests.rs +++ b/crates/project-model/src/tests.rs @@ -193,6 +193,12 @@ fn rust_project_hello_world_project_model() { } #[test] +fn rust_project_labeled_project_model() { + // This just needs to parse. + _ = load_rust_project("labeled-project.json"); +} + +#[test] fn rust_project_cfg_groups() { let (crate_graph, _proc_macros) = load_rust_project("cfg-groups.json"); check_crate_graph(crate_graph, expect_file!["../test_data/output/rust_project_cfg_groups.txt"]); diff --git a/crates/project-model/src/workspace.rs b/crates/project-model/src/workspace.rs index 8f15f7e150..581b5fa514 100644 --- a/crates/project-model/src/workspace.rs +++ b/crates/project-model/src/workspace.rs @@ -1831,7 +1831,7 @@ fn sysroot_to_crate_graph( let display_name = CrateDisplayName::from_canonical_name(&stitched[krate].name); let crate_id = crate_graph.add_crate_root( file_id, - Edition::CURRENT_FIXME, + stitched.edition, Some(display_name), None, cfg_options.clone(), diff --git a/crates/project-model/test_data/labeled-project.json b/crates/project-model/test_data/labeled-project.json new file mode 100644 index 0000000000..5c0e1f3397 --- /dev/null +++ b/crates/project-model/test_data/labeled-project.json @@ -0,0 +1,37 @@ +{ + "sysroot_src": null, + "crates": [ + { + "display_name": "hello_world", + "root_module": "$ROOT$src/lib.rs", + "edition": "2018", + "deps": [], + "is_workspace_member": true, + "build": { + "label": "//:hello_world", + "build_file": "$ROOT$BUILD", + "target_kind": "bin" + } + } + ], + "runnables": [ + { + "kind": "run", + "program": "bazel", + "args": ["run", "{label}"], + "cwd": "$ROOT$" + }, + { + "kind": "flycheck", + "program": "$ROOT$custom-flychecker.sh", + "args": ["{label}"], + "cwd": "$ROOT$" + }, + { + "kind": "we-ignore-unknown-runnable-kinds-for-forwards-compatibility", + "program": "abc", + "args": ["{label}"], + "cwd": "$ROOT$" + } + ] +} diff --git a/crates/rust-analyzer/Cargo.toml b/crates/rust-analyzer/Cargo.toml index d1283ca59e..beb83a8173 100644 --- a/crates/rust-analyzer/Cargo.toml +++ b/crates/rust-analyzer/Cargo.toml @@ -94,6 +94,8 @@ test-utils.workspace = true test-fixture.workspace = true syntax-bridge.workspace = true +intern = { path = "../intern", features = ["prevent-gc"] } + [features] jemalloc = ["jemallocator", "profile/jemalloc"] force-always-assert = ["stdx/force-always-assert"] diff --git a/crates/rust-analyzer/src/cli/analysis_stats.rs b/crates/rust-analyzer/src/cli/analysis_stats.rs index 1995d38898..74828cba02 100644 --- a/crates/rust-analyzer/src/cli/analysis_stats.rs +++ b/crates/rust-analyzer/src/cli/analysis_stats.rs @@ -10,15 +10,15 @@ use std::{ use cfg::{CfgAtom, CfgDiff}; use hir::{ - Adt, AssocItem, Crate, DefWithBody, FindPathConfig, HasCrate, HasSource, HirDisplay, ModuleDef, - Name, crate_lang_items, + Adt, AssocItem, Crate, DefWithBody, FindPathConfig, GenericDef, HasCrate, HasSource, + HirDisplay, ModuleDef, Name, Variant, VariantId, crate_lang_items, db::{DefDatabase, ExpandDatabase, HirDatabase}, next_solver::{DbInterner, GenericArgs}, }; use hir_def::{ - SyntheticSyntax, - expr_store::BodySourceMap, - hir::{ExprId, PatId}, + DefWithBodyId, ExpressionStoreOwnerId, GenericDefId, SyntheticSyntax, + expr_store::{Body, BodySourceMap, ExpressionStore}, + hir::{ExprId, PatId, generics::GenericParams}, }; use hir_ty::InferenceResult; use ide::{ @@ -91,6 +91,7 @@ impl flags::AnalysisStats { } }, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; @@ -126,8 +127,8 @@ impl flags::AnalysisStats { let source_roots = krates .iter() .cloned() - .map(|krate| db.file_source_root(krate.root_file(db)).source_root_id(db)) - .unique(); + .map(|krate| (db.file_source_root(krate.root_file(db)).source_root_id(db), krate)) + .unique_by(|(source_root_id, _)| *source_root_id); let mut dep_loc = 0; let mut workspace_loc = 0; @@ -137,7 +138,7 @@ impl flags::AnalysisStats { let mut workspace_item_stats = PrettyItemStats::default(); let mut dep_item_stats = PrettyItemStats::default(); - for source_root_id in source_roots { + for (source_root_id, krate) in source_roots { let source_root = db.source_root(source_root_id).source_root(db); for file_id in source_root.iter() { if let Some(p) = source_root.path_for_file(&file_id) @@ -148,7 +149,8 @@ impl flags::AnalysisStats { let length = db.file_text(file_id).text(db).lines().count(); let item_stats = db .file_item_tree( - EditionedFileId::current_edition_guess_origin(db, file_id).into(), + EditionedFileId::current_edition(db, file_id).into(), + krate.into(), ) .item_tree_stats() .into(); @@ -160,7 +162,8 @@ impl flags::AnalysisStats { let length = db.file_text(file_id).text(db).lines().count(); let item_stats = db .file_item_tree( - EditionedFileId::current_edition_guess_origin(db, file_id).into(), + EditionedFileId::current_edition(db, file_id).into(), + krate.into(), ) .item_tree_stats() .into(); @@ -226,6 +229,8 @@ impl flags::AnalysisStats { eprint!(" crates: {num_crates}"); let mut num_decls = 0; let mut bodies = Vec::new(); + let mut signatures = Vec::new(); + let mut variants = Vec::new(); let mut adts = Vec::new(); let mut file_ids = Vec::new(); @@ -243,10 +248,15 @@ impl flags::AnalysisStats { match decl { ModuleDef::Function(f) => bodies.push(DefWithBody::from(f)), ModuleDef::Adt(a) => { - if let Adt::Enum(e) = a { - for v in e.variants(db) { - bodies.push(DefWithBody::from(v)); + match a { + Adt::Enum(e) => { + for v in e.variants(db) { + bodies.push(DefWithBody::from(v)); + variants.push(Variant::EnumVariant(v)); + } } + Adt::Struct(it) => variants.push(Variant::Struct(it)), + Adt::Union(it) => variants.push(Variant::Union(it)), } adts.push(a) } @@ -264,24 +274,32 @@ impl flags::AnalysisStats { }, _ => (), }; + if let Some(g) = decl.as_generic_def() { + signatures.push(g); + } } for impl_def in module.impl_defs(db) { + signatures.push(impl_def.into()); for item in impl_def.items(db) { num_decls += 1; match item { - AssocItem::Function(f) => bodies.push(DefWithBody::from(f)), + AssocItem::Function(f) => { + bodies.push(DefWithBody::from(f)); + signatures.push(f.into()) + } AssocItem::Const(c) => { bodies.push(DefWithBody::from(c)); + signatures.push(c.into()); } - _ => (), + AssocItem::TypeAlias(t) => signatures.push(t.into()), } } } } } eprintln!( - ", mods: {}, decls: {num_decls}, bodies: {}, adts: {}, consts: {}", + ", mods: {}, decls: {num_decls}, bodies: {}, adts: {}, consts: {}, signatures: {}, variants: {}", visited_modules.len(), bodies.len(), adts.len(), @@ -289,6 +307,8 @@ impl flags::AnalysisStats { .iter() .filter(|it| matches!(it, DefWithBody::Const(_) | DefWithBody::Static(_))) .count(), + signatures.len(), + variants.len() ); eprintln!(" Workspace:"); @@ -324,15 +344,15 @@ impl flags::AnalysisStats { } if !self.skip_lowering { - self.run_body_lowering(db, &vfs, &bodies, verbosity); + self.run_body_lowering(db, &vfs, &bodies, &signatures, &variants, verbosity); } if !self.skip_inference { - self.run_inference(db, &vfs, &bodies, verbosity); + self.run_inference(db, &vfs, &bodies, &signatures, &variants, verbosity); } if !self.skip_mir_stats { - self.run_mir_lowering(db, &bodies, verbosity); + self.run_mir_lowering(db, &bodies, &signatures, &variants, verbosity); } if !self.skip_data_layout { @@ -340,7 +360,7 @@ impl flags::AnalysisStats { } if !self.skip_const_eval { - self.run_const_eval(db, &bodies, verbosity); + self.run_const_eval(db, &bodies, &signatures, &variants, verbosity); } }); @@ -381,7 +401,7 @@ impl flags::AnalysisStats { let mut fail = 0; for &a in adts { let interner = DbInterner::new_no_crate(db); - let generic_params = db.generic_params(a.into()); + let generic_params = GenericParams::of(db, a.into()); if generic_params.iter_type_or_consts().next().is_some() || generic_params.iter_lt().next().is_some() { @@ -393,7 +413,7 @@ impl flags::AnalysisStats { hir_def::AdtId::from(a), GenericArgs::empty(interner).store(), hir_ty::ParamEnvAndCrate { - param_env: db.trait_environment(a.into()), + param_env: db.trait_environment(GenericDefId::from(a).into()), krate: a.krate(db).into(), } .store(), @@ -413,7 +433,14 @@ impl flags::AnalysisStats { report_metric("data layout time", data_layout_time.time.as_millis() as u64, "ms"); } - fn run_const_eval(&self, db: &RootDatabase, bodies: &[DefWithBody], verbosity: Verbosity) { + fn run_const_eval( + &self, + db: &RootDatabase, + bodies: &[DefWithBody], + _signatures: &[GenericDef], + _variants: &[Variant], + verbosity: Verbosity, + ) { let len = bodies .iter() .filter(|body| matches!(body, DefWithBody::Const(_) | DefWithBody::Static(_))) @@ -428,7 +455,9 @@ impl flags::AnalysisStats { let mut all = 0; let mut fail = 0; for &b in bodies { - bar.set_message(move || format!("const eval: {}", full_name(db, b, b.module(db)))); + bar.set_message(move || { + format!("const eval: {}", full_name(db, || b.name(db), b.module(db))) + }); let res = match b { DefWithBody::Const(c) => c.eval(db), DefWithBody::Static(s) => s.eval(db), @@ -492,7 +521,7 @@ impl flags::AnalysisStats { let mut sw = self.stop_watch(); for &file_id in file_ids { - let file_id = file_id.editioned_file_id(db); + let file_id = file_id.span_file_id(db); let sema = hir::Semantics::new(db); let display_target = match sema.first_crate(file_id.file_id()) { Some(krate) => krate.to_display_target(sema.db), @@ -684,7 +713,14 @@ impl flags::AnalysisStats { bar.finish_and_clear(); } - fn run_mir_lowering(&self, db: &RootDatabase, bodies: &[DefWithBody], verbosity: Verbosity) { + fn run_mir_lowering( + &self, + db: &RootDatabase, + bodies: &[DefWithBody], + _signatures: &[GenericDef], + _variants: &[Variant], + verbosity: Verbosity, + ) { let mut bar = match verbosity { Verbosity::Quiet | Verbosity::Spammy => ProgressReport::hidden(), _ if self.parallel || self.output.is_some() => ProgressReport::hidden(), @@ -695,14 +731,14 @@ impl flags::AnalysisStats { let mut fail = 0; for &body in bodies { bar.set_message(move || { - format!("mir lowering: {}", full_name(db, body, body.module(db))) + format!("mir lowering: {}", full_name(db, || body.name(db), body.module(db))) }); bar.inc(1); - if matches!(body, DefWithBody::Variant(_)) { + if matches!(body, DefWithBody::EnumVariant(_)) { continue; } let module = body.module(db); - if !self.should_process(db, body, module) { + if !self.should_process(db, || body.name(db), module) { continue; } @@ -740,6 +776,8 @@ impl flags::AnalysisStats { db: &RootDatabase, vfs: &Vfs, bodies: &[DefWithBody], + signatures: &[GenericDef], + variants: &[Variant], verbosity: Verbosity, ) { let mut bar = match verbosity { @@ -750,12 +788,31 @@ impl flags::AnalysisStats { if self.parallel { let mut inference_sw = self.stop_watch(); - let bodies = bodies.iter().filter_map(|&body| body.try_into().ok()).collect::<Vec<_>>(); + let bodies = bodies + .iter() + .filter_map(|&body| body.try_into().ok()) + .collect::<Vec<DefWithBodyId>>(); bodies .par_iter() .map_with(db.clone(), |snap, &body| { - snap.body(body); - InferenceResult::for_body(snap, body); + InferenceResult::of(snap, body); + }) + .count(); + let signatures = signatures + .iter() + .filter_map(|&signatures| signatures.try_into().ok()) + .collect::<Vec<GenericDefId>>(); + signatures + .par_iter() + .map_with(db.clone(), |snap, &signatures| { + InferenceResult::of(snap, signatures); + }) + .count(); + let variants = variants.iter().copied().map(Into::into).collect::<Vec<VariantId>>(); + variants + .par_iter() + .map_with(db.clone(), |snap, &variants| { + InferenceResult::of(snap, variants); }) .count(); eprintln!("{:<20} {}", "Parallel Inference:", inference_sw.elapsed()); @@ -779,7 +836,7 @@ impl flags::AnalysisStats { let display_target = module.krate(db).to_display_target(db); if let Some(only_name) = self.only.as_deref() && name.display(db, Edition::LATEST).to_string() != only_name - && full_name(db, body_id, module) != only_name + && full_name(db, || body_id.name(db), module) != only_name { continue; } @@ -789,7 +846,9 @@ impl flags::AnalysisStats { DefWithBody::Function(it) => it.source(db).map(|it| it.syntax().cloned()), DefWithBody::Static(it) => it.source(db).map(|it| it.syntax().cloned()), DefWithBody::Const(it) => it.source(db).map(|it| it.syntax().cloned()), - DefWithBody::Variant(it) => it.source(db).map(|it| it.syntax().cloned()), + DefWithBody::EnumVariant(it) => { + it.source(db).map(|it| it.syntax().cloned()) + } }; if let Some(src) = source { let original_file = src.file_id.original_file(db); @@ -797,33 +856,44 @@ impl flags::AnalysisStats { let syntax_range = src.text_range(); format!( "processing: {} ({} {:?})", - full_name(db, body_id, module), + full_name(db, || body_id.name(db), module), path, syntax_range ) } else { - format!("processing: {}", full_name(db, body_id, module)) + format!("processing: {}", full_name(db, || body_id.name(db), module)) } } else { - format!("processing: {}", full_name(db, body_id, module)) + format!("processing: {}", full_name(db, || body_id.name(db), module)) } }; if verbosity.is_spammy() { bar.println(msg()); } bar.set_message(msg); - let body = db.body(body_def_id); + let body = Body::of(db, body_def_id); let inference_result = - catch_unwind(AssertUnwindSafe(|| InferenceResult::for_body(db, body_def_id))); + catch_unwind(AssertUnwindSafe(|| InferenceResult::of(db, body_def_id))); let inference_result = match inference_result { Ok(inference_result) => inference_result, Err(p) => { if let Some(s) = p.downcast_ref::<&str>() { - eprintln!("infer panicked for {}: {}", full_name(db, body_id, module), s); + eprintln!( + "infer panicked for {}: {}", + full_name(db, || body_id.name(db), module), + s + ); } else if let Some(s) = p.downcast_ref::<String>() { - eprintln!("infer panicked for {}: {}", full_name(db, body_id, module), s); + eprintln!( + "infer panicked for {}: {}", + full_name(db, || body_id.name(db), module), + s + ); } else { - eprintln!("infer panicked for {}", full_name(db, body_id, module)); + eprintln!( + "infer panicked for {}", + full_name(db, || body_id.name(db), module) + ); } panics += 1; bar.inc(1); @@ -831,7 +901,7 @@ impl flags::AnalysisStats { } }; // This query is LRU'd, so actually calling it will skew the timing results. - let sm = || db.body_with_source_map(body_def_id).1; + let sm = || &Body::with_source_map(db, body_def_id).1; // region:expressions let (previous_exprs, previous_unknown, previous_partially_unknown) = @@ -842,7 +912,7 @@ impl flags::AnalysisStats { let unknown_or_partial = if ty.is_ty_error() { num_exprs_unknown += 1; if verbosity.is_spammy() { - if let Some((path, start, end)) = expr_syntax_range(db, vfs, &sm(), expr_id) + if let Some((path, start, end)) = expr_syntax_range(db, vfs, sm(), expr_id) { bar.println(format!( "{} {}:{}-{}:{}: Unknown type", @@ -869,7 +939,7 @@ impl flags::AnalysisStats { }; if self.only.is_some() && verbosity.is_spammy() { // in super-verbose mode for just one function, we print every single expression - if let Some((_, start, end)) = expr_syntax_range(db, vfs, &sm(), expr_id) { + if let Some((_, start, end)) = expr_syntax_range(db, vfs, sm(), expr_id) { bar.println(format!( "{}:{}-{}:{}: {}", start.line + 1, @@ -888,14 +958,14 @@ impl flags::AnalysisStats { if unknown_or_partial && self.output == Some(OutputFormat::Csv) { println!( r#"{},type,"{}""#, - location_csv_expr(db, vfs, &sm(), expr_id), + location_csv_expr(db, vfs, sm(), expr_id), ty.display(db, display_target) ); } if let Some(mismatch) = inference_result.type_mismatch_for_expr(expr_id) { num_expr_type_mismatches += 1; if verbosity.is_verbose() { - if let Some((path, start, end)) = expr_syntax_range(db, vfs, &sm(), expr_id) + if let Some((path, start, end)) = expr_syntax_range(db, vfs, sm(), expr_id) { bar.println(format!( "{} {}:{}-{}:{}: Expected {}, got {}", @@ -919,7 +989,7 @@ impl flags::AnalysisStats { if self.output == Some(OutputFormat::Csv) { println!( r#"{},mismatch,"{}","{}""#, - location_csv_expr(db, vfs, &sm(), expr_id), + location_csv_expr(db, vfs, sm(), expr_id), mismatch.expected.as_ref().display(db, display_target), mismatch.actual.as_ref().display(db, display_target) ); @@ -929,7 +999,7 @@ impl flags::AnalysisStats { if verbosity.is_spammy() { bar.println(format!( "In {}: {} exprs, {} unknown, {} partial", - full_name(db, body_id, module), + full_name(db, || body_id.name(db), module), num_exprs - previous_exprs, num_exprs_unknown - previous_unknown, num_exprs_partially_unknown - previous_partially_unknown @@ -946,7 +1016,7 @@ impl flags::AnalysisStats { let unknown_or_partial = if ty.is_ty_error() { num_pats_unknown += 1; if verbosity.is_spammy() { - if let Some((path, start, end)) = pat_syntax_range(db, vfs, &sm(), pat_id) { + if let Some((path, start, end)) = pat_syntax_range(db, vfs, sm(), pat_id) { bar.println(format!( "{} {}:{}-{}:{}: Unknown type", path, @@ -972,7 +1042,7 @@ impl flags::AnalysisStats { }; if self.only.is_some() && verbosity.is_spammy() { // in super-verbose mode for just one function, we print every single pattern - if let Some((_, start, end)) = pat_syntax_range(db, vfs, &sm(), pat_id) { + if let Some((_, start, end)) = pat_syntax_range(db, vfs, sm(), pat_id) { bar.println(format!( "{}:{}-{}:{}: {}", start.line + 1, @@ -991,14 +1061,14 @@ impl flags::AnalysisStats { if unknown_or_partial && self.output == Some(OutputFormat::Csv) { println!( r#"{},type,"{}""#, - location_csv_pat(db, vfs, &sm(), pat_id), + location_csv_pat(db, vfs, sm(), pat_id), ty.display(db, display_target) ); } if let Some(mismatch) = inference_result.type_mismatch_for_pat(pat_id) { num_pat_type_mismatches += 1; if verbosity.is_verbose() { - if let Some((path, start, end)) = pat_syntax_range(db, vfs, &sm(), pat_id) { + if let Some((path, start, end)) = pat_syntax_range(db, vfs, sm(), pat_id) { bar.println(format!( "{} {}:{}-{}:{}: Expected {}, got {}", path, @@ -1021,7 +1091,7 @@ impl flags::AnalysisStats { if self.output == Some(OutputFormat::Csv) { println!( r#"{},mismatch,"{}","{}""#, - location_csv_pat(db, vfs, &sm(), pat_id), + location_csv_pat(db, vfs, sm(), pat_id), mismatch.expected.as_ref().display(db, display_target), mismatch.actual.as_ref().display(db, display_target) ); @@ -1031,7 +1101,7 @@ impl flags::AnalysisStats { if verbosity.is_spammy() { bar.println(format!( "In {}: {} pats, {} unknown, {} partial", - full_name(db, body_id, module), + full_name(db, || body_id.name(db), module), num_pats - previous_pats, num_pats_unknown - previous_unknown, num_pats_partially_unknown - previous_partially_unknown @@ -1075,20 +1145,104 @@ impl flags::AnalysisStats { db: &RootDatabase, vfs: &Vfs, bodies: &[DefWithBody], + signatures: &[GenericDef], + variants: &[Variant], verbosity: Verbosity, ) { let mut bar = match verbosity { Verbosity::Quiet | Verbosity::Spammy => ProgressReport::hidden(), _ if self.output.is_some() => ProgressReport::hidden(), - _ => ProgressReport::new(bodies.len()), + _ => ProgressReport::new(bodies.len() + signatures.len() + variants.len()), }; let mut sw = self.stop_watch(); bar.tick(); + for &signature in signatures { + let Ok(signature_id) = signature.try_into() else { continue }; + let module = signature.module(db); + if !self.should_process(db, || signature.name(db), module) { + continue; + } + let msg = move || { + if verbosity.is_verbose() { + let source = match signature { + GenericDef::Function(it) => it.source(db).map(|it| it.syntax().cloned()), + GenericDef::Static(it) => it.source(db).map(|it| it.syntax().cloned()), + GenericDef::Const(it) => it.source(db).map(|it| it.syntax().cloned()), + GenericDef::Adt(adt) => adt.source(db).map(|it| it.syntax().cloned()), + GenericDef::Trait(it) => it.source(db).map(|it| it.syntax().cloned()), + GenericDef::TypeAlias(type_alias) => { + type_alias.source(db).map(|it| it.syntax().cloned()) + } + GenericDef::Impl(it) => it.source(db).map(|it| it.syntax().cloned()), + }; + if let Some(src) = source { + let original_file = src.file_id.original_file(db); + let path = vfs.file_path(original_file.file_id(db)); + let syntax_range = src.text_range(); + format!( + "processing: {} ({} {:?})", + full_name(db, || signature.name(db), module), + path, + syntax_range + ) + } else { + format!("processing: {}", full_name(db, || signature.name(db), module)) + } + } else { + format!("processing: {}", full_name(db, || signature.name(db), module)) + } + }; + if verbosity.is_spammy() { + bar.println(msg()); + } + bar.set_message(msg); + ExpressionStore::of(db, ExpressionStoreOwnerId::Signature(signature_id)); + bar.inc(1); + } + + for &variant in variants { + let variant_id = variant.into(); + let module = variant.module(db); + if !self.should_process(db, || Some(variant.name(db)), module) { + continue; + } + let msg = move || { + if verbosity.is_verbose() { + let source = match variant { + Variant::EnumVariant(it) => it.source(db).map(|it| it.syntax().cloned()), + Variant::Struct(it) => it.source(db).map(|it| it.syntax().cloned()), + Variant::Union(it) => it.source(db).map(|it| it.syntax().cloned()), + }; + if let Some(src) = source { + let original_file = src.file_id.original_file(db); + let path = vfs.file_path(original_file.file_id(db)); + let syntax_range = src.text_range(); + format!( + "processing: {} ({} {:?})", + full_name(db, || Some(variant.name(db)), module), + path, + syntax_range + ) + } else { + format!("processing: {}", full_name(db, || Some(variant.name(db)), module)) + } + } else { + format!("processing: {}", full_name(db, || Some(variant.name(db)), module)) + } + }; + if verbosity.is_spammy() { + bar.println(msg()); + } + bar.set_message(msg); + ExpressionStore::of(db, ExpressionStoreOwnerId::VariantFields(variant_id)); + bar.inc(1); + } + for &body_id in bodies { let Ok(body_def_id) = body_id.try_into() else { continue }; let module = body_id.module(db); - if !self.should_process(db, body_id, module) { + if !self.should_process(db, || body_id.name(db), module) { continue; } let msg = move || { @@ -1097,7 +1251,9 @@ impl flags::AnalysisStats { DefWithBody::Function(it) => it.source(db).map(|it| it.syntax().cloned()), DefWithBody::Static(it) => it.source(db).map(|it| it.syntax().cloned()), DefWithBody::Const(it) => it.source(db).map(|it| it.syntax().cloned()), - DefWithBody::Variant(it) => it.source(db).map(|it| it.syntax().cloned()), + DefWithBody::EnumVariant(it) => { + it.source(db).map(|it| it.syntax().cloned()) + } }; if let Some(src) = source { let original_file = src.file_id.original_file(db); @@ -1105,28 +1261,28 @@ impl flags::AnalysisStats { let syntax_range = src.text_range(); format!( "processing: {} ({} {:?})", - full_name(db, body_id, module), + full_name(db, || body_id.name(db), module), path, syntax_range ) } else { - format!("processing: {}", full_name(db, body_id, module)) + format!("processing: {}", full_name(db, || body_id.name(db), module)) } } else { - format!("processing: {}", full_name(db, body_id, module)) + format!("processing: {}", full_name(db, || body_id.name(db), module)) } }; if verbosity.is_spammy() { bar.println(msg()); } bar.set_message(msg); - db.body(body_def_id); + Body::of(db, body_def_id); bar.inc(1); } bar.finish_and_clear(); let body_lowering_time = sw.elapsed(); - eprintln!("{:<20} {}", "Body lowering:", body_lowering_time); + eprintln!("{:<20} {}", "Expression Store Lowering:", body_lowering_time); report_metric("body lowering time", body_lowering_time.time.as_millis() as u64, "ms"); } @@ -1280,12 +1436,17 @@ impl flags::AnalysisStats { eprintln!("{:<20} {} ({} files)", "IDE:", ide_time, file_ids.len()); } - fn should_process(&self, db: &RootDatabase, body_id: DefWithBody, module: hir::Module) -> bool { + fn should_process( + &self, + db: &RootDatabase, + name_fn: impl Fn() -> Option<Name>, + module: hir::Module, + ) -> bool { if let Some(only_name) = self.only.as_deref() { - let name = body_id.name(db).unwrap_or_else(Name::missing); + let name = name_fn().unwrap_or_else(Name::missing); if name.display(db, Edition::LATEST).to_string() != only_name - && full_name(db, body_id, module) != only_name + && full_name(db, name_fn, module) != only_name { return false; } @@ -1298,7 +1459,7 @@ impl flags::AnalysisStats { } } -fn full_name(db: &RootDatabase, body_id: DefWithBody, module: hir::Module) -> String { +fn full_name(db: &RootDatabase, name: impl Fn() -> Option<Name>, module: hir::Module) -> String { module .krate(db) .display_name(db) @@ -1310,7 +1471,7 @@ fn full_name(db: &RootDatabase, body_id: DefWithBody, module: hir::Module) -> St .into_iter() .filter_map(|it| it.name(db)) .rev() - .chain(Some(body_id.name(db).unwrap_or_else(Name::missing))) + .chain(Some(name().unwrap_or_else(Name::missing))) .map(|it| it.display(db, Edition::LATEST).to_string()), ) .join("::") diff --git a/crates/rust-analyzer/src/cli/diagnostics.rs b/crates/rust-analyzer/src/cli/diagnostics.rs index 575c77f842..efbaad3c49 100644 --- a/crates/rust-analyzer/src/cli/diagnostics.rs +++ b/crates/rust-analyzer/src/cli/diagnostics.rs @@ -41,6 +41,7 @@ impl flags::Diagnostics { load_out_dirs_from_check: !self.disable_build_scripts, with_proc_macro_server, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; let (db, _vfs, _proc_macro) = diff --git a/crates/rust-analyzer/src/cli/flags.rs b/crates/rust-analyzer/src/cli/flags.rs index c522060181..03849938f5 100644 --- a/crates/rust-analyzer/src/cli/flags.rs +++ b/crates/rust-analyzer/src/cli/flags.rs @@ -40,6 +40,8 @@ xflags::xflags! { cmd parse { /// Suppress printing. optional --no-dump + /// Output as JSON. + optional --json } /// Parse stdin and print the list of symbols. @@ -189,6 +191,9 @@ xflags::xflags! { /// Exclude code from vendored libraries from the resulting index. optional --exclude-vendored-libraries + + /// The number of worker threads for cache priming. Defaults to the number of physical cores. + optional --num-threads num_threads: usize } } } @@ -233,6 +238,7 @@ pub struct LspServer { #[derive(Debug)] pub struct Parse { pub no_dump: bool, + pub json: bool, } #[derive(Debug)] @@ -257,8 +263,8 @@ pub struct AnalysisStats { pub disable_build_scripts: bool, pub disable_proc_macros: bool, pub proc_macro_srv: Option<PathBuf>, - pub skip_lowering: bool, pub skip_lang_items: bool, + pub skip_lowering: bool, pub skip_inference: bool, pub skip_mir_stats: bool, pub skip_data_layout: bool, @@ -335,6 +341,7 @@ pub struct Scip { pub output: Option<PathBuf>, pub config_path: Option<PathBuf>, pub exclude_vendored_libraries: bool, + pub num_threads: Option<usize>, } impl RustAnalyzer { diff --git a/crates/rust-analyzer/src/cli/lsif.rs b/crates/rust-analyzer/src/cli/lsif.rs index e5e238db63..3950a581fd 100644 --- a/crates/rust-analyzer/src/cli/lsif.rs +++ b/crates/rust-analyzer/src/cli/lsif.rs @@ -293,6 +293,7 @@ impl flags::Lsif { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; let path = AbsPathBuf::assert_utf8(env::current_dir()?.join(self.path)); diff --git a/crates/rust-analyzer/src/cli/parse.rs b/crates/rust-analyzer/src/cli/parse.rs index 85ec95409a..aa1b659d8b 100644 --- a/crates/rust-analyzer/src/cli/parse.rs +++ b/crates/rust-analyzer/src/cli/parse.rs @@ -1,18 +1,101 @@ //! Read Rust code on stdin, print syntax tree on stdout. use ide::Edition; -use syntax::{AstNode, SourceFile}; +use ide_db::line_index::LineIndex; +use serde::Serialize; +use syntax::{AstNode, NodeOrToken, SourceFile, SyntaxNode, SyntaxToken}; use crate::cli::{flags, read_stdin}; +#[derive(Serialize)] +struct JsonNode { + kind: String, + #[serde(rename = "type")] + node_type: &'static str, + start: [u32; 3], + end: [u32; 3], + #[serde(skip_serializing_if = "Option::is_none")] + text: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + children: Option<Vec<JsonNode>>, +} + +fn pos(line_index: &LineIndex, offset: syntax::TextSize) -> [u32; 3] { + let offset_u32 = u32::from(offset); + let line_col = line_index.line_col(offset); + [offset_u32, line_col.line, line_col.col] +} + impl flags::Parse { pub fn run(self) -> anyhow::Result<()> { let _p = tracing::info_span!("flags::Parse::run").entered(); let text = read_stdin()?; + let line_index = LineIndex::new(&text); let file = SourceFile::parse(&text, Edition::CURRENT).tree(); + if !self.no_dump { - println!("{:#?}", file.syntax()); + if self.json { + let json_tree = node_to_json(NodeOrToken::Node(file.syntax().clone()), &line_index); + println!("{}", serde_json::to_string(&json_tree)?); + } else { + println!("{:#?}", file.syntax()); + } } + std::mem::forget(file); Ok(()) } } + +fn node_to_json(node: NodeOrToken<SyntaxNode, SyntaxToken>, line_index: &LineIndex) -> JsonNode { + let range = node.text_range(); + let kind = format!("{:?}", node.kind()); + + match node { + NodeOrToken::Node(n) => { + let children: Vec<_> = + n.children_with_tokens().map(|it| node_to_json(it, line_index)).collect(); + JsonNode { + kind, + node_type: "Node", + start: pos(line_index, range.start()), + end: pos(line_index, range.end()), + text: None, + children: Some(children), + } + } + NodeOrToken::Token(t) => JsonNode { + kind, + node_type: "Token", + start: pos(line_index, range.start()), + end: pos(line_index, range.end()), + text: Some(t.text().to_owned()), + children: None, + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cli::flags; + + #[test] + fn test_parse_json_output() { + let text = "fn main() {}".to_owned(); + let flags = flags::Parse { json: true, no_dump: false }; + let line_index = LineIndex::new(&text); + + let file = SourceFile::parse(&text, Edition::CURRENT).tree(); + + let output = if flags.json { + let json_tree = node_to_json(NodeOrToken::Node(file.syntax().clone()), &line_index); + serde_json::to_string(&json_tree).unwrap() + } else { + format!("{:#?}", file.syntax()) + }; + + assert!(output.contains(r#""kind":"SOURCE_FILE""#)); + assert!(output.contains(r#""text":"main""#)); + assert!(output.contains(r#""start":[0,0,0]"#)); + } +} diff --git a/crates/rust-analyzer/src/cli/prime_caches.rs b/crates/rust-analyzer/src/cli/prime_caches.rs index d5da679179..beedcfae4e 100644 --- a/crates/rust-analyzer/src/cli/prime_caches.rs +++ b/crates/rust-analyzer/src/cli/prime_caches.rs @@ -38,6 +38,7 @@ impl flags::PrimeCaches { // we want to ensure that this command, not `load_workspace_at`, // is responsible for that work. prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: config.proc_macro_num_processes(), }; diff --git a/crates/rust-analyzer/src/cli/run_tests.rs b/crates/rust-analyzer/src/cli/run_tests.rs index d4a56d773e..e8c88cadf6 100644 --- a/crates/rust-analyzer/src/cli/run_tests.rs +++ b/crates/rust-analyzer/src/cli/run_tests.rs @@ -23,6 +23,7 @@ impl flags::RunTests { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; let (ref db, _vfs, _proc_macro) = diff --git a/crates/rust-analyzer/src/cli/rustc_tests.rs b/crates/rust-analyzer/src/cli/rustc_tests.rs index e8c6c5f4d4..49f28352b6 100644 --- a/crates/rust-analyzer/src/cli/rustc_tests.rs +++ b/crates/rust-analyzer/src/cli/rustc_tests.rs @@ -103,6 +103,7 @@ impl Tester { load_out_dirs_from_check: false, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; let (db, _vfs, _proc_macro) = diff --git a/crates/rust-analyzer/src/cli/scip.rs b/crates/rust-analyzer/src/cli/scip.rs index ed0476697c..ef6d4399e6 100644 --- a/crates/rust-analyzer/src/cli/scip.rs +++ b/crates/rust-analyzer/src/cli/scip.rs @@ -52,6 +52,7 @@ impl flags::Scip { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: true, + num_worker_threads: self.num_threads.unwrap_or_else(num_cpus::get_physical), proc_macro_processes: config.proc_macro_num_processes(), }; let cargo_config = config.cargo(None); diff --git a/crates/rust-analyzer/src/cli/ssr.rs b/crates/rust-analyzer/src/cli/ssr.rs index 5c69bda723..7b00aebbfc 100644 --- a/crates/rust-analyzer/src/cli/ssr.rs +++ b/crates/rust-analyzer/src/cli/ssr.rs @@ -20,6 +20,7 @@ impl flags::Ssr { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; let (ref db, vfs, _proc_macro) = load_workspace_at( @@ -57,6 +58,7 @@ impl flags::Search { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; let (ref db, _vfs, _proc_macro) = load_workspace_at( @@ -74,7 +76,7 @@ impl flags::Search { let sr = db.source_root(root).source_root(db); for file_id in sr.iter() { for debug_info in match_finder.debug_where_text_equal( - EditionedFileId::current_edition_guess_origin(db, file_id), + EditionedFileId::current_edition(db, file_id), debug_snippet, ) { println!("{debug_info:#?}"); diff --git a/crates/rust-analyzer/src/cli/unresolved_references.rs b/crates/rust-analyzer/src/cli/unresolved_references.rs index 49c6fcb91e..2d9b870f4d 100644 --- a/crates/rust-analyzer/src/cli/unresolved_references.rs +++ b/crates/rust-analyzer/src/cli/unresolved_references.rs @@ -44,6 +44,7 @@ impl flags::UnresolvedReferences { load_out_dirs_from_check: !self.disable_build_scripts, with_proc_macro_server, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: config.proc_macro_num_processes(), }; let (db, vfs, _proc_macro) = diff --git a/crates/rust-analyzer/src/config.rs b/crates/rust-analyzer/src/config.rs index 0dda7f3cc2..2ccd85f0e3 100644 --- a/crates/rust-analyzer/src/config.rs +++ b/crates/rust-analyzer/src/config.rs @@ -948,18 +948,18 @@ config_data! { /// Override the command used for bench runnables. /// The first element of the array should be the program to execute (for example, `cargo`). /// - /// Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically + /// Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically /// replace the package name, target option (such as `--bin` or `--example`), the target name and - /// the test name (name of test function or test mod path). + /// the arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`). runnables_bench_overrideCommand: Option<Vec<String>> = None, /// Command to be executed instead of 'cargo' for runnables. runnables_command: Option<String> = None, /// Override the command used for bench runnables. /// The first element of the array should be the program to execute (for example, `cargo`). /// - /// Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically + /// Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically /// replace the package name, target option (such as `--bin` or `--example`), the target name and - /// the test name (name of test function or test mod path). + /// the arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`). runnables_doctest_overrideCommand: Option<Vec<String>> = None, /// Additional arguments to be passed to cargo for runnables such as /// tests or binaries. For example, it may be `--release`. @@ -977,9 +977,9 @@ config_data! { /// Override the command used for test runnables. /// The first element of the array should be the program to execute (for example, `cargo`). /// - /// Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically + /// Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically /// replace the package name, target option (such as `--bin` or `--example`), the target name and - /// the test name (name of test function or test mod path). + /// the arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`). runnables_test_overrideCommand: Option<Vec<String>> = None, /// Path to the Cargo.toml of the rust compiler workspace, for usage in rustc_private diff --git a/crates/rust-analyzer/src/flycheck.rs b/crates/rust-analyzer/src/flycheck.rs index 512c231990..c41696bf3f 100644 --- a/crates/rust-analyzer/src/flycheck.rs +++ b/crates/rust-analyzer/src/flycheck.rs @@ -22,7 +22,6 @@ use serde_derive::Deserialize; pub(crate) use cargo_metadata::diagnostic::{ Applicability, Diagnostic, DiagnosticCode, DiagnosticLevel, DiagnosticSpan, }; -use toolchain::DISPLAY_COMMAND_IGNORE_ENVS; use toolchain::Tool; use triomphe::Arc; @@ -144,6 +143,7 @@ impl FlycheckConfig { } impl fmt::Display for FlycheckConfig { + /// Show a shortened version of the check command. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { FlycheckConfig::Automatic { cargo_options, .. } => { @@ -153,12 +153,23 @@ impl fmt::Display for FlycheckConfig { // Don't show `my_custom_check --foo $saved_file` literally to the user, as it // looks like we've forgotten to substitute $saved_file. // + // `my_custom_check --foo /home/user/project/src/dir/foo.rs` is too verbose. + // // Instead, show `my_custom_check --foo ...`. The // actual path is often too long to be worth showing // in the IDE (e.g. in the VS Code status bar). let display_args = args .iter() - .map(|arg| if arg == SAVED_FILE_PLACEHOLDER_DOLLAR { "..." } else { arg }) + .map(|arg| { + if (arg == SAVED_FILE_PLACEHOLDER_DOLLAR) + || (arg == SAVED_FILE_INLINE) + || arg.ends_with(".rs") + { + "..." + } else { + arg + } + }) .collect::<Vec<_>>(); write!(f, "{command} {}", display_args.join(" ")) @@ -371,6 +382,7 @@ enum FlycheckCommandOrigin { ProjectJsonRunnable, } +#[derive(Debug)] enum StateChange { Restart { generation: DiagnosticsGeneration, @@ -403,24 +415,31 @@ struct FlycheckActor { /// doesn't provide a way to read sub-process output without blocking, so we /// have to wrap sub-processes output handling in a thread and pass messages /// back over a channel. - command_handle: Option<CommandHandle<CargoCheckMessage>>, + command_handle: Option<CommandHandle<CheckMessage>>, /// The receiver side of the channel mentioned above. - command_receiver: Option<Receiver<CargoCheckMessage>>, + command_receiver: Option<Receiver<CheckMessage>>, diagnostics_cleared_for: FxHashSet<PackageSpecifier>, diagnostics_received: DiagnosticsReceived, } -#[derive(PartialEq)] +#[derive(PartialEq, Debug)] enum DiagnosticsReceived { - Yes, - No, - YesAndClearedForAll, + /// We started a flycheck, but we haven't seen any diagnostics yet. + NotYet, + /// We received a non-zero number of diagnostics from rustc or clippy (via + /// cargo or custom check command). This means there were errors or + /// warnings. + AtLeastOne, + /// We received a non-zero number of diagnostics, and the scope is + /// workspace, so we've discarded the previous workspace diagnostics. + AtLeastOneAndClearedWorkspace, } #[allow(clippy::large_enum_variant)] +#[derive(Debug)] enum Event { RequestStateChange(StateChange), - CheckEvent(Option<CargoCheckMessage>), + CheckEvent(Option<CheckMessage>), } /// This is stable behaviour. Don't change. @@ -428,6 +447,7 @@ const SAVED_FILE_PLACEHOLDER_DOLLAR: &str = "$saved_file"; const LABEL_INLINE: &str = "{label}"; const SAVED_FILE_INLINE: &str = "{saved_file}"; +#[derive(Debug)] struct Substitutions<'a> { label: Option<&'a str>, saved_file: Option<&'a str>, @@ -511,7 +531,7 @@ impl FlycheckActor { command_handle: None, command_receiver: None, diagnostics_cleared_for: Default::default(), - diagnostics_received: DiagnosticsReceived::No, + diagnostics_received: DiagnosticsReceived::NotYet, } } @@ -539,17 +559,37 @@ impl FlycheckActor { self.cancel_check_process(); } Event::RequestStateChange(StateChange::Restart { - generation, - scope, - saved_file, - target, + mut generation, + mut scope, + mut saved_file, + mut target, }) => { // Cancel the previously spawned process self.cancel_check_process(); + + // Debounce by briefly waiting for other state changes. while let Ok(restart) = inbox.recv_timeout(Duration::from_millis(50)) { - // restart chained with a stop, so just cancel - if let StateChange::Cancel = restart { - continue 'event; + match restart { + StateChange::Cancel => { + // We got a cancel straight after this restart request, so + // don't do anything. + continue 'event; + } + StateChange::Restart { + generation: g, + scope: s, + saved_file: sf, + target: t, + } => { + // We got another restart request. Take the parameters + // from the last restart request in this time window, + // because the most recent request is probably the most + // relevant to the user. + generation = g; + scope = s; + saved_file = sf; + target = t; + } } } @@ -563,23 +603,13 @@ impl FlycheckActor { }; let debug_command = format!("{command:?}"); - let user_facing_command = match origin { - // Don't show all the --format=json-with-blah-blah args, just the simple - // version - FlycheckCommandOrigin::Cargo => self.config.to_string(), - // show them the full command but pretty printed. advanced user - FlycheckCommandOrigin::ProjectJsonRunnable - | FlycheckCommandOrigin::CheckOverrideCommand => display_command( - &command, - Some(std::path::Path::new(self.root.as_path())), - ), - }; + let user_facing_command = self.config.to_string(); tracing::debug!(?origin, ?command, "will restart flycheck"); let (sender, receiver) = unbounded(); match CommandHandle::spawn( command, - CargoCheckParser, + CheckParser, sender, match &self.config { FlycheckConfig::Automatic { cargo_options, .. } => { @@ -640,7 +670,7 @@ impl FlycheckActor { error ); } - if self.diagnostics_received == DiagnosticsReceived::No { + if self.diagnostics_received == DiagnosticsReceived::NotYet { tracing::trace!(flycheck_id = self.id, "clearing diagnostics"); // We finished without receiving any diagnostics. // Clear everything for good measure @@ -699,7 +729,7 @@ impl FlycheckActor { self.report_progress(Progress::DidFinish(res)); } Event::CheckEvent(Some(message)) => match message { - CargoCheckMessage::CompilerArtifact(msg) => { + CheckMessage::CompilerArtifact(msg) => { tracing::trace!( flycheck_id = self.id, artifact = msg.target.name, @@ -729,15 +759,15 @@ impl FlycheckActor { }); } } - CargoCheckMessage::Diagnostic { diagnostic, package_id } => { + CheckMessage::Diagnostic { diagnostic, package_id } => { tracing::trace!( flycheck_id = self.id, message = diagnostic.message, package_id = package_id.as_ref().map(|it| it.as_str()), "diagnostic received" ); - if self.diagnostics_received == DiagnosticsReceived::No { - self.diagnostics_received = DiagnosticsReceived::Yes; + if self.diagnostics_received == DiagnosticsReceived::NotYet { + self.diagnostics_received = DiagnosticsReceived::AtLeastOne; } if let Some(package_id) = &package_id { if self.diagnostics_cleared_for.insert(package_id.clone()) { @@ -754,9 +784,10 @@ impl FlycheckActor { }); } } else if self.diagnostics_received - != DiagnosticsReceived::YesAndClearedForAll + != DiagnosticsReceived::AtLeastOneAndClearedWorkspace { - self.diagnostics_received = DiagnosticsReceived::YesAndClearedForAll; + self.diagnostics_received = + DiagnosticsReceived::AtLeastOneAndClearedWorkspace; self.send(FlycheckMessage::ClearDiagnostics { id: self.id, kind: ClearDiagnosticsKind::All(ClearScope::Workspace), @@ -792,7 +823,7 @@ impl FlycheckActor { fn clear_diagnostics_state(&mut self) { self.diagnostics_cleared_for.clear(); - self.diagnostics_received = DiagnosticsReceived::No; + self.diagnostics_received = DiagnosticsReceived::NotYet; } fn explicit_check_command( @@ -942,15 +973,19 @@ impl FlycheckActor { } #[allow(clippy::large_enum_variant)] -enum CargoCheckMessage { +#[derive(Debug)] +enum CheckMessage { + /// A message from `cargo check`, including details like the path + /// to the relevant `Cargo.toml`. CompilerArtifact(cargo_metadata::Artifact), + /// A diagnostic message from rustc itself. Diagnostic { diagnostic: Diagnostic, package_id: Option<PackageSpecifier> }, } -struct CargoCheckParser; +struct CheckParser; -impl JsonLinesParser<CargoCheckMessage> for CargoCheckParser { - fn from_line(&self, line: &str, error: &mut String) -> Option<CargoCheckMessage> { +impl JsonLinesParser<CheckMessage> for CheckParser { + fn from_line(&self, line: &str, error: &mut String) -> Option<CheckMessage> { let mut deserializer = serde_json::Deserializer::from_str(line); deserializer.disable_recursion_limit(); if let Ok(message) = JsonMessage::deserialize(&mut deserializer) { @@ -958,10 +993,10 @@ impl JsonLinesParser<CargoCheckMessage> for CargoCheckParser { // Skip certain kinds of messages to only spend time on what's useful JsonMessage::Cargo(message) => match message { cargo_metadata::Message::CompilerArtifact(artifact) if !artifact.fresh => { - Some(CargoCheckMessage::CompilerArtifact(artifact)) + Some(CheckMessage::CompilerArtifact(artifact)) } cargo_metadata::Message::CompilerMessage(msg) => { - Some(CargoCheckMessage::Diagnostic { + Some(CheckMessage::Diagnostic { diagnostic: msg.message, package_id: Some(PackageSpecifier::Cargo { package_id: Arc::new(msg.package_id), @@ -971,7 +1006,7 @@ impl JsonLinesParser<CargoCheckMessage> for CargoCheckParser { _ => None, }, JsonMessage::Rustc(message) => { - Some(CargoCheckMessage::Diagnostic { diagnostic: message, package_id: None }) + Some(CheckMessage::Diagnostic { diagnostic: message, package_id: None }) } }; } @@ -981,76 +1016,26 @@ impl JsonLinesParser<CargoCheckMessage> for CargoCheckParser { None } - fn from_eof(&self) -> Option<CargoCheckMessage> { + fn from_eof(&self) -> Option<CheckMessage> { None } } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] #[serde(untagged)] enum JsonMessage { Cargo(cargo_metadata::Message), Rustc(Diagnostic), } -/// Not good enough to execute in a shell, but good enough to show the user without all the noisy -/// quotes -/// -/// Pass implicit_cwd if there is one regarded as the obvious by the user, so we can skip showing it. -/// Compactness is the aim of the game, the output typically gets truncated quite a lot. -fn display_command(c: &Command, implicit_cwd: Option<&std::path::Path>) -> String { - let mut o = String::new(); - use std::fmt::Write; - let lossy = std::ffi::OsStr::to_string_lossy; - if let Some(dir) = c.get_current_dir() { - if Some(dir) == implicit_cwd.map(std::path::Path::new) { - // pass - } else if dir.to_string_lossy().contains(" ") { - write!(o, "cd {:?} && ", dir).unwrap(); - } else { - write!(o, "cd {} && ", dir.display()).unwrap(); - } - } - for (env, val) in c.get_envs() { - let (env, val) = (lossy(env), val.map(lossy).unwrap_or(std::borrow::Cow::Borrowed(""))); - if DISPLAY_COMMAND_IGNORE_ENVS.contains(&env.as_ref()) { - continue; - } - if env.contains(" ") { - write!(o, "\"{}={}\" ", env, val).unwrap(); - } else if val.contains(" ") { - write!(o, "{}=\"{}\" ", env, val).unwrap(); - } else { - write!(o, "{}={} ", env, val).unwrap(); - } - } - let prog = lossy(c.get_program()); - if prog.contains(" ") { - write!(o, "{:?}", prog).unwrap(); - } else { - write!(o, "{}", prog).unwrap(); - } - for arg in c.get_args() { - let arg = lossy(arg); - if arg.contains(" ") { - write!(o, " \"{}\"", arg).unwrap(); - } else { - write!(o, " {}", arg).unwrap(); - } - } - o -} - #[cfg(test)] mod tests { + use super::*; use ide_db::FxHashMap; use itertools::Itertools; use paths::Utf8Path; use project_model::project_json; - use crate::flycheck::Substitutions; - use crate::flycheck::display_command; - #[test] fn test_substitutions() { let label = ":label"; @@ -1139,34 +1124,47 @@ mod tests { } #[test] - fn test_display_command() { - use std::path::Path; - let workdir = Path::new("workdir"); - let mut cmd = toolchain::command("command", workdir, &FxHashMap::default()); - assert_eq!(display_command(cmd.arg("--arg"), Some(workdir)), "command --arg"); - assert_eq!( - display_command(cmd.arg("spaced arg"), Some(workdir)), - "command --arg \"spaced arg\"" - ); - assert_eq!( - display_command(cmd.env("ENVIRON", "yeah"), Some(workdir)), - "ENVIRON=yeah command --arg \"spaced arg\"" - ); - assert_eq!( - display_command(cmd.env("OTHER", "spaced env"), Some(workdir)), - "ENVIRON=yeah OTHER=\"spaced env\" command --arg \"spaced arg\"" - ); - assert_eq!( - display_command(cmd.current_dir("/tmp"), Some(workdir)), - "cd /tmp && ENVIRON=yeah OTHER=\"spaced env\" command --arg \"spaced arg\"" - ); - assert_eq!( - display_command(cmd.current_dir("/tmp and/thing"), Some(workdir)), - "cd \"/tmp and/thing\" && ENVIRON=yeah OTHER=\"spaced env\" command --arg \"spaced arg\"" - ); - assert_eq!( - display_command(cmd.current_dir("/tmp and/thing"), Some(Path::new("/tmp and/thing"))), - "ENVIRON=yeah OTHER=\"spaced env\" command --arg \"spaced arg\"" - ); + fn test_flycheck_config_display() { + let clippy = FlycheckConfig::Automatic { + cargo_options: CargoOptions { + subcommand: "clippy".to_owned(), + target_tuples: vec![], + all_targets: false, + set_test: false, + no_default_features: false, + all_features: false, + features: vec![], + extra_args: vec![], + extra_test_bin_args: vec![], + extra_env: FxHashMap::default(), + target_dir_config: TargetDirectoryConfig::default(), + }, + ansi_color_output: true, + }; + assert_eq!(clippy.to_string(), "cargo clippy"); + + let custom_dollar = FlycheckConfig::CustomCommand { + command: "check".to_owned(), + args: vec!["--input".to_owned(), "$saved_file".to_owned()], + extra_env: FxHashMap::default(), + invocation_strategy: InvocationStrategy::Once, + }; + assert_eq!(custom_dollar.to_string(), "check --input ..."); + + let custom_inline = FlycheckConfig::CustomCommand { + command: "check".to_owned(), + args: vec!["--input".to_owned(), "{saved_file}".to_owned()], + extra_env: FxHashMap::default(), + invocation_strategy: InvocationStrategy::Once, + }; + assert_eq!(custom_inline.to_string(), "check --input ..."); + + let custom_rs = FlycheckConfig::CustomCommand { + command: "check".to_owned(), + args: vec!["--input".to_owned(), "/path/to/file.rs".to_owned()], + extra_env: FxHashMap::default(), + invocation_strategy: InvocationStrategy::Once, + }; + assert_eq!(custom_rs.to_string(), "check --input ..."); } } diff --git a/crates/rust-analyzer/src/handlers/dispatch.rs b/crates/rust-analyzer/src/handlers/dispatch.rs index 90deae2d90..67bd643fce 100644 --- a/crates/rust-analyzer/src/handlers/dispatch.rs +++ b/crates/rust-analyzer/src/handlers/dispatch.rs @@ -414,7 +414,8 @@ impl NotificationDispatcher<'_> { let params = match not.extract::<N::Params>(N::METHOD) { Ok(it) => it, Err(ExtractError::JsonError { method, error }) => { - panic!("Invalid request\nMethod: {method}\n error: {error}",) + tracing::error!(method = %method, error = %error, "invalid notification"); + return self; } Err(ExtractError::MethodMismatch(not)) => { self.not = Some(not); diff --git a/crates/rust-analyzer/src/handlers/notification.rs b/crates/rust-analyzer/src/handlers/notification.rs index 138310b78f..09b6794e4f 100644 --- a/crates/rust-analyzer/src/handlers/notification.rs +++ b/crates/rust-analyzer/src/handlers/notification.rs @@ -295,8 +295,9 @@ pub(crate) fn handle_did_change_watched_files( for change in params.changes.iter().unique_by(|&it| &it.uri) { if let Ok(path) = from_proto::abs_path(&change.uri) { if !trigger_flycheck { + // Trigger if no workspaces contain this file. trigger_flycheck = - state.config.workspace_roots().iter().any(|root| !path.starts_with(root)); + state.config.workspace_roots().iter().all(|root| !path.starts_with(root)); } state.loader.handle.invalidate(path); } diff --git a/crates/rust-analyzer/src/integrated_benchmarks.rs b/crates/rust-analyzer/src/integrated_benchmarks.rs index d16ca2fb48..6a74b8a54d 100644 --- a/crates/rust-analyzer/src/integrated_benchmarks.rs +++ b/crates/rust-analyzer/src/integrated_benchmarks.rs @@ -53,6 +53,7 @@ fn integrated_highlighting_benchmark() { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: false, + num_worker_threads: 1, proc_macro_processes: 1, }; @@ -122,6 +123,7 @@ fn integrated_completion_benchmark() { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: true, + num_worker_threads: 1, proc_macro_processes: 1, }; @@ -324,6 +326,7 @@ fn integrated_diagnostics_benchmark() { load_out_dirs_from_check: true, with_proc_macro_server: ProcMacroServerChoice::Sysroot, prefill_caches: true, + num_worker_threads: 1, proc_macro_processes: 1, }; diff --git a/crates/rust-analyzer/src/main_loop.rs b/crates/rust-analyzer/src/main_loop.rs index 64decc9e0d..7c494de6f7 100644 --- a/crates/rust-analyzer/src/main_loop.rs +++ b/crates/rust-analyzer/src/main_loop.rs @@ -92,7 +92,7 @@ impl fmt::Display for Event { Event::DeferredTask(_) => write!(f, "Event::DeferredTask"), Event::TestResult(_) => write!(f, "Event::TestResult"), Event::DiscoverProject(_) => write!(f, "Event::DiscoverProject"), - Event::FetchWorkspaces(_) => write!(f, "Event::SwitchWorkspaces"), + Event::FetchWorkspaces(_) => write!(f, "Event::FetchWorkspaces"), } } } @@ -1178,6 +1178,8 @@ impl GlobalState { } => self.diagnostics.clear_check_older_than_for_package(id, package_id, generation), FlycheckMessage::Progress { id, progress } => { let format_with_id = |user_facing_command: String| { + // When we're running multiple flychecks, we have to include a disambiguator in + // the title, or the editor complains. Note that this is a user-facing string. if self.flycheck.len() == 1 { user_facing_command } else { diff --git a/crates/rust-analyzer/src/target_spec.rs b/crates/rust-analyzer/src/target_spec.rs index b8d9acc02a..8be061cacf 100644 --- a/crates/rust-analyzer/src/target_spec.rs +++ b/crates/rust-analyzer/src/target_spec.rs @@ -6,7 +6,7 @@ use cargo_metadata::PackageId; use cfg::{CfgAtom, CfgExpr}; use hir::sym; use ide::{Cancellable, Crate, FileId, RunnableKind, TestId}; -use project_model::project_json::Runnable; +use project_model::project_json::{self, Runnable}; use project_model::{CargoFeatures, ManifestPath, TargetKind}; use rustc_hash::FxHashSet; use triomphe::Arc; @@ -72,48 +72,51 @@ pub(crate) struct ProjectJsonTargetSpec { } impl ProjectJsonTargetSpec { + fn find_replace_runnable( + &self, + kind: project_json::RunnableKind, + replacer: &dyn Fn(&Self, &str) -> String, + ) -> Option<Runnable> { + for runnable in &self.shell_runnables { + if runnable.kind == kind { + let mut runnable = runnable.clone(); + + let replaced_args: Vec<_> = + runnable.args.iter().map(|arg| replacer(self, arg)).collect(); + runnable.args = replaced_args; + + return Some(runnable); + } + } + + None + } + pub(crate) fn runnable_args(&self, kind: &RunnableKind) -> Option<Runnable> { match kind { - RunnableKind::Bin => { - for runnable in &self.shell_runnables { - if matches!(runnable.kind, project_model::project_json::RunnableKind::Run) { - let mut runnable = runnable.clone(); - - let replaced_args: Vec<_> = runnable - .args - .iter() - .map(|arg| arg.replace("{label}", &self.label)) - .collect(); - runnable.args = replaced_args; - - return Some(runnable); - } - } - - None - } + RunnableKind::Bin => self + .find_replace_runnable(project_json::RunnableKind::Run, &|this, arg| { + arg.replace("{label}", &this.label) + }), RunnableKind::Test { test_id, .. } => { - for runnable in &self.shell_runnables { - if matches!(runnable.kind, project_model::project_json::RunnableKind::TestOne) { - let mut runnable = runnable.clone(); - - let replaced_args: Vec<_> = runnable - .args - .iter() - .map(|arg| arg.replace("{test_id}", &test_id.to_string())) - .map(|arg| arg.replace("{label}", &self.label)) - .collect(); - runnable.args = replaced_args; - - return Some(runnable); - } - } - - None + self.find_replace_runnable(project_json::RunnableKind::Run, &|this, arg| { + arg.replace("{label}", &this.label).replace("{test_id}", &test_id.to_string()) + }) + } + RunnableKind::TestMod { path } => self + .find_replace_runnable(project_json::RunnableKind::TestMod, &|this, arg| { + arg.replace("{label}", &this.label).replace("{test_pattern}", path) + }), + RunnableKind::Bench { test_id } => { + self.find_replace_runnable(project_json::RunnableKind::BenchOne, &|this, arg| { + arg.replace("{label}", &this.label).replace("{bench_id}", &test_id.to_string()) + }) + } + RunnableKind::DocTest { test_id } => { + self.find_replace_runnable(project_json::RunnableKind::DocTestOne, &|this, arg| { + arg.replace("{label}", &this.label).replace("{test_id}", &test_id.to_string()) + }) } - RunnableKind::TestMod { .. } => None, - RunnableKind::Bench { .. } => None, - RunnableKind::DocTest { .. } => None, } } } @@ -129,38 +132,21 @@ impl CargoTargetSpec { let extra_test_binary_args = config.extra_test_binary_args; let mut cargo_args = Vec::new(); - let mut executable_args = Vec::new(); + let executable_args = Self::executable_args_for(kind, extra_test_binary_args); match kind { - RunnableKind::Test { test_id, attr } => { + RunnableKind::Test { .. } => { cargo_args.push(config.test_command); - executable_args.push(test_id.to_string()); - if let TestId::Path(_) = test_id { - executable_args.push("--exact".to_owned()); - } - executable_args.extend(extra_test_binary_args); - if attr.ignore { - executable_args.push("--ignored".to_owned()); - } } - RunnableKind::TestMod { path } => { + RunnableKind::TestMod { .. } => { cargo_args.push(config.test_command); - executable_args.push(path.clone()); - executable_args.extend(extra_test_binary_args); } - RunnableKind::Bench { test_id } => { + RunnableKind::Bench { .. } => { cargo_args.push(config.bench_command); - executable_args.push(test_id.to_string()); - if let TestId::Path(_) = test_id { - executable_args.push("--exact".to_owned()); - } - executable_args.extend(extra_test_binary_args); } - RunnableKind::DocTest { test_id } => { + RunnableKind::DocTest { .. } => { cargo_args.push("test".to_owned()); cargo_args.push("--doc".to_owned()); - executable_args.push(test_id.to_string()); - executable_args.extend(extra_test_binary_args); } RunnableKind::Bin => { let subcommand = match spec { @@ -253,16 +239,70 @@ impl CargoTargetSpec { TargetKind::BuildScript | TargetKind::Other => "", }; + let target = |kind, target| match kind { + TargetKind::Bin | TargetKind::Test | TargetKind::Bench | TargetKind::Example => target, + _ => "", + }; + let replace_placeholders = |arg: String| match &spec { Some(spec) => arg .replace("${package}", &spec.package) .replace("${target_arg}", target_arg(spec.target_kind)) - .replace("${target}", &spec.target) + .replace("${target}", target(spec.target_kind, &spec.target)) .replace("${test_name}", &test_name), _ => arg, }; - args.map(|args| args.into_iter().map(replace_placeholders).collect()) + let extra_test_binary_args = config.extra_test_binary_args; + let executable_args = Self::executable_args_for(kind, extra_test_binary_args); + + args.map(|mut args| { + let exec_args_idx = args.iter().position(|a| a == "${executable_args}"); + + if let Some(idx) = exec_args_idx { + args.splice(idx..idx + 1, executable_args); + } + + args.into_iter().map(replace_placeholders).filter(|a| !a.trim().is_empty()).collect() + }) + } + + fn executable_args_for( + kind: &RunnableKind, + extra_test_binary_args: impl IntoIterator<Item = String>, + ) -> Vec<String> { + let mut executable_args = Vec::new(); + + match kind { + RunnableKind::Test { test_id, attr } => { + executable_args.push(test_id.to_string()); + if let TestId::Path(_) = test_id { + executable_args.push("--exact".to_owned()); + } + executable_args.extend(extra_test_binary_args); + if attr.ignore { + executable_args.push("--ignored".to_owned()); + } + } + RunnableKind::TestMod { path } => { + executable_args.push(path.clone()); + executable_args.extend(extra_test_binary_args); + } + RunnableKind::Bench { test_id } => { + executable_args.push(test_id.to_string()); + if let TestId::Path(_) = test_id { + executable_args.push("--exact".to_owned()); + } + executable_args.extend(extra_test_binary_args); + } + RunnableKind::DocTest { test_id } => { + executable_args.push(test_id.to_string()); + executable_args.extend(extra_test_binary_args); + } + RunnableKind::Bin => {} + } + + executable_args } pub(crate) fn push_to(self, buf: &mut Vec<String>, kind: &RunnableKind) { diff --git a/crates/rust-analyzer/src/tracing/config.rs b/crates/rust-analyzer/src/tracing/config.rs index ca897aeb3e..2bc9f3c34a 100644 --- a/crates/rust-analyzer/src/tracing/config.rs +++ b/crates/rust-analyzer/src/tracing/config.rs @@ -1,7 +1,7 @@ //! Simple logger that logs either to stderr or to a file, using `tracing_subscriber` //! filter syntax and `tracing_appender` for non blocking output. -use std::io::{self}; +use std::io; use anyhow::Context; use tracing::level_filters::LevelFilter; diff --git a/crates/rust-analyzer/tests/slow-tests/flycheck.rs b/crates/rust-analyzer/tests/slow-tests/flycheck.rs new file mode 100644 index 0000000000..c1d53fb33a --- /dev/null +++ b/crates/rust-analyzer/tests/slow-tests/flycheck.rs @@ -0,0 +1,112 @@ +use test_utils::skip_slow_tests; + +use crate::support::Project; + +#[test] +fn test_flycheck_diagnostics_for_unused_variable() { + if skip_slow_tests() { + return; + } + + let server = Project::with_fixture( + r#" +//- /Cargo.toml +[package] +name = "foo" +version = "0.0.0" + +//- /src/main.rs +fn main() { + let x = 1; +} +"#, + ) + .with_config(serde_json::json!({ + "checkOnSave": true, + })) + .server() + .wait_until_workspace_is_loaded(); + + let diagnostics = server.wait_for_diagnostics(); + assert!( + diagnostics.diagnostics.iter().any(|d| d.message.contains("unused variable")), + "expected unused variable diagnostic, got: {:?}", + diagnostics.diagnostics, + ); +} + +#[test] +fn test_flycheck_diagnostic_cleared_after_fix() { + if skip_slow_tests() { + return; + } + + let server = Project::with_fixture( + r#" +//- /Cargo.toml +[package] +name = "foo" +version = "0.0.0" + +//- /src/main.rs +fn main() { + let x = 1; +} +"#, + ) + .with_config(serde_json::json!({ + "checkOnSave": true, + })) + .server() + .wait_until_workspace_is_loaded(); + + // Wait for the unused variable diagnostic to appear. + let diagnostics = server.wait_for_diagnostics(); + assert!( + diagnostics.diagnostics.iter().any(|d| d.message.contains("unused variable")), + "expected unused variable diagnostic, got: {:?}", + diagnostics.diagnostics, + ); + + // Fix the code by removing the unused variable. + server.write_file_and_save("src/main.rs", "fn main() {}\n".to_owned()); + + // Wait for diagnostics to be cleared. + server.wait_for_diagnostics_cleared(); +} + +#[test] +fn test_flycheck_diagnostic_with_override_command() { + if skip_slow_tests() { + return; + } + + let server = Project::with_fixture( + r#" +//- /Cargo.toml +[package] +name = "foo" +version = "0.0.0" + +//- /src/main.rs +fn main() {} +"#, + ) + .with_config(serde_json::json!({ + "checkOnSave": true, + "check": { + "overrideCommand": ["rustc", "--error-format=json", "$saved_file"] + } + })) + .server() + .wait_until_workspace_is_loaded(); + + server.write_file_and_save("src/main.rs", "fn main() {\n let x = 1;\n}\n".to_owned()); + + let diagnostics = server.wait_for_diagnostics(); + assert!( + diagnostics.diagnostics.iter().any(|d| d.message.contains("unused variable")), + "expected unused variable diagnostic, got: {:?}", + diagnostics.diagnostics, + ); +} diff --git a/crates/rust-analyzer/tests/slow-tests/main.rs b/crates/rust-analyzer/tests/slow-tests/main.rs index b4a7b44d16..fcdc8bb7cd 100644 --- a/crates/rust-analyzer/tests/slow-tests/main.rs +++ b/crates/rust-analyzer/tests/slow-tests/main.rs @@ -15,12 +15,14 @@ extern crate rustc_driver as _; mod cli; +mod flycheck; mod ratoml; mod support; mod testdir; use std::{collections::HashMap, path::PathBuf, time::Instant}; +use ide_db::FxHashMap; use lsp_types::{ CodeActionContext, CodeActionParams, CompletionParams, DidOpenTextDocumentParams, DocumentFormattingParams, DocumentRangeFormattingParams, FileRename, FormattingOptions, @@ -672,6 +674,17 @@ fn test_format_document_range() { return; } + // This test requires a nightly toolchain, so skip if it's not available. + let cwd = std::env::current_dir().unwrap_or_default(); + let has_nightly_rustfmt = toolchain::command("rustfmt", cwd, &FxHashMap::default()) + .args(["+nightly", "--version"]) + .output() + .is_ok_and(|out| out.status.success()); + if !has_nightly_rustfmt { + tracing::warn!("skipping test_format_document_range: nightly rustfmt not available"); + return; + } + let server = Project::with_fixture( r#" //- /Cargo.toml diff --git a/crates/rust-analyzer/tests/slow-tests/support.rs b/crates/rust-analyzer/tests/slow-tests/support.rs index 195ad226ae..7ee31f3d53 100644 --- a/crates/rust-analyzer/tests/slow-tests/support.rs +++ b/crates/rust-analyzer/tests/slow-tests/support.rs @@ -8,7 +8,9 @@ use std::{ use crossbeam_channel::{Receiver, after, select}; use itertools::Itertools; use lsp_server::{Connection, Message, Notification, Request}; -use lsp_types::{TextDocumentIdentifier, Url, notification::Exit, request::Shutdown}; +use lsp_types::{ + PublishDiagnosticsParams, TextDocumentIdentifier, Url, notification::Exit, request::Shutdown, +}; use parking_lot::{Mutex, MutexGuard}; use paths::{Utf8Path, Utf8PathBuf}; use rust_analyzer::{ @@ -407,6 +409,53 @@ impl Server { .unwrap_or_else(|Timeout| panic!("timeout while waiting for ws to load")); self } + pub(crate) fn wait_for_diagnostics(&self) -> PublishDiagnosticsParams { + for msg in self.messages.borrow().iter() { + if let Message::Notification(n) = msg + && n.method == "textDocument/publishDiagnostics" + { + let params: PublishDiagnosticsParams = + serde_json::from_value(n.params.clone()).unwrap(); + if !params.diagnostics.is_empty() { + return params; + } + } + } + loop { + let msg = self + .recv() + .unwrap_or_else(|Timeout| panic!("timeout while waiting for diagnostics")) + .expect("connection closed while waiting for diagnostics"); + if let Message::Notification(n) = &msg + && n.method == "textDocument/publishDiagnostics" + { + let params: PublishDiagnosticsParams = + serde_json::from_value(n.params.clone()).unwrap(); + if !params.diagnostics.is_empty() { + return params; + } + } + } + } + + pub(crate) fn wait_for_diagnostics_cleared(&self) { + loop { + let msg = self + .recv() + .unwrap_or_else(|Timeout| panic!("timeout while waiting for diagnostics to clear")) + .expect("connection closed while waiting for diagnostics to clear"); + if let Message::Notification(n) = &msg + && n.method == "textDocument/publishDiagnostics" + { + let params: PublishDiagnosticsParams = + serde_json::from_value(n.params.clone()).unwrap(); + if params.diagnostics.is_empty() { + return; + } + } + } + } + fn wait_for_message_cond( &self, n: usize, diff --git a/crates/span/src/ast_id.rs b/crates/span/src/ast_id.rs index 599b3c7175..f6500a9b4d 100644 --- a/crates/span/src/ast_id.rs +++ b/crates/span/src/ast_id.rs @@ -88,7 +88,6 @@ impl fmt::Debug for ErasedFileAstId { Module, Static, Trait, - TraitAlias, Variant, Const, Fn, @@ -129,7 +128,6 @@ enum ErasedFileAstIdKind { Module, Static, Trait, - TraitAlias, // Until here associated with `ErasedHasNameFileAstId`. // The following are associated with `ErasedAssocItemFileAstId`. Variant, @@ -208,6 +206,11 @@ impl ErasedFileAstId { self.0 >> (HASH_BITS + INDEX_BITS) } + #[inline] + pub fn is_root(self) -> bool { + self.kind() == ErasedFileAstIdKind::Root as u32 + } + fn ast_id_for( node: &SyntaxNode, index_map: &mut ErasedAstIdNextIndexMap, @@ -222,14 +225,16 @@ impl ErasedFileAstId { .or_else(|| asm_expr_ast_id(node, index_map)) } - fn should_alloc(node: &SyntaxNode) -> bool { + fn should_alloc(node: &SyntaxNode) -> Option<ErasedFileAstIdKind> { let kind = node.kind(); should_alloc_has_name(kind) - || should_alloc_assoc_item(kind) - || ast::ExternBlock::can_cast(kind) - || ast::Use::can_cast(kind) - || ast::Impl::can_cast(kind) - || ast::AsmExpr::can_cast(kind) + .or_else(|| should_alloc_assoc_item(kind)) + .or_else(|| { + ast::ExternBlock::can_cast(kind).then_some(ErasedFileAstIdKind::ExternBlock) + }) + .or_else(|| ast::Use::can_cast(kind).then_some(ErasedFileAstIdKind::Use)) + .or_else(|| ast::Impl::can_cast(kind).then_some(ErasedFileAstIdKind::Impl)) + .or_else(|| ast::AsmExpr::can_cast(kind).then_some(ErasedFileAstIdKind::AsmExpr)) } #[inline] @@ -480,8 +485,8 @@ macro_rules! register_has_name_ast_id { } } - fn should_alloc_has_name(kind: SyntaxKind) -> bool { - false $( || ast::$ident::can_cast(kind) )* + fn should_alloc_has_name(kind: SyntaxKind) -> Option<ErasedFileAstIdKind> { + $( if ast::$ident::can_cast(kind) { Some(ErasedFileAstIdKind::$ident) } else )* { None } } }; } @@ -530,8 +535,8 @@ macro_rules! register_assoc_item_ast_id { } } - fn should_alloc_assoc_item(kind: SyntaxKind) -> bool { - false $( || ast::$ident::can_cast(kind) )* + fn should_alloc_assoc_item(kind: SyntaxKind) -> Option<ErasedFileAstIdKind> { + $( if ast::$ident::can_cast(kind) { Some(ErasedFileAstIdKind::$ident) } else )* { None } } }; } @@ -614,22 +619,49 @@ impl AstIdMap { syntax::WalkEvent::Enter(node) => { if ast::BlockExpr::can_cast(node.kind()) { blocks.push((node, ContainsItems::No)); - } else if ErasedFileAstId::should_alloc(&node) { + } else if let Some(kind) = ErasedFileAstId::should_alloc(&node) { // Allocate blocks on-demand, only if they have items. // We don't associate items with blocks, only with items, since block IDs can be quite unstable. // FIXME: Is this the correct thing to do? Macro calls might actually be more incremental if // associated with blocks (not sure). Either way it's not a big deal. + let is_item = matches!( + kind, + ErasedFileAstIdKind::Enum + | ErasedFileAstIdKind::Struct + | ErasedFileAstIdKind::Union + | ErasedFileAstIdKind::ExternCrate + | ErasedFileAstIdKind::MacroDef + | ErasedFileAstIdKind::MacroRules + | ErasedFileAstIdKind::Module + | ErasedFileAstIdKind::Static + | ErasedFileAstIdKind::Trait + | ErasedFileAstIdKind::Const + | ErasedFileAstIdKind::Fn + | ErasedFileAstIdKind::TypeAlias + | ErasedFileAstIdKind::ExternBlock + | ErasedFileAstIdKind::Use + | ErasedFileAstIdKind::Impl + ); if let Some(( last_block_node, already_allocated @ ContainsItems::No, )) = blocks.last_mut() + && (is_item + || (kind == ErasedFileAstIdKind::MacroCall && { + let mut anc = node.ancestors(); + _ = anc.next(); + anc.next().is_some_and(|it| { + it.kind() == SyntaxKind::MACRO_EXPR + }) && anc.next().is_some_and(|it| { + it.kind() == SyntaxKind::EXPR_STMT + || it.kind() == SyntaxKind::STMT_LIST + }) + })) { - let block_ast_id = block_expr_ast_id( - last_block_node, - &mut index_map, - parent_of(parent_idx, &res), - ) - .expect("not a BlockExpr"); + let parent = parent_of(parent_idx, &res); + let block_ast_id = + block_expr_ast_id(last_block_node, &mut index_map, parent) + .expect("not a BlockExpr"); res.arena .alloc((SyntaxNodePtr::new(last_block_node), block_ast_id)); *already_allocated = ContainsItems::Yes; @@ -647,8 +679,9 @@ impl AstIdMap { } syntax::WalkEvent::Leave(node) => { if ast::BlockExpr::can_cast(node.kind()) { - assert_eq!( - blocks.pop().map(|it| it.0), + let block = blocks.pop(); + debug_assert_eq!( + block.map(|it| it.0), Some(node), "left a BlockExpr we never entered" ); diff --git a/crates/stdx/src/lib.rs b/crates/stdx/src/lib.rs index 7ab26b1890..275e0e5ac8 100644 --- a/crates/stdx/src/lib.rs +++ b/crates/stdx/src/lib.rs @@ -1,5 +1,6 @@ //! Missing batteries for standard libraries. +use std::borrow::Cow; use std::io as sio; use std::process::Command; use std::{cmp::Ordering, ops, time::Instant}; @@ -221,12 +222,7 @@ pub fn trim_indent(mut text: &str) -> String { if text.starts_with('\n') { text = &text[1..]; } - let indent = text - .lines() - .filter(|it| !it.trim().is_empty()) - .map(|it| it.len() - it.trim_start().len()) - .min() - .unwrap_or(0); + let indent = indent_of(text); text.split_inclusive('\n') .map( |line| { @@ -236,6 +232,43 @@ pub fn trim_indent(mut text: &str) -> String { .collect() } +#[must_use] +fn indent_of(text: &str) -> usize { + text.lines() + .filter(|it| !it.trim().is_empty()) + .map(|it| it.len() - it.trim_start().len()) + .min() + .unwrap_or(0) +} + +#[must_use] +pub fn dedent_by(spaces: usize, text: &str) -> String { + text.split_inclusive('\n') + .map(|line| { + let trimmed = line.trim_start_matches(' '); + if line.len() - trimmed.len() <= spaces { trimmed } else { &line[spaces..] } + }) + .collect() +} + +/// Indent non empty lines, including the first line +#[must_use] +pub fn indent_string(s: &str, indent_level: u8) -> String { + if indent_level == 0 || s.is_empty() { + return s.to_owned(); + } + let indent_str = " ".repeat(indent_level as usize); + s.split_inclusive("\n") + .map(|line| { + if line.trim_end().is_empty() { + Cow::Borrowed(line) + } else { + format!("{indent_str}{line}").into() + } + }) + .collect() +} + pub fn equal_range_by<T, F>(slice: &[T], mut key: F) -> ops::Range<usize> where F: FnMut(&T) -> Ordering, @@ -367,6 +400,37 @@ mod tests { } #[test] + fn test_dedent() { + assert_eq!(dedent_by(0, ""), ""); + assert_eq!(dedent_by(1, ""), ""); + assert_eq!(dedent_by(2, ""), ""); + assert_eq!(dedent_by(0, "foo"), "foo"); + assert_eq!(dedent_by(2, "foo"), "foo"); + assert_eq!(dedent_by(2, " foo"), "foo"); + assert_eq!(dedent_by(2, " foo"), " foo"); + assert_eq!(dedent_by(2, " foo\nbar"), " foo\nbar"); + assert_eq!(dedent_by(2, "foo\n bar"), "foo\n bar"); + assert_eq!(dedent_by(2, "foo\n\n bar"), "foo\n\n bar"); + assert_eq!(dedent_by(2, "foo\n.\n bar"), "foo\n.\n bar"); + assert_eq!(dedent_by(2, "foo\n .\n bar"), "foo\n.\n bar"); + assert_eq!(dedent_by(2, "foo\n .\n bar"), "foo\n .\n bar"); + } + + #[test] + fn test_indent_of() { + assert_eq!(indent_of(""), 0); + assert_eq!(indent_of(" "), 0); + assert_eq!(indent_of(" x"), 1); + assert_eq!(indent_of(" x\n"), 1); + assert_eq!(indent_of(" x\ny"), 0); + assert_eq!(indent_of(" x\n y"), 1); + assert_eq!(indent_of(" x\n y"), 1); + assert_eq!(indent_of(" x\n y"), 2); + assert_eq!(indent_of(" x\n y\n"), 2); + assert_eq!(indent_of(" x\n\n y\n"), 2); + } + + #[test] fn test_replace() { #[track_caller] fn test_replace(src: &str, from: char, to: &str, expected: &str) { diff --git a/crates/syntax/fuzz/Cargo.toml b/crates/syntax/fuzz/Cargo.toml index b2f238efc0..41db3ddcc5 100644 --- a/crates/syntax/fuzz/Cargo.toml +++ b/crates/syntax/fuzz/Cargo.toml @@ -23,6 +23,3 @@ path = "fuzz_targets/parser.rs" [[bin]] name = "reparse" path = "fuzz_targets/reparse.rs" - -[lints] -workspace = true diff --git a/crates/syntax/rust.ungram b/crates/syntax/rust.ungram index 991fe7d83a..3113fc7430 100644 --- a/crates/syntax/rust.ungram +++ b/crates/syntax/rust.ungram @@ -245,7 +245,7 @@ RecordFieldList = RecordField = Attr* Visibility? 'unsafe'? - Name ':' Type ('=' Expr)? + Name ':' Type ('=' default_val:ConstArg)? TupleFieldList = '(' fields:(TupleField (',' TupleField)* ','?)? ')' @@ -268,7 +268,7 @@ VariantList = Variant = Attr* Visibility? - Name FieldList? ('=' Expr)? + Name FieldList? ('=' ConstArg)? Union = Attr* Visibility? @@ -472,8 +472,11 @@ RefExpr = TryExpr = Attr* Expr '?' +TryBlockModifier = + 'try' ('bikeshed' Type)? + BlockExpr = - Attr* Label? ('try' | 'unsafe' | ('async' 'move'?) | ('gen' 'move'?) | 'const') StmtList + Attr* Label? (TryBlockModifier | 'unsafe' | ('async' 'move'?) | ('gen' 'move'?) | 'const') StmtList PrefixExpr = Attr* op:('-' | '!' | '*') Expr diff --git a/crates/syntax/src/algo.rs b/crates/syntax/src/algo.rs index 3ab9c90262..c679921b3f 100644 --- a/crates/syntax/src/algo.rs +++ b/crates/syntax/src/algo.rs @@ -132,3 +132,19 @@ pub fn previous_non_trivia_token(e: impl Into<SyntaxElement>) -> Option<SyntaxTo } None } + +pub fn next_non_trivia_token(e: impl Into<SyntaxElement>) -> Option<SyntaxToken> { + let mut token = match e.into() { + SyntaxElement::Node(n) => n.last_token()?, + SyntaxElement::Token(t) => t, + } + .next_token(); + while let Some(inner) = token { + if !inner.kind().is_trivia() { + return Some(inner); + } else { + token = inner.next_token(); + } + } + None +} diff --git a/crates/syntax/src/ast/edit.rs b/crates/syntax/src/ast/edit.rs index 9b30642fe4..b706d7f722 100644 --- a/crates/syntax/src/ast/edit.rs +++ b/crates/syntax/src/ast/edit.rs @@ -43,8 +43,14 @@ impl ops::Add<u8> for IndentLevel { } } +impl ops::AddAssign<u8> for IndentLevel { + fn add_assign(&mut self, rhs: u8) { + self.0 += rhs; + } +} + impl IndentLevel { - pub fn single() -> IndentLevel { + pub fn zero() -> IndentLevel { IndentLevel(0) } pub fn is_zero(&self) -> bool { diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs index 1cd8146f68..7f59ae4213 100644 --- a/crates/syntax/src/ast/edit_in_place.rs +++ b/crates/syntax/src/ast/edit_in_place.rs @@ -9,8 +9,9 @@ use crate::{ SyntaxKind::{ATTR, COMMENT, WHITESPACE}, SyntaxNode, SyntaxToken, algo::{self, neighbor}, - ast::{self, HasGenericParams, edit::IndentLevel, make}, - ted::{self, Position}, + ast::{self, HasGenericParams, edit::IndentLevel, make, syntax_factory::SyntaxFactory}, + syntax_editor::{Position, SyntaxEditor}, + ted, }; use super::{GenericParam, HasName}; @@ -26,13 +27,13 @@ impl GenericParamsOwnerEdit for ast::Fn { Some(it) => it, None => { let position = if let Some(name) = self.name() { - Position::after(name.syntax) + ted::Position::after(name.syntax) } else if let Some(fn_token) = self.fn_token() { - Position::after(fn_token) + ted::Position::after(fn_token) } else if let Some(param_list) = self.param_list() { - Position::before(param_list.syntax) + ted::Position::before(param_list.syntax) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_generic_param_list(position) } @@ -42,11 +43,11 @@ impl GenericParamsOwnerEdit for ast::Fn { fn get_or_create_where_clause(&self) -> ast::WhereClause { if self.where_clause().is_none() { let position = if let Some(ty) = self.ret_type() { - Position::after(ty.syntax()) + ted::Position::after(ty.syntax()) } else if let Some(param_list) = self.param_list() { - Position::after(param_list.syntax()) + ted::Position::after(param_list.syntax()) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_where_clause(position); } @@ -60,8 +61,8 @@ impl GenericParamsOwnerEdit for ast::Impl { Some(it) => it, None => { let position = match self.impl_token() { - Some(imp_token) => Position::after(imp_token), - None => Position::last_child_of(self.syntax()), + Some(imp_token) => ted::Position::after(imp_token), + None => ted::Position::last_child_of(self.syntax()), }; create_generic_param_list(position) } @@ -71,8 +72,8 @@ impl GenericParamsOwnerEdit for ast::Impl { fn get_or_create_where_clause(&self) -> ast::WhereClause { if self.where_clause().is_none() { let position = match self.assoc_item_list() { - Some(items) => Position::before(items.syntax()), - None => Position::last_child_of(self.syntax()), + Some(items) => ted::Position::before(items.syntax()), + None => ted::Position::last_child_of(self.syntax()), }; create_where_clause(position); } @@ -86,11 +87,11 @@ impl GenericParamsOwnerEdit for ast::Trait { Some(it) => it, None => { let position = if let Some(name) = self.name() { - Position::after(name.syntax) + ted::Position::after(name.syntax) } else if let Some(trait_token) = self.trait_token() { - Position::after(trait_token) + ted::Position::after(trait_token) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_generic_param_list(position) } @@ -100,9 +101,9 @@ impl GenericParamsOwnerEdit for ast::Trait { fn get_or_create_where_clause(&self) -> ast::WhereClause { if self.where_clause().is_none() { let position = match (self.assoc_item_list(), self.semicolon_token()) { - (Some(items), _) => Position::before(items.syntax()), - (_, Some(tok)) => Position::before(tok), - (None, None) => Position::last_child_of(self.syntax()), + (Some(items), _) => ted::Position::before(items.syntax()), + (_, Some(tok)) => ted::Position::before(tok), + (None, None) => ted::Position::last_child_of(self.syntax()), }; create_where_clause(position); } @@ -116,11 +117,11 @@ impl GenericParamsOwnerEdit for ast::TypeAlias { Some(it) => it, None => { let position = if let Some(name) = self.name() { - Position::after(name.syntax) + ted::Position::after(name.syntax) } else if let Some(trait_token) = self.type_token() { - Position::after(trait_token) + ted::Position::after(trait_token) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_generic_param_list(position) } @@ -130,10 +131,10 @@ impl GenericParamsOwnerEdit for ast::TypeAlias { fn get_or_create_where_clause(&self) -> ast::WhereClause { if self.where_clause().is_none() { let position = match self.eq_token() { - Some(tok) => Position::before(tok), + Some(tok) => ted::Position::before(tok), None => match self.semicolon_token() { - Some(tok) => Position::before(tok), - None => Position::last_child_of(self.syntax()), + Some(tok) => ted::Position::before(tok), + None => ted::Position::last_child_of(self.syntax()), }, }; create_where_clause(position); @@ -148,11 +149,11 @@ impl GenericParamsOwnerEdit for ast::Struct { Some(it) => it, None => { let position = if let Some(name) = self.name() { - Position::after(name.syntax) + ted::Position::after(name.syntax) } else if let Some(struct_token) = self.struct_token() { - Position::after(struct_token) + ted::Position::after(struct_token) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_generic_param_list(position) } @@ -166,13 +167,13 @@ impl GenericParamsOwnerEdit for ast::Struct { ast::FieldList::TupleFieldList(it) => Some(it), }); let position = if let Some(tfl) = tfl { - Position::after(tfl.syntax()) + ted::Position::after(tfl.syntax()) } else if let Some(gpl) = self.generic_param_list() { - Position::after(gpl.syntax()) + ted::Position::after(gpl.syntax()) } else if let Some(name) = self.name() { - Position::after(name.syntax()) + ted::Position::after(name.syntax()) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_where_clause(position); } @@ -186,11 +187,11 @@ impl GenericParamsOwnerEdit for ast::Enum { Some(it) => it, None => { let position = if let Some(name) = self.name() { - Position::after(name.syntax) + ted::Position::after(name.syntax) } else if let Some(enum_token) = self.enum_token() { - Position::after(enum_token) + ted::Position::after(enum_token) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_generic_param_list(position) } @@ -200,11 +201,11 @@ impl GenericParamsOwnerEdit for ast::Enum { fn get_or_create_where_clause(&self) -> ast::WhereClause { if self.where_clause().is_none() { let position = if let Some(gpl) = self.generic_param_list() { - Position::after(gpl.syntax()) + ted::Position::after(gpl.syntax()) } else if let Some(name) = self.name() { - Position::after(name.syntax()) + ted::Position::after(name.syntax()) } else { - Position::last_child_of(self.syntax()) + ted::Position::last_child_of(self.syntax()) }; create_where_clause(position); } @@ -212,12 +213,12 @@ impl GenericParamsOwnerEdit for ast::Enum { } } -fn create_where_clause(position: Position) { +fn create_where_clause(position: ted::Position) { let where_clause = make::where_clause(empty()).clone_for_update(); ted::insert(position, where_clause.syntax()); } -fn create_generic_param_list(position: Position) -> ast::GenericParamList { +fn create_generic_param_list(position: ted::Position) -> ast::GenericParamList { let gpl = make::generic_param_list(empty()).clone_for_update(); ted::insert_raw(position, gpl.syntax()); gpl @@ -253,7 +254,7 @@ impl ast::GenericParamList { pub fn add_generic_param(&self, generic_param: ast::GenericParam) { match self.generic_params().last() { Some(last_param) => { - let position = Position::after(last_param.syntax()); + let position = ted::Position::after(last_param.syntax()); let elements = vec![ make::token(T![,]).into(), make::tokens::single_space().into(), @@ -262,7 +263,7 @@ impl ast::GenericParamList { ted::insert_all(position, elements); } None => { - let after_l_angle = Position::after(self.l_angle_token().unwrap()); + let after_l_angle = ted::Position::after(self.l_angle_token().unwrap()); ted::insert(after_l_angle, generic_param.syntax()); } } @@ -412,7 +413,7 @@ impl ast::UseTree { match self.use_tree_list() { Some(it) => it, None => { - let position = Position::last_child_of(self.syntax()); + let position = ted::Position::last_child_of(self.syntax()); let use_tree_list = make::use_tree_list(empty()).clone_for_update(); let mut elements = Vec::with_capacity(2); if self.coloncolon_token().is_none() { @@ -458,7 +459,7 @@ impl ast::UseTree { // Next, transform 'suffix' use tree into 'prefix::{suffix}' let subtree = self.clone_subtree().clone_for_update(); ted::remove_all_iter(self.syntax().children_with_tokens()); - ted::insert(Position::first_child_of(self.syntax()), prefix.syntax()); + ted::insert(ted::Position::first_child_of(self.syntax()), prefix.syntax()); self.get_or_create_use_tree_list().add_use_tree(subtree); fn split_path_prefix(prefix: &ast::Path) -> Option<()> { @@ -507,7 +508,7 @@ impl ast::UseTreeList { pub fn add_use_tree(&self, use_tree: ast::UseTree) { let (position, elements) = match self.use_trees().last() { Some(last_tree) => ( - Position::after(last_tree.syntax()), + ted::Position::after(last_tree.syntax()), vec![ make::token(T![,]).into(), make::tokens::single_space().into(), @@ -516,8 +517,8 @@ impl ast::UseTreeList { ), None => { let position = match self.l_curly_token() { - Some(l_curly) => Position::after(l_curly), - None => Position::last_child_of(self.syntax()), + Some(l_curly) => ted::Position::after(l_curly), + None => ted::Position::last_child_of(self.syntax()), }; (position, vec![use_tree.syntax.into()]) } @@ -582,15 +583,15 @@ impl ast::AssocItemList { let (indent, position, whitespace) = match self.assoc_items().last() { Some(last_item) => ( IndentLevel::from_node(last_item.syntax()), - Position::after(last_item.syntax()), + ted::Position::after(last_item.syntax()), "\n\n", ), None => match self.l_curly_token() { Some(l_curly) => { normalize_ws_between_braces(self.syntax()); - (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n") + (IndentLevel::from_token(&l_curly) + 1, ted::Position::after(&l_curly), "\n") } - None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"), + None => (IndentLevel::zero(), ted::Position::last_child_of(self.syntax()), "\n"), }, }; let elements: Vec<SyntaxElement> = vec![ @@ -618,17 +619,17 @@ impl ast::RecordExprFieldList { let position = match self.fields().last() { Some(last_field) => { let comma = get_or_insert_comma_after(last_field.syntax()); - Position::after(comma) + ted::Position::after(comma) } None => match self.l_curly_token() { - Some(it) => Position::after(it), - None => Position::last_child_of(self.syntax()), + Some(it) => ted::Position::after(it), + None => ted::Position::last_child_of(self.syntax()), }, }; ted::insert_all(position, vec![whitespace.into(), field.syntax().clone().into()]); if is_multiline { - ted::insert(Position::after(field.syntax()), ast::make::token(T![,])); + ted::insert(ted::Position::after(field.syntax()), ast::make::token(T![,])); } } } @@ -656,7 +657,7 @@ impl ast::RecordExprField { ast::make::tokens::single_space().into(), expr.syntax().clone().into(), ]; - ted::insert_all_raw(Position::last_child_of(self.syntax()), children); + ted::insert_all_raw(ted::Position::last_child_of(self.syntax()), children); } } } @@ -679,17 +680,17 @@ impl ast::RecordPatFieldList { Some(last_field) => { let syntax = last_field.syntax(); let comma = get_or_insert_comma_after(syntax); - Position::after(comma) + ted::Position::after(comma) } None => match self.l_curly_token() { - Some(it) => Position::after(it), - None => Position::last_child_of(self.syntax()), + Some(it) => ted::Position::after(it), + None => ted::Position::last_child_of(self.syntax()), }, }; ted::insert_all(position, vec![whitespace.into(), field.syntax().clone().into()]); if is_multiline { - ted::insert(Position::after(field.syntax()), ast::make::token(T![,])); + ted::insert(ted::Position::after(field.syntax()), ast::make::token(T![,])); } } } @@ -703,7 +704,7 @@ fn get_or_insert_comma_after(syntax: &SyntaxNode) -> SyntaxToken { Some(it) => it, None => { let comma = ast::make::token(T![,]); - ted::insert(Position::after(syntax), &comma); + ted::insert(ted::Position::after(syntax), &comma); comma } } @@ -728,7 +729,7 @@ fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> { } } Some(ws) if ws.kind() == T!['}'] => { - ted::insert(Position::after(l), make::tokens::whitespace(&format!("\n{indent}"))); + ted::insert(ted::Position::after(l), make::tokens::whitespace(&format!("\n{indent}"))); } _ => (), } @@ -780,6 +781,56 @@ impl ast::IdentPat { } } } + + pub fn set_pat_with_editor( + &self, + pat: Option<ast::Pat>, + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, + ) { + match pat { + None => { + if let Some(at_token) = self.at_token() { + // Remove `@ Pat` + let start = at_token.clone().into(); + let end = self + .pat() + .map(|it| it.syntax().clone().into()) + .unwrap_or_else(|| at_token.into()); + syntax_editor.delete_all(start..=end); + + // Remove any trailing ws + if let Some(last) = + self.syntax().last_token().filter(|it| it.kind() == WHITESPACE) + { + last.detach(); + } + } + } + Some(pat) => { + if let Some(old_pat) = self.pat() { + // Replace existing pattern + syntax_editor.replace(old_pat.syntax(), pat.syntax()) + } else if let Some(at_token) = self.at_token() { + // Have an `@` token but not a pattern yet + syntax_editor.insert(Position::after(at_token), pat.syntax()); + } else { + // Don't have an `@`, should have a name + let name = self.name().unwrap(); + + syntax_editor.insert_all( + Position::after(name.syntax()), + vec![ + syntax_factory.whitespace(" ").into(), + syntax_factory.token(T![@]).into(), + syntax_factory.whitespace(" ").into(), + pat.syntax().clone().into(), + ], + ) + } + } + } + } } pub trait HasVisibilityEdit: ast::HasVisibility { diff --git a/crates/syntax/src/ast/expr_ext.rs b/crates/syntax/src/ast/expr_ext.rs index db66995381..b44150f868 100644 --- a/crates/syntax/src/ast/expr_ext.rs +++ b/crates/syntax/src/ast/expr_ext.rs @@ -375,7 +375,11 @@ impl ast::Literal { pub enum BlockModifier { Async(SyntaxToken), Unsafe(SyntaxToken), - Try(SyntaxToken), + Try { + try_token: SyntaxToken, + bikeshed_token: Option<SyntaxToken>, + result_type: Option<ast::Type>, + }, Const(SyntaxToken), AsyncGen(SyntaxToken), Gen(SyntaxToken), @@ -394,7 +398,13 @@ impl ast::BlockExpr { }) .or_else(|| self.async_token().map(BlockModifier::Async)) .or_else(|| self.unsafe_token().map(BlockModifier::Unsafe)) - .or_else(|| self.try_token().map(BlockModifier::Try)) + .or_else(|| { + let modifier = self.try_block_modifier()?; + let try_token = modifier.try_token()?; + let bikeshed_token = modifier.bikeshed_token(); + let result_type = modifier.ty(); + Some(BlockModifier::Try { try_token, bikeshed_token, result_type }) + }) .or_else(|| self.const_token().map(BlockModifier::Const)) .or_else(|| self.label().map(BlockModifier::Label)) } diff --git a/crates/syntax/src/ast/generated/nodes.rs b/crates/syntax/src/ast/generated/nodes.rs index 7b9f5b9166..7334de0fd9 100644 --- a/crates/syntax/src/ast/generated/nodes.rs +++ b/crates/syntax/src/ast/generated/nodes.rs @@ -323,6 +323,8 @@ impl BlockExpr { #[inline] pub fn stmt_list(&self) -> Option<StmtList> { support::child(&self.syntax) } #[inline] + pub fn try_block_modifier(&self) -> Option<TryBlockModifier> { support::child(&self.syntax) } + #[inline] pub fn async_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![async]) } #[inline] pub fn const_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![const]) } @@ -331,8 +333,6 @@ impl BlockExpr { #[inline] pub fn move_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![move]) } #[inline] - pub fn try_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![try]) } - #[inline] pub fn unsafe_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![unsafe]) } } pub struct BoxPat { @@ -1337,7 +1337,7 @@ impl ast::HasName for RecordField {} impl ast::HasVisibility for RecordField {} impl RecordField { #[inline] - pub fn expr(&self) -> Option<Expr> { support::child(&self.syntax) } + pub fn default_val(&self) -> Option<ConstArg> { support::child(&self.syntax) } #[inline] pub fn ty(&self) -> Option<Type> { support::child(&self.syntax) } #[inline] @@ -1630,6 +1630,19 @@ impl Trait { #[inline] pub fn unsafe_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![unsafe]) } } +pub struct TryBlockModifier { + pub(crate) syntax: SyntaxNode, +} +impl TryBlockModifier { + #[inline] + pub fn ty(&self) -> Option<Type> { support::child(&self.syntax) } + #[inline] + pub fn bikeshed_token(&self) -> Option<SyntaxToken> { + support::token(&self.syntax, T![bikeshed]) + } + #[inline] + pub fn try_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![try]) } +} pub struct TryExpr { pub(crate) syntax: SyntaxNode, } @@ -1883,7 +1896,7 @@ impl ast::HasName for Variant {} impl ast::HasVisibility for Variant {} impl Variant { #[inline] - pub fn expr(&self) -> Option<Expr> { support::child(&self.syntax) } + pub fn const_arg(&self) -> Option<ConstArg> { support::child(&self.syntax) } #[inline] pub fn field_list(&self) -> Option<FieldList> { support::child(&self.syntax) } #[inline] @@ -6320,6 +6333,38 @@ impl fmt::Debug for Trait { f.debug_struct("Trait").field("syntax", &self.syntax).finish() } } +impl AstNode for TryBlockModifier { + #[inline] + fn kind() -> SyntaxKind + where + Self: Sized, + { + TRY_BLOCK_MODIFIER + } + #[inline] + fn can_cast(kind: SyntaxKind) -> bool { kind == TRY_BLOCK_MODIFIER } + #[inline] + fn cast(syntax: SyntaxNode) -> Option<Self> { + if Self::can_cast(syntax.kind()) { Some(Self { syntax }) } else { None } + } + #[inline] + fn syntax(&self) -> &SyntaxNode { &self.syntax } +} +impl hash::Hash for TryBlockModifier { + fn hash<H: hash::Hasher>(&self, state: &mut H) { self.syntax.hash(state); } +} +impl Eq for TryBlockModifier {} +impl PartialEq for TryBlockModifier { + fn eq(&self, other: &Self) -> bool { self.syntax == other.syntax } +} +impl Clone for TryBlockModifier { + fn clone(&self) -> Self { Self { syntax: self.syntax.clone() } } +} +impl fmt::Debug for TryBlockModifier { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TryBlockModifier").field("syntax", &self.syntax).finish() + } +} impl AstNode for TryExpr { #[inline] fn kind() -> SyntaxKind @@ -9979,6 +10024,11 @@ impl std::fmt::Display for Trait { std::fmt::Display::fmt(self.syntax(), f) } } +impl std::fmt::Display for TryBlockModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self.syntax(), f) + } +} impl std::fmt::Display for TryExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self.syntax(), f) diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 98d759aef2..00971569a2 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -74,10 +74,18 @@ pub mod ext { expr_from_text("_") } pub fn expr_ty_default(ty: &ast::Type) -> ast::Expr { - expr_from_text(&format!("{ty}::default()")) + if !ty.needs_angles_in_path() { + expr_from_text(&format!("{ty}::default()")) + } else { + expr_from_text(&format!("<{ty}>::default()")) + } } pub fn expr_ty_new(ty: &ast::Type) -> ast::Expr { - expr_from_text(&format!("{ty}::new()")) + if !ty.needs_angles_in_path() { + expr_from_text(&format!("{ty}::new()")) + } else { + expr_from_text(&format!("<{ty}>::new()")) + } } pub fn expr_self() -> ast::Expr { expr_from_text("self") @@ -644,7 +652,8 @@ pub fn expr_await(expr: ast::Expr) -> ast::Expr { expr_from_text(&format!("{expr}.await")) } pub fn expr_match(expr: ast::Expr, match_arm_list: ast::MatchArmList) -> ast::MatchExpr { - expr_from_text(&format!("match {expr} {match_arm_list}")) + let ws = block_whitespace(&expr); + expr_from_text(&format!("match {expr}{ws}{match_arm_list}")) } pub fn expr_if( condition: ast::Expr, @@ -656,14 +665,17 @@ pub fn expr_if( Some(ast::ElseBranch::IfExpr(if_expr)) => format!("else {if_expr}"), None => String::new(), }; - expr_from_text(&format!("if {condition} {then_branch} {else_branch}")) + let ws = block_whitespace(&condition); + expr_from_text(&format!("if {condition}{ws}{then_branch} {else_branch}")) } pub fn expr_for_loop(pat: ast::Pat, expr: ast::Expr, block: ast::BlockExpr) -> ast::ForExpr { - expr_from_text(&format!("for {pat} in {expr} {block}")) + let ws = block_whitespace(&expr); + expr_from_text(&format!("for {pat} in {expr}{ws}{block}")) } pub fn expr_while_loop(condition: ast::Expr, block: ast::BlockExpr) -> ast::WhileExpr { - expr_from_text(&format!("while {condition} {block}")) + let ws = block_whitespace(&condition); + expr_from_text(&format!("while {condition}{ws}{block}")) } pub fn expr_loop(block: ast::BlockExpr) -> ast::Expr { @@ -723,6 +735,9 @@ pub fn expr_assignment(lhs: ast::Expr, rhs: ast::Expr) -> ast::BinExpr { fn expr_from_text<E: Into<ast::Expr> + AstNode>(text: &str) -> E { ast_from_text(&format!("const C: () = {text};")) } +fn block_whitespace(after: &impl AstNode) -> &'static str { + if after.syntax().text().contains_char('\n') { "\n" } else { " " } +} pub fn expr_let(pattern: ast::Pat, expr: ast::Expr) -> ast::LetExpr { ast_from_text(&format!("const _: () = while let {pattern} = {expr} {{}};")) } @@ -880,8 +895,9 @@ pub fn ref_pat(pat: ast::Pat) -> ast::RefPat { pub fn match_arm(pat: ast::Pat, guard: Option<ast::MatchGuard>, expr: ast::Expr) -> ast::MatchArm { let comma_str = if expr.is_block_like() { "" } else { "," }; + let ws = guard.as_ref().filter(|_| expr.is_block_like()).map_or(" ", block_whitespace); return match guard { - Some(guard) => from_text(&format!("{pat} {guard} => {expr}{comma_str}")), + Some(guard) => from_text(&format!("{pat} {guard} =>{ws}{expr}{comma_str}")), None => from_text(&format!("{pat} => {expr}{comma_str}")), }; @@ -890,19 +906,6 @@ pub fn match_arm(pat: ast::Pat, guard: Option<ast::MatchGuard>, expr: ast::Expr) } } -pub fn match_arm_with_guard( - pats: impl IntoIterator<Item = ast::Pat>, - guard: ast::Expr, - expr: ast::Expr, -) -> ast::MatchArm { - let pats_str = pats.into_iter().join(" | "); - return from_text(&format!("{pats_str} if {guard} => {expr}")); - - fn from_text(text: &str) -> ast::MatchArm { - ast_from_text(&format!("fn f() {{ match () {{{text}}} }}")) - } -} - pub fn match_guard(condition: ast::Expr) -> ast::MatchGuard { return from_text(&format!("if {condition}")); diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index 76cfea9d5b..63e4608d0f 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs @@ -717,6 +717,10 @@ impl ast::Type { None } } + + pub fn needs_angles_in_path(&self) -> bool { + !matches!(self, ast::Type::PathType(_)) || self.generic_arg_list().is_some() + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs index 5fe419ad4e..44114a7802 100644 --- a/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -1,9 +1,11 @@ //! Wrappers over [`make`] constructors +use either::Either; + use crate::{ AstNode, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken, ast::{ self, HasArgList, HasAttrs, HasGenericArgs, HasGenericParams, HasLoopBody, HasName, - HasTypeBounds, HasVisibility, RangeItem, make, + HasTypeBounds, HasVisibility, Lifetime, Param, RangeItem, make, }, syntax_editor::SyntaxMappingBuilder, }; @@ -19,6 +21,18 @@ impl SyntaxFactory { make::name_ref(name).clone_for_update() } + pub fn name_ref_self_ty(&self) -> ast::NameRef { + make::name_ref_self_ty().clone_for_update() + } + + pub fn expr_todo(&self) -> ast::Expr { + make::ext::expr_todo().clone_for_update() + } + + pub fn expr_self(&self) -> ast::Expr { + make::ext::expr_self().clone_for_update() + } + pub fn lifetime(&self, text: &str) -> ast::Lifetime { make::lifetime(text).clone_for_update() } @@ -49,6 +63,26 @@ impl SyntaxFactory { ast } + pub fn type_bound(&self, bound: ast::Type) -> ast::TypeBound { + make::type_bound(bound).clone_for_update() + } + + pub fn type_bound_list( + &self, + bounds: impl IntoIterator<Item = ast::TypeBound>, + ) -> Option<ast::TypeBoundList> { + let (bounds, input) = iterator_input(bounds); + let ast = make::type_bound_list(bounds)?.clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_children(input, ast.bounds().map(|b| b.syntax().clone())); + builder.finish(&mut mapping); + } + + Some(ast) + } + pub fn type_param( &self, name: ast::Name, @@ -75,6 +109,160 @@ impl SyntaxFactory { make::path_from_text(text).clone_for_update() } + pub fn path_concat(&self, first: ast::Path, second: ast::Path) -> ast::Path { + make::path_concat(first, second).clone_for_update() + } + + pub fn visibility_pub_crate(&self) -> ast::Visibility { + make::visibility_pub_crate().clone_for_update() + } + + pub fn visibility_pub(&self) -> ast::Visibility { + make::visibility_pub().clone_for_update() + } + + pub fn struct_( + &self, + visibility: Option<ast::Visibility>, + strukt_name: ast::Name, + generic_param_list: Option<ast::GenericParamList>, + field_list: ast::FieldList, + ) -> ast::Struct { + let ast = make::struct_( + visibility.clone(), + strukt_name.clone(), + generic_param_list.clone(), + field_list.clone(), + ) + .clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(visibility) = visibility { + builder.map_node( + visibility.syntax().clone(), + ast.visibility().unwrap().syntax().clone(), + ); + } + builder.map_node(strukt_name.syntax().clone(), ast.name().unwrap().syntax().clone()); + if let Some(generic_param_list) = generic_param_list { + builder.map_node( + generic_param_list.syntax().clone(), + ast.generic_param_list().unwrap().syntax().clone(), + ); + } + builder + .map_node(field_list.syntax().clone(), ast.field_list().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn unnamed_param(&self, ty: ast::Type) -> ast::Param { + let ast = make::unnamed_param(ty.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn ty_fn_ptr<I: Iterator<Item = Param>>( + &self, + is_unsafe: bool, + abi: Option<ast::Abi>, + params: I, + ret_type: Option<ast::RetType>, + ) -> ast::FnPtrType { + let (params, params_input) = iterator_input(params); + let ast = make::ty_fn_ptr(is_unsafe, abi.clone(), params.into_iter(), ret_type.clone()) + .clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(abi) = abi { + builder.map_node(abi.syntax().clone(), ast.abi().unwrap().syntax().clone()); + } + builder.map_children( + params_input, + ast.param_list().unwrap().params().map(|p| p.syntax().clone()), + ); + if let Some(ret_type) = ret_type { + builder + .map_node(ret_type.syntax().clone(), ast.ret_type().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast + } + + pub fn where_pred( + &self, + path: Either<ast::Lifetime, ast::Type>, + bounds: impl IntoIterator<Item = ast::TypeBound>, + ) -> ast::WherePred { + let (bounds, bounds_input) = iterator_input(bounds); + let ast = make::where_pred(path.clone(), bounds).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + match &path { + Either::Left(lifetime) => { + builder.map_node( + lifetime.syntax().clone(), + ast.lifetime().unwrap().syntax().clone(), + ); + } + Either::Right(ty) => { + builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone()); + } + } + if let Some(type_bound_list) = ast.type_bound_list() { + builder.map_children( + bounds_input, + type_bound_list.bounds().map(|b| b.syntax().clone()), + ); + } + builder.finish(&mut mapping); + } + + ast + } + + pub fn where_clause( + &self, + predicates: impl IntoIterator<Item = ast::WherePred>, + ) -> ast::WhereClause { + let (predicates, input) = iterator_input(predicates); + let ast = make::where_clause(predicates).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_children(input, ast.predicates().map(|p| p.syntax().clone())); + builder.finish(&mut mapping); + } + + ast + } + + pub fn impl_trait_type(&self, bounds: ast::TypeBoundList) -> ast::ImplTraitType { + let ast = make::impl_trait_type(bounds.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder + .map_node(bounds.syntax().clone(), ast.type_bound_list().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn expr_field(&self, receiver: ast::Expr, field: &str) -> ast::FieldExpr { let ast::Expr::FieldExpr(ast) = make::expr_field(receiver.clone(), field).clone_for_update() @@ -265,6 +453,64 @@ impl SyntaxFactory { ast } + pub fn generic_ty_path_segment( + &self, + name_ref: ast::NameRef, + generic_args: impl IntoIterator<Item = ast::GenericArg>, + ) -> ast::PathSegment { + let (generic_args, input) = iterator_input(generic_args); + let ast = make::generic_ty_path_segment(name_ref.clone(), generic_args).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(name_ref.syntax().clone(), ast.name_ref().unwrap().syntax().clone()); + builder.map_children( + input, + ast.generic_arg_list().unwrap().generic_args().map(|a| a.syntax().clone()), + ); + builder.finish(&mut mapping); + } + + ast + } + + pub fn tail_only_block_expr(&self, tail_expr: ast::Expr) -> ast::BlockExpr { + let ast = make::tail_only_block_expr(tail_expr.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let stmt_list = ast.stmt_list().unwrap(); + let mut builder = SyntaxMappingBuilder::new(stmt_list.syntax().clone()); + builder.map_node( + tail_expr.syntax().clone(), + stmt_list.tail_expr().unwrap().syntax().clone(), + ); + builder.finish(&mut mapping); + } + + ast + } + + pub fn expr_bin_op(&self, lhs: ast::Expr, op: ast::BinaryOp, rhs: ast::Expr) -> ast::Expr { + let ast::Expr::BinExpr(ast) = + make::expr_bin_op(lhs.clone(), op, rhs.clone()).clone_for_update() + else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(lhs.syntax().clone(), ast.lhs().unwrap().syntax().clone()); + builder.map_node(rhs.syntax().clone(), ast.rhs().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast.into() + } + + pub fn ty_placeholder(&self) -> ast::Type { + make::ty_placeholder().clone_for_update() + } + pub fn path_segment_generics( &self, name_ref: ast::NameRef, @@ -295,7 +541,23 @@ impl SyntaxFactory { visibility: Option<ast::Visibility>, use_tree: ast::UseTree, ) -> ast::Use { - make::use_(attrs, visibility, use_tree).clone_for_update() + let (attrs, attrs_input) = iterator_input(attrs); + let ast = make::use_(attrs, visibility.clone(), use_tree.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_children(attrs_input, ast.attrs().map(|attr| attr.syntax().clone())); + if let Some(visibility) = visibility { + builder.map_node( + visibility.syntax().clone(), + ast.visibility().unwrap().syntax().clone(), + ); + } + builder.map_node(use_tree.syntax().clone(), ast.use_tree().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast } pub fn use_tree( @@ -305,7 +567,25 @@ impl SyntaxFactory { alias: Option<ast::Rename>, add_star: bool, ) -> ast::UseTree { - make::use_tree(path, use_tree_list, alias, add_star).clone_for_update() + let ast = make::use_tree(path.clone(), use_tree_list.clone(), alias.clone(), add_star) + .clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(path.syntax().clone(), ast.path().unwrap().syntax().clone()); + if let Some(use_tree_list) = use_tree_list { + builder.map_node( + use_tree_list.syntax().clone(), + ast.use_tree_list().unwrap().syntax().clone(), + ); + } + if let Some(alias) = alias { + builder.map_node(alias.syntax().clone(), ast.rename().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast } pub fn path_unqualified(&self, segment: ast::PathSegment) -> ast::Path { @@ -806,10 +1086,6 @@ impl SyntaxFactory { unreachable!() }; - if let Some(mut mapping) = self.mappings() { - SyntaxMappingBuilder::new(ast.syntax().clone()).finish(&mut mapping); - } - ast } @@ -1198,6 +1474,22 @@ impl SyntaxFactory { ast } + pub fn record_expr_field_list( + &self, + fields: impl IntoIterator<Item = ast::RecordExprField>, + ) -> ast::RecordExprFieldList { + let (fields, input) = iterator_input(fields); + let ast = make::record_expr_field_list(fields).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_children(input, ast.fields().map(|f| f.syntax().clone())); + builder.finish(&mut mapping); + } + + ast + } + pub fn record_expr_field( &self, name: ast::NameRef, @@ -1208,7 +1500,20 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_node(name.syntax().clone(), ast.name_ref().unwrap().syntax().clone()); + if let Some(ast_name_ref) = ast.name_ref() { + // NameRef is a direct child + builder.map_node(name.syntax().clone(), ast_name_ref.syntax().clone()); + } else { + // NameRef is nested inside PathExpr > Path > PathSegment. + // map_node requires the output to be a direct child of the builder's parent, so + // we need a separate builder scoped to PathSegment. + let ast::Expr::PathExpr(path_expr) = ast.expr().unwrap() else { unreachable!() }; + let path_segment = path_expr.path().unwrap().segment().unwrap(); + let inner_name_ref = path_segment.name_ref().unwrap(); + let mut inner_builder = SyntaxMappingBuilder::new(path_segment.syntax().clone()); + inner_builder.map_node(name.syntax().clone(), inner_name_ref.syntax().clone()); + inner_builder.finish(&mut mapping); + } if let Some(expr) = expr { builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone()); } @@ -1413,8 +1718,10 @@ impl SyntaxFactory { } if let Some(discriminant) = discriminant { - builder - .map_node(discriminant.syntax().clone(), ast.expr().unwrap().syntax().clone()); + builder.map_node( + discriminant.syntax().clone(), + ast.const_arg().unwrap().syntax().clone(), + ); } builder.finish(&mut mapping); @@ -1507,6 +1814,10 @@ impl SyntaxFactory { ast } + pub fn assoc_item_list_empty(&self) -> ast::AssocItemList { + make::assoc_item_list(None).clone_for_update() + } + pub fn attr_outer(&self, meta: ast::Meta) -> ast::Attr { let ast = make::attr_outer(meta.clone()).clone_for_update(); @@ -1590,6 +1901,65 @@ impl SyntaxFactory { ast } + pub fn self_param(&self) -> ast::SelfParam { + let ast = make::self_param().clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn impl_( + &self, + attrs: impl IntoIterator<Item = ast::Attr>, + generic_params: Option<ast::GenericParamList>, + generic_args: Option<ast::GenericArgList>, + path_type: ast::Type, + where_clause: Option<ast::WhereClause>, + body: Option<ast::AssocItemList>, + ) -> ast::Impl { + let (attrs, attrs_input) = iterator_input(attrs); + let ast = make::impl_( + attrs, + generic_params.clone(), + generic_args.clone(), + path_type.clone(), + where_clause.clone(), + body.clone(), + ) + .clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_children(attrs_input, ast.attrs().map(|attr| attr.syntax().clone())); + if let Some(generic_params) = generic_params { + builder.map_node( + generic_params.syntax().clone(), + ast.generic_param_list().unwrap().syntax().clone(), + ); + } + builder.map_node(path_type.syntax().clone(), ast.self_ty().unwrap().syntax().clone()); + if let Some(where_clause) = where_clause { + builder.map_node( + where_clause.syntax().clone(), + ast.where_clause().unwrap().syntax().clone(), + ); + } + if let Some(body) = body { + builder.map_node( + body.syntax().clone(), + ast.assoc_item_list().unwrap().syntax().clone(), + ); + } + builder.finish(&mut mapping); + } + + ast + } + pub fn ret_type(&self, ty: ast::Type) -> ast::RetType { let ast = make::ret_type(ty.clone()).clone_for_update(); @@ -1616,6 +1986,65 @@ impl SyntaxFactory { } ast } + + pub fn field_from_idents<'a>( + &self, + parts: impl std::iter::IntoIterator<Item = &'a str>, + ) -> Option<ast::Expr> { + make::ext::field_from_idents(parts) + } + + pub fn expr_await(&self, expr: ast::Expr) -> ast::AwaitExpr { + let ast::Expr::AwaitExpr(ast) = make::expr_await(expr.clone()).clone_for_update() else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn expr_break(&self, label: Option<Lifetime>, expr: Option<ast::Expr>) -> ast::BreakExpr { + let ast::Expr::BreakExpr(ast) = + make::expr_break(label.clone(), expr.clone()).clone_for_update() + else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(label) = label { + builder.map_node(label.syntax().clone(), ast.lifetime().unwrap().syntax().clone()); + } + if let Some(expr) = expr { + builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast + } + + pub fn expr_continue(&self, label: Option<Lifetime>) -> ast::ContinueExpr { + let ast::Expr::ContinueExpr(ast) = make::expr_continue(label.clone()).clone_for_update() + else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(label) = label { + builder.map_node(label.syntax().clone(), ast.lifetime().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast + } } // `ext` constructors diff --git a/crates/syntax/src/syntax_editor.rs b/crates/syntax/src/syntax_editor.rs index 5683d891be..e6937e4d0f 100644 --- a/crates/syntax/src/syntax_editor.rs +++ b/crates/syntax/src/syntax_editor.rs @@ -20,7 +20,7 @@ mod edit_algo; mod edits; mod mapping; -pub use edits::Removable; +pub use edits::{GetOrCreateWhereClause, Removable}; pub use mapping::{SyntaxMapping, SyntaxMappingBuilder}; #[derive(Debug)] diff --git a/crates/syntax/src/syntax_editor/edits.rs b/crates/syntax/src/syntax_editor/edits.rs index 9090f7c9eb..44f0a8038e 100644 --- a/crates/syntax/src/syntax_editor/edits.rs +++ b/crates/syntax/src/syntax_editor/edits.rs @@ -10,6 +10,107 @@ use crate::{ syntax_editor::{Position, SyntaxEditor}, }; +pub trait GetOrCreateWhereClause: ast::HasGenericParams { + fn where_clause_position(&self) -> Option<Position>; + + fn get_or_create_where_clause( + &self, + editor: &mut SyntaxEditor, + make: &SyntaxFactory, + new_preds: impl Iterator<Item = ast::WherePred>, + ) { + let existing = self.where_clause(); + let all_preds: Vec<_> = + existing.iter().flat_map(|wc| wc.predicates()).chain(new_preds).collect(); + let new_where = make.where_clause(all_preds); + + if let Some(existing) = &existing { + editor.replace(existing.syntax(), new_where.syntax()); + } else if let Some(pos) = self.where_clause_position() { + editor.insert_all( + pos, + vec![make.whitespace(" ").into(), new_where.syntax().clone().into()], + ); + } + } +} + +impl GetOrCreateWhereClause for ast::Fn { + fn where_clause_position(&self) -> Option<Position> { + if let Some(ty) = self.ret_type() { + Some(Position::after(ty.syntax())) + } else if let Some(param_list) = self.param_list() { + Some(Position::after(param_list.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Impl { + fn where_clause_position(&self) -> Option<Position> { + if let Some(ty) = self.self_ty() { + Some(Position::after(ty.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Trait { + fn where_clause_position(&self) -> Option<Position> { + if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::TypeAlias { + fn where_clause_position(&self) -> Option<Position> { + if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Struct { + fn where_clause_position(&self) -> Option<Position> { + let tfl = self.field_list().and_then(|fl| match fl { + ast::FieldList::RecordFieldList(_) => None, + ast::FieldList::TupleFieldList(it) => Some(it), + }); + if let Some(tfl) = tfl { + Some(Position::after(tfl.syntax())) + } else if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Enum { + fn where_clause_position(&self) -> Option<Position> { + if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + impl SyntaxEditor { /// Adds a new generic param to the function using `SyntaxEditor` pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) { @@ -109,7 +210,7 @@ impl ast::AssocItemList { normalize_ws_between_braces(editor, self.syntax()); (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n") } - None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"), + None => (IndentLevel::zero(), Position::last_child_of(self.syntax()), "\n"), }, }; @@ -141,7 +242,7 @@ impl ast::VariantList { normalize_ws_between_braces(editor, self.syntax()); (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly)) } - None => (IndentLevel::single(), Position::last_child_of(self.syntax())), + None => (IndentLevel::zero(), Position::last_child_of(self.syntax())), }, }; let elements: Vec<SyntaxElement> = vec![ diff --git a/crates/syntax/src/syntax_error.rs b/crates/syntax/src/syntax_error.rs index 1c902893ab..6f00ef4ed5 100644 --- a/crates/syntax/src/syntax_error.rs +++ b/crates/syntax/src/syntax_error.rs @@ -9,16 +9,6 @@ use crate::{TextRange, TextSize}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SyntaxError(String, TextRange); -// FIXME: there was an unused SyntaxErrorKind previously (before this enum was removed) -// It was introduced in this PR: https://github.com/rust-lang/rust-analyzer/pull/846/files#diff-827da9b03b8f9faa1bade5cdd44d5dafR95 -// but it was not removed by a mistake. -// -// So, we need to find a place where to stick validation for attributes in match clauses. -// Code before refactor: -// InvalidMatchInnerAttr => { -// write!(f, "Inner attributes are only allowed directly after the opening brace of the match expression") -// } - impl SyntaxError { pub fn new(message: impl Into<String>, range: TextRange) -> Self { Self(message.into(), range) diff --git a/crates/test-fixture/src/lib.rs b/crates/test-fixture/src/lib.rs index ca68edd88c..e271c32c86 100644 --- a/crates/test-fixture/src/lib.rs +++ b/crates/test-fixture/src/lib.rs @@ -149,8 +149,8 @@ pub trait WithFixture: Default + ExpandDatabase + SourceDatabase + 'static { let fixture = ChangeFixture::parse(ra_fixture); fixture.change.apply(&mut db); assert_eq!(fixture.files.len(), 1, "Multiple file found in the fixture"); - let file = EditionedFileId::from_span_guess_origin(&db, fixture.files[0]); - (db, file) + let file_id = EditionedFileId::from_span_file_id(&db, fixture.files[0]); + (db, file_id) } /// See the trait documentation for more information on fixtures. @@ -165,7 +165,7 @@ pub trait WithFixture: Default + ExpandDatabase + SourceDatabase + 'static { let files = fixture .files .into_iter() - .map(|file| EditionedFileId::from_span_guess_origin(&db, file)) + .map(|file| EditionedFileId::from_span_file_id(&db, file)) .collect(); (db, files) } @@ -222,7 +222,7 @@ pub trait WithFixture: Default + ExpandDatabase + SourceDatabase + 'static { let (file_id, range_or_offset) = fixture .file_position .expect("Could not find file position in fixture. Did you forget to add an `$0`?"); - let file_id = EditionedFileId::from_span_guess_origin(&db, file_id); + let file_id = EditionedFileId::from_span_file_id(&db, file_id); (db, file_id, range_or_offset) } diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs index 48c3e89525..86fb080732 100644 --- a/crates/test-utils/src/minicore.rs +++ b/crates/test-utils/src/minicore.rs @@ -43,6 +43,7 @@ //! dispatch_from_dyn: unsize, pin //! hash: sized //! include: +//! include_bytes: //! index: sized //! infallible: //! int_impl: size_of, transmute @@ -953,6 +954,9 @@ pub mod ops { #[lang = "from_residual"] fn from_residual(residual: R) -> Self; } + pub const trait Residual<O>: Sized { + type TryType: [const] Try<Output = O, Residual = Self>; + } #[lang = "Try"] pub trait Try: FromResidual<Self::Residual> { type Output; @@ -962,6 +966,12 @@ pub mod ops { #[lang = "branch"] fn branch(self) -> ControlFlow<Self::Residual, Self::Output>; } + #[lang = "into_try_type"] + pub const fn residual_into_try_type<R: [const] Residual<O>, O>( + r: R, + ) -> <R as Residual<O>>::TryType { + FromResidual::from_residual(r) + } impl<B, C> Try for ControlFlow<B, C> { type Output = C; @@ -985,6 +995,10 @@ pub mod ops { } } } + + impl<B, C> Residual<C> for ControlFlow<B, Infallible> { + type TryType = ControlFlow<B, C>; + } // region:option impl<T> Try for Option<T> { type Output = T; @@ -1008,6 +1022,10 @@ pub mod ops { } } } + + impl<T> const Residual<T> for Option<Infallible> { + type TryType = Option<T>; + } // endregion:option // region:result // region:from @@ -1037,10 +1055,14 @@ pub mod ops { } } } + + impl<T, E> const Residual<T> for Result<Infallible, E> { + type TryType = Result<T, E>; + } // endregion:from // endregion:result } - pub use self::try_::{ControlFlow, FromResidual, Try}; + pub use self::try_::{ControlFlow, FromResidual, Residual, Try}; // endregion:try // region:add @@ -1481,6 +1503,19 @@ pub mod slice { loop {} } } + + // region:default + impl<T> const Default for &[T] { + fn default() -> Self { + &[] + } + } + impl<T> const Default for &mut [T] { + fn default() -> Self { + &mut [] + } + } + // endregion:default } // endregion:slice @@ -1667,6 +1702,21 @@ pub mod iter { } } + pub struct Filter<I, P> { + iter: I, + predicate: P, + } + impl<I: Iterator, P> Iterator for Filter<I, P> + where + P: FnMut(&I::Item) -> bool, + { + type Item = I::Item; + + fn next(&mut self) -> Option<I::Item> { + loop {} + } + } + pub struct FilterMap<I, F> { iter: I, f: F, @@ -1683,7 +1733,7 @@ pub mod iter { } } } - pub use self::adapters::{FilterMap, Take}; + pub use self::adapters::{Filter, FilterMap, Take}; mod sources { mod repeat { @@ -1734,6 +1784,13 @@ pub mod iter { { loop {} } + fn filter<P>(self, predicate: P) -> crate::iter::Filter<Self, P> + where + Self: Sized, + P: FnMut(&Self::Item) -> bool, + { + loop {} + } fn filter_map<B, F>(self, _f: F) -> crate::iter::FilterMap<Self, F> where Self: Sized, @@ -2040,6 +2097,14 @@ mod macros { } // endregion:include + // region:include_bytes + #[rustc_builtin_macro] + #[macro_export] + macro_rules! include_bytes { + ($file:expr $(,)?) => {{ /* compiler built-in */ }}; + } + // endregion:include_bytes + // region:concat #[rustc_builtin_macro] #[macro_export] diff --git a/crates/toolchain/src/lib.rs b/crates/toolchain/src/lib.rs index 1a17269838..39319886cf 100644 --- a/crates/toolchain/src/lib.rs +++ b/crates/toolchain/src/lib.rs @@ -74,9 +74,6 @@ impl Tool { // Prevent rustup from automatically installing toolchains, see https://github.com/rust-lang/rust-analyzer/issues/20719. pub const NO_RUSTUP_AUTO_INSTALL_ENV: (&str, &str) = ("RUSTUP_AUTO_INSTALL", "0"); -// These get ignored when displaying what command is running in LSP status messages. -pub const DISPLAY_COMMAND_IGNORE_ENVS: &[&str] = &[NO_RUSTUP_AUTO_INSTALL_ENV.0]; - #[allow(clippy::disallowed_types)] /* generic parameter allows for FxHashMap */ pub fn command<H>( cmd: impl AsRef<OsStr>, diff --git a/crates/tt/src/storage.rs b/crates/tt/src/storage.rs index 4dd02d875a..50a1106175 100644 --- a/crates/tt/src/storage.rs +++ b/crates/tt/src/storage.rs @@ -488,7 +488,7 @@ impl TopSubtree { unreachable!() }; *open_span = S::new(span.open.range, 0); - *close_span = S::new(span.close.range, 0); + *close_span = S::new(span.close.range, 1); } dispatch! { match &mut self.repr => tt => do_it(tt, span) diff --git a/crates/vfs-notify/src/lib.rs b/crates/vfs-notify/src/lib.rs index c6393cc692..428b19c50b 100644 --- a/crates/vfs-notify/src/lib.rs +++ b/crates/vfs-notify/src/lib.rs @@ -208,6 +208,22 @@ impl NotifyActor { ) }) .filter_map(|path| -> Option<(AbsPathBuf, Option<Vec<u8>>)> { + // Ignore events for files/directories that we're not watching. + if !(self.watched_file_entries.contains(&path) + || self + .watched_dir_entries + .iter() + .any(|dir| dir.contains_file(&path))) + { + return None; + } + + // For removed files, fs::metadata() will return Err, but + // we still want to update the VFS. + if matches!(event.kind, EventKind::Remove(_)) { + return Some((path, None)); + } + let meta = fs::metadata(&path).ok()?; if meta.file_type().is_dir() && self @@ -223,15 +239,6 @@ impl NotifyActor { return None; } - if !(self.watched_file_entries.contains(&path) - || self - .watched_dir_entries - .iter() - .any(|dir| dir.contains_file(&path))) - { - return None; - } - let contents = read(&path); Some((path, contents)) }) @@ -317,7 +324,7 @@ impl NotifyActor { fn watch(&mut self, path: &Path) { if let Some((watcher, _)) = &mut self.watcher { - log_notify_error(watcher.watch(path, RecursiveMode::NonRecursive)); + log_notify_error(watcher.watch(path, RecursiveMode::Recursive)); } } diff --git a/docs/book/README.md b/docs/book/README.md index 0a3161f3af..cd4d8783a4 100644 --- a/docs/book/README.md +++ b/docs/book/README.md @@ -6,7 +6,7 @@ The rust analyzer manual uses [mdbook](https://rust-lang.github.io/mdBook/). To run the documentation site locally: -```shell +```bash cargo install mdbook cargo xtask codegen cd docs/book diff --git a/docs/book/src/configuration_generated.md b/docs/book/src/configuration_generated.md index 8460c2c7d0..35fba5accd 100644 --- a/docs/book/src/configuration_generated.md +++ b/docs/book/src/configuration_generated.md @@ -1380,9 +1380,9 @@ Default: `null` Override the command used for bench runnables. The first element of the array should be the program to execute (for example, `cargo`). -Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically +Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically replace the package name, target option (such as `--bin` or `--example`), the target name and -the test name (name of test function or test mod path). +the arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`). ## rust-analyzer.runnables.command {#runnables.command} @@ -1399,9 +1399,9 @@ Default: `null` Override the command used for bench runnables. The first element of the array should be the program to execute (for example, `cargo`). -Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically +Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically replace the package name, target option (such as `--bin` or `--example`), the target name and -the test name (name of test function or test mod path). +the arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`). ## rust-analyzer.runnables.extraArgs {#runnables.extraArgs} @@ -1444,9 +1444,9 @@ Default: `null` Override the command used for test runnables. The first element of the array should be the program to execute (for example, `cargo`). -Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically +Use the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically replace the package name, target option (such as `--bin` or `--example`), the target name and -the test name (name of test function or test mod path). +the arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`). ## rust-analyzer.rustc.source {#rustc.source} diff --git a/docs/book/src/contributing/README.md b/docs/book/src/contributing/README.md index c95a1dba62..bb2b6081ad 100644 --- a/docs/book/src/contributing/README.md +++ b/docs/book/src/contributing/README.md @@ -4,7 +4,7 @@ rust-analyzer is an ordinary Rust project, which is organized as a Cargo workspa So, just ```bash -$ cargo test +cargo test ``` should be enough to get you started! @@ -203,14 +203,14 @@ It is enabled by `RA_COUNT=1`. To measure time for from-scratch analysis, use something like this: ```bash -$ cargo run --release -p rust-analyzer -- analysis-stats ../chalk/ +cargo run --release -p rust-analyzer -- analysis-stats ../chalk/ ``` For measuring time of incremental analysis, use either of these: ```bash -$ cargo run --release -p rust-analyzer -- analysis-bench ../chalk/ --highlight ../chalk/chalk-engine/src/logic.rs -$ cargo run --release -p rust-analyzer -- analysis-bench ../chalk/ --complete ../chalk/chalk-engine/src/logic.rs:94:0 +cargo run --release -p rust-analyzer -- analysis-bench ../chalk/ --highlight ../chalk/chalk-engine/src/logic.rs +cargo run --release -p rust-analyzer -- analysis-bench ../chalk/ --complete ../chalk/chalk-engine/src/logic.rs:94:0 ``` Look for `fn benchmark_xxx` tests for a quick way to reproduce performance problems. @@ -283,7 +283,8 @@ repository. We use the [rustc-josh-sync](https://github.com/rust-lang/josh-sync) repositories. You can find documentation of the tool [here](https://github.com/rust-lang/josh-sync). You can install the synchronization tool using the following commands: -``` + +```bash cargo install --locked --git https://github.com/rust-lang/josh-sync ``` diff --git a/docs/book/src/contributing/debugging.md b/docs/book/src/contributing/debugging.md index fcda664f5e..ace9be025a 100644 --- a/docs/book/src/contributing/debugging.md +++ b/docs/book/src/contributing/debugging.md @@ -68,7 +68,7 @@ while d == 4 { // set a breakpoint here and change the value However for this to work, you will need to enable debug_assertions in your build -```rust +```bash RUSTFLAGS='--cfg debug_assertions' cargo build --release ``` diff --git a/docs/book/src/installation.md b/docs/book/src/installation.md index 3a4c0cf227..cc636c31e6 100644 --- a/docs/book/src/installation.md +++ b/docs/book/src/installation.md @@ -13,7 +13,9 @@ editor](./other_editors.html). rust-analyzer will attempt to install the standard library source code automatically. You can also install it manually with `rustup`. - $ rustup component add rust-src +```bash +rustup component add rust-src +``` Only the latest stable standard library source is officially supported for use with rust-analyzer. If you are using an older toolchain or have diff --git a/docs/book/src/non_cargo_based_projects.md b/docs/book/src/non_cargo_based_projects.md index f1f10ae336..9cc3292444 100644 --- a/docs/book/src/non_cargo_based_projects.md +++ b/docs/book/src/non_cargo_based_projects.md @@ -135,7 +135,7 @@ interface Crate { cfg_groups?: string[]; /// The set of cfgs activated for a given crate, like /// `["unix", "feature=\"foo\"", "feature=\"bar\""]`. - cfg: string[]; + cfg?: string[]; /// Target tuple for this Crate. /// /// Used when running `rustc --print cfg` @@ -143,7 +143,7 @@ interface Crate { target?: string; /// Environment variables, used for /// the `env!` macro - env: { [key: string]: string; }; + env?: { [key: string]: string; }; /// Extra crate-level attributes applied to this crate. /// /// rust-analyzer will behave as if these attributes @@ -155,7 +155,8 @@ interface Crate { crate_attrs?: string[]; /// Whether the crate is a proc-macro crate. - is_proc_macro: boolean; + /// Defaults to `false` if unspecified. + is_proc_macro?: boolean; /// For proc-macro crates, path to compiled /// proc-macro (.so file). proc_macro_dylib_path?: string; diff --git a/docs/book/src/rust_analyzer_binary.md b/docs/book/src/rust_analyzer_binary.md index c7ac3087ce..2b62011a8e 100644 --- a/docs/book/src/rust_analyzer_binary.md +++ b/docs/book/src/rust_analyzer_binary.md @@ -11,9 +11,11 @@ your `$PATH`. On Linux to install the `rust-analyzer` binary into `~/.local/bin`, these commands should work: - $ mkdir -p ~/.local/bin - $ curl -L https://github.com/rust-lang/rust-analyzer/releases/latest/download/rust-analyzer-x86_64-unknown-linux-gnu.gz | gunzip -c - > ~/.local/bin/rust-analyzer - $ chmod +x ~/.local/bin/rust-analyzer +```bash +mkdir -p ~/.local/bin +curl -L https://github.com/rust-lang/rust-analyzer/releases/latest/download/rust-analyzer-x86_64-unknown-linux-gnu.gz | gunzip -c - > ~/.local/bin/rust-analyzer +chmod +x ~/.local/bin/rust-analyzer +``` Make sure that `~/.local/bin` is listed in the `$PATH` variable and use the appropriate URL if you’re not on a `x86-64` system. @@ -24,8 +26,10 @@ or `/usr/local/bin` will work just as well. Alternatively, you can install it from source using the command below. You’ll need the latest stable version of the Rust toolchain. - $ git clone https://github.com/rust-lang/rust-analyzer.git && cd rust-analyzer - $ cargo xtask install --server +```bash +git clone https://github.com/rust-lang/rust-analyzer.git && cd rust-analyzer +cargo xtask install --server +``` If your editor can’t find the binary even though the binary is on your `$PATH`, the likely explanation is that it doesn’t see the same `$PATH` @@ -38,7 +42,9 @@ the environment should help. `rust-analyzer` is available in `rustup`: - $ rustup component add rust-analyzer +```bash +rustup component add rust-analyzer +``` ### Arch Linux @@ -53,7 +59,9 @@ User Repository): Install it with pacman, for example: - $ pacman -S rust-analyzer +```bash +pacman -S rust-analyzer +``` ### Gentoo Linux @@ -64,7 +72,9 @@ Install it with pacman, for example: The `rust-analyzer` binary can be installed via [Homebrew](https://brew.sh/). - $ brew install rust-analyzer +```bash +brew install rust-analyzer +``` ### Windows diff --git a/docs/book/src/troubleshooting.md b/docs/book/src/troubleshooting.md index a357cbef41..c315bfad7f 100644 --- a/docs/book/src/troubleshooting.md +++ b/docs/book/src/troubleshooting.md @@ -37,13 +37,13 @@ bypassing LSP machinery. When filing issues, it is useful (but not necessary) to try to minimize examples. An ideal bug reproduction looks like this: -```shell -$ git clone https://github.com/username/repo.git && cd repo && git switch --detach commit-hash -$ rust-analyzer --version +```bash +git clone https://github.com/username/repo.git && cd repo && git switch --detach commit-hash +rust-analyzer --version rust-analyzer dd12184e4 2021-05-08 dev -$ rust-analyzer analysis-stats . -💀 💀 💀 +rust-analyzer analysis-stats . ``` +💀 💀 💀 It is especially useful when the `repo` doesn’t use external crates or the standard library. diff --git a/docs/book/src/vs_code.md b/docs/book/src/vs_code.md index 233b862d2c..69a96156b8 100644 --- a/docs/book/src/vs_code.md +++ b/docs/book/src/vs_code.md @@ -49,7 +49,9 @@ Alternatively, download a VSIX corresponding to your platform from the Install the extension with the `Extensions: Install from VSIX` command within VS Code, or from the command line via: - $ code --install-extension /path/to/rust-analyzer.vsix +```bash +code --install-extension /path/to/rust-analyzer.vsix +``` If you are running an unsupported platform, you can install `rust-analyzer-no-server.vsix` and compile or obtain a server binary. @@ -64,8 +66,10 @@ example: Both the server and the Code plugin can be installed from source: - $ git clone https://github.com/rust-lang/rust-analyzer.git && cd rust-analyzer - $ cargo xtask install +```bash +git clone https://github.com/rust-lang/rust-analyzer.git && cd rust-analyzer +cargo xtask install +``` You’ll need Cargo, nodejs (matching a supported version of VS Code) and npm for this. @@ -76,7 +80,9 @@ Remote, instead you’ll need to install the `.vsix` manually. If you’re not using Code, you can compile and install only the LSP server: - $ cargo xtask install --server +```bash +cargo xtask install --server +``` Make sure that `.cargo/bin` is in `$PATH` and precedes paths where `rust-analyzer` may also be installed. Specifically, `rustup` includes a @@ -92,12 +98,12 @@ some directories, `/usr/bin` will always be mounted under system-wide installation of Rust, or any other libraries you might want to link to. Some compilers and libraries can be acquired as Flatpak SDKs, such as `org.freedesktop.Sdk.Extension.rust-stable` or -`org.freedesktop.Sdk.Extension.llvm15`. +`org.freedesktop.Sdk.Extension.llvm21`. If you use a Flatpak SDK for Rust, it must be in your `PATH`: - * install the SDK extensions with `flatpak install org.freedesktop.Sdk.Extension.{llvm15,rust-stable}//23.08` - * enable SDK extensions in the editor with the environment variable `FLATPAK_ENABLE_SDK_EXT=llvm15,rust-stable` (this can be done using flatseal or `flatpak override`) + * install the SDK extensions with `flatpak install org.freedesktop.Sdk.Extension.{llvm21,rust-stable}//25.08` + * enable SDK extensions in the editor with the environment variable `FLATPAK_ENABLE_SDK_EXT=llvm21,rust-stable` (this can be done using flatseal or `flatpak override`) If you want to use Flatpak in combination with `rustup`, the following steps might help: @@ -118,4 +124,3 @@ steps might help: A C compiler should already be available via `org.freedesktop.Sdk`. Any other tools or libraries you will need to acquire from Flatpak. - diff --git a/editors/code/package-lock.json b/editors/code/package-lock.json index 57f6bf69be..b51dc4d132 100644 --- a/editors/code/package-lock.json +++ b/editors/code/package-lock.json @@ -738,9 +738,9 @@ } }, "node_modules/@eslint/config-array/node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "dev": true, "license": "MIT", "dependencies": { @@ -749,9 +749,9 @@ } }, "node_modules/@eslint/config-array/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "license": "ISC", "dependencies": { @@ -799,9 +799,9 @@ } }, "node_modules/@eslint/eslintrc/node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "dev": true, "license": "MIT", "dependencies": { @@ -810,9 +810,9 @@ } }, "node_modules/@eslint/eslintrc/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "license": "ISC", "dependencies": { @@ -934,29 +934,6 @@ "url": "https://github.com/sponsors/nzakas" } }, - "node_modules/@isaacs/balanced-match": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz", - "integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": "20 || >=22" - } - }, - "node_modules/@isaacs/brace-expansion": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.0.tgz", - "integrity": "sha512-ZT55BDLV0yv0RBm2czMiZ+SqCGO7AvmOM3G/w2xhVPH+te0aKgFjmBvGlL1dH+ql2tgGO3MVrbb3jCKyvpgnxA==", - "dev": true, - "license": "MIT", - "dependencies": { - "@isaacs/balanced-match": "^4.0.1" - }, - "engines": { - "node": "20 || >=22" - } - }, "node_modules/@isaacs/cliui": { "version": "8.0.2", "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", @@ -1486,7 +1463,6 @@ "integrity": "sha512-4gbs64bnbSzu4FpgMiQ1A+D+urxkoJk/kqlDJ2W//5SygaEiAP2B4GoS7TEdxgwol2el03gckFV9lJ4QOMiiHg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.25.0", "@typescript-eslint/types": "8.25.0", @@ -1841,9 +1817,9 @@ ] }, "node_modules/@vscode/vsce/node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "dev": true, "license": "MIT", "dependencies": { @@ -1852,9 +1828,9 @@ } }, "node_modules/@vscode/vsce/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "license": "ISC", "dependencies": { @@ -1870,7 +1846,6 @@ "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", "dev": true, "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -2085,9 +2060,9 @@ "license": "BSD-2-Clause" }, "node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz", + "integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==", "license": "MIT", "dependencies": { "balanced-match": "^1.0.0" @@ -2840,7 +2815,6 @@ "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", "license": "ISC", - "peer": true, "engines": { "node": ">=12" } @@ -3322,7 +3296,6 @@ "integrity": "sha512-KjeihdFqTPhOMXTt7StsDxriV4n66ueuF/jfPNC3j/lduHwr/ijDwJMsF+wyMJethgiKi5wniIE243vi07d3pg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.12.1", @@ -3443,9 +3416,9 @@ } }, "node_modules/eslint/node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "dev": true, "license": "MIT", "dependencies": { @@ -3467,9 +3440,9 @@ } }, "node_modules/eslint/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "license": "ISC", "dependencies": { @@ -3724,9 +3697,9 @@ } }, "node_modules/flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true, "license": "ISC" }, @@ -3911,17 +3884,40 @@ "node": ">=10.13.0" } }, + "node_modules/glob/node_modules/balanced-match": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz", + "integrity": "sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==", + "dev": true, + "license": "MIT", + "engines": { + "node": "18 || 20 || >=22" + } + }, + "node_modules/glob/node_modules/brace-expansion": { + "version": "5.0.5", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz", + "integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^4.0.2" + }, + "engines": { + "node": "18 || 20 || >=22" + } + }, "node_modules/glob/node_modules/minimatch": { - "version": "10.1.1", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz", - "integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.2.4.tgz", + "integrity": "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==", "dev": true, "license": "BlueOak-1.0.0", "dependencies": { - "@isaacs/brace-expansion": "^5.0.0" + "brace-expansion": "^5.0.2" }, "engines": { - "node": "20 || >=22" + "node": "18 || 20 || >=22" }, "funding": { "url": "https://github.com/sponsors/isaacs" @@ -4410,7 +4406,6 @@ "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.4.2.tgz", "integrity": "sha512-rg9zJN+G4n2nfJl5MW3BMygZX56zKPNVEYYqq7adpmMh4Jn2QNEwhvQlFy6jPVdcod7txZtKHWnyZiA3a0zP7A==", "license": "MIT", - "peer": true, "bin": { "jiti": "lib/jiti-cli.mjs" } @@ -4655,9 +4650,9 @@ } }, "node_modules/lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "version": "4.17.23", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "dev": true, "license": "MIT" }, @@ -4827,9 +4822,9 @@ } }, "node_modules/micromatch/node_modules/picomatch": { - "version": "2.3.1", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", - "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz", + "integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==", "dev": true, "license": "MIT", "engines": { @@ -4900,13 +4895,13 @@ } }, "node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "version": "9.0.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", + "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", "dev": true, "license": "ISC", "dependencies": { - "brace-expansion": "^2.0.1" + "brace-expansion": "^2.0.2" }, "engines": { "node": ">=16 || 14 >=14.17" @@ -5468,9 +5463,9 @@ "license": "ISC" }, "node_modules/picomatch": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", - "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", + "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "dev": true, "license": "MIT", "engines": { @@ -5584,9 +5579,9 @@ } }, "node_modules/qs": { - "version": "6.14.1", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz", - "integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==", + "version": "6.14.2", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz", + "integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==", "dev": true, "license": "BSD-3-Clause", "dependencies": { @@ -6678,7 +6673,6 @@ "integrity": "sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==", "dev": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -6725,9 +6719,9 @@ "license": "MIT" }, "node_modules/undici": { - "version": "6.21.3", - "resolved": "https://registry.npmjs.org/undici/-/undici-6.21.3.tgz", - "integrity": "sha512-gBLkYIlEnSp8pFbT64yFgGE6UIB9tAkhukC23PmMDCe5Nd+cRqKxSjw5y54MK2AZMgZfJWMaNE4nYUHgi1XEOw==", + "version": "6.24.1", + "resolved": "https://registry.npmjs.org/undici/-/undici-6.24.1.tgz", + "integrity": "sha512-sC+b0tB1whOCzbtlx20fx3WgCXwkW627p4EA9uM+/tNNPkSS+eSEld6pAs9nDv7WbY1UUljBMYPtu9BCOrCWKA==", "dev": true, "license": "MIT", "engines": { @@ -6846,9 +6840,9 @@ } }, "node_modules/vscode-languageclient/node_modules/minimatch": { - "version": "5.1.6", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", - "integrity": "sha512-lKwV/1brpG6mBUFHtb7NUmtABCb2WZZmm2wNiOA5hAb8VdCS4B3dtMWyvcoViccwAW/COERjXLt0zP1zXUN26g==", + "version": "5.1.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.9.tgz", + "integrity": "sha512-7o1wEA2RyMP7Iu7GNba9vc0RWWGACJOCZBJX2GJWip0ikV+wcOsgVuY9uE8CPiyQhkGFSlhuSkZPavN7u1c2Fw==", "license": "ISC", "dependencies": { "brace-expansion": "^2.0.1" diff --git a/editors/code/package.json b/editors/code/package.json index fc20597e88..1dd513c9de 100644 --- a/editors/code/package.json +++ b/editors/code/package.json @@ -2865,7 +2865,7 @@ "title": "Runnables", "properties": { "rust-analyzer.runnables.bench.overrideCommand": { - "markdownDescription": "Override the command used for bench runnables.\nThe first element of the array should be the program to execute (for example, `cargo`).\n\nUse the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically\nreplace the package name, target option (such as `--bin` or `--example`), the target name and\nthe test name (name of test function or test mod path).", + "markdownDescription": "Override the command used for bench runnables.\nThe first element of the array should be the program to execute (for example, `cargo`).\n\nUse the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically\nreplace the package name, target option (such as `--bin` or `--example`), the target name and\nthe arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`).", "default": null, "type": [ "null", @@ -2894,7 +2894,7 @@ "title": "Runnables", "properties": { "rust-analyzer.runnables.doctest.overrideCommand": { - "markdownDescription": "Override the command used for bench runnables.\nThe first element of the array should be the program to execute (for example, `cargo`).\n\nUse the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically\nreplace the package name, target option (such as `--bin` or `--example`), the target name and\nthe test name (name of test function or test mod path).", + "markdownDescription": "Override the command used for bench runnables.\nThe first element of the array should be the program to execute (for example, `cargo`).\n\nUse the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically\nreplace the package name, target option (such as `--bin` or `--example`), the target name and\nthe arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`).", "default": null, "type": [ "null", @@ -2948,7 +2948,7 @@ "title": "Runnables", "properties": { "rust-analyzer.runnables.test.overrideCommand": { - "markdownDescription": "Override the command used for test runnables.\nThe first element of the array should be the program to execute (for example, `cargo`).\n\nUse the placeholders `${package}`, `${target_arg}`, `${target}`, `${test_name}` to dynamically\nreplace the package name, target option (such as `--bin` or `--example`), the target name and\nthe test name (name of test function or test mod path).", + "markdownDescription": "Override the command used for test runnables.\nThe first element of the array should be the program to execute (for example, `cargo`).\n\nUse the placeholders `${package}`, `${target_arg}`, `${target}`, `${executable_args}` to dynamically\nreplace the package name, target option (such as `--bin` or `--example`), the target name and\nthe arguments passed to test binary args (includes `rust-analyzer.runnables.extraTestBinaryArgs`).", "default": null, "type": [ "null", diff --git a/lib/lsp-server/src/req_queue.rs b/lib/lsp-server/src/req_queue.rs index c216864bee..84748bbca8 100644 --- a/lib/lsp-server/src/req_queue.rs +++ b/lib/lsp-server/src/req_queue.rs @@ -18,6 +18,12 @@ impl<I, O> Default for ReqQueue<I, O> { } } +impl<I, O> ReqQueue<I, O> { + pub fn has_pending(&self) -> bool { + self.incoming.has_pending() || self.outgoing.has_pending() + } +} + #[derive(Debug)] pub struct Incoming<I> { pending: HashMap<RequestId, I>, @@ -51,6 +57,10 @@ impl<I> Incoming<I> { pub fn is_completed(&self, id: &RequestId) -> bool { !self.pending.contains_key(id) } + + pub fn has_pending(&self) -> bool { + !self.pending.is_empty() + } } impl<O> Outgoing<O> { @@ -64,4 +74,8 @@ impl<O> Outgoing<O> { pub fn complete(&mut self, id: RequestId) -> Option<O> { self.pending.remove(&id) } + + pub fn has_pending(&self) -> bool { + !self.pending.is_empty() + } } diff --git a/lib/smol_str/CHANGELOG.md b/lib/smol_str/CHANGELOG.md index 4aa25fa134..6327275d07 100644 --- a/lib/smol_str/CHANGELOG.md +++ b/lib/smol_str/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 0.3.6 - 2026-03-04 +- Fix the `borsh` feature. + ## 0.3.5 - 2026-01-08 - Optimise `SmolStr::clone` 4-5x speedup inline, 0.5x heap (slow down). diff --git a/lib/smol_str/Cargo.toml b/lib/smol_str/Cargo.toml index 4e7844b49e..22068fe841 100644 --- a/lib/smol_str/Cargo.toml +++ b/lib/smol_str/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "smol_str" -version = "0.3.5" +version = "0.3.6" description = "small-string optimized string type with O(1) clone" license = "MIT OR Apache-2.0" repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/smol_str" diff --git a/lib/smol_str/src/borsh.rs b/lib/smol_str/src/borsh.rs index 527ce85a17..44ae513ed4 100644 --- a/lib/smol_str/src/borsh.rs +++ b/lib/smol_str/src/borsh.rs @@ -16,7 +16,7 @@ impl BorshDeserialize for SmolStr { #[inline] fn deserialize_reader<R: Read>(reader: &mut R) -> borsh::io::Result<Self> { let len = u32::deserialize_reader(reader)?; - if (len as usize) < INLINE_CAP { + if (len as usize) <= INLINE_CAP { let mut buf = [0u8; INLINE_CAP]; reader.read_exact(&mut buf[..len as usize])?; _ = core::str::from_utf8(&buf[..len as usize]).map_err(|err| { diff --git a/lib/smol_str/src/lib.rs b/lib/smol_str/src/lib.rs index 0d1f01a32b..55ede286c2 100644 --- a/lib/smol_str/src/lib.rs +++ b/lib/smol_str/src/lib.rs @@ -34,13 +34,17 @@ use core::{ pub struct SmolStr(Repr); impl SmolStr { + /// The maximum byte length of a string that can be stored inline + /// without heap allocation. + pub const INLINE_CAP: usize = INLINE_CAP; + /// Constructs an inline variant of `SmolStr`. /// /// This never allocates. /// /// # Panics /// - /// Panics if `text.len() > 23`. + /// Panics if `text.len() > `[`SmolStr::INLINE_CAP`]. #[inline] pub const fn new_inline(text: &str) -> SmolStr { assert!(text.len() <= INLINE_CAP); // avoids bounds checks in loop @@ -100,6 +104,24 @@ impl SmolStr { pub const fn is_heap_allocated(&self) -> bool { matches!(self.0, Repr::Heap(..)) } + + /// Constructs a `SmolStr` from a byte slice, returning an error if the slice is not valid + /// UTF-8. + #[inline] + pub fn from_utf8(bytes: &[u8]) -> Result<SmolStr, core::str::Utf8Error> { + core::str::from_utf8(bytes).map(SmolStr::new) + } + + /// Constructs a `SmolStr` from a byte slice without checking that the bytes are valid UTF-8. + /// + /// # Safety + /// + /// `bytes` must be valid UTF-8. + #[inline] + pub unsafe fn from_utf8_unchecked(bytes: &[u8]) -> SmolStr { + // SAFETY: caller guarantees bytes are valid UTF-8 + SmolStr::new(unsafe { core::str::from_utf8_unchecked(bytes) }) + } } impl Clone for SmolStr { @@ -116,7 +138,10 @@ impl Clone for SmolStr { return cold_clone(self); } - // SAFETY: We verified that the payload of `Repr` is a POD + // SAFETY: The non-heap variants (`Repr::Inline` and `Repr::Static`) contain only + // `Copy` data (a `[u8; 23]` + `InlineSize` enum, or a `&'static str` fat pointer) + // and carry no drop glue, so a raw `ptr::read` bitwise copy is sound. + // The heap variant (`Repr::Heap`) is excluded above. unsafe { core::ptr::read(self as *const SmolStr) } } } @@ -142,7 +167,12 @@ impl ops::Deref for SmolStr { impl Eq for SmolStr {} impl PartialEq<SmolStr> for SmolStr { fn eq(&self, other: &SmolStr) -> bool { - self.0.ptr_eq(&other.0) || self.as_str() == other.as_str() + match (&self.0, &other.0) { + (Repr::Inline { len: l_len, buf: l_buf }, Repr::Inline { len: r_len, buf: r_buf }) => { + l_len == r_len && l_buf == r_buf + } + _ => self.as_str() == other.as_str(), + } } } @@ -215,6 +245,48 @@ impl PartialOrd for SmolStr { } } +impl PartialOrd<str> for SmolStr { + fn partial_cmp(&self, other: &str) -> Option<Ordering> { + Some(self.as_str().cmp(other)) + } +} + +impl<'a> PartialOrd<&'a str> for SmolStr { + fn partial_cmp(&self, other: &&'a str) -> Option<Ordering> { + Some(self.as_str().cmp(*other)) + } +} + +impl PartialOrd<SmolStr> for &str { + fn partial_cmp(&self, other: &SmolStr) -> Option<Ordering> { + Some((*self).cmp(other.as_str())) + } +} + +impl PartialOrd<String> for SmolStr { + fn partial_cmp(&self, other: &String) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + +impl PartialOrd<SmolStr> for String { + fn partial_cmp(&self, other: &SmolStr) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + +impl<'a> PartialOrd<&'a String> for SmolStr { + fn partial_cmp(&self, other: &&'a String) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + +impl PartialOrd<SmolStr> for &String { + fn partial_cmp(&self, other: &SmolStr) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + impl hash::Hash for SmolStr { fn hash<H: hash::Hasher>(&self, hasher: &mut H) { self.as_str().hash(hasher); @@ -359,6 +431,20 @@ impl AsRef<std::path::Path> for SmolStr { } } +impl From<char> for SmolStr { + #[inline] + fn from(c: char) -> SmolStr { + let mut buf = [0; INLINE_CAP]; + let len = c.len_utf8(); + c.encode_utf8(&mut buf); + SmolStr(Repr::Inline { + // SAFETY: A char is at most 4 bytes, which is always <= INLINE_CAP (23). + len: unsafe { InlineSize::transmute_from_u8(len as u8) }, + buf, + }) + } +} + impl From<&str> for SmolStr { #[inline] fn from(s: &str) -> SmolStr { @@ -483,11 +569,15 @@ enum InlineSize { } impl InlineSize { - /// SAFETY: `value` must be less than or equal to [`INLINE_CAP`] + /// # Safety + /// + /// `value` must be in the range `0..=23` (i.e. a valid `InlineSize` discriminant). + /// Values outside this range would produce an invalid enum discriminant, which is UB. #[inline(always)] const unsafe fn transmute_from_u8(value: u8) -> Self { debug_assert!(value <= InlineSize::_V23 as u8); - // SAFETY: The caller is responsible to uphold this invariant + // SAFETY: The caller guarantees `value` is a valid discriminant for this + // `#[repr(u8)]` enum (0..=23), so the transmute produces a valid `InlineSize`. unsafe { mem::transmute::<u8, Self>(value) } } } @@ -563,24 +653,15 @@ impl Repr { Repr::Static(data) => data, Repr::Inline { len, buf } => { let len = *len as usize; - // SAFETY: len is guaranteed to be <= INLINE_CAP + // SAFETY: `len` is an `InlineSize` discriminant (0..=23) which is always + // <= INLINE_CAP (23), so `..len` is always in bounds of `buf: [u8; 23]`. let buf = unsafe { buf.get_unchecked(..len) }; - // SAFETY: buf is guaranteed to be valid utf8 for ..len bytes + // SAFETY: All constructors that produce `Repr::Inline` copy from valid + // UTF-8 sources (`&str` or char encoding), so `buf[..len]` is valid UTF-8. unsafe { ::core::str::from_utf8_unchecked(buf) } } } } - - fn ptr_eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Heap(l0), Self::Heap(r0)) => Arc::ptr_eq(l0, r0), - (Self::Static(l0), Self::Static(r0)) => core::ptr::eq(l0, r0), - (Self::Inline { len: l_len, buf: l_buf }, Self::Inline { len: r_len, buf: r_buf }) => { - l_len == r_len && l_buf == r_buf - } - _ => false, - } - } } /// Convert value to [`SmolStr`] using [`fmt::Display`], potentially without allocating. @@ -666,7 +747,7 @@ impl StrExt for str { buf[..len].copy_from_slice(self.as_bytes()); buf[..len].make_ascii_lowercase(); SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `len` is guarded to be <= INLINE_CAP (23), a valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(len as u8) }, buf, }) @@ -683,7 +764,7 @@ impl StrExt for str { buf[..len].copy_from_slice(self.as_bytes()); buf[..len].make_ascii_uppercase(); SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `len` is guarded to be <= INLINE_CAP (23), a valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(len as u8) }, buf, }) @@ -703,8 +784,11 @@ impl StrExt for str { if let [from_u8] = from.as_bytes() && let [to_u8] = to.as_bytes() { + // SAFETY: `from` and `to` are single-byte `&str`s. In valid UTF-8, a single-byte + // code unit is always in the range 0x00..=0x7F (i.e. ASCII). The closure only + // replaces the matching ASCII byte with another ASCII byte, and returns all + // other bytes unchanged, so UTF-8 validity is preserved. return if self.len() <= count { - // SAFETY: `from_u8` & `to_u8` are ascii unsafe { replacen_1_ascii(self, |b| if b == from_u8 { *to_u8 } else { *b }) } } else { unsafe { @@ -736,7 +820,11 @@ impl StrExt for str { } } -/// SAFETY: `map` fn must only replace ascii with ascii or return unchanged bytes. +/// # Safety +/// +/// `map` must satisfy: for every byte `b` in `src`, if `b <= 0x7F` (ASCII) then `map(b)` must +/// also be `<= 0x7F` (ASCII). If `b > 0x7F` (part of a multi-byte UTF-8 sequence), `map` must +/// return `b` unchanged. This ensures the output is valid UTF-8 whenever the input is. #[inline] unsafe fn replacen_1_ascii(src: &str, mut map: impl FnMut(&u8) -> u8) -> SmolStr { if src.len() <= INLINE_CAP { @@ -745,13 +833,16 @@ unsafe fn replacen_1_ascii(src: &str, mut map: impl FnMut(&u8) -> u8) -> SmolStr buf[idx] = map(b); } SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `src` is a `&str` so `src.len()` <= INLINE_CAP <= 23, which is a + // valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(src.len() as u8) }, buf, }) } else { let out = src.as_bytes().iter().map(map).collect(); - // SAFETY: We replaced ascii with ascii on valid utf8 strings. + // SAFETY: The caller guarantees `map` only substitutes ASCII bytes with ASCII + // bytes and leaves multi-byte UTF-8 continuation bytes untouched, so the + // output byte sequence is valid UTF-8. unsafe { String::from_utf8_unchecked(out).into() } } } @@ -773,9 +864,11 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C let mut is_ascii = [false; N]; while slice.len() >= N { - // SAFETY: checked in loop condition + // SAFETY: The loop condition guarantees `slice.len() >= N`, so `..N` is in bounds. let chunk = unsafe { slice.get_unchecked(..N) }; - // SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets + // SAFETY: `out_slice` starts with the same length as `slice` (both derived from + // `s.len()`) and both are advanced by the same offset `N` each iteration, so + // `out_slice.len() >= N` holds whenever `slice.len() >= N`. let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) }; for j in 0..N { @@ -794,6 +887,7 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C out_chunk[j] = convert(&chunk[j]); } + // SAFETY: Same reasoning as above — both slices have len >= N at this point. slice = unsafe { slice.get_unchecked(N..) }; out_slice = unsafe { out_slice.get_unchecked_mut(N..) }; } @@ -804,7 +898,9 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C if byte > 127 { break; } - // SAFETY: out_slice has at least same length as input slice + // SAFETY: `out_slice` is always the same length as `slice` (both start equal and + // are advanced by 1 together), and `slice` is non-empty per the loop condition, + // so index 0 and `1..` are in bounds for both. unsafe { *out_slice.get_unchecked_mut(0) = convert(&byte); } @@ -813,8 +909,10 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C } unsafe { - // SAFETY: we know this is a valid char boundary - // since we only skipped over leading ascii bytes + // SAFETY: We only advanced past bytes that satisfy `b <= 127`, i.e. ASCII bytes. + // In UTF-8, ASCII bytes (0x00..=0x7F) are always single-byte code points and + // never appear as continuation bytes, so the remaining `slice` starts at a valid + // UTF-8 char boundary. let rest = core::str::from_utf8_unchecked(slice); (out, rest) } @@ -850,10 +948,18 @@ macro_rules! format_smolstr { /// A builder that can be used to efficiently build a [`SmolStr`]. /// /// This won't allocate if the final string fits into the inline buffer. -#[derive(Clone, Default, Debug, PartialEq, Eq)] +#[derive(Clone, Default, Debug)] pub struct SmolStrBuilder(SmolStrBuilderRepr); -#[derive(Clone, Debug, PartialEq, Eq)] +impl PartialEq for SmolStrBuilder { + fn eq(&self, other: &Self) -> bool { + self.as_str() == other.as_str() + } +} + +impl Eq for SmolStrBuilder {} + +#[derive(Clone, Debug)] enum SmolStrBuilderRepr { Inline { len: usize, buf: [u8; INLINE_CAP] }, Heap(String), @@ -873,11 +979,57 @@ impl SmolStrBuilder { Self(SmolStrBuilderRepr::Inline { buf: [0; INLINE_CAP], len: 0 }) } + /// Creates a new empty [`SmolStrBuilder`] with at least the specified capacity. + /// + /// If `capacity` is less than or equal to [`SmolStr::INLINE_CAP`], the builder + /// will use inline storage and not allocate. Otherwise, it will pre-allocate a + /// heap buffer of the requested capacity. + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + if capacity <= INLINE_CAP { + Self::new() + } else { + Self(SmolStrBuilderRepr::Heap(String::with_capacity(capacity))) + } + } + + /// Returns the number of bytes accumulated in the builder so far. + #[inline] + pub fn len(&self) -> usize { + match &self.0 { + SmolStrBuilderRepr::Inline { len, .. } => *len, + SmolStrBuilderRepr::Heap(heap) => heap.len(), + } + } + + /// Returns `true` if the builder has a length of zero bytes. + #[inline] + pub fn is_empty(&self) -> bool { + match &self.0 { + SmolStrBuilderRepr::Inline { len, .. } => *len == 0, + SmolStrBuilderRepr::Heap(heap) => heap.is_empty(), + } + } + + /// Returns a `&str` slice of the builder's current contents. + #[inline] + pub fn as_str(&self) -> &str { + match &self.0 { + SmolStrBuilderRepr::Inline { len, buf } => { + // SAFETY: `buf[..*len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8, and `*len <= INLINE_CAP` is maintained + // by the inline branch logic. + unsafe { core::str::from_utf8_unchecked(&buf[..*len]) } + } + SmolStrBuilderRepr::Heap(heap) => heap.as_str(), + } + } + /// Builds a [`SmolStr`] from `self`. #[must_use] - pub fn finish(&self) -> SmolStr { - SmolStr(match &self.0 { - &SmolStrBuilderRepr::Inline { len, buf } => { + pub fn finish(self) -> SmolStr { + SmolStr(match self.0 { + SmolStrBuilderRepr::Inline { len, buf } => { debug_assert!(len <= INLINE_CAP); Repr::Inline { // SAFETY: We know that `value.len` is less than or equal to the maximum value of `InlineSize` @@ -885,7 +1037,7 @@ impl SmolStrBuilder { buf, } } - SmolStrBuilderRepr::Heap(heap) => Repr::new(heap), + SmolStrBuilderRepr::Heap(heap) => Repr::new(&heap), }) } @@ -900,8 +1052,10 @@ impl SmolStrBuilder { *len += char_len; } else { let mut heap = String::with_capacity(new_len); - // copy existing inline bytes over to the heap - // SAFETY: inline data is guaranteed to be valid utf8 for `old_len` bytes + // SAFETY: `buf[..*len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8 (from `char::encode_utf8` or `&str` + // byte copies), so extending the Vec with these bytes preserves the + // String's UTF-8 invariant. unsafe { heap.as_mut_vec().extend_from_slice(&buf[..*len]) }; heap.push(c); self.0 = SmolStrBuilderRepr::Heap(heap); @@ -926,8 +1080,10 @@ impl SmolStrBuilder { let mut heap = String::with_capacity(*len); - // copy existing inline bytes over to the heap - // SAFETY: inline data is guaranteed to be valid utf8 for `old_len` bytes + // SAFETY: `buf[..old_len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8 (from `char::encode_utf8` or `&str` byte + // copies), so extending the Vec with these bytes preserves the String's + // UTF-8 invariant. unsafe { heap.as_mut_vec().extend_from_slice(&buf[..old_len]) }; heap.push_str(s); self.0 = SmolStrBuilderRepr::Heap(heap); @@ -945,6 +1101,30 @@ impl fmt::Write for SmolStrBuilder { } } +impl iter::Extend<char> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = char>>(&mut self, iter: I) { + for c in iter { + self.push(c); + } + } +} + +impl<'a> iter::Extend<&'a str> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = &'a str>>(&mut self, iter: I) { + for s in iter { + self.push_str(s); + } + } +} + +impl<'a> iter::Extend<&'a String> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = &'a String>>(&mut self, iter: I) { + for s in iter { + self.push_str(s); + } + } +} + impl From<SmolStrBuilder> for SmolStr { fn from(value: SmolStrBuilder) -> Self { value.finish() diff --git a/lib/smol_str/src/serde.rs b/lib/smol_str/src/serde.rs index 66cbcd3bad..9d82d64805 100644 --- a/lib/smol_str/src/serde.rs +++ b/lib/smol_str/src/serde.rs @@ -16,7 +16,7 @@ where impl<'a> Visitor<'a> for SmolStrVisitor { type Value = SmolStr; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str("a string") } diff --git a/lib/smol_str/tests/test.rs b/lib/smol_str/tests/test.rs index 640e7df681..83648edeec 100644 --- a/lib/smol_str/tests/test.rs +++ b/lib/smol_str/tests/test.rs @@ -10,6 +10,7 @@ use smol_str::{SmolStr, SmolStrBuilder}; #[cfg(target_pointer_width = "64")] fn smol_str_is_smol() { assert_eq!(::std::mem::size_of::<SmolStr>(), ::std::mem::size_of::<String>(),); + assert_eq!(::std::mem::size_of::<Option<SmolStr>>(), ::std::mem::size_of::<SmolStr>(),); } #[test] @@ -332,6 +333,29 @@ fn test_builder_push() { assert_eq!("a".repeat(24), s); } +#[test] +fn test_from_char() { + // ASCII char + let s: SmolStr = 'a'.into(); + assert_eq!(s, "a"); + assert!(!s.is_heap_allocated()); + + // Multi-byte char (2 bytes) + let s: SmolStr = SmolStr::from('ñ'); + assert_eq!(s, "ñ"); + assert!(!s.is_heap_allocated()); + + // 3-byte char + let s: SmolStr = '€'.into(); + assert_eq!(s, "€"); + assert!(!s.is_heap_allocated()); + + // 4-byte char (emoji) + let s: SmolStr = '🦀'.into(); + assert_eq!(s, "🦀"); + assert!(!s.is_heap_allocated()); +} + #[cfg(test)] mod test_str_ext { use smol_str::StrExt; @@ -393,7 +417,7 @@ mod test_str_ext { } } -#[cfg(feature = "borsh")] +#[cfg(all(feature = "borsh", feature = "std"))] mod borsh_tests { use borsh::BorshDeserialize; use smol_str::{SmolStr, ToSmolStr}; diff --git a/rust-version b/rust-version index a1011c4a0a..68f38716db 100644 --- a/rust-version +++ b/rust-version @@ -1 +1 @@ -ba284f468cd2cda48420251efc991758ec13d450 +1174f784096deb8e4ba93f7e4b5ccb7bb4ba2c55 diff --git a/triagebot.toml b/triagebot.toml index ac4efd0a24..5fd97f52d8 100644 --- a/triagebot.toml +++ b/triagebot.toml @@ -24,4 +24,5 @@ labels = ["has-merge-commits", "S-waiting-on-author"] [transfer] # Canonicalize issue numbers to avoid closing the wrong issue when upstreaming this subtree -[canonicalize-issue-links] +[issue-links] +check-commits = "uncanonicalized" diff --git a/xtask/src/codegen/grammar/ast_src.rs b/xtask/src/codegen/grammar/ast_src.rs index b9f570fe0e..564d9cc24e 100644 --- a/xtask/src/codegen/grammar/ast_src.rs +++ b/xtask/src/codegen/grammar/ast_src.rs @@ -112,7 +112,7 @@ const RESERVED: &[&str] = &[ // keywords that are keywords only in specific parse contexts #[doc(alias = "WEAK_KEYWORDS")] const CONTEXTUAL_KEYWORDS: &[&str] = - &["macro_rules", "union", "default", "raw", "dyn", "auto", "yeet", "safe"]; + &["macro_rules", "union", "default", "raw", "dyn", "auto", "yeet", "safe", "bikeshed"]; // keywords we use for special macro expansions const CONTEXTUAL_BUILTIN_KEYWORDS: &[&str] = &[ "asm", @@ -189,9 +189,9 @@ pub(crate) fn generate_kind_src( } } }); - PUNCT.iter().zip(used_puncts).filter(|(_, used)| !used).for_each(|((punct, _), _)| { + if let Some(punct) = PUNCT.iter().zip(used_puncts).find(|(_, used)| !used) { panic!("Punctuation {punct:?} is not used in grammar"); - }); + } keywords.extend(RESERVED.iter().copied()); keywords.sort(); keywords.dedup(); |