Skip to main content

macros/
for_lt.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::{
4    Span,
5    TokenStream, //
6};
7use quote::{
8    format_ident,
9    quote, //
10};
11use syn::{
12    parse::{
13        Parse,
14        ParseStream, //
15    },
16    visit::Visit,
17    visit_mut::VisitMut,
18    Lifetime,
19    Result,
20    Token,
21    Type, //
22};
23
24pub(crate) enum HigherRankedType {
25    Explicit {
26        _for_token: Token![for],
27        _lt_token: Token![<],
28        lifetime: Lifetime,
29        _gt_token: Token![>],
30        ty: Type,
31    },
32    Implicit {
33        ty: Type,
34    },
35}
36
37impl Parse for HigherRankedType {
38    fn parse(input: ParseStream<'_>) -> Result<Self> {
39        if input.peek(Token![for]) {
40            Ok(Self::Explicit {
41                _for_token: input.parse()?,
42                _lt_token: input.parse()?,
43                lifetime: input.parse()?,
44                _gt_token: input.parse()?,
45                ty: input.parse()?,
46            })
47        } else {
48            Ok(Self::Implicit { ty: input.parse()? })
49        }
50    }
51}
52
53trait TypeExt {
54    fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type;
55    fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type;
56    fn has_lifetime(&self, lt: &Lifetime) -> bool;
57}
58
59impl TypeExt for Type {
60    fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type {
61        struct ElidedLifetimeExpander<'a>(&'a Lifetime);
62
63        impl VisitMut for ElidedLifetimeExpander<'_> {
64            fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) {
65                // Expand explicit `'_`
66                if lifetime.ident == "_" {
67                    *lifetime = self.0.clone();
68                }
69            }
70
71            fn visit_type_reference_mut(&mut self, reference: &mut syn::TypeReference) {
72                syn::visit_mut::visit_type_reference_mut(self, reference);
73
74                if reference.lifetime.is_none() {
75                    reference.lifetime = Some(self.0.clone());
76                }
77            }
78        }
79
80        let mut ret = self.clone();
81        ElidedLifetimeExpander(explicit_lt).visit_type_mut(&mut ret);
82        ret
83    }
84
85    fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type {
86        struct LifetimeReplacer<'a>(&'a Lifetime, &'a Lifetime);
87
88        impl VisitMut for LifetimeReplacer<'_> {
89            fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) {
90                if lifetime.ident == self.0.ident {
91                    *lifetime = self.1.clone();
92                }
93            }
94        }
95
96        let mut ret = self.clone();
97        LifetimeReplacer(src, dst).visit_type_mut(&mut ret);
98        ret
99    }
100
101    fn has_lifetime(&self, lt: &Lifetime) -> bool {
102        struct HasLifetime<'a>(&'a Lifetime, bool);
103
104        impl Visit<'_> for HasLifetime<'_> {
105            fn visit_lifetime(&mut self, lifetime: &Lifetime) {
106                if lifetime.ident == self.0.ident {
107                    self.1 = true;
108                }
109            }
110
111            // Macro invocations are opaque; conservatively assume they may
112            // reference the lifetime.
113            fn visit_macro(&mut self, _: &syn::Macro) {
114                self.1 = true;
115            }
116        }
117
118        let mut visitor = HasLifetime(lt, false);
119        visitor.visit_type(self);
120        visitor.1
121    }
122}
123
124struct Prover<'a>(&'a Lifetime, Vec<&'a Type>);
125
126impl<'a> Prover<'a> {
127    /// Prove that `ty` is covariant over `'lt`.
128    ///
129    /// This also needs to prove that it'll be wellformed for any instance of `'lt`.
130    /// It can be assumed that `ty` will be wellformed if `'lt` is substituted to `'static`.
131    fn prove(&mut self, ty: &'a Type) {
132        match ty {
133            Type::Paren(ty) => self.prove(&ty.elem),
134            Type::Group(ty) => self.prove(&ty.elem),
135
136            // No lifetime involved
137            Type::Never(_) => {}
138
139            // `[T; N]` and `[T]` is covariant over `T`.
140            Type::Array(ty) => self.prove(&ty.elem),
141            Type::Slice(ty) => self.prove(&ty.elem),
142
143            Type::Tuple(ty) => {
144                for elem in &ty.elems {
145                    self.prove(elem);
146                }
147            }
148
149            // `*const T` is covariant over `T`
150            Type::Ptr(ty) if ty.const_token.is_some() => self.prove(&ty.elem),
151
152            // `&T` is covariant over `T` and lifetime.
153            //
154            // Note that if we encounter `&'other_lt T`, then we still need to make sure the type
155            // is wellformed if `T` involves `&'lt`, so we defer to the compiler.
156            //
157            // This is to block cases like `ForLt!(for<'a> &'static &'a u32)`, as the presence of
158            // the type implies `'a: 'static` but this is unsound.
159            Type::Reference(ty)
160                if ty.mutability.is_none() && ty.lifetime.as_ref() == Some(self.0) =>
161            {
162                self.prove(&ty.elem)
163            }
164
165            // `&[mut] T` is covariant over lifetime.
166            // In case we have `&[mut] NoLifetime`, we don't need to do additional checks.
167            Type::Reference(ty) if !ty.elem.has_lifetime(self.0) => (),
168
169            // No mention of lifetime at all, no need to perform compiler check.
170            ty if !ty.has_lifetime(self.0) => (),
171
172            // Otherwise, we need to emit checks so that compiler can determine if the types are
173            // actually covariant.
174            ty => self.1.push(ty),
175        }
176    }
177}
178
179pub(crate) fn for_lt(input: HigherRankedType) -> TokenStream {
180    let (ty, lifetime) = match input {
181        HigherRankedType::Explicit { lifetime, ty, .. } => (ty, lifetime),
182        HigherRankedType::Implicit { ty } => {
183            // If there's no explicit `for<'a>` binder, inject a synthetic `'__elided` lifetime
184            // and expand elided sites.
185            let lifetime = Lifetime {
186                apostrophe: Span::mixed_site(),
187                ident: format_ident!("__elided", span = Span::mixed_site()),
188            };
189            (ty.expand_elided_lifetime(&lifetime), lifetime)
190        }
191    };
192
193    let mut prover = Prover(&lifetime, Vec::new());
194    prover.prove(&ty);
195
196    let mut proof = Vec::new();
197
198    // Emit proofs for every type that requires additional compiler help in proving covariance.
199    for (idx, required_proof) in prover.1.into_iter().enumerate() {
200        // Insert a proof that the type is well-formed.
201        //
202        // This is intended to workaround a Rust compiler soundness bug related to HRTB.
203        // https://github.com/rust-lang/rust/issues/152489
204        //
205        // This needs to be a struct instead of fn to avoid the implied WF bounds.
206        let wf_proof_name = format_ident!("ProveWf{idx}");
207        proof.push(quote!(
208            struct #wf_proof_name<#lifetime>(
209                ::core::marker::PhantomData<&#lifetime ()>, #required_proof
210            );
211        ));
212
213        // Insert a proof that the type is covariant.
214        let cov_proof_name = format_ident!("prove_covariant_{idx}");
215        proof.push(quote!(
216            fn #cov_proof_name<'__short, '__long: '__short>(
217                long: #wf_proof_name<'__long>
218            ) -> #wf_proof_name<'__short> {
219                long
220            }
221        ));
222    }
223
224    // Make sure that the type is wellformed when substituting lifetime with `'static`.
225    //
226    // Currently the Rust compiler doesn't check this, see the above `ProveWf` documentation.
227    //
228    // We prefer to use this way of proving WF-ness as it can work when generics are involved.
229    let ty_static = ty.replace_lifetime(
230        &lifetime,
231        &Lifetime {
232            apostrophe: Span::mixed_site(),
233            ident: format_ident!("static"),
234        },
235    );
236
237    quote!(
238        ::kernel::types::for_lt::UnsafeForLtImpl::<
239            dyn for<#lifetime> ::kernel::types::for_lt::WithLt<#lifetime, Of = #ty>,
240            #ty_static,
241            {
242                #(#proof)*
243
244                0
245            }
246        >
247    )
248}