1use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
8use std::collections::HashMap;
9use std::fmt::Write;
10
11pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
12 let attr = attr.to_string();
13
14 if attr.is_empty() {
15 panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
16 }
17
18 if attr.len() > 255 {
19 panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
20 }
21
22 let mut tokens: Vec<_> = ts.into_iter().collect();
23
24 tokens
26 .iter()
27 .find_map(|token| match token {
28 TokenTree::Ident(ident) => match ident.to_string().as_str() {
29 "mod" => Some(true),
30 _ => None,
31 },
32 _ => None,
33 })
34 .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
35
36 let body = match tokens.pop() {
38 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
39 _ => panic!("Cannot locate main body of module"),
40 };
41
42 let mut body_it = body.stream().into_iter();
44 let mut tests = Vec::new();
45 let mut attributes: HashMap<String, TokenStream> = HashMap::new();
46 while let Some(token) = body_it.next() {
47 match token {
48 TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
49 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
50 if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
51 attributes
54 .entry(name.to_string())
55 .or_default()
56 .extend([token, TokenTree::Group(g)]);
57 }
58 continue;
59 }
60 _ => (),
61 },
62 TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
63 if let Some(TokenTree::Ident(test_name)) = body_it.next() {
64 tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
65 }
66 }
67
68 _ => (),
69 }
70 attributes.clear();
71 }
72
73 let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
75 tokens.insert(
76 0,
77 TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
78 );
79
80 let mut kunit_macros = "".to_owned();
113 let mut test_cases = "".to_owned();
114 let mut assert_macros = "".to_owned();
115 let path = crate::helpers::file();
116 let num_tests = tests.len();
117 for (test, cfg_attr) in tests {
118 let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
119 let kunit_wrapper = format!(
123 r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
124 {{
125 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
126 {cfg_attr} {{
127 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
128 use ::kernel::kunit::is_test_result_ok;
129 assert!(is_test_result_ok({test}()));
130 }}
131 }}"#,
132 );
133 writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
134 writeln!(
135 test_cases,
136 " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),"
137 )
138 .unwrap();
139 writeln!(
140 assert_macros,
141 r#"
142/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
143#[allow(unused)]
144macro_rules! assert {{
145 ($cond:expr $(,)?) => {{{{
146 kernel::kunit_assert!("{test}", "{path}", 0, $cond);
147 }}}}
148}}
149
150/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
151#[allow(unused)]
152macro_rules! assert_eq {{
153 ($left:expr, $right:expr $(,)?) => {{{{
154 kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
155 }}}}
156}}
157 "#
158 )
159 .unwrap();
160 }
161
162 writeln!(kunit_macros).unwrap();
163 writeln!(
164 kunit_macros,
165 "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];",
166 num_tests + 1
167 )
168 .unwrap();
169
170 writeln!(
171 kunit_macros,
172 "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
173 )
174 .unwrap();
175
176 let mut new_body = vec![];
179 let mut body_it = body.stream().into_iter();
180
181 while let Some(token) = body_it.next() {
182 match token {
183 TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
184 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
185 Some(next) => {
186 new_body.extend([token, next]);
187 }
188 _ => {
189 new_body.push(token);
190 }
191 },
192 _ => {
193 new_body.push(token);
194 }
195 }
196 }
197
198 let mut final_body = TokenStream::new();
199 final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
200 final_body.extend(new_body);
201 final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
202
203 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
204
205 tokens.into_iter().collect()
206}