1use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, quote_spanned};
5use syn::{
6 braced,
7 parse::{End, Parse},
8 parse_quote,
9 punctuated::Punctuated,
10 spanned::Spanned,
11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12};
13
14use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15
16pub(crate) struct Initializer {
17 attrs: Vec<InitializerAttribute>,
18 this: Option<This>,
19 path: Path,
20 brace_token: token::Brace,
21 fields: Punctuated<InitializerField, Token![,]>,
22 rest: Option<(Token![..], Expr)>,
23 error: Option<(Token![?], Type)>,
24}
25
26struct This {
27 _and_token: Token![&],
28 ident: Ident,
29 _in_token: Token![in],
30}
31
32struct InitializerField {
33 attrs: Vec<Attribute>,
34 kind: InitializerKind,
35}
36
37enum InitializerKind {
38 Value {
39 ident: Ident,
40 value: Option<(Token![:], Expr)>,
41 },
42 Init {
43 ident: Ident,
44 _left_arrow_token: Token![<-],
45 value: Expr,
46 },
47 Code {
48 _underscore_token: Token![_],
49 _colon_token: Token![:],
50 block: Block,
51 },
52}
53
54impl InitializerKind {
55 fn ident(&self) -> Option<&Ident> {
56 match self {
57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58 Self::Code { .. } => None,
59 }
60 }
61}
62
63enum InitializerAttribute {
64 DefaultError(DefaultErrorAttribute),
65}
66
67struct DefaultErrorAttribute {
68 ty: Box<Type>,
69}
70
71pub(crate) fn expand(
72 Initializer {
73 attrs,
74 this,
75 path,
76 brace_token,
77 fields,
78 rest,
79 error,
80 }: Initializer,
81 default_error: Option<&'static str>,
82 pinned: bool,
83 dcx: &mut DiagCtxt,
84) -> Result<TokenStream, ErrorGuaranteed> {
85 let error = error.map_or_else(
86 || {
87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88 #[expect(irrefutable_let_patterns)]
89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90 Some(ty.clone())
91 } else {
92 acc
93 }
94 }) {
95 default_error
96 } else if let Some(default_error) = default_error {
97 syn::parse_str(default_error).unwrap()
98 } else {
99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100 parse_quote!(::core::convert::Infallible)
101 }
102 },
103 |(_, err)| Box::new(err),
104 );
105 let slot = format_ident!("slot");
106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107 (
108 format_ident!("HasPinData"),
109 format_ident!("PinData"),
110 format_ident!("__pin_data"),
111 format_ident!("pin_init_from_closure"),
112 )
113 } else {
114 (
115 format_ident!("HasInitData"),
116 format_ident!("InitData"),
117 format_ident!("__init_data"),
118 format_ident!("init_from_closure"),
119 )
120 };
121 let init_kind = get_init_kind(rest, dcx);
122 let zeroable_check = match init_kind {
123 InitKind::Normal => quote!(),
124 InitKind::Zeroing => quote! {
125 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130 where T: ::pin_init::Zeroable
131 {}
132 assert_zeroable(#slot);
134 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136 },
137 };
138 let this = match this {
139 None => quote!(),
140 Some(This { ident, .. }) => quote! {
141 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144 },
145 };
146 let data = Ident::new("__data", Span::mixed_site());
148 let init_fields = init_fields(&fields, pinned, &data, &slot);
149 let field_check = make_field_check(&fields, init_kind, &path);
150 Ok(quote! {{
151 let #data = unsafe {
154 use ::pin_init::__internal::#has_data_trait;
155 #path::#get_data()
158 };
159 let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
161 #data,
162 move |slot| {
163 #zeroable_check
164 #this
165 #init_fields
166 #field_check
167 Ok(unsafe { ::pin_init::__internal::InitOk::new() })
169 }
170 );
171 let init = move |slot| -> ::core::result::Result<(), #error> {
172 init(slot).map(|__InitOk| ())
173 };
174 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
176 #[allow(
179 clippy::let_and_return,
180 reason = "some clippy versions warn about the let binding"
181 )]
182 init
183 }})
184}
185
186enum InitKind {
187 Normal,
188 Zeroing,
189}
190
191fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
192 let Some((dotdot, expr)) = rest else {
193 return InitKind::Normal;
194 };
195 match &expr {
196 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
197 Expr::Path(ExprPath {
198 attrs,
199 qself: None,
200 path:
201 Path {
202 leading_colon: None,
203 segments,
204 },
205 }) if attrs.is_empty()
206 && segments.len() == 2
207 && segments[0].ident == "Zeroable"
208 && segments[0].arguments.is_none()
209 && segments[1].ident == "init_zeroed"
210 && segments[1].arguments.is_none() =>
211 {
212 return InitKind::Zeroing;
213 }
214 _ => {}
215 },
216 _ => {}
217 }
218 dcx.error(
219 dotdot.span().join(expr.span()).unwrap_or(expr.span()),
220 "expected nothing or `..Zeroable::init_zeroed()`.",
221 );
222 InitKind::Normal
223}
224
225fn init_fields(
227 fields: &Punctuated<InitializerField, Token![,]>,
228 pinned: bool,
229 data: &Ident,
230 slot: &Ident,
231) -> TokenStream {
232 let mut guards = vec![];
233 let mut guard_attrs = vec![];
234 let mut res = TokenStream::new();
235 for InitializerField { attrs, kind } in fields {
236 let cfgs = {
237 let mut cfgs = attrs.clone();
238 cfgs.retain(|attr| attr.path().is_ident("cfg"));
239 cfgs
240 };
241 let init = match kind {
242 InitializerKind::Value { ident, value } => {
243 let mut value_ident = ident.clone();
244 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
245 value_ident.set_span(value.span());
248 quote!(let #value_ident = #value;)
249 });
250 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
252 let accessor = if pinned {
257 let project_ident = format_ident!("__project_{ident}");
258 quote! {
259 unsafe { #data.#project_ident(&mut (*#slot).#ident) }
261 }
262 } else {
263 quote! {
264 unsafe { &mut (*#slot).#ident }
266 }
267 };
268 quote! {
269 #(#attrs)*
270 {
271 #value_prep
272 unsafe { #write(&raw mut (*#slot).#ident, #value_ident) };
274 }
275 #(#cfgs)*
276 #[allow(unused_variables)]
277 let #ident = #accessor;
278 }
279 }
280 InitializerKind::Init { ident, value, .. } => {
281 let init = format_ident!("init", span = value.span());
283 let (value_init, accessor) = if pinned {
288 let project_ident = format_ident!("__project_{ident}");
289 (
290 quote! {
291 unsafe { #data.#ident(&raw mut (*#slot).#ident, #init)? };
297 },
298 quote! {
299 unsafe { #data.#project_ident(&mut (*#slot).#ident) }
301 },
302 )
303 } else {
304 (
305 quote! {
306 unsafe {
309 ::pin_init::Init::__init(
310 #init,
311 &raw mut (*#slot).#ident,
312 )?
313 };
314 },
315 quote! {
316 unsafe { &mut (*#slot).#ident }
318 },
319 )
320 };
321 quote! {
322 #(#attrs)*
323 {
324 let #init = #value;
325 #value_init
326 }
327 #(#cfgs)*
328 #[allow(unused_variables)]
329 let #ident = #accessor;
330 }
331 }
332 InitializerKind::Code { block: value, .. } => quote! {
333 #(#attrs)*
334 #[allow(unused_braces)]
335 #value
336 },
337 };
338 res.extend(init);
339 if let Some(ident) = kind.ident() {
340 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
342 res.extend(quote! {
343 #(#cfgs)*
344 let #guard = unsafe {
350 ::pin_init::__internal::DropGuard::new(
351 &raw mut (*slot).#ident
352 )
353 };
354 });
355 guards.push(guard);
356 guard_attrs.push(cfgs);
357 }
358 }
359 quote! {
360 #res
361 #(
364 #(#guard_attrs)*
365 ::core::mem::forget(#guards);
366 )*
367 }
368}
369
370fn make_field_check(
372 fields: &Punctuated<InitializerField, Token![,]>,
373 init_kind: InitKind,
374 path: &Path,
375) -> TokenStream {
376 let field_attrs = fields
377 .iter()
378 .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
379 let field_name = fields.iter().filter_map(|f| f.kind.ident());
380 match init_kind {
381 InitKind::Normal => quote! {
382 #[allow(unreachable_code, clippy::diverging_sub_expression)]
386 let _ = || unsafe {
388 ::core::ptr::write(slot, #path {
389 #(
390 #(#field_attrs)*
391 #field_name: ::core::panic!(),
392 )*
393 })
394 };
395 },
396 InitKind::Zeroing => quote! {
397 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
402 let _ = || unsafe {
404 ::core::ptr::write(slot, #path {
405 #(
406 #(#field_attrs)*
407 #field_name: ::core::panic!(),
408 )*
409 ..::core::mem::zeroed()
410 })
411 };
412 },
413 }
414}
415
416impl Parse for Initializer {
417 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
418 let attrs = input.call(Attribute::parse_outer)?;
419 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
420 let path = input.parse()?;
421 let content;
422 let brace_token = braced!(content in input);
423 let mut fields = Punctuated::new();
424 loop {
425 let lh = content.lookahead1();
426 if lh.peek(End) || lh.peek(Token![..]) {
427 break;
428 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
429 fields.push_value(content.parse()?);
430 let lh = content.lookahead1();
431 if lh.peek(End) {
432 break;
433 } else if lh.peek(Token![,]) {
434 fields.push_punct(content.parse()?);
435 } else {
436 return Err(lh.error());
437 }
438 } else {
439 return Err(lh.error());
440 }
441 }
442 let rest = content
443 .peek(Token![..])
444 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
445 .transpose()?;
446 let error = input
447 .peek(Token![?])
448 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
449 .transpose()?;
450 let attrs = attrs
451 .into_iter()
452 .map(|a| {
453 if a.path().is_ident("default_error") {
454 a.parse_args::<DefaultErrorAttribute>()
455 .map(InitializerAttribute::DefaultError)
456 } else {
457 Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
458 }
459 })
460 .collect::<Result<Vec<_>, _>>()?;
461 Ok(Self {
462 attrs,
463 this,
464 path,
465 brace_token,
466 fields,
467 rest,
468 error,
469 })
470 }
471}
472
473impl Parse for DefaultErrorAttribute {
474 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
475 Ok(Self { ty: input.parse()? })
476 }
477}
478
479impl Parse for This {
480 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
481 Ok(Self {
482 _and_token: input.parse()?,
483 ident: input.parse()?,
484 _in_token: input.parse()?,
485 })
486 }
487}
488
489impl Parse for InitializerField {
490 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
491 let attrs = input.call(Attribute::parse_outer)?;
492 Ok(Self {
493 attrs,
494 kind: input.parse()?,
495 })
496 }
497}
498
499impl Parse for InitializerKind {
500 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
501 let lh = input.lookahead1();
502 if lh.peek(Token![_]) {
503 Ok(Self::Code {
504 _underscore_token: input.parse()?,
505 _colon_token: input.parse()?,
506 block: input.parse()?,
507 })
508 } else if lh.peek(Ident) {
509 let ident = input.parse()?;
510 let lh = input.lookahead1();
511 if lh.peek(Token![<-]) {
512 Ok(Self::Init {
513 ident,
514 _left_arrow_token: input.parse()?,
515 value: input.parse()?,
516 })
517 } else if lh.peek(Token![:]) {
518 Ok(Self::Value {
519 ident,
520 value: Some((input.parse()?, input.parse()?)),
521 })
522 } else if lh.peek(Token![,]) || lh.peek(End) {
523 Ok(Self::Value { ident, value: None })
524 } else {
525 Err(lh.error())
526 }
527 } else {
528 Err(lh.error())
529 }
530 }
531}