1use proc_macro2::{
4 Span,
5 TokenStream, };
7use quote::{
8 format_ident,
9 quote, };
11use syn::{
12 parse::{
13 Parse,
14 ParseStream, },
16 visit::Visit,
17 visit_mut::VisitMut,
18 Lifetime,
19 Result,
20 Token,
21 Type, };
23
24pub(crate) enum HigherRankedType {
25 Explicit {
26 _for_token: Token![for],
27 _lt_token: Token![<],
28 lifetime: Lifetime,
29 _gt_token: Token![>],
30 ty: Type,
31 },
32 Implicit {
33 ty: Type,
34 },
35}
36
37impl Parse for HigherRankedType {
38 fn parse(input: ParseStream<'_>) -> Result<Self> {
39 if input.peek(Token![for]) {
40 Ok(Self::Explicit {
41 _for_token: input.parse()?,
42 _lt_token: input.parse()?,
43 lifetime: input.parse()?,
44 _gt_token: input.parse()?,
45 ty: input.parse()?,
46 })
47 } else {
48 Ok(Self::Implicit { ty: input.parse()? })
49 }
50 }
51}
52
53trait TypeExt {
54 fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type;
55 fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type;
56 fn has_lifetime(&self, lt: &Lifetime) -> bool;
57}
58
59impl TypeExt for Type {
60 fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type {
61 struct ElidedLifetimeExpander<'a>(&'a Lifetime);
62
63 impl VisitMut for ElidedLifetimeExpander<'_> {
64 fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) {
65 if lifetime.ident == "_" {
67 *lifetime = self.0.clone();
68 }
69 }
70
71 fn visit_type_reference_mut(&mut self, reference: &mut syn::TypeReference) {
72 syn::visit_mut::visit_type_reference_mut(self, reference);
73
74 if reference.lifetime.is_none() {
75 reference.lifetime = Some(self.0.clone());
76 }
77 }
78 }
79
80 let mut ret = self.clone();
81 ElidedLifetimeExpander(explicit_lt).visit_type_mut(&mut ret);
82 ret
83 }
84
85 fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type {
86 struct LifetimeReplacer<'a>(&'a Lifetime, &'a Lifetime);
87
88 impl VisitMut for LifetimeReplacer<'_> {
89 fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) {
90 if lifetime.ident == self.0.ident {
91 *lifetime = self.1.clone();
92 }
93 }
94 }
95
96 let mut ret = self.clone();
97 LifetimeReplacer(src, dst).visit_type_mut(&mut ret);
98 ret
99 }
100
101 fn has_lifetime(&self, lt: &Lifetime) -> bool {
102 struct HasLifetime<'a>(&'a Lifetime, bool);
103
104 impl Visit<'_> for HasLifetime<'_> {
105 fn visit_lifetime(&mut self, lifetime: &Lifetime) {
106 if lifetime.ident == self.0.ident {
107 self.1 = true;
108 }
109 }
110
111 fn visit_macro(&mut self, _: &syn::Macro) {
114 self.1 = true;
115 }
116 }
117
118 let mut visitor = HasLifetime(lt, false);
119 visitor.visit_type(self);
120 visitor.1
121 }
122}
123
124struct Prover<'a>(&'a Lifetime, Vec<&'a Type>);
125
126impl<'a> Prover<'a> {
127 fn prove(&mut self, ty: &'a Type) {
132 match ty {
133 Type::Paren(ty) => self.prove(&ty.elem),
134 Type::Group(ty) => self.prove(&ty.elem),
135
136 Type::Never(_) => {}
138
139 Type::Array(ty) => self.prove(&ty.elem),
141 Type::Slice(ty) => self.prove(&ty.elem),
142
143 Type::Tuple(ty) => {
144 for elem in &ty.elems {
145 self.prove(elem);
146 }
147 }
148
149 Type::Ptr(ty) if ty.const_token.is_some() => self.prove(&ty.elem),
151
152 Type::Reference(ty)
160 if ty.mutability.is_none() && ty.lifetime.as_ref() == Some(self.0) =>
161 {
162 self.prove(&ty.elem)
163 }
164
165 Type::Reference(ty) if !ty.elem.has_lifetime(self.0) => (),
168
169 ty if !ty.has_lifetime(self.0) => (),
171
172 ty => self.1.push(ty),
175 }
176 }
177}
178
179pub(crate) fn for_lt(input: HigherRankedType) -> TokenStream {
180 let (ty, lifetime) = match input {
181 HigherRankedType::Explicit { lifetime, ty, .. } => (ty, lifetime),
182 HigherRankedType::Implicit { ty } => {
183 let lifetime = Lifetime {
186 apostrophe: Span::mixed_site(),
187 ident: format_ident!("__elided", span = Span::mixed_site()),
188 };
189 (ty.expand_elided_lifetime(&lifetime), lifetime)
190 }
191 };
192
193 let mut prover = Prover(&lifetime, Vec::new());
194 prover.prove(&ty);
195
196 let mut proof = Vec::new();
197
198 for (idx, required_proof) in prover.1.into_iter().enumerate() {
200 let wf_proof_name = format_ident!("ProveWf{idx}");
207 proof.push(quote!(
208 struct #wf_proof_name<#lifetime>(
209 ::core::marker::PhantomData<&#lifetime ()>, #required_proof
210 );
211 ));
212
213 let cov_proof_name = format_ident!("prove_covariant_{idx}");
215 proof.push(quote!(
216 fn #cov_proof_name<'__short, '__long: '__short>(
217 long: #wf_proof_name<'__long>
218 ) -> #wf_proof_name<'__short> {
219 long
220 }
221 ));
222 }
223
224 let ty_static = ty.replace_lifetime(
230 &lifetime,
231 &Lifetime {
232 apostrophe: Span::mixed_site(),
233 ident: format_ident!("static"),
234 },
235 );
236
237 quote!(
238 ::kernel::types::for_lt::UnsafeForLtImpl::<
239 dyn for<#lifetime> ::kernel::types::for_lt::WithLt<#lifetime, Of = #ty>,
240 #ty_static,
241 {
242 #(#proof)*
243
244 0
245 }
246 >
247 )
248}