Skip to main content

pin_init_internal/
init.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, quote_spanned};
5use syn::{
6    braced,
7    parse::{End, Parse},
8    parse_quote,
9    punctuated::Punctuated,
10    spanned::Spanned,
11    token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12};
13
14use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15
16pub(crate) struct Initializer {
17    attrs: Vec<InitializerAttribute>,
18    this: Option<This>,
19    path: Path,
20    brace_token: token::Brace,
21    fields: Punctuated<InitializerField, Token![,]>,
22    rest: Option<(Token![..], Expr)>,
23    error: Option<(Token![?], Type)>,
24}
25
26struct This {
27    _and_token: Token![&],
28    ident: Ident,
29    _in_token: Token![in],
30}
31
32struct InitializerField {
33    attrs: Vec<Attribute>,
34    kind: InitializerKind,
35}
36
37enum InitializerKind {
38    Value {
39        ident: Ident,
40        value: Option<(Token![:], Expr)>,
41    },
42    Init {
43        ident: Ident,
44        _left_arrow_token: Token![<-],
45        value: Expr,
46    },
47    Code {
48        _underscore_token: Token![_],
49        _colon_token: Token![:],
50        block: Block,
51    },
52}
53
54impl InitializerKind {
55    fn ident(&self) -> Option<&Ident> {
56        match self {
57            Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58            Self::Code { .. } => None,
59        }
60    }
61}
62
63enum InitializerAttribute {
64    DefaultError(DefaultErrorAttribute),
65}
66
67struct DefaultErrorAttribute {
68    ty: Box<Type>,
69}
70
71pub(crate) fn expand(
72    Initializer {
73        attrs,
74        this,
75        path,
76        brace_token,
77        fields,
78        rest,
79        error,
80    }: Initializer,
81    default_error: Option<&'static str>,
82    pinned: bool,
83    dcx: &mut DiagCtxt,
84) -> Result<TokenStream, ErrorGuaranteed> {
85    let error = error.map_or_else(
86        || {
87            if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                #[expect(irrefutable_let_patterns)]
89                if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                    Some(ty.clone())
91                } else {
92                    acc
93                }
94            }) {
95                default_error
96            } else if let Some(default_error) = default_error {
97                syn::parse_str(default_error).unwrap()
98            } else {
99                dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                parse_quote!(::core::convert::Infallible)
101            }
102        },
103        |(_, err)| Box::new(err),
104    );
105    let slot = format_ident!("slot");
106    let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107        (
108            format_ident!("HasPinData"),
109            format_ident!("PinData"),
110            format_ident!("__pin_data"),
111            format_ident!("pin_init_from_closure"),
112        )
113    } else {
114        (
115            format_ident!("HasInitData"),
116            format_ident!("InitData"),
117            format_ident!("__init_data"),
118            format_ident!("init_from_closure"),
119        )
120    };
121    let init_kind = get_init_kind(rest, dcx);
122    let zeroable_check = match init_kind {
123        InitKind::Normal => quote!(),
124        InitKind::Zeroing => quote! {
125            // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126            // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127            // This allows us to also remove the check that all fields are present (since we
128            // already set the memory to zero and that is a valid bit pattern).
129            fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130            where T: ::pin_init::Zeroable
131            {}
132            // Ensure that the struct is indeed `Zeroable`.
133            assert_zeroable(#slot);
134            // SAFETY: The type implements `Zeroable` by the check above.
135            unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136        },
137    };
138    let this = match this {
139        None => quote!(),
140        Some(This { ident, .. }) => quote! {
141            // Create the `this` so it can be referenced by the user inside of the
142            // expressions creating the individual fields.
143            let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144        },
145    };
146    // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147    let data = Ident::new("__data", Span::mixed_site());
148    let init_fields = init_fields(&fields, pinned, &data, &slot);
149    let field_check = make_field_check(&fields, init_kind, &path);
150    Ok(quote! {{
151        // Get the data about fields from the supplied type.
152        // SAFETY: TODO
153        let #data = unsafe {
154            use ::pin_init::__internal::#has_data_trait;
155            // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
156            // generics (which need to be present with that syntax).
157            #path::#get_data()
158        };
159        // Ensure that `#data` really is of type `#data` and help with type inference:
160        let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
161            #data,
162            move |slot| {
163                #zeroable_check
164                #this
165                #init_fields
166                #field_check
167                // SAFETY: we are the `init!` macro that is allowed to call this.
168                Ok(unsafe { ::pin_init::__internal::InitOk::new() })
169            }
170        );
171        let init = move |slot| -> ::core::result::Result<(), #error> {
172            init(slot).map(|__InitOk| ())
173        };
174        // SAFETY: TODO
175        let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
176        // FIXME: this let binding is required to avoid a compiler error (cycle when computing the
177        // opaque type returned by this function) before Rust 1.81. Remove after MSRV bump.
178        #[allow(
179            clippy::let_and_return,
180            reason = "some clippy versions warn about the let binding"
181        )]
182        init
183    }})
184}
185
186enum InitKind {
187    Normal,
188    Zeroing,
189}
190
191fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
192    let Some((dotdot, expr)) = rest else {
193        return InitKind::Normal;
194    };
195    match &expr {
196        Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
197            Expr::Path(ExprPath {
198                attrs,
199                qself: None,
200                path:
201                    Path {
202                        leading_colon: None,
203                        segments,
204                    },
205            }) if attrs.is_empty()
206                && segments.len() == 2
207                && segments[0].ident == "Zeroable"
208                && segments[0].arguments.is_none()
209                && segments[1].ident == "init_zeroed"
210                && segments[1].arguments.is_none() =>
211            {
212                return InitKind::Zeroing;
213            }
214            _ => {}
215        },
216        _ => {}
217    }
218    dcx.error(
219        dotdot.span().join(expr.span()).unwrap_or(expr.span()),
220        "expected nothing or `..Zeroable::init_zeroed()`.",
221    );
222    InitKind::Normal
223}
224
225/// Generate the code that initializes the fields of the struct using the initializers in `field`.
226fn init_fields(
227    fields: &Punctuated<InitializerField, Token![,]>,
228    pinned: bool,
229    data: &Ident,
230    slot: &Ident,
231) -> TokenStream {
232    let mut guards = vec![];
233    let mut guard_attrs = vec![];
234    let mut res = TokenStream::new();
235    for InitializerField { attrs, kind } in fields {
236        let cfgs = {
237            let mut cfgs = attrs.clone();
238            cfgs.retain(|attr| attr.path().is_ident("cfg"));
239            cfgs
240        };
241        let init = match kind {
242            InitializerKind::Value { ident, value } => {
243                let mut value_ident = ident.clone();
244                let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
245                    // Setting the span of `value_ident` to `value`'s span improves error messages
246                    // when the type of `value` is wrong.
247                    value_ident.set_span(value.span());
248                    quote!(let #value_ident = #value;)
249                });
250                // Again span for better diagnostics
251                let write = quote_spanned!(ident.span()=> ::core::ptr::write);
252                // NOTE: the field accessor ensures that the initialized field is properly aligned.
253                // Unaligned fields will cause the compiler to emit E0793. We do not support
254                // unaligned fields since `Init::__init` requires an aligned pointer; the call to
255                // `ptr::write` below has the same requirement.
256                let accessor = if pinned {
257                    let project_ident = format_ident!("__project_{ident}");
258                    quote! {
259                        // SAFETY: TODO
260                        unsafe { #data.#project_ident(&mut (*#slot).#ident) }
261                    }
262                } else {
263                    quote! {
264                        // SAFETY: TODO
265                        unsafe { &mut (*#slot).#ident }
266                    }
267                };
268                quote! {
269                    #(#attrs)*
270                    {
271                        #value_prep
272                        // SAFETY: TODO
273                        unsafe { #write(&raw mut (*#slot).#ident, #value_ident) };
274                    }
275                    #(#cfgs)*
276                    #[allow(unused_variables)]
277                    let #ident = #accessor;
278                }
279            }
280            InitializerKind::Init { ident, value, .. } => {
281                // Again span for better diagnostics
282                let init = format_ident!("init", span = value.span());
283                // NOTE: the field accessor ensures that the initialized field is properly aligned.
284                // Unaligned fields will cause the compiler to emit E0793. We do not support
285                // unaligned fields since `Init::__init` requires an aligned pointer; the call to
286                // `ptr::write` below has the same requirement.
287                let (value_init, accessor) = if pinned {
288                    let project_ident = format_ident!("__project_{ident}");
289                    (
290                        quote! {
291                            // SAFETY:
292                            // - `slot` is valid, because we are inside of an initializer closure, we
293                            //   return when an error/panic occurs.
294                            // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
295                            //   for `#ident`.
296                            unsafe { #data.#ident(&raw mut (*#slot).#ident, #init)? };
297                        },
298                        quote! {
299                            // SAFETY: TODO
300                            unsafe { #data.#project_ident(&mut (*#slot).#ident) }
301                        },
302                    )
303                } else {
304                    (
305                        quote! {
306                            // SAFETY: `slot` is valid, because we are inside of an initializer
307                            // closure, we return when an error/panic occurs.
308                            unsafe {
309                                ::pin_init::Init::__init(
310                                    #init,
311                                    &raw mut (*#slot).#ident,
312                                )?
313                            };
314                        },
315                        quote! {
316                            // SAFETY: TODO
317                            unsafe { &mut (*#slot).#ident }
318                        },
319                    )
320                };
321                quote! {
322                    #(#attrs)*
323                    {
324                        let #init = #value;
325                        #value_init
326                    }
327                    #(#cfgs)*
328                    #[allow(unused_variables)]
329                    let #ident = #accessor;
330                }
331            }
332            InitializerKind::Code { block: value, .. } => quote! {
333                #(#attrs)*
334                #[allow(unused_braces)]
335                #value
336            },
337        };
338        res.extend(init);
339        if let Some(ident) = kind.ident() {
340            // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
341            let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
342            res.extend(quote! {
343                #(#cfgs)*
344                // Create the drop guard:
345                //
346                // We rely on macro hygiene to make it impossible for users to access this local
347                // variable.
348                // SAFETY: We forget the guard later when initialization has succeeded.
349                let #guard = unsafe {
350                    ::pin_init::__internal::DropGuard::new(
351                        &raw mut (*slot).#ident
352                    )
353                };
354            });
355            guards.push(guard);
356            guard_attrs.push(cfgs);
357        }
358    }
359    quote! {
360        #res
361        // If execution reaches this point, all fields have been initialized. Therefore we can now
362        // dismiss the guards by forgetting them.
363        #(
364            #(#guard_attrs)*
365            ::core::mem::forget(#guards);
366        )*
367    }
368}
369
370/// Generate the check for ensuring that every field has been initialized.
371fn make_field_check(
372    fields: &Punctuated<InitializerField, Token![,]>,
373    init_kind: InitKind,
374    path: &Path,
375) -> TokenStream {
376    let field_attrs = fields
377        .iter()
378        .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
379    let field_name = fields.iter().filter_map(|f| f.kind.ident());
380    match init_kind {
381        InitKind::Normal => quote! {
382            // We use unreachable code to ensure that all fields have been mentioned exactly once,
383            // this struct initializer will still be type-checked and complain with a very natural
384            // error message if a field is forgotten/mentioned more than once.
385            #[allow(unreachable_code, clippy::diverging_sub_expression)]
386            // SAFETY: this code is never executed.
387            let _ = || unsafe {
388                ::core::ptr::write(slot, #path {
389                    #(
390                        #(#field_attrs)*
391                        #field_name: ::core::panic!(),
392                    )*
393                })
394            };
395        },
396        InitKind::Zeroing => quote! {
397            // We use unreachable code to ensure that all fields have been mentioned at most once.
398            // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
399            // be zeroed. This struct initializer will still be type-checked and complain with a
400            // very natural error message if a field is mentioned more than once, or doesn't exist.
401            #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
402            // SAFETY: this code is never executed.
403            let _ = || unsafe {
404                ::core::ptr::write(slot, #path {
405                    #(
406                        #(#field_attrs)*
407                        #field_name: ::core::panic!(),
408                    )*
409                    ..::core::mem::zeroed()
410                })
411            };
412        },
413    }
414}
415
416impl Parse for Initializer {
417    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
418        let attrs = input.call(Attribute::parse_outer)?;
419        let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
420        let path = input.parse()?;
421        let content;
422        let brace_token = braced!(content in input);
423        let mut fields = Punctuated::new();
424        loop {
425            let lh = content.lookahead1();
426            if lh.peek(End) || lh.peek(Token![..]) {
427                break;
428            } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
429                fields.push_value(content.parse()?);
430                let lh = content.lookahead1();
431                if lh.peek(End) {
432                    break;
433                } else if lh.peek(Token![,]) {
434                    fields.push_punct(content.parse()?);
435                } else {
436                    return Err(lh.error());
437                }
438            } else {
439                return Err(lh.error());
440            }
441        }
442        let rest = content
443            .peek(Token![..])
444            .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
445            .transpose()?;
446        let error = input
447            .peek(Token![?])
448            .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
449            .transpose()?;
450        let attrs = attrs
451            .into_iter()
452            .map(|a| {
453                if a.path().is_ident("default_error") {
454                    a.parse_args::<DefaultErrorAttribute>()
455                        .map(InitializerAttribute::DefaultError)
456                } else {
457                    Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
458                }
459            })
460            .collect::<Result<Vec<_>, _>>()?;
461        Ok(Self {
462            attrs,
463            this,
464            path,
465            brace_token,
466            fields,
467            rest,
468            error,
469        })
470    }
471}
472
473impl Parse for DefaultErrorAttribute {
474    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
475        Ok(Self { ty: input.parse()? })
476    }
477}
478
479impl Parse for This {
480    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
481        Ok(Self {
482            _and_token: input.parse()?,
483            ident: input.parse()?,
484            _in_token: input.parse()?,
485        })
486    }
487}
488
489impl Parse for InitializerField {
490    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
491        let attrs = input.call(Attribute::parse_outer)?;
492        Ok(Self {
493            attrs,
494            kind: input.parse()?,
495        })
496    }
497}
498
499impl Parse for InitializerKind {
500    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
501        let lh = input.lookahead1();
502        if lh.peek(Token![_]) {
503            Ok(Self::Code {
504                _underscore_token: input.parse()?,
505                _colon_token: input.parse()?,
506                block: input.parse()?,
507            })
508        } else if lh.peek(Ident) {
509            let ident = input.parse()?;
510            let lh = input.lookahead1();
511            if lh.peek(Token![<-]) {
512                Ok(Self::Init {
513                    ident,
514                    _left_arrow_token: input.parse()?,
515                    value: input.parse()?,
516                })
517            } else if lh.peek(Token![:]) {
518                Ok(Self::Value {
519                    ident,
520                    value: Some((input.parse()?, input.parse()?)),
521                })
522            } else if lh.peek(Token![,]) || lh.peek(End) {
523                Ok(Self::Value { ident, value: None })
524            } else {
525                Err(lh.error())
526            }
527        } else {
528            Err(lh.error())
529        }
530    }
531}