1use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote};
5use syn::{
6 braced,
7 parse::{End, Parse},
8 parse_quote,
9 punctuated::Punctuated,
10 spanned::Spanned,
11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12};
13
14use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15
16pub(crate) struct Initializer {
17 attrs: Vec<InitializerAttribute>,
18 this: Option<This>,
19 path: Path,
20 brace_token: token::Brace,
21 fields: Punctuated<InitializerField, Token![,]>,
22 rest: Option<(Token![..], Expr)>,
23 error: Option<(Token![?], Type)>,
24}
25
26struct This {
27 _and_token: Token![&],
28 ident: Ident,
29 _in_token: Token![in],
30}
31
32struct InitializerField {
33 attrs: Vec<Attribute>,
34 kind: InitializerKind,
35}
36
37enum InitializerKind {
38 Value {
39 ident: Ident,
40 value: Option<(Token![:], Expr)>,
41 },
42 Init {
43 ident: Ident,
44 _left_arrow_token: Token![<-],
45 value: Expr,
46 },
47 Code {
48 _underscore_token: Token![_],
49 _colon_token: Token![:],
50 block: Block,
51 },
52}
53
54impl InitializerKind {
55 fn ident(&self) -> Option<&Ident> {
56 match self {
57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58 Self::Code { .. } => None,
59 }
60 }
61}
62
63enum InitializerAttribute {
64 DefaultError(DefaultErrorAttribute),
65}
66
67struct DefaultErrorAttribute {
68 ty: Box<Type>,
69}
70
71pub(crate) fn expand(
72 Initializer {
73 attrs,
74 this,
75 path,
76 brace_token,
77 fields,
78 rest,
79 error,
80 }: Initializer,
81 default_error: Option<&'static str>,
82 pinned: bool,
83 dcx: &mut DiagCtxt,
84) -> Result<TokenStream, ErrorGuaranteed> {
85 let error = error.map_or_else(
86 || {
87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88 #[expect(irrefutable_let_patterns)]
89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90 Some(ty.clone())
91 } else {
92 acc
93 }
94 }) {
95 default_error
96 } else if let Some(default_error) = default_error {
97 syn::parse_str(default_error).unwrap()
98 } else {
99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100 parse_quote!(::core::convert::Infallible)
101 }
102 },
103 |(_, err)| Box::new(err),
104 );
105 let slot = format_ident!("slot");
106 let (has_data_trait, get_data, init_from_closure) = if pinned {
107 (
108 format_ident!("HasPinData"),
109 format_ident!("__pin_data"),
110 format_ident!("pin_init_from_closure"),
111 )
112 } else {
113 (
114 format_ident!("HasInitData"),
115 format_ident!("__init_data"),
116 format_ident!("init_from_closure"),
117 )
118 };
119 let init_kind = get_init_kind(rest, dcx);
120 let zeroable_check = match init_kind {
121 InitKind::Normal => quote!(),
122 InitKind::Zeroing => quote! {
123 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
128 where T: ::pin_init::Zeroable
129 {}
130 assert_zeroable(#slot);
132 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
134 },
135 };
136 let this = match this {
137 None => quote!(),
138 Some(This { ident, .. }) => quote! {
139 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
142 },
143 };
144 let data = Ident::new("__data", Span::mixed_site());
146 let init_fields = init_fields(&fields, pinned, &data, &slot);
147 let field_check = make_field_check(&fields, init_kind, &path);
148 Ok(quote! {{
149 let #data = unsafe {
152 use ::pin_init::__internal::#has_data_trait;
153 #path::#get_data()
156 };
157 let init = #data.__make_closure::<_, #error>(
159 move |slot| {
160 #zeroable_check
161 #this
162 #init_fields
163 #field_check
164 Ok(unsafe { ::pin_init::__internal::InitOk::new() })
166 }
167 );
168 let init = move |slot| -> ::core::result::Result<(), #error> {
169 init(slot).map(|__InitOk| ())
170 };
171 unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }
173 }})
174}
175
176enum InitKind {
177 Normal,
178 Zeroing,
179}
180
181fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
182 let Some((dotdot, expr)) = rest else {
183 return InitKind::Normal;
184 };
185 match &expr {
186 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
187 Expr::Path(ExprPath {
188 attrs,
189 qself: None,
190 path:
191 Path {
192 leading_colon: None,
193 segments,
194 },
195 }) if attrs.is_empty()
196 && segments.len() == 2
197 && segments[0].ident == "Zeroable"
198 && segments[0].arguments.is_none()
199 && segments[1].ident == "init_zeroed"
200 && segments[1].arguments.is_none() =>
201 {
202 return InitKind::Zeroing;
203 }
204 _ => {}
205 },
206 _ => {}
207 }
208 dcx.error(
209 dotdot.span().join(expr.span()).unwrap_or(expr.span()),
210 "expected nothing or `..Zeroable::init_zeroed()`.",
211 );
212 InitKind::Normal
213}
214
215fn init_fields(
217 fields: &Punctuated<InitializerField, Token![,]>,
218 pinned: bool,
219 data: &Ident,
220 slot: &Ident,
221) -> TokenStream {
222 let mut guards = vec![];
223 let mut guard_attrs = vec![];
224 let mut res = TokenStream::new();
225 for InitializerField { attrs, kind } in fields {
226 let cfgs = {
227 let mut cfgs = attrs.clone();
228 cfgs.retain(|attr| attr.path().is_ident("cfg"));
229 cfgs
230 };
231
232 let ident = match kind {
233 InitializerKind::Value { ident, .. } => ident,
234 InitializerKind::Init { ident, .. } => ident,
235 InitializerKind::Code { block, .. } => {
236 res.extend(quote! {
237 #(#attrs)*
238 #[allow(unused_braces)]
239 #block
240 });
241 continue;
242 }
243 };
244
245 let slot = if pinned {
246 quote! {
247 (unsafe { #data.#ident(#slot) })
253 }
254 } else {
255 quote! {
256 (unsafe {
263 ::pin_init::__internal::Slot::<::pin_init::__internal::Unpinned, _>::new(
264 &raw mut (*#slot).#ident
265 )
266 })
267 }
268 };
269
270 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
272
273 let init = match kind {
274 InitializerKind::Value { ident, value } => {
275 let value = value
276 .as_ref()
277 .map(|(_, value)| quote!(#value))
278 .unwrap_or_else(|| quote!(#ident));
279
280 quote! {
281 #(#attrs)*
282 let mut #guard = #slot.write(#value);
283
284 }
285 }
286 InitializerKind::Init { value, .. } => {
287 quote! {
288 #(#attrs)*
289 let mut #guard = #slot.init(#value)?;
290 }
291 }
292 InitializerKind::Code { .. } => unreachable!(),
293 };
294
295 res.extend(quote! {
296 #init
297
298 #(#cfgs)*
299 #[allow(unused_variables, non_snake_case)]
302 let #ident = #guard.let_binding();
303 });
304
305 guards.push(guard);
306 guard_attrs.push(cfgs);
307 }
308 quote! {
309 #res
310 #(
313 #(#guard_attrs)*
314 ::core::mem::forget(#guards);
315 )*
316 }
317}
318
319fn make_field_check(
321 fields: &Punctuated<InitializerField, Token![,]>,
322 init_kind: InitKind,
323 path: &Path,
324) -> TokenStream {
325 let field_attrs: Vec<_> = fields
326 .iter()
327 .filter_map(|f| f.kind.ident().map(|_| &f.attrs))
328 .collect();
329 let field_name: Vec<_> = fields.iter().filter_map(|f| f.kind.ident()).collect();
330 let zeroing_trailer = match init_kind {
331 InitKind::Normal => None,
332 InitKind::Zeroing => Some(quote! {
333 ..::core::mem::zeroed()
334 }),
335 };
336 quote! {
337 #[allow(unreachable_code, clippy::diverging_sub_expression)]
338 let _ = || unsafe {
341 #(
346 #(#field_attrs)*
347 let _ = &(*slot).#field_name;
348 )*
349
350 ::core::ptr::write(slot, #path {
355 #(
356 #(#field_attrs)*
357 #field_name: loop {},
358 )*
359 #zeroing_trailer
360 })
361 };
362 }
363}
364
365impl Parse for Initializer {
366 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
367 let attrs = input.call(Attribute::parse_outer)?;
368 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
369 let path = input.parse()?;
370 let content;
371 let brace_token = braced!(content in input);
372 let mut fields = Punctuated::new();
373 loop {
374 let lh = content.lookahead1();
375 if lh.peek(End) || lh.peek(Token![..]) {
376 break;
377 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
378 fields.push_value(content.parse()?);
379 let lh = content.lookahead1();
380 if lh.peek(End) {
381 break;
382 } else if lh.peek(Token![,]) {
383 fields.push_punct(content.parse()?);
384 } else {
385 return Err(lh.error());
386 }
387 } else {
388 return Err(lh.error());
389 }
390 }
391 let rest = content
392 .peek(Token![..])
393 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
394 .transpose()?;
395 let error = input
396 .peek(Token![?])
397 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
398 .transpose()?;
399 let attrs = attrs
400 .into_iter()
401 .map(|a| {
402 if a.path().is_ident("default_error") {
403 a.parse_args::<DefaultErrorAttribute>()
404 .map(InitializerAttribute::DefaultError)
405 } else {
406 Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
407 }
408 })
409 .collect::<Result<Vec<_>, _>>()?;
410 Ok(Self {
411 attrs,
412 this,
413 path,
414 brace_token,
415 fields,
416 rest,
417 error,
418 })
419 }
420}
421
422impl Parse for DefaultErrorAttribute {
423 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
424 Ok(Self { ty: input.parse()? })
425 }
426}
427
428impl Parse for This {
429 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
430 Ok(Self {
431 _and_token: input.parse()?,
432 ident: input.parse()?,
433 _in_token: input.parse()?,
434 })
435 }
436}
437
438impl Parse for InitializerField {
439 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
440 let attrs = input.call(Attribute::parse_outer)?;
441 Ok(Self {
442 attrs,
443 kind: input.parse()?,
444 })
445 }
446}
447
448impl Parse for InitializerKind {
449 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
450 let lh = input.lookahead1();
451 if lh.peek(Token![_]) {
452 Ok(Self::Code {
453 _underscore_token: input.parse()?,
454 _colon_token: input.parse()?,
455 block: input.parse()?,
456 })
457 } else if lh.peek(Ident) {
458 let ident = input.parse()?;
459 let lh = input.lookahead1();
460 if lh.peek(Token![<-]) {
461 Ok(Self::Init {
462 ident,
463 _left_arrow_token: input.parse()?,
464 value: input.parse()?,
465 })
466 } else if lh.peek(Token![:]) {
467 Ok(Self::Value {
468 ident,
469 value: Some((input.parse()?, input.parse()?)),
470 })
471 } else if lh.peek(Token![,]) || lh.peek(End) {
472 Ok(Self::Value { ident, value: None })
473 } else {
474 Err(lh.error())
475 }
476 } else {
477 Err(lh.error())
478 }
479 }
480}