macros/
kunit.rs

1// SPDX-License-Identifier: GPL-2.0
2
3//! Procedural macro to run KUnit tests using a user-space like syntax.
4//!
5//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6
7use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
8use std::collections::HashMap;
9use std::fmt::Write;
10
11pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
12    let attr = attr.to_string();
13
14    if attr.is_empty() {
15        panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
16    }
17
18    if attr.len() > 255 {
19        panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
20    }
21
22    let mut tokens: Vec<_> = ts.into_iter().collect();
23
24    // Scan for the `mod` keyword.
25    tokens
26        .iter()
27        .find_map(|token| match token {
28            TokenTree::Ident(ident) => match ident.to_string().as_str() {
29                "mod" => Some(true),
30                _ => None,
31            },
32            _ => None,
33        })
34        .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
35
36    // Retrieve the main body. The main body should be the last token tree.
37    let body = match tokens.pop() {
38        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
39        _ => panic!("Cannot locate main body of module"),
40    };
41
42    // Get the functions set as tests. Search for `[test]` -> `fn`.
43    let mut body_it = body.stream().into_iter();
44    let mut tests = Vec::new();
45    let mut attributes: HashMap<String, TokenStream> = HashMap::new();
46    while let Some(token) = body_it.next() {
47        match token {
48            TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
49                Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
50                    if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
51                        // Collect attributes because we need to find which are tests. We also
52                        // need to copy `cfg` attributes so tests can be conditionally enabled.
53                        attributes
54                            .entry(name.to_string())
55                            .or_default()
56                            .extend([token, TokenTree::Group(g)]);
57                    }
58                    continue;
59                }
60                _ => (),
61            },
62            TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
63                if let Some(TokenTree::Ident(test_name)) = body_it.next() {
64                    tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
65                }
66            }
67
68            _ => (),
69        }
70        attributes.clear();
71    }
72
73    // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
74    let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
75    tokens.insert(
76        0,
77        TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
78    );
79
80    // Generate the test KUnit test suite and a test case for each `#[test]`.
81    // The code generated for the following test module:
82    //
83    // ```
84    // #[kunit_tests(kunit_test_suit_name)]
85    // mod tests {
86    //     #[test]
87    //     fn foo() {
88    //         assert_eq!(1, 1);
89    //     }
90    //
91    //     #[test]
92    //     fn bar() {
93    //         assert_eq!(2, 2);
94    //     }
95    // }
96    // ```
97    //
98    // Looks like:
99    //
100    // ```
101    // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
102    // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
103    //
104    // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
105    //     ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo),
106    //     ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar),
107    //     ::kernel::kunit::kunit_case_null(),
108    // ];
109    //
110    // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
111    // ```
112    let mut kunit_macros = "".to_owned();
113    let mut test_cases = "".to_owned();
114    let mut assert_macros = "".to_owned();
115    let path = crate::helpers::file();
116    let num_tests = tests.len();
117    for (test, cfg_attr) in tests {
118        let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
119        // Append any `cfg` attributes the user might have written on their tests so we don't
120        // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
121        // the length of the assert message.
122        let kunit_wrapper = format!(
123            r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
124            {{
125                (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
126                {cfg_attr} {{
127                    (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
128                    use ::kernel::kunit::is_test_result_ok;
129                    assert!(is_test_result_ok({test}()));
130                }}
131            }}"#,
132        );
133        writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
134        writeln!(
135            test_cases,
136            "    ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),"
137        )
138        .unwrap();
139        writeln!(
140            assert_macros,
141            r#"
142/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
143#[allow(unused)]
144macro_rules! assert {{
145    ($cond:expr $(,)?) => {{{{
146        kernel::kunit_assert!("{test}", "{path}", 0, $cond);
147    }}}}
148}}
149
150/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
151#[allow(unused)]
152macro_rules! assert_eq {{
153    ($left:expr, $right:expr $(,)?) => {{{{
154        kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
155    }}}}
156}}
157        "#
158        )
159        .unwrap();
160    }
161
162    writeln!(kunit_macros).unwrap();
163    writeln!(
164        kunit_macros,
165        "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases}    ::kernel::kunit::kunit_case_null(),\n];",
166        num_tests + 1
167    )
168    .unwrap();
169
170    writeln!(
171        kunit_macros,
172        "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
173    )
174    .unwrap();
175
176    // Remove the `#[test]` macros.
177    // We do this at a token level, in order to preserve span information.
178    let mut new_body = vec![];
179    let mut body_it = body.stream().into_iter();
180
181    while let Some(token) = body_it.next() {
182        match token {
183            TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
184                Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
185                Some(next) => {
186                    new_body.extend([token, next]);
187                }
188                _ => {
189                    new_body.push(token);
190                }
191            },
192            _ => {
193                new_body.push(token);
194            }
195        }
196    }
197
198    let mut final_body = TokenStream::new();
199    final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
200    final_body.extend(new_body);
201    final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
202
203    tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
204
205    tokens.into_iter().collect()
206}