1use std::num::NonZeroU32;
12
13use proc_macro2::{Span, TokenStream};
14use quote::{quote, quote_spanned, ToTokens};
15use syn::{
16 parse_quote, spanned::Spanned as _, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error,
17 Expr, ExprLit, Field, GenericParam, Ident, Index, Lit, LitStr, Meta, Path, Type, Variant,
18 Visibility, WherePredicate,
19};
20
21use crate::repr::{CompoundRepr, EnumRepr, PrimitiveRepr, Repr, Spanned};
22
23pub(crate) struct Ctx {
24 pub(crate) ast: DeriveInput,
25 pub(crate) zerocopy_crate: Path,
26
27 pub(crate) skip_on_error: bool,
30
31 pub(crate) on_error_span: Option<proc_macro2::Span>,
33}
34
35impl Ctx {
36 pub(crate) fn try_from_derive_input(ast: DeriveInput) -> Result<Self, Error> {
39 let mut path = parse_quote!(::zerocopy);
40 let mut skip_on_error = false;
41 let mut on_error_span = None;
42
43 for attr in &ast.attrs {
44 if let Meta::List(ref meta_list) = attr.meta {
45 if meta_list.path.is_ident("zerocopy") {
46 attr.parse_nested_meta(|meta| {
47 if meta.path.is_ident("crate") {
48 let expr = meta.value().and_then(|value| value.parse());
49 if let Ok(Expr::Lit(ExprLit { lit: Lit::Str(lit), .. })) = expr {
50 if let Ok(path_lit) = lit.parse::<Ident>() {
51 path = parse_quote!(::#path_lit);
52 return Ok(());
53 }
54 }
55
56 return Err(Error::new(
57 Span::call_site(),
58 "`crate` attribute requires a path as the value",
59 ));
60 }
61
62 if meta.path.is_ident("on_error") {
63 on_error_span = Some(meta.path.span());
64 let value = meta.value()?;
65 let s: LitStr = value.parse()?;
66 match s.value().as_str() {
67 "skip" => skip_on_error = true,
68 "fail" => skip_on_error = false,
69 _ => return Err(Error::new(
70 s.span(),
71 "unrecognized value for `on_error` attribute from `zerocopy`; expected `skip` or `fail`",
72 )),
73 }
74 return Ok(());
75 }
76
77 Err(Error::new(
78 Span::call_site(),
79 format!(
80 "unknown attribute encountered: {}",
81 meta.path.into_token_stream()
82 ),
83 ))
84 })?;
85 }
86 }
87 }
88
89 Ok(Self { ast, zerocopy_crate: path, skip_on_error, on_error_span })
90 }
91
92 pub(crate) fn with_input(&self, input: &DeriveInput) -> Self {
93 Self {
94 ast: input.clone(),
95 zerocopy_crate: self.zerocopy_crate.clone(),
96 skip_on_error: self.skip_on_error,
97 on_error_span: self.on_error_span,
98 }
99 }
100
101 pub(crate) fn core_path(&self) -> TokenStream {
102 let zerocopy_crate = &self.zerocopy_crate;
103 quote!(#zerocopy_crate::util::macro_util::core_reexport)
104 }
105
106 pub(crate) fn cfg_compile_error(&self) -> TokenStream {
107 if cfg!(zerocopy_unstable_derive_on_error) {
115 quote!()
116 } else if let Some(span) = self.on_error_span {
117 let core = self.core_path();
118 let error_message = "`on_error` is experimental; pass '--cfg zerocopy_unstable_derive_on_error' to enable";
119 quote::quote_spanned! {span=>
120 #[allow(unused_attributes, unexpected_cfgs)]
121 const _: () = {
122 #[cfg(not(zerocopy_unstable_derive_on_error))]
123 #core::compile_error!(#error_message);
124 };
125 }
126 } else {
127 quote!()
128 }
129 }
130
131 pub(crate) fn error_or_skip<E>(&self, error: E) -> Result<TokenStream, E> {
132 if self.skip_on_error {
133 Ok(self.cfg_compile_error())
134 } else {
135 Err(error)
136 }
137 }
138}
139
140pub(crate) trait DataExt {
141 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)>;
150
151 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)>;
152
153 fn tag(&self) -> Option<Ident>;
154}
155
156impl DataExt for Data {
157 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
158 match self {
159 Data::Struct(strc) => strc.fields(),
160 Data::Enum(enm) => enm.fields(),
161 Data::Union(un) => un.fields(),
162 }
163 }
164
165 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
166 match self {
167 Data::Struct(strc) => strc.variants(),
168 Data::Enum(enm) => enm.variants(),
169 Data::Union(un) => un.variants(),
170 }
171 }
172
173 fn tag(&self) -> Option<Ident> {
174 match self {
175 Data::Struct(strc) => strc.tag(),
176 Data::Enum(enm) => enm.tag(),
177 Data::Union(un) => un.tag(),
178 }
179 }
180}
181
182impl DataExt for DataStruct {
183 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
184 map_fields(&self.fields)
185 }
186
187 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
188 vec![(None, self.fields())]
189 }
190
191 fn tag(&self) -> Option<Ident> {
192 None
193 }
194}
195
196impl DataExt for DataEnum {
197 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
198 map_fields(self.variants.iter().flat_map(|var| &var.fields))
199 }
200
201 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
202 self.variants.iter().map(|var| (Some(var), map_fields(&var.fields))).collect()
203 }
204
205 fn tag(&self) -> Option<Ident> {
206 Some(Ident::new("___ZerocopyTag", Span::call_site()))
207 }
208}
209
210impl DataExt for DataUnion {
211 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
212 map_fields(&self.fields.named)
213 }
214
215 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
216 vec![(None, self.fields())]
217 }
218
219 fn tag(&self) -> Option<Ident> {
220 None
221 }
222}
223
224fn map_fields<'a>(
225 fields: impl 'a + IntoIterator<Item = &'a Field>,
226) -> Vec<(&'a Visibility, TokenStream, &'a Type)> {
227 fields
228 .into_iter()
229 .enumerate()
230 .map(|(idx, f)| {
231 (
232 &f.vis,
233 f.ident
234 .as_ref()
235 .map(ToTokens::to_token_stream)
236 .unwrap_or_else(|| Index::from(idx).to_token_stream()),
237 &f.ty,
238 )
239 })
240 .collect()
241}
242
243pub(crate) fn to_ident_str(t: &impl ToString) -> String {
244 let s = t.to_string();
245 if let Some(stripped) = s.strip_prefix("r#") {
246 stripped.to_string()
247 } else {
248 s
249 }
250}
251
252pub(crate) enum PaddingCheck {
255 Struct,
258 ReprCStruct,
260 Union,
262 Enum { tag_type_definition: TokenStream },
267}
268
269impl PaddingCheck {
270 pub(crate) fn validator_trait_and_macro_idents(&self) -> (Ident, Ident) {
273 let (trt, mcro) = match self {
274 PaddingCheck::Struct => ("PaddingFree", "struct_padding"),
275 PaddingCheck::ReprCStruct => ("DynamicPaddingFree", "repr_c_struct_has_padding"),
276 PaddingCheck::Union => ("PaddingFree", "union_padding"),
277 PaddingCheck::Enum { .. } => ("PaddingFree", "enum_padding"),
278 };
279
280 let trt = Ident::new(trt, Span::call_site());
281 let mcro = Ident::new(mcro, Span::call_site());
282 (trt, mcro)
283 }
284
285 pub(crate) fn validator_macro_context(&self) -> Option<&TokenStream> {
288 match self {
289 PaddingCheck::Struct | PaddingCheck::ReprCStruct | PaddingCheck::Union => None,
290 PaddingCheck::Enum { tag_type_definition } => Some(tag_type_definition),
291 }
292 }
293}
294
295#[derive(Clone)]
296pub(crate) enum Trait {
297 KnownLayout,
298 HasTag,
299 HasField {
300 variant_id: Box<Expr>,
301 field: Box<Type>,
302 field_id: Box<Expr>,
303 },
304 ProjectField {
305 variant_id: Box<Expr>,
306 field: Box<Type>,
307 field_id: Box<Expr>,
308 invariants: Box<Type>,
309 },
310 Immutable,
311 TryFromBytes,
312 FromZeros,
313 FromBytes,
314 IntoBytes,
315 Unaligned,
316 Sized,
317 ByteHash,
318 ByteEq,
319 SplitAt,
320}
321
322impl ToTokens for Trait {
323 fn to_tokens(&self, tokens: &mut TokenStream) {
324 let s = match self {
334 Trait::HasField { .. } => "HasField",
335 Trait::ProjectField { .. } => "ProjectField",
336 Trait::KnownLayout => "KnownLayout",
337 Trait::HasTag => "HasTag",
338 Trait::Immutable => "Immutable",
339 Trait::TryFromBytes => "TryFromBytes",
340 Trait::FromZeros => "FromZeros",
341 Trait::FromBytes => "FromBytes",
342 Trait::IntoBytes => "IntoBytes",
343 Trait::Unaligned => "Unaligned",
344 Trait::Sized => "Sized",
345 Trait::ByteHash => "ByteHash",
346 Trait::ByteEq => "ByteEq",
347 Trait::SplitAt => "SplitAt",
348 };
349 let ident = Ident::new(s, Span::call_site());
350 let arguments: Option<syn::AngleBracketedGenericArguments> = match self {
351 Trait::HasField { variant_id, field, field_id } => {
352 Some(parse_quote!(<#field, #variant_id, #field_id>))
353 }
354 Trait::ProjectField { variant_id, field, field_id, invariants } => {
355 Some(parse_quote!(<#field, #invariants, #variant_id, #field_id>))
356 }
357 Trait::KnownLayout
358 | Trait::HasTag
359 | Trait::Immutable
360 | Trait::TryFromBytes
361 | Trait::FromZeros
362 | Trait::FromBytes
363 | Trait::IntoBytes
364 | Trait::Unaligned
365 | Trait::Sized
366 | Trait::ByteHash
367 | Trait::ByteEq
368 | Trait::SplitAt => None,
369 };
370 tokens.extend(quote!(#ident #arguments));
371 }
372}
373
374impl Trait {
375 pub(crate) fn crate_path(&self, ctx: &Ctx) -> Path {
376 let zerocopy_crate = &ctx.zerocopy_crate;
377 let core = ctx.core_path();
378 match self {
379 Self::Sized => parse_quote!(#core::marker::#self),
380 _ => parse_quote!(#zerocopy_crate::#self),
381 }
382 }
383}
384
385pub(crate) enum TraitBound {
386 Slf,
387 Other(Trait),
388}
389
390pub(crate) enum FieldBounds<'a> {
391 None,
392 All(&'a [TraitBound]),
393 Trailing(&'a [TraitBound]),
394 Explicit(Vec<WherePredicate>),
395}
396
397impl<'a> FieldBounds<'a> {
398 pub(crate) const ALL_SELF: FieldBounds<'a> = FieldBounds::All(&[TraitBound::Slf]);
399 pub(crate) const TRAILING_SELF: FieldBounds<'a> = FieldBounds::Trailing(&[TraitBound::Slf]);
400}
401
402pub(crate) enum SelfBounds<'a> {
403 None,
404 All(&'a [Trait]),
405}
406
407#[allow(clippy::needless_lifetimes)]
410impl<'a> SelfBounds<'a> {
411 pub(crate) const SIZED: Self = Self::All(&[Trait::Sized]);
412}
413
414pub(crate) fn normalize_bounds<'a>(
416 slf: &'a Trait,
417 bounds: &'a [TraitBound],
418) -> impl 'a + Iterator<Item = Trait> {
419 bounds.iter().map(move |bound| match bound {
420 TraitBound::Slf => slf.clone(),
421 TraitBound::Other(trt) => trt.clone(),
422 })
423}
424
425pub(crate) struct ImplBlockBuilder<'a> {
426 ctx: &'a Ctx,
427 data: &'a dyn DataExt,
428 trt: Trait,
429 field_type_trait_bounds: FieldBounds<'a>,
430 self_type_trait_bounds: SelfBounds<'a>,
431 padding_check: Option<PaddingCheck>,
432 param_extras: Vec<GenericParam>,
433 inner_extras: Option<TokenStream>,
434 outer_extras: Option<TokenStream>,
435}
436
437impl<'a> ImplBlockBuilder<'a> {
438 pub(crate) fn new(
439 ctx: &'a Ctx,
440 data: &'a dyn DataExt,
441 trt: Trait,
442 field_type_trait_bounds: FieldBounds<'a>,
443 ) -> Self {
444 Self {
445 ctx,
446 data,
447 trt,
448 field_type_trait_bounds,
449 self_type_trait_bounds: SelfBounds::None,
450 padding_check: None,
451 param_extras: Vec::new(),
452 inner_extras: None,
453 outer_extras: None,
454 }
455 }
456
457 pub(crate) fn self_type_trait_bounds(mut self, self_type_trait_bounds: SelfBounds<'a>) -> Self {
458 self.self_type_trait_bounds = self_type_trait_bounds;
459 self
460 }
461
462 pub(crate) fn padding_check<P: Into<Option<PaddingCheck>>>(mut self, padding_check: P) -> Self {
463 self.padding_check = padding_check.into();
464 self
465 }
466
467 pub(crate) fn param_extras(mut self, param_extras: Vec<GenericParam>) -> Self {
468 self.param_extras.extend(param_extras);
469 self
470 }
471
472 pub(crate) fn inner_extras(mut self, inner_extras: TokenStream) -> Self {
473 self.inner_extras = Some(inner_extras);
474 self
475 }
476
477 pub(crate) fn outer_extras<T: Into<Option<TokenStream>>>(mut self, outer_extras: T) -> Self {
478 self.outer_extras = outer_extras.into();
479 self
480 }
481
482 pub(crate) fn build(self) -> TokenStream {
483 let type_ident = &self.ctx.ast.ident;
543 let trait_path = self.trt.crate_path(self.ctx);
544 let fields = self.data.fields();
545 let variants = self.data.variants();
546 let tag = self.data.tag();
547 let zerocopy_crate = &self.ctx.zerocopy_crate;
548
549 fn bound_tt(ty: &Type, traits: impl Iterator<Item = Trait>, ctx: &Ctx) -> WherePredicate {
550 let traits = traits.map(|t| t.crate_path(ctx));
551 parse_quote!(#ty: #(#traits)+*)
552 }
553 let field_type_bounds: Vec<_> = match (self.field_type_trait_bounds, &fields[..]) {
554 (FieldBounds::All(traits), _) => fields
555 .iter()
556 .map(|(_vis, _name, ty)| {
557 bound_tt(ty, normalize_bounds(&self.trt, traits), self.ctx)
558 })
559 .collect(),
560 (FieldBounds::None, _) | (FieldBounds::Trailing(..), []) => vec![],
561 (FieldBounds::Trailing(traits), [.., last]) => {
562 vec![bound_tt(last.2, normalize_bounds(&self.trt, traits), self.ctx)]
563 }
564 (FieldBounds::Explicit(bounds), _) => bounds,
565 };
566
567 let padding_check_bound = self
568 .padding_check
569 .map(|check| {
570 let repr =
575 Repr::<PrimitiveRepr, NonZeroU32>::from_attrs(&self.ctx.ast.attrs).unwrap();
576 let core = self.ctx.core_path();
577 let option = quote! { #core::option::Option };
578 let nonzero = quote! { #core::num::NonZeroUsize };
579 let none = quote! { #option::None::<#nonzero> };
580 let repr_align =
581 repr.get_align().map(|spanned| {
582 let n = spanned.t.get();
583 quote_spanned! { spanned.span => (#nonzero::new(#n as usize)) }
584 }).unwrap_or(quote! { (#none) });
585 let repr_packed =
586 repr.get_packed().map(|packed| {
587 let n = packed.get();
588 quote! { (#nonzero::new(#n as usize)) }
589 }).unwrap_or(quote! { (#none) });
590 let variant_types = variants.iter().map(|(_, fields)| {
591 let types = fields.iter().map(|(_vis, _name, ty)| ty);
592 quote!([#((#types)),*])
593 });
594 let validator_context = check.validator_macro_context();
595 let (trt, validator_macro) = check.validator_trait_and_macro_idents();
596 let t = tag.iter();
597 parse_quote! {
598 (): #zerocopy_crate::util::macro_util::#trt<
599 Self,
600 {
601 #validator_context
602 #zerocopy_crate::#validator_macro!(Self, #repr_align, #repr_packed, #(#t,)* #(#variant_types),*)
603 }
604 >
605 }
606 });
607
608 let self_bounds: Option<WherePredicate> = match self.self_type_trait_bounds {
609 SelfBounds::None => None,
610 SelfBounds::All(traits) => {
611 Some(bound_tt(&parse_quote!(Self), traits.iter().cloned(), self.ctx))
612 }
613 };
614
615 let bounds = self
616 .ctx
617 .ast
618 .generics
619 .where_clause
620 .as_ref()
621 .map(|where_clause| where_clause.predicates.iter())
622 .into_iter()
623 .flatten()
624 .chain(field_type_bounds.iter())
625 .chain(padding_check_bound.iter())
626 .chain(self_bounds.iter());
627
628 let mut params: Vec<_> = self
630 .ctx
631 .ast
632 .generics
633 .params
634 .clone()
635 .into_iter()
636 .map(|mut param| {
637 match &mut param {
638 GenericParam::Type(ty) => ty.default = None,
639 GenericParam::Const(cnst) => cnst.default = None,
640 GenericParam::Lifetime(_) => {}
641 }
642 parse_quote!(#param)
643 })
644 .chain(self.param_extras)
645 .collect();
646
647 params.sort_by_cached_key(|param| match param {
650 GenericParam::Lifetime(_) => 0,
651 GenericParam::Type(_) => 1,
652 GenericParam::Const(_) => 2,
653 });
654
655 let param_idents = self.ctx.ast.generics.params.iter().map(|param| match param {
658 GenericParam::Type(ty) => {
659 let ident = &ty.ident;
660 quote!(#ident)
661 }
662 GenericParam::Lifetime(l) => {
663 let ident = &l.lifetime;
664 quote!(#ident)
665 }
666 GenericParam::Const(cnst) => {
667 let ident = &cnst.ident;
668 quote!({#ident})
669 }
670 });
671
672 let inner_extras = self.inner_extras;
673 let allow_trivial_bounds =
674 if self.ctx.skip_on_error { quote!(#[allow(trivial_bounds)]) } else { quote!() };
675 let impl_tokens = quote! {
676 #allow_trivial_bounds
677 unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* >
678 where
679 #(#bounds,)*
680 {
681 fn only_derive_is_allowed_to_implement_this_trait() {}
682
683 #inner_extras
684 }
685 };
686
687 let outer_extras = self.outer_extras.filter(|e| !e.is_empty());
688 let cfg_compile_error = self.ctx.cfg_compile_error();
689 const_block([Some(cfg_compile_error), Some(impl_tokens), outer_extras])
690 }
691}
692
693#[allow(unused)]
701trait BoolExt {
702 fn then_some<T>(self, t: T) -> Option<T>;
703}
704
705impl BoolExt for bool {
706 fn then_some<T>(self, t: T) -> Option<T> {
707 if self {
708 Some(t)
709 } else {
710 None
711 }
712 }
713}
714
715pub(crate) fn const_block(items: impl IntoIterator<Item = Option<TokenStream>>) -> TokenStream {
716 let items = items.into_iter().flatten();
717 quote! {
718 #[allow(
719 deprecated,
722 private_bounds,
726 non_local_definitions,
727 non_camel_case_types,
728 non_upper_case_globals,
729 non_snake_case,
730 non_ascii_idents,
731 clippy::missing_inline_in_public_items,
732 )]
733 #[deny(ambiguous_associated_items)]
734 #[automatically_derived]
737 const _: () = {
738 #(#items)*
739 };
740 }
741}
742pub(crate) fn generate_tag_enum(ctx: &Ctx, repr: &EnumRepr, data: &DataEnum) -> TokenStream {
743 let zerocopy_crate = &ctx.zerocopy_crate;
744 let variants = data.variants.iter().map(|v| {
745 let ident = &v.ident;
746 if let Some((eq, discriminant)) = &v.discriminant {
747 quote! { #ident #eq #discriminant }
748 } else {
749 quote! { #ident }
750 }
751 });
752
753 let repr = match repr {
757 EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
758 EnumRepr::Compound(c, _) => quote! { #c },
759 };
760
761 quote! {
762 #repr
763 #[allow(dead_code)]
764 pub enum ___ZerocopyTag {
765 #(#variants,)*
766 }
767
768 unsafe impl #zerocopy_crate::Immutable for ___ZerocopyTag {
771 fn only_derive_is_allowed_to_implement_this_trait() {}
772 }
773 }
774}
775pub(crate) fn enum_size_from_repr(repr: &EnumRepr) -> Result<usize, Error> {
776 use CompoundRepr::*;
777 use PrimitiveRepr::*;
778 use Repr::*;
779 match repr {
780 Transparent(span)
781 | Compound(
782 Spanned {
783 t: C | Rust | Primitive(U32 | I32 | U64 | I64 | U128 | I128 | Usize | Isize),
784 span,
785 },
786 _,
787 ) => Err(Error::new(
788 *span,
789 "`FromBytes` only supported on enums with `#[repr(...)]` attributes `u8`, `i8`, `u16`, or `i16`",
790 )),
791 Compound(Spanned { t: Primitive(U8 | I8), span: _ }, _align) => Ok(8),
792 Compound(Spanned { t: Primitive(U16 | I16), span: _ }, _align) => Ok(16),
793 }
794}
795
796#[cfg(test)]
797pub(crate) mod testutil {
798 use proc_macro2::TokenStream;
799 use syn::visit::{self, Visit};
800
801 pub(crate) fn check_hygiene(ts: TokenStream) {
807 struct AmbiguousItemVisitor;
808
809 impl<'ast> Visit<'ast> for AmbiguousItemVisitor {
810 fn visit_path(&mut self, i: &'ast syn::Path) {
811 if i.segments.len() > 1 && i.segments.first().unwrap().ident == "Self" {
812 panic!(
813 "Found ambiguous path `{}` in generated output. \
814 All associated item access must be fully qualified (e.g., `<Self as Trait>::Item`) \
815 to prevent hygiene issues.",
816 quote::quote!(#i)
817 );
818 }
819 visit::visit_path(self, i);
820 }
821 }
822
823 let file = syn::parse2::<syn::File>(ts).expect("failed to parse generated output as File");
824 AmbiguousItemVisitor.visit_file(&file);
825 }
826
827 #[test]
828 fn test_check_hygiene_success() {
829 check_hygiene(quote::quote! {
830 fn foo() {
831 let _ = <Self as Trait>::Item;
832 }
833 });
834 }
835
836 #[test]
837 #[should_panic(expected = "Found ambiguous path `Self :: Ambiguous`")]
838 fn test_check_hygiene_failure() {
839 check_hygiene(quote::quote! {
840 fn foo() {
841 let _ = Self::Ambiguous;
842 }
843 });
844 }
845}