Skip to main content

zerocopy_derive/derive/
from_bytes.rs

1// SPDX-License-Identifier: (BSD-2-Clause OR Apache-2.0) OR MIT
2
3use proc_macro2::{Span, TokenStream};
4use syn::{
5    parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp,
6    WherePredicate,
7};
8
9use crate::{
10    derive::try_from_bytes::derive_try_from_bytes,
11    repr::{CompoundRepr, EnumRepr, Repr, Spanned},
12    util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound},
13};
14/// Returns `Ok(index)` if variant `index` of the enum has a discriminant of
15/// zero. If `Err(bool)` is returned, the boolean is true if the enum has
16/// unknown discriminants (e.g. discriminants set to const expressions which we
17/// can't evaluate in a proc macro). If the enum has unknown discriminants, then
18/// it might have a zero variant that we just can't detect.
19pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> {
20    // Discriminants can be anywhere in the range [i128::MIN, u128::MAX] because
21    // the discriminant type may be signed or unsigned. Since we only care about
22    // tracking the discriminant when it's less than or equal to zero, we can
23    // avoid u128 -> i128 conversions and bounds checking by making the "next
24    // discriminant" value implicitly negative.
25    // Technically 64 bits is enough, but 128 is better for future compatibility
26    // with https://github.com/rust-lang/rust/issues/56071
27    let mut next_negative_discriminant = Some(0);
28
29    // Sometimes we encounter explicit discriminants that we can't know the
30    // value of (e.g. a constant expression that requires evaluation). These
31    // could evaluate to zero or a negative number, but we can't assume that
32    // they do (no false positives allowed!). So we treat them like strictly-
33    // positive values that can't result in any zero variants, and track whether
34    // we've encountered any unknown discriminants.
35    let mut has_unknown_discriminants = false;
36
37    for (i, v) in enm.variants.iter().enumerate() {
38        match v.discriminant.as_ref() {
39            // Implicit discriminant
40            None => {
41                match next_negative_discriminant.as_mut() {
42                    Some(0) => return Ok(i),
43                    // n is nonzero so subtraction is always safe
44                    Some(n) => *n -= 1,
45                    None => (),
46                }
47            }
48            // Explicit positive discriminant
49            Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => {
50                match int.base10_parse::<u128>().ok() {
51                    Some(0) => return Ok(i),
52                    Some(_) => next_negative_discriminant = None,
53                    None => {
54                        // Numbers should never fail to parse, but just in case:
55                        has_unknown_discriminants = true;
56                        next_negative_discriminant = None;
57                    }
58                }
59            }
60            // Explicit negative discriminant
61            Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr {
62                Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
63                    match int.base10_parse::<u128>().ok() {
64                        Some(0) => return Ok(i),
65                        // x is nonzero so subtraction is always safe
66                        Some(x) => next_negative_discriminant = Some(x - 1),
67                        None => {
68                            // Numbers should never fail to parse, but just in
69                            // case:
70                            has_unknown_discriminants = true;
71                            next_negative_discriminant = None;
72                        }
73                    }
74                }
75                // Unknown negative discriminant (e.g. const repr)
76                _ => {
77                    has_unknown_discriminants = true;
78                    next_negative_discriminant = None;
79                }
80            },
81            // Unknown discriminant (e.g. const expr)
82            _ => {
83                has_unknown_discriminants = true;
84                next_negative_discriminant = None;
85            }
86        }
87    }
88
89    Err(has_unknown_discriminants)
90}
91pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
92    let try_from_bytes = derive_try_from_bytes(ctx, top_level)?;
93    let from_zeros = match &ctx.ast.data {
94        Data::Struct(strct) => derive_from_zeros_struct(ctx, strct),
95        Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?,
96        Data::Union(unn) => derive_from_zeros_union(ctx, unn),
97    };
98    Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect())
99}
100pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
101    let from_zeros = derive_from_zeros(ctx, top_level)?;
102    let from_bytes = match &ctx.ast.data {
103        Data::Struct(strct) => derive_from_bytes_struct(ctx, strct),
104        Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?,
105        Data::Union(unn) => derive_from_bytes_union(ctx, unn),
106    };
107
108    Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect())
109}
110fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
111    ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build()
112}
113fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
114    let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
115
116    // We don't actually care what the repr is; we just care that it's one of
117    // the allowed ones.
118    match repr {
119        Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => {
120        }
121        Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => {
122            return ctx.error_or_skip(
123                Error::new(
124                    Span::call_site(),
125                    "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
126                ),
127            );
128        }
129    }
130
131    let zero_variant = match find_zero_variant(enm) {
132        Ok(index) => enm.variants.iter().nth(index).unwrap(),
133        // Has unknown variants
134        Err(true) => {
135            return ctx.error_or_skip(Error::new_spanned(
136                &ctx.ast,
137                "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\
138                help: This enum has discriminants which are not literal integers. One of those may \
139                define or imply which variant has a discriminant of zero. Use a literal integer to \
140                define or imply the variant with a discriminant of zero.",
141            ));
142        }
143        // Does not have unknown variants
144        Err(false) => {
145            return ctx.error_or_skip(Error::new_spanned(
146                &ctx.ast,
147                "FromZeros only supported on enums with a variant that has a discriminant of `0`",
148            ));
149        }
150    };
151
152    let zerocopy_crate = &ctx.zerocopy_crate;
153    let explicit_bounds = zero_variant
154        .fields
155        .iter()
156        .map(|field| {
157            let ty = &field.ty;
158            parse_quote! { #ty: #zerocopy_crate::FromZeros }
159        })
160        .collect::<Vec<WherePredicate>>();
161
162    Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds))
163        .build())
164}
165fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
166    let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
167    ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build()
168}
169fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
170    ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build()
171}
172fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
173    let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
174
175    let variants_required = 1usize << enum_size_from_repr(&repr)?;
176    if enm.variants.len() != variants_required {
177        return ctx.error_or_skip(Error::new_spanned(
178            &ctx.ast,
179            format!(
180                "FromBytes only supported on {} enum with {} variants",
181                repr.repr_type_name(),
182                variants_required
183            ),
184        ));
185    }
186
187    Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build())
188}
189fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
190    let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
191    ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build()
192}