pin_init_internal/
zeroable.rs

1// SPDX-License-Identifier: GPL-2.0
2
3#[cfg(not(kernel))]
4use proc_macro2 as proc_macro;
5
6use crate::helpers::{parse_generics, Generics};
7use proc_macro::{TokenStream, TokenTree};
8
9pub(crate) fn parse_zeroable_derive_input(
10    input: TokenStream,
11) -> (
12    Vec<TokenTree>,
13    Vec<TokenTree>,
14    Vec<TokenTree>,
15    Option<TokenTree>,
16) {
17    let (
18        Generics {
19            impl_generics,
20            decl_generics: _,
21            ty_generics,
22        },
23        mut rest,
24    ) = parse_generics(input);
25    // This should be the body of the struct `{...}`.
26    let last = rest.pop();
27    // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
28    let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
29    // Are we inside of a generic where we want to add `Zeroable`?
30    let mut in_generic = !impl_generics.is_empty();
31    // Have we already inserted `Zeroable`?
32    let mut inserted = false;
33    // Level of `<>` nestings.
34    let mut nested = 0;
35    for tt in impl_generics {
36        match &tt {
37            // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
38            TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
39                if in_generic && !inserted {
40                    new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
41                }
42                in_generic = true;
43                inserted = false;
44                new_impl_generics.push(tt);
45            }
46            // If we find `'`, then we are entering a lifetime.
47            TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
48                in_generic = false;
49                new_impl_generics.push(tt);
50            }
51            TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
52                new_impl_generics.push(tt);
53                if in_generic {
54                    new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
55                    inserted = true;
56                }
57            }
58            TokenTree::Punct(p) if p.as_char() == '<' => {
59                nested += 1;
60                new_impl_generics.push(tt);
61            }
62            TokenTree::Punct(p) if p.as_char() == '>' => {
63                assert!(nested > 0);
64                nested -= 1;
65                new_impl_generics.push(tt);
66            }
67            _ => new_impl_generics.push(tt),
68        }
69    }
70    assert_eq!(nested, 0);
71    if in_generic && !inserted {
72        new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
73    }
74    (rest, new_impl_generics, ty_generics, last)
75}
76
77pub(crate) fn derive(input: TokenStream) -> TokenStream {
78    let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
79    quote! {
80        ::pin_init::__derive_zeroable!(
81            parse_input:
82                @sig(#(#rest)*),
83                @impl_generics(#(#new_impl_generics)*),
84                @ty_generics(#(#ty_generics)*),
85                @body(#last),
86        );
87    }
88}
89
90pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream {
91    let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
92    quote! {
93        ::pin_init::__maybe_derive_zeroable!(
94            parse_input:
95                @sig(#(#rest)*),
96                @impl_generics(#(#new_impl_generics)*),
97                @ty_generics(#(#ty_generics)*),
98                @body(#last),
99        );
100    }
101}