1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
use core::any::{type_name, Any};

pub use any_hint::AnyHint;
pub use any_visit::AnyVisit;
pub use id::ProtocolId;

use crate::{ControlFlow, Visitor, WalkerHints};

/// A protocol between a walker and visitor.
///
/// On the walker side this takes the form of hints a visitor can give.
/// On the visitor side this takes the form of visits a walker can inject values into.
///
/// When a visitor hints a walker should use a particular protocol, its expected
/// that the walker visits using that protocol.
///
/// A protocol never needs to be a value, so it's recommended to use an uninhabited type
/// like an empty enum to represent them.
pub trait Protocol: Any {
    /// Arbitrary hint metadata for the protocol.
    ///
    /// This allows a visitor to give extra information to a walker when hinting to
    /// use the protocol.
    type Hint<'ctx>;

    /// Data known about the protocol before hinting.
    ///
    /// This allows a walker to give extra information to a visitor to make a
    /// better decision when selecting a hint.
    type Known<'ctx>;

    /// The visit data the walker provides to the visitor.
    ///
    /// This may be actual data or another walker for a part of the bigger value.
    /// The '`walking` lifetime is only alive while the walker is walking.
    /// As such, a visitor cannot borrow from a `'walking` lifetime containing type
    /// for it's output.
    type Accessor<'walking, 'ctx: 'walking>;
}

/// Protocol specific hint for a walker.
///
/// A walker can implement any number of these for any protocols it supports.
pub trait Hint<'ctx, P: Protocol> {
    /// Hint that protocol `P` should be used.
    ///
    /// After hinting a protocol, a walker should invoke only the same protocol on the visitor.
    /// This is not forced though.
    fn hint(&mut self, visitor: &mut dyn Visitor<'ctx>, hint: P::Hint<'ctx>) -> ControlFlow;

    /// Any information the walker has for the protocol.
    ///
    /// This information should be easy to get inside the walker, and should
    /// only be used when making a decision of what protocol to hint as a visitor.
    /// This can be helpful for doing things like preallocating space in the visitor.
    ///
    /// Most protocols will allow returning a value representing no knowledge is known by the
    /// walker.
    fn known(&mut self, hint: &P::Hint<'ctx>) -> P::Known<'ctx>;
}

/// Protocol specific visit for a visitor.
///
/// A visitor can implement any number of these for any protocols it supports.
pub trait Visit<'ctx, P: Protocol> {
    /// Visit a value from the walker.
    ///
    /// The `accessor` will either be a concrete value or another walker.
    /// This is dependant on the protocol `P`. The visitor can do whatever it wants
    /// with the `accessor` during the function call.
    fn visit(&mut self, accessor: P::Accessor<'_, 'ctx>) -> ControlFlow;
}

#[derive(thiserror::Error, Debug)]
#[error("Expected Hint to be for `{expected}` but got one for `{got}`.")]
pub struct WalkerWrongProtocol {
    pub got: &'static str,
    pub expected: &'static str,
}

#[derive(thiserror::Error, Debug)]
pub enum WalkerMissingProtocol {
    #[error("Walker doesn't support the protocol `{0}`.")]
    Missing(&'static str),

    #[error(transparent)]
    Wrong(#[from] WalkerWrongProtocol),
}

#[derive(thiserror::Error, Debug)]
#[error("Expected Visit to be for `{expected}` but got one for `{got}`.")]
pub struct VisitorWrongProtocol {
    pub got: &'static str,
    pub expected: &'static str,
}

#[derive(thiserror::Error, Debug)]
pub enum VisitorMissingProtocol {
    #[error("Visitor doesn't support the protocol `{0}`.")]
    Missing(&'static str),

    #[error(transparent)]
    Wrong(#[from] VisitorWrongProtocol),
}

/// Try to lookup a [`Hint`] on a walker.
pub fn try_lookup_hint<'ctx, P: Protocol, H: ?Sized + WalkerHints<'ctx>>(
    hints: &mut H,
) -> Result<Option<&mut dyn Hint<'ctx, P>>, WalkerWrongProtocol> {
    match hints.protocol(ProtocolId::of::<P>()) {
        Some(protocol) => match protocol.downcast::<P>() {
            Ok(hint) => Ok(Some(hint)),
            Err(hint) => Err(WalkerWrongProtocol {
                got: hint.protocol_type_name(),
                expected: type_name::<P>(),
            }),
        },
        None => Ok(None),
    }
}

pub fn lookup_hint<'ctx, P: Protocol, H: ?Sized + WalkerHints<'ctx>>(
    hints: &mut H,
) -> Result<&mut dyn Hint<'ctx, P>, WalkerMissingProtocol> {
    try_lookup_hint(hints)?.ok_or(WalkerMissingProtocol::Missing(type_name::<P>()))
}

pub fn try_lookup_visit<'ctx, P: Protocol, V: ?Sized + Visitor<'ctx>>(
    visitor: &mut V,
) -> Result<Option<&mut dyn Visit<'ctx, P>>, VisitorWrongProtocol> {
    match visitor.protocol(ProtocolId::of::<P>()) {
        Some(protocol) => match protocol.downcast::<P>() {
            Ok(visit) => Ok(Some(visit)),
            Err(visit) => Err(VisitorWrongProtocol {
                got: visit.protocol_type_name(),
                expected: type_name::<P>(),
            }),
        },
        None => Ok(None),
    }
}

pub fn lookup_visit<'ctx, P: Protocol, V: ?Sized + Visitor<'ctx>>(
    visitor: &mut V,
) -> Result<&mut dyn Visit<'ctx, P>, VisitorMissingProtocol> {
    try_lookup_visit(visitor)?.ok_or(VisitorMissingProtocol::Missing(type_name::<P>()))
}

mod id {
    use super::Protocol;
    use core::any::TypeId;

    /// ID of a protocol.
    ///
    /// This can be used to query if a walker or visitor supports a protocol.
    #[derive(PartialEq, Eq, Hash, Ord, PartialOrd, Debug, Copy, Clone)]
    pub struct ProtocolId(TypeId);

    impl ProtocolId {
        /// Get the ID of a protocol.
        ///
        /// The ID is unique per protocol.
        pub fn of<P: Protocol>() -> Self {
            Self(TypeId::of::<P>())
        }
    }
}

mod any_hint {
    use core::{
        any::{type_name, Any},
        marker::PhantomData,
        mem::MaybeUninit,
    };

    use crate::Hint;

    use super::{Protocol, ProtocolId};

    /// Form of `Hint` without `P`.
    trait ErasedHint<'ctx>: Any {}

    /// Get the size of pointers to trait objects for this target.
    const DYN_PTR_SIZE: usize = core::mem::size_of::<&mut dyn ErasedHint<'static>>();

    /// Type erased form of `&'walking mut dyn Hint<'value, P, Err>` where `P` is erased.
    pub struct AnyHint<'walking, 'ctx> {
        /// ID of `P`.
        id: ProtocolId,
        name: &'static str,

        /// This field stores a `&'walking mut dyn Hint<'value, P, Err>`.
        fat_ptr: MaybeUninit<[u8; DYN_PTR_SIZE]>,

        /// Mimick what we actually store with a trait without `P`.
        _marker: PhantomData<&'walking mut dyn ErasedHint<'ctx>>,
    }

    impl<'walking, 'ctx> AnyHint<'walking, 'ctx> {
        /// Erase the `P` in a hint.
        ///
        /// This allows returning a hint from a object safe method.
        pub fn new<P: Protocol>(visit: &'walking mut dyn Hint<'ctx, P>) -> Self {
            Self {
                id: ProtocolId::of::<P>(),
                name: type_name::<P>(),
                // SAFETY: A maybe uninit array of bytes can hold any pointer.
                // Additionally, transmute makes sure the size is correct.
                fat_ptr: unsafe { core::mem::transmute(visit) },
                _marker: PhantomData,
            }
        }

        /// Try to downcast the hint for the given protocol.
        ///
        /// If the hint is of the wrong type then `None` is returned.
        pub fn downcast<P: Protocol>(self) -> Result<&'walking mut dyn Hint<'ctx, P>, Self> {
            if self.id == ProtocolId::of::<P>() {
                // SAFETY: Only `new` can make a value of this type, and it stores the ID of `P`.
                // If the IDs are equal then we can act like any and downcast back to the real
                // type.
                //
                // An important note is this method takes ownership. Which allows it to return
                // the borrow with the `'walking` lifetime instead of a sub-borrow.
                Ok(unsafe { core::mem::transmute(self.fat_ptr) })
            } else {
                Err(self)
            }
        }

        pub fn protocol_type_name(&self) -> &'static str {
            self.name
        }
    }
}

mod any_visit {
    use core::{
        any::{type_name, Any},
        marker::PhantomData,
        mem::MaybeUninit,
    };

    use crate::Visit;

    use super::{Protocol, ProtocolId};

    /// Form of `Visit` without `P`.
    trait ErasedVisit<'ctx>: Any {}

    /// Get the size of pointers to trait objects for this target.
    const DYN_PTR_SIZE: usize = core::mem::size_of::<&mut dyn ErasedVisit<'static>>();

    /// Type erased form of `&'walking mut dyn Visit<'value, P, Err>` where `P` is erased.
    pub struct AnyVisit<'walking, 'ctx> {
        /// ID of `P`.
        id: ProtocolId,
        name: &'static str,

        /// This field stores a `&'walking mut dyn Visit<'value, P, Err>`.
        fat_ptr: MaybeUninit<[u8; DYN_PTR_SIZE]>,

        /// Mimick what we actually store with a trait without `P`.
        _marker: PhantomData<&'walking mut dyn ErasedVisit<'ctx>>,
    }

    impl<'walking, 'ctx> AnyVisit<'walking, 'ctx> {
        /// Erase the `P` in a Visit.
        ///
        /// This allows returning a Visit from a object safe method.
        pub fn new<P: Protocol>(visit: &'walking mut dyn Visit<'ctx, P>) -> Self {
            Self {
                id: ProtocolId::of::<P>(),
                name: type_name::<P>(),
                // SAFETY: A maybe uninit array of bytes can hold any pointer.
                // Additionally, transmute makes sure the size is correct.
                fat_ptr: unsafe { core::mem::transmute(visit) },
                _marker: PhantomData,
            }
        }

        /// Try to downcast the Visit for the given protocol.
        ///
        /// If the Visit is of the wrong type then `None` is returned.
        pub fn downcast<P: Protocol>(self) -> Result<&'walking mut dyn Visit<'ctx, P>, Self> {
            if self.id == ProtocolId::of::<P>() {
                // SAFETY: Only `new` can make a value of this type, and it stores the ID of `P`.
                // If the IDs are equal then we can act like any and downcast back to the real
                // type.
                //
                // An important note is this method takes ownership. Which allows it to return
                // the borrow with the `'walking` lifetime instead of a sub-borrow.
                Ok(unsafe { core::mem::transmute(self.fat_ptr) })
            } else {
                Err(self)
            }
        }

        pub fn protocol_type_name(&self) -> &'static str {
            self.name
        }
    }
}

/// The following shows a safe form of the generic types in this module.
/// This shows how the lifetimes are correct.
#[cfg(test)]
#[allow(unused)]
mod generic_example {
    use crate::Hint;

    use super::{Protocol, ProtocolId};

    pub struct Generic<'walking, 'ctx, P> {
        id: ProtocolId,
        fat_ptr: &'walking mut dyn Hint<'ctx, P>,
    }

    impl<'walking, 'ctx, P: Protocol> Generic<'walking, 'ctx, P> {
        pub fn new(visit: &'walking mut dyn Hint<'ctx, P>) -> Self {
            Self {
                id: ProtocolId::of::<P>(),
                fat_ptr: visit,
            }
        }

        pub fn downcast(self) -> Result<&'walking mut dyn Hint<'ctx, P>, Self> {
            if self.id == ProtocolId::of::<P>() {
                // Notice how this is valid.
                Ok(self.fat_ptr)
            } else {
                Err(self)
            }
        }
    }
}