Unnamed repository; edit this file 'description' to name the repository.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use criterion::{Criterion, criterion_group, criterion_main};
use rand::distr::{Alphanumeric, SampleString};
use smol_str::{SmolStr, StrExt, ToSmolStr, format_smolstr};
use std::hint::black_box;

/// 12: small (inline)
/// 50: medium (heap)
/// 1000: large (heap)
const TEST_LENS: [usize; 3] = [12, 50, 1000];

fn format_bench(c: &mut Criterion) {
    for len in TEST_LENS {
        let n = rand::random_range(10000..99999);
        let str_len = len.checked_sub(n.to_smolstr().len()).unwrap();
        let str = Alphanumeric.sample_string(&mut rand::rng(), str_len);

        c.bench_function(&format!("format_smolstr! len={len}"), |b| {
            let mut v = <_>::default();
            b.iter(|| v = format_smolstr!("{str}-{n}"));
            assert_eq!(v, format!("{str}-{n}"));
        });
    }
}

fn from_str_bench(c: &mut Criterion) {
    for len in TEST_LENS {
        let str = Alphanumeric.sample_string(&mut rand::rng(), len);

        c.bench_function(&format!("SmolStr::from len={len}"), |b| {
            let mut v = <_>::default();
            b.iter(|| v = SmolStr::from(black_box(&str)));
            assert_eq!(v, str);
        });
    }
}

fn clone_bench(c: &mut Criterion) {
    for len in TEST_LENS {
        let str = Alphanumeric.sample_string(&mut rand::rng(), len);
        let smolstr = SmolStr::new(&str);

        c.bench_function(&format!("SmolStr::clone len={len}"), |b| {
            let mut v = <_>::default();
            b.iter(|| v = smolstr.clone());
            assert_eq!(v, str);
        });
    }
}

fn eq_bench(c: &mut Criterion) {
    for len in TEST_LENS {
        let str = Alphanumeric.sample_string(&mut rand::rng(), len);
        let smolstr = SmolStr::new(&str);

        c.bench_function(&format!("SmolStr::eq len={len}"), |b| {
            let mut v = false;
            b.iter(|| v = smolstr == black_box(&str));
            assert!(v);
        });
    }
}

fn to_lowercase_bench(c: &mut Criterion) {
    const END_CHAR: char = 'İ';

    for len in TEST_LENS {
        // mostly ascii seq with some non-ascii at the end
        let mut str = Alphanumeric.sample_string(&mut rand::rng(), len - END_CHAR.len_utf8());
        str.push(END_CHAR);
        let str = str.as_str();

        c.bench_function(&format!("to_lowercase_smolstr len={len}"), |b| {
            let mut v = <_>::default();
            b.iter(|| v = str.to_lowercase_smolstr());
            assert_eq!(v, str.to_lowercase());
        });
    }
}

fn to_ascii_lowercase_bench(c: &mut Criterion) {
    for len in TEST_LENS {
        let str = Alphanumeric.sample_string(&mut rand::rng(), len);
        let str = str.as_str();

        c.bench_function(&format!("to_ascii_lowercase_smolstr len={len}"), |b| {
            let mut v = <_>::default();
            b.iter(|| v = str.to_ascii_lowercase_smolstr());
            assert_eq!(v, str.to_ascii_lowercase());
        });
    }
}

fn replace_bench(c: &mut Criterion) {
    for len in TEST_LENS {
        let s_dash_s = Alphanumeric.sample_string(&mut rand::rng(), len / 2)
            + "-"
            + &Alphanumeric.sample_string(&mut rand::rng(), len - 1 - len / 2);
        let str = s_dash_s.as_str();

        c.bench_function(&format!("replace_smolstr len={len}"), |b| {
            let mut v = <_>::default();
            b.iter(|| v = str.replace_smolstr("-", "_"));
            assert_eq!(v, str.replace("-", "_"));
        });
    }
}

criterion_group!(
    benches,
    format_bench,
    from_str_bench,
    clone_bench,
    eq_bench,
    to_lowercase_bench,
    to_ascii_lowercase_bench,
    replace_bench,
);
criterion_main!(benches);
1'>271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
//! Definition of `SolverDefId`

use hir_def::{
    AdtId, AttrDefId, BuiltinDeriveImplId, CallableDefId, ConstId, DefWithBodyId, EnumId,
    EnumVariantId, FunctionId, GeneralConstId, GenericDefId, ImplId, StaticId, StructId, TraitId,
    TypeAliasId, UnionId,
};
use rustc_type_ir::inherent;
use stdx::impl_from;

use crate::db::{InternedClosureId, InternedCoroutineId, InternedOpaqueTyId};

use super::DbInterner;

#[derive(Debug, PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash, salsa::Supertype)]
pub enum Ctor {
    Struct(StructId),
    Enum(EnumVariantId),
}

#[derive(PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash, salsa::Supertype)]
pub enum SolverDefId {
    AdtId(AdtId),
    ConstId(ConstId),
    FunctionId(FunctionId),
    ImplId(ImplId),
    BuiltinDeriveImplId(BuiltinDeriveImplId),
    StaticId(StaticId),
    TraitId(TraitId),
    TypeAliasId(TypeAliasId),
    InternedClosureId(InternedClosureId),
    InternedCoroutineId(InternedCoroutineId),
    InternedOpaqueTyId(InternedOpaqueTyId),
    EnumVariantId(EnumVariantId),
    // FIXME(next-solver): Do we need the separation of `Ctor`? It duplicates some variants.
    Ctor(Ctor),
}

impl std::fmt::Debug for SolverDefId {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let interner = DbInterner::conjure();
        let db = interner.db;
        match *self {
            SolverDefId::AdtId(AdtId::StructId(id)) => {
                f.debug_tuple("AdtId").field(&db.struct_signature(id).name.as_str()).finish()
            }
            SolverDefId::AdtId(AdtId::EnumId(id)) => {
                f.debug_tuple("AdtId").field(&db.enum_signature(id).name.as_str()).finish()
            }
            SolverDefId::AdtId(AdtId::UnionId(id)) => {
                f.debug_tuple("AdtId").field(&db.union_signature(id).name.as_str()).finish()
            }
            SolverDefId::ConstId(id) => f
                .debug_tuple("ConstId")
                .field(&db.const_signature(id).name.as_ref().map_or("_", |name| name.as_str()))
                .finish(),
            SolverDefId::FunctionId(id) => {
                f.debug_tuple("FunctionId").field(&db.function_signature(id).name.as_str()).finish()
            }
            SolverDefId::ImplId(id) => f.debug_tuple("ImplId").field(&id).finish(),
            SolverDefId::BuiltinDeriveImplId(id) => f.debug_tuple("ImplId").field(&id).finish(),
            SolverDefId::StaticId(id) => {
                f.debug_tuple("StaticId").field(&db.static_signature(id).name.as_str()).finish()
            }
            SolverDefId::TraitId(id) => {
                f.debug_tuple("TraitId").field(&db.trait_signature(id).name.as_str()).finish()
            }
            SolverDefId::TypeAliasId(id) => f
                .debug_tuple("TypeAliasId")
                .field(&db.type_alias_signature(id).name.as_str())
                .finish(),
            SolverDefId::InternedClosureId(id) => {
                f.debug_tuple("InternedClosureId").field(&id).finish()
            }
            SolverDefId::InternedCoroutineId(id) => {
                f.debug_tuple("InternedCoroutineId").field(&id).finish()
            }
            SolverDefId::InternedOpaqueTyId(id) => {
                f.debug_tuple("InternedOpaqueTyId").field(&id).finish()
            }
            SolverDefId::EnumVariantId(id) => {
                let parent_enum = id.loc(db).parent;
                f.debug_tuple("EnumVariantId")
                    .field(&format_args!(
                        "\"{}::{}\"",
                        db.enum_signature(parent_enum).name.as_str(),
                        parent_enum.enum_variants(db).variant_name_by_id(id).unwrap().as_str()
                    ))
                    .finish()
            }
            SolverDefId::Ctor(Ctor::Struct(id)) => {
                f.debug_tuple("Ctor").field(&db.struct_signature(id).name.as_str()).finish()
            }
            SolverDefId::Ctor(Ctor::Enum(id)) => {
                let parent_enum = id.loc(db).parent;
                f.debug_tuple("Ctor")
                    .field(&format_args!(
                        "\"{}::{}\"",
                        db.enum_signature(parent_enum).name.as_str(),
                        parent_enum.enum_variants(db).variant_name_by_id(id).unwrap().as_str()
                    ))
                    .finish()
            }
        }
    }
}

impl_from!(
    AdtId(StructId, EnumId, UnionId),
    ConstId,
    FunctionId,
    ImplId,
    BuiltinDeriveImplId,
    StaticId,
    TraitId,
    TypeAliasId,
    InternedClosureId,
    InternedCoroutineId,
    InternedOpaqueTyId,
    EnumVariantId,
    Ctor
    for SolverDefId
);

impl From<GenericDefId> for SolverDefId {
    fn from(value: GenericDefId) -> Self {
        match value {
            GenericDefId::AdtId(adt_id) => SolverDefId::AdtId(adt_id),
            GenericDefId::ConstId(const_id) => SolverDefId::ConstId(const_id),
            GenericDefId::FunctionId(function_id) => SolverDefId::FunctionId(function_id),
            GenericDefId::ImplId(impl_id) => SolverDefId::ImplId(impl_id),
            GenericDefId::StaticId(static_id) => SolverDefId::StaticId(static_id),
            GenericDefId::TraitId(trait_id) => SolverDefId::TraitId(trait_id),
            GenericDefId::TypeAliasId(type_alias_id) => SolverDefId::TypeAliasId(type_alias_id),
        }
    }
}

impl From<GeneralConstId> for SolverDefId {
    #[inline]
    fn from(value: GeneralConstId) -> Self {
        match value {
            GeneralConstId::ConstId(const_id) => SolverDefId::ConstId(const_id),
            GeneralConstId::StaticId(static_id) => SolverDefId::StaticId(static_id),
        }
    }
}

impl From<DefWithBodyId> for SolverDefId {
    #[inline]
    fn from(value: DefWithBodyId) -> Self {
        match value {
            DefWithBodyId::FunctionId(id) => id.into(),
            DefWithBodyId::StaticId(id) => id.into(),
            DefWithBodyId::ConstId(id) => id.into(),
            DefWithBodyId::VariantId(id) => id.into(),
        }
    }
}

impl TryFrom<SolverDefId> for AttrDefId {
    type Error = ();
    #[inline]
    fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
        match value {
            SolverDefId::AdtId(it) => Ok(it.into()),
            SolverDefId::ConstId(it) => Ok(it.into()),
            SolverDefId::FunctionId(it) => Ok(it.into()),
            SolverDefId::ImplId(it) => Ok(it.into()),
            SolverDefId::StaticId(it) => Ok(it.into()),
            SolverDefId::TraitId(it) => Ok(it.into()),
            SolverDefId::TypeAliasId(it) => Ok(it.into()),
            SolverDefId::EnumVariantId(it) => Ok(it.into()),
            SolverDefId::Ctor(Ctor::Struct(it)) => Ok(it.into()),
            SolverDefId::Ctor(Ctor::Enum(it)) => Ok(it.into()),
            SolverDefId::BuiltinDeriveImplId(_)
            | SolverDefId::InternedClosureId(_)
            | SolverDefId::InternedCoroutineId(_)
            | SolverDefId::InternedOpaqueTyId(_) => Err(()),
        }
    }
}

impl TryFrom<SolverDefId> for DefWithBodyId {
    type Error = ();

    #[inline]
    fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
        let id = match value {
            SolverDefId::ConstId(id) => id.into(),
            SolverDefId::FunctionId(id) => id.into(),
            SolverDefId::StaticId(id) => id.into(),
            SolverDefId::EnumVariantId(id) | SolverDefId::Ctor(Ctor::Enum(id)) => id.into(),
            SolverDefId::InternedOpaqueTyId(_)
            | SolverDefId::TraitId(_)
            | SolverDefId::TypeAliasId(_)
            | SolverDefId::ImplId(_)
            | SolverDefId::BuiltinDeriveImplId(_)
            | SolverDefId::InternedClosureId(_)
            | SolverDefId::InternedCoroutineId(_)
            | SolverDefId::Ctor(Ctor::Struct(_))
            | SolverDefId::AdtId(_) => return Err(()),
        };
        Ok(id)
    }
}

impl TryFrom<SolverDefId> for GenericDefId {
    type Error = ();

    fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
        Ok(match value {
            SolverDefId::AdtId(adt_id) => GenericDefId::AdtId(adt_id),
            SolverDefId::ConstId(const_id) => GenericDefId::ConstId(const_id),
            SolverDefId::FunctionId(function_id) => GenericDefId::FunctionId(function_id),
            SolverDefId::ImplId(impl_id) => GenericDefId::ImplId(impl_id),
            SolverDefId::StaticId(static_id) => GenericDefId::StaticId(static_id),
            SolverDefId::TraitId(trait_id) => GenericDefId::TraitId(trait_id),
            SolverDefId::TypeAliasId(type_alias_id) => GenericDefId::TypeAliasId(type_alias_id),
            SolverDefId::InternedClosureId(_)
            | SolverDefId::InternedCoroutineId(_)
            | SolverDefId::InternedOpaqueTyId(_)
            | SolverDefId::EnumVariantId(_)
            | SolverDefId::BuiltinDeriveImplId(_)
            | SolverDefId::Ctor(_) => return Err(()),
        })
    }
}

impl SolverDefId {
    #[inline]
    #[track_caller]
    pub fn expect_opaque_ty(self) -> InternedOpaqueTyId {
        match self {
            SolverDefId::InternedOpaqueTyId(it) => it,
            _ => panic!("expected opaque type, found {self:?}"),
        }
    }

    #[inline]
    #[track_caller]
    pub fn expect_type_alias(self) -> TypeAliasId {
        match self {
            SolverDefId::TypeAliasId(it) => it,
            _ => panic!("expected type alias, found {self:?}"),
        }
    }
}

impl<'db> inherent::DefId<DbInterner<'db>> for SolverDefId {
    fn as_local(self) -> Option<SolverDefId> {
        Some(self)
    }
    fn is_local(self) -> bool {
        true
    }
}

macro_rules! declare_id_wrapper {
    ($name:ident, $wraps:ident) => {
        #[derive(Clone, Copy, PartialEq, Eq, Hash)]
        pub struct $name(pub $wraps);

        impl std::fmt::Debug for $name {
            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                std::fmt::Debug::fmt(&SolverDefId::from(self.0), f)
            }
        }

        impl From<$name> for $wraps {
            #[inline]
            fn from(value: $name) -> $wraps {
                value.0
            }
        }

        impl From<$wraps> for $name {
            #[inline]
            fn from(value: $wraps) -> $name {
                Self(value)
            }
        }

        impl From<$name> for SolverDefId {
            #[inline]
            fn from(value: $name) -> SolverDefId {
                value.0.into()
            }
        }

        impl TryFrom<SolverDefId> for $name {
            type Error = ();

            #[inline]
            fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
                match value {
                    SolverDefId::$wraps(it) => Ok(Self(it)),
                    _ => Err(()),
                }
            }
        }

        impl<'db> inherent::DefId<DbInterner<'db>> for $name {
            fn as_local(self) -> Option<SolverDefId> {
                Some(self.into())
            }
            fn is_local(self) -> bool {
                true
            }
        }
    };
}

declare_id_wrapper!(TraitIdWrapper, TraitId);
declare_id_wrapper!(TypeAliasIdWrapper, TypeAliasId);
declare_id_wrapper!(ClosureIdWrapper, InternedClosureId);
declare_id_wrapper!(CoroutineIdWrapper, InternedCoroutineId);
declare_id_wrapper!(AdtIdWrapper, AdtId);

#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct GeneralConstIdWrapper(pub GeneralConstId);

impl std::fmt::Debug for GeneralConstIdWrapper {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        std::fmt::Debug::fmt(&self.0, f)
    }
}
impl From<GeneralConstIdWrapper> for GeneralConstId {
    #[inline]
    fn from(value: GeneralConstIdWrapper) -> GeneralConstId {
        value.0
    }
}
impl From<GeneralConstId> for GeneralConstIdWrapper {
    #[inline]
    fn from(value: GeneralConstId) -> GeneralConstIdWrapper {
        Self(value)
    }
}
impl From<GeneralConstIdWrapper> for SolverDefId {
    #[inline]
    fn from(value: GeneralConstIdWrapper) -> SolverDefId {
        match value.0 {
            GeneralConstId::ConstId(id) => SolverDefId::ConstId(id),
            GeneralConstId::StaticId(id) => SolverDefId::StaticId(id),
        }
    }
}
impl TryFrom<SolverDefId> for GeneralConstIdWrapper {
    type Error = ();
    #[inline]
    fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
        match value {
            SolverDefId::ConstId(it) => Ok(Self(it.into())),
            SolverDefId::StaticId(it) => Ok(Self(it.into())),
            _ => Err(()),
        }
    }
}
impl<'db> inherent::DefId<DbInterner<'db>> for GeneralConstIdWrapper {
    fn as_local(self) -> Option<SolverDefId> {
        Some(self.into())
    }
    fn is_local(self) -> bool {
        true
    }
}

#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct CallableIdWrapper(pub CallableDefId);

impl std::fmt::Debug for CallableIdWrapper {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        std::fmt::Debug::fmt(&self.0, f)
    }
}
impl From<CallableIdWrapper> for CallableDefId {
    #[inline]
    fn from(value: CallableIdWrapper) -> CallableDefId {
        value.0
    }
}
impl From<CallableDefId> for CallableIdWrapper {
    #[inline]
    fn from(value: CallableDefId) -> CallableIdWrapper {
        Self(value)
    }
}
impl From<CallableIdWrapper> for SolverDefId {
    #[inline]
    fn from(value: CallableIdWrapper) -> SolverDefId {
        match value.0 {
            CallableDefId::FunctionId(it) => it.into(),
            CallableDefId::StructId(it) => Ctor::Struct(it).into(),
            CallableDefId::EnumVariantId(it) => Ctor::Enum(it).into(),
        }
    }
}
impl TryFrom<SolverDefId> for CallableIdWrapper {
    type Error = ();
    #[inline]
    fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
        match value {
            SolverDefId::FunctionId(it) => Ok(Self(it.into())),
            SolverDefId::Ctor(Ctor::Struct(it)) => Ok(Self(it.into())),
            SolverDefId::Ctor(Ctor::Enum(it)) => Ok(Self(it.into())),
            _ => Err(()),
        }
    }
}
impl<'db> inherent::DefId<DbInterner<'db>> for CallableIdWrapper {
    fn as_local(self) -> Option<SolverDefId> {
        Some(self.into())
    }
    fn is_local(self) -> bool {
        true
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AnyImplId {
    ImplId(ImplId),
    BuiltinDeriveImplId(BuiltinDeriveImplId),
}

impl_from!(ImplId, BuiltinDeriveImplId for AnyImplId);

impl From<AnyImplId> for SolverDefId {
    #[inline]
    fn from(value: AnyImplId) -> SolverDefId {
        match value {
            AnyImplId::ImplId(it) => it.into(),
            AnyImplId::BuiltinDeriveImplId(it) => it.into(),
        }
    }
}
impl TryFrom<SolverDefId> for AnyImplId {
    type Error = ();
    #[inline]
    fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
        match value {
            SolverDefId::ImplId(it) => Ok(it.into()),
            SolverDefId::BuiltinDeriveImplId(it) => Ok(it.into()),
            _ => Err(()),
        }
    }
}
impl<'db> inherent::DefId<DbInterner<'db>> for AnyImplId {
    fn as_local(self) -> Option<SolverDefId> {
        Some(self.into())
    }
    fn is_local(self) -> bool {
        true
    }
}