Skip to main content

pin_init_internal/
init.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote};
5use syn::{
6    braced,
7    parse::{End, Parse},
8    parse_quote,
9    punctuated::Punctuated,
10    spanned::Spanned,
11    token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12};
13
14use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15
16pub(crate) struct Initializer {
17    attrs: Vec<InitializerAttribute>,
18    this: Option<This>,
19    path: Path,
20    brace_token: token::Brace,
21    fields: Punctuated<InitializerField, Token![,]>,
22    rest: Option<(Token![..], Expr)>,
23    error: Option<(Token![?], Type)>,
24}
25
26struct This {
27    _and_token: Token![&],
28    ident: Ident,
29    _in_token: Token![in],
30}
31
32struct InitializerField {
33    attrs: Vec<Attribute>,
34    kind: InitializerKind,
35}
36
37enum InitializerKind {
38    Value {
39        ident: Ident,
40        value: Option<(Token![:], Expr)>,
41    },
42    Init {
43        ident: Ident,
44        _left_arrow_token: Token![<-],
45        value: Expr,
46    },
47    Code {
48        _underscore_token: Token![_],
49        _colon_token: Token![:],
50        block: Block,
51    },
52}
53
54impl InitializerKind {
55    fn ident(&self) -> Option<&Ident> {
56        match self {
57            Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58            Self::Code { .. } => None,
59        }
60    }
61}
62
63enum InitializerAttribute {
64    DefaultError(DefaultErrorAttribute),
65}
66
67struct DefaultErrorAttribute {
68    ty: Box<Type>,
69}
70
71pub(crate) fn expand(
72    Initializer {
73        attrs,
74        this,
75        path,
76        brace_token,
77        fields,
78        rest,
79        error,
80    }: Initializer,
81    default_error: Option<&'static str>,
82    pinned: bool,
83    dcx: &mut DiagCtxt,
84) -> Result<TokenStream, ErrorGuaranteed> {
85    let error = error.map_or_else(
86        || {
87            if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                #[expect(irrefutable_let_patterns)]
89                if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                    Some(ty.clone())
91                } else {
92                    acc
93                }
94            }) {
95                default_error
96            } else if let Some(default_error) = default_error {
97                syn::parse_str(default_error).unwrap()
98            } else {
99                dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                parse_quote!(::core::convert::Infallible)
101            }
102        },
103        |(_, err)| Box::new(err),
104    );
105    let slot = format_ident!("slot");
106    let (has_data_trait, get_data, init_from_closure) = if pinned {
107        (
108            format_ident!("HasPinData"),
109            format_ident!("__pin_data"),
110            format_ident!("pin_init_from_closure"),
111        )
112    } else {
113        (
114            format_ident!("HasInitData"),
115            format_ident!("__init_data"),
116            format_ident!("init_from_closure"),
117        )
118    };
119    let init_kind = get_init_kind(rest, dcx);
120    let zeroable_check = match init_kind {
121        InitKind::Normal => quote!(),
122        InitKind::Zeroing => quote! {
123            // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
124            // Therefore we check if the struct implements `Zeroable` and then zero the memory.
125            // This allows us to also remove the check that all fields are present (since we
126            // already set the memory to zero and that is a valid bit pattern).
127            fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
128            where T: ::pin_init::Zeroable
129            {}
130            // Ensure that the struct is indeed `Zeroable`.
131            assert_zeroable(#slot);
132            // SAFETY: The type implements `Zeroable` by the check above.
133            unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
134        },
135    };
136    let this = match this {
137        None => quote!(),
138        Some(This { ident, .. }) => quote! {
139            // Create the `this` so it can be referenced by the user inside of the
140            // expressions creating the individual fields.
141            let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
142        },
143    };
144    // `mixed_site` ensures that the data is not accessible to the user-controlled code.
145    let data = Ident::new("__data", Span::mixed_site());
146    let init_fields = init_fields(&fields, pinned, &data, &slot);
147    let field_check = make_field_check(&fields, init_kind, &path);
148    Ok(quote! {{
149        // Get the data about fields from the supplied type.
150        // SAFETY: TODO
151        let #data = unsafe {
152            use ::pin_init::__internal::#has_data_trait;
153            // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
154            // generics (which need to be present with that syntax).
155            #path::#get_data()
156        };
157        // Ensure that `#data` really is of type `#data` and help with type inference:
158        let init = #data.__make_closure::<_, #error>(
159            move |slot| {
160                #zeroable_check
161                #this
162                #init_fields
163                #field_check
164                // SAFETY: we are the `init!` macro that is allowed to call this.
165                Ok(unsafe { ::pin_init::__internal::InitOk::new() })
166            }
167        );
168        let init = move |slot| -> ::core::result::Result<(), #error> {
169            init(slot).map(|__InitOk| ())
170        };
171        // SAFETY: TODO
172        unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }
173    }})
174}
175
176enum InitKind {
177    Normal,
178    Zeroing,
179}
180
181fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
182    let Some((dotdot, expr)) = rest else {
183        return InitKind::Normal;
184    };
185    match &expr {
186        Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
187            Expr::Path(ExprPath {
188                attrs,
189                qself: None,
190                path:
191                    Path {
192                        leading_colon: None,
193                        segments,
194                    },
195            }) if attrs.is_empty()
196                && segments.len() == 2
197                && segments[0].ident == "Zeroable"
198                && segments[0].arguments.is_none()
199                && segments[1].ident == "init_zeroed"
200                && segments[1].arguments.is_none() =>
201            {
202                return InitKind::Zeroing;
203            }
204            _ => {}
205        },
206        _ => {}
207    }
208    dcx.error(
209        dotdot.span().join(expr.span()).unwrap_or(expr.span()),
210        "expected nothing or `..Zeroable::init_zeroed()`.",
211    );
212    InitKind::Normal
213}
214
215/// Generate the code that initializes the fields of the struct using the initializers in `field`.
216fn init_fields(
217    fields: &Punctuated<InitializerField, Token![,]>,
218    pinned: bool,
219    data: &Ident,
220    slot: &Ident,
221) -> TokenStream {
222    let mut guards = vec![];
223    let mut guard_attrs = vec![];
224    let mut res = TokenStream::new();
225    for InitializerField { attrs, kind } in fields {
226        let cfgs = {
227            let mut cfgs = attrs.clone();
228            cfgs.retain(|attr| attr.path().is_ident("cfg"));
229            cfgs
230        };
231
232        let ident = match kind {
233            InitializerKind::Value { ident, .. } => ident,
234            InitializerKind::Init { ident, .. } => ident,
235            InitializerKind::Code { block, .. } => {
236                res.extend(quote! {
237                    #(#attrs)*
238                    #[allow(unused_braces)]
239                    #block
240                });
241                continue;
242            }
243        };
244
245        let slot = if pinned {
246            quote! {
247                // SAFETY:
248                // - `slot` is valid and properly aligned.
249                // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned.
250                // - `make_field_check` prevents `#ident` from being used twice, therefore
251                //   `(*slot).#ident` is exclusively accessed and has not been initialized.
252                (unsafe { #data.#ident(#slot) })
253            }
254        } else {
255            quote! {
256                // For `init!()` macro, everything is unpinned.
257                // SAFETY:
258                // - `&raw mut (*slot).#ident` is valid.
259                // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned.
260                // - `make_field_check` prevents `#ident` from being used twice, therefore
261                //   `(*slot).#ident` is exclusively accessed and has not been initialized.
262                (unsafe {
263                    ::pin_init::__internal::Slot::<::pin_init::__internal::Unpinned, _>::new(
264                        &raw mut (*#slot).#ident
265                    )
266                })
267            }
268        };
269
270        // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
271        let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
272
273        let init = match kind {
274            InitializerKind::Value { ident, value } => {
275                let value = value
276                    .as_ref()
277                    .map(|(_, value)| quote!(#value))
278                    .unwrap_or_else(|| quote!(#ident));
279
280                quote! {
281                    #(#attrs)*
282                    let mut #guard = #slot.write(#value);
283
284                }
285            }
286            InitializerKind::Init { value, .. } => {
287                quote! {
288                    #(#attrs)*
289                    let mut #guard = #slot.init(#value)?;
290                }
291            }
292            InitializerKind::Code { .. } => unreachable!(),
293        };
294
295        res.extend(quote! {
296            #init
297
298            #(#cfgs)*
299            // Allow `non_snake_case` since the same warning is going to be reported for the struct
300            // field.
301            #[allow(unused_variables, non_snake_case)]
302            let #ident = #guard.let_binding();
303        });
304
305        guards.push(guard);
306        guard_attrs.push(cfgs);
307    }
308    quote! {
309        #res
310        // If execution reaches this point, all fields have been initialized. Therefore we can now
311        // dismiss the guards by forgetting them.
312        #(
313            #(#guard_attrs)*
314            ::core::mem::forget(#guards);
315        )*
316    }
317}
318
319/// Generate the check for ensuring that every field has been initialized and aligned.
320fn make_field_check(
321    fields: &Punctuated<InitializerField, Token![,]>,
322    init_kind: InitKind,
323    path: &Path,
324) -> TokenStream {
325    let field_attrs: Vec<_> = fields
326        .iter()
327        .filter_map(|f| f.kind.ident().map(|_| &f.attrs))
328        .collect();
329    let field_name: Vec<_> = fields.iter().filter_map(|f| f.kind.ident()).collect();
330    let zeroing_trailer = match init_kind {
331        InitKind::Normal => None,
332        InitKind::Zeroing => Some(quote! {
333            ..::core::mem::zeroed()
334        }),
335    };
336    quote! {
337        #[allow(unreachable_code, clippy::diverging_sub_expression)]
338        // We use unreachable code to perform field checks. They're still checked by the compiler.
339        // SAFETY: this code is never executed.
340        let _ = || unsafe {
341            // Create references to ensure that the initialized field is properly aligned.
342            // Unaligned fields will cause the compiler to emit E0793. We do not support
343            // unaligned fields since `Init::__init` requires an aligned pointer; the call to
344            // `ptr::write` for value-initialization case has the same requirement.
345            #(
346                #(#field_attrs)*
347                let _ = &(*slot).#field_name;
348            )*
349
350            // If the zeroing trailer is not present, this checks that all fields have been
351            // mentioned exactly once. If the zeroing trailer is present, all missing fields will be
352            // zeroed, so this checks that all fields have been mentioned at most once. The use of
353            // struct initializer will still generate very natural error messages for any misuse.
354            ::core::ptr::write(slot, #path {
355                #(
356                    #(#field_attrs)*
357                    #field_name: loop {},
358                )*
359                #zeroing_trailer
360            })
361        };
362    }
363}
364
365impl Parse for Initializer {
366    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
367        let attrs = input.call(Attribute::parse_outer)?;
368        let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
369        let path = input.parse()?;
370        let content;
371        let brace_token = braced!(content in input);
372        let mut fields = Punctuated::new();
373        loop {
374            let lh = content.lookahead1();
375            if lh.peek(End) || lh.peek(Token![..]) {
376                break;
377            } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
378                fields.push_value(content.parse()?);
379                let lh = content.lookahead1();
380                if lh.peek(End) {
381                    break;
382                } else if lh.peek(Token![,]) {
383                    fields.push_punct(content.parse()?);
384                } else {
385                    return Err(lh.error());
386                }
387            } else {
388                return Err(lh.error());
389            }
390        }
391        let rest = content
392            .peek(Token![..])
393            .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
394            .transpose()?;
395        let error = input
396            .peek(Token![?])
397            .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
398            .transpose()?;
399        let attrs = attrs
400            .into_iter()
401            .map(|a| {
402                if a.path().is_ident("default_error") {
403                    a.parse_args::<DefaultErrorAttribute>()
404                        .map(InitializerAttribute::DefaultError)
405                } else {
406                    Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
407                }
408            })
409            .collect::<Result<Vec<_>, _>>()?;
410        Ok(Self {
411            attrs,
412            this,
413            path,
414            brace_token,
415            fields,
416            rest,
417            error,
418        })
419    }
420}
421
422impl Parse for DefaultErrorAttribute {
423    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
424        Ok(Self { ty: input.parse()? })
425    }
426}
427
428impl Parse for This {
429    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
430        Ok(Self {
431            _and_token: input.parse()?,
432            ident: input.parse()?,
433            _in_token: input.parse()?,
434        })
435    }
436}
437
438impl Parse for InitializerField {
439    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
440        let attrs = input.call(Attribute::parse_outer)?;
441        Ok(Self {
442            attrs,
443            kind: input.parse()?,
444        })
445    }
446}
447
448impl Parse for InitializerKind {
449    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
450        let lh = input.lookahead1();
451        if lh.peek(Token![_]) {
452            Ok(Self::Code {
453                _underscore_token: input.parse()?,
454                _colon_token: input.parse()?,
455                block: input.parse()?,
456            })
457        } else if lh.peek(Ident) {
458            let ident = input.parse()?;
459            let lh = input.lookahead1();
460            if lh.peek(Token![<-]) {
461                Ok(Self::Init {
462                    ident,
463                    _left_arrow_token: input.parse()?,
464                    value: input.parse()?,
465                })
466            } else if lh.peek(Token![:]) {
467                Ok(Self::Value {
468                    ident,
469                    value: Some((input.parse()?, input.parse()?)),
470                })
471            } else if lh.peek(Token![,]) || lh.peek(End) {
472                Ok(Self::Value { ident, value: None })
473            } else {
474                Err(lh.error())
475            }
476        } else {
477            Err(lh.error())
478        }
479    }
480}