diff --git a/Cargo.toml b/Cargo.toml index 87e702e..986f229 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,9 +3,9 @@ resolver = "2" members = ["test-context", "test-context-macros"] [workspace.package] -edition = "2021" +edition = "2024" version = "0.5.0" -rust-version = "1.75.0" +rust-version = "1.91.0" homepage = "https://github.com/JasterV/test-context" repository = "https://github.com/JasterV/test-context" authors = [ diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 5965521..2767a11 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.90" +channel = "1.91" components = ["rustfmt", "clippy", "rust-src", "rust-analyzer"] diff --git a/test-context-macros/src/args.rs b/test-context-macros/src/args.rs index 7a06b63..013dff2 100644 --- a/test-context-macros/src/args.rs +++ b/test-context-macros/src/args.rs @@ -1,4 +1,4 @@ -use syn::{parse::Parse, Token, Type}; +use syn::{Token, Type, parse::Parse}; pub(crate) struct TestContextArgs { pub(crate) context_type: Type, diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index 5f504c0..7713d5e 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -3,7 +3,7 @@ mod args; use args::TestContextArgs; use proc_macro::TokenStream; use quote::{format_ident, quote}; -use syn::{Block, Ident}; +use syn::Ident; /// Macro to use on tests to add the setup/teardown functionality of your context. /// @@ -28,111 +28,82 @@ use syn::{Block, Ident}; #[proc_macro_attribute] pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { let args = syn::parse_macro_input!(attr as TestContextArgs); - let input = syn::parse_macro_input!(item as syn::ItemFn); + + let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone()); + let input = refactor_input_body(input, &args, context_arg_name); + + quote! { #input }.into() +} + +fn refactor_input_body( + mut input: syn::ItemFn, + args: &TestContextArgs, + context_arg_name: Option, +) -> syn::ItemFn { + let context_type = &args.context_type; + let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx")); + let result_name = format_ident!("wrapped_result"); + let body = &input.block; let is_async = input.sig.asyncness.is_some(); - let (new_input, context_arg_name) = - extract_and_remove_context_arg(input.clone(), args.context_type.clone()); - - let wrapper_body = if is_async { - async_wrapper_body(args, &context_arg_name, &input.block) - } else { - sync_wrapper_body(args, &context_arg_name, &input.block) - }; - - let mut result_input = new_input; - result_input.block = Box::new(syn::parse2(wrapper_body).unwrap()); - - quote! { #result_input }.into() -} - -fn async_wrapper_body( - args: TestContextArgs, - context_arg_name: &Option, - body: &Block, -) -> proc_macro2::TokenStream { - let context_type = args.context_type; - let result_name = format_ident!("wrapped_result"); - - let binding = format_ident!("test_ctx"); - let context_name = context_arg_name.as_ref().unwrap_or(&binding); - - let body = if args.skip_teardown { - quote! { - let #context_name = <#context_type as test_context::AsyncTestContext>::setup().await; - let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; - } - } else { - quote! { - let mut #context_name = <#context_type as test_context::AsyncTestContext>::setup().await; - let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; - <#context_type as test_context::AsyncTestContext>::teardown(#context_name).await; - } - }; - - let handle_wrapped_result = handle_result(result_name); - - quote! { - { - use test_context::futures::FutureExt; - #body - #handle_wrapped_result - } - } -} - -fn sync_wrapper_body( - args: TestContextArgs, - context_arg_name: &Option, - body: &Block, -) -> proc_macro2::TokenStream { - let context_type = args.context_type; - let result_name = format_ident!("wrapped_result"); - - let binding = format_ident!("test_ctx"); - let context_name = context_arg_name.as_ref().unwrap_or(&binding); - - let body = if args.skip_teardown { - quote! { - let mut #context_name = <#context_type as test_context::TestContext>::setup(); - let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - let #context_name = &mut #context_name; - #body - })); - } - } else { - quote! { - let mut #context_name = <#context_type as test_context::TestContext>::setup(); - let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - #body - })); - <#context_type as test_context::TestContext>::teardown(#context_name); - } - }; - - let handle_wrapped_result = handle_result(result_name); - - quote! { - { - #body - #handle_wrapped_result - } - } -} - -fn handle_result(result_name: Ident) -> proc_macro2::TokenStream { - quote! { - match #result_name { - Ok(value) => value, - Err(err) => { - std::panic::resume_unwind(err); + let body = match (is_async, args.skip_teardown) { + (true, true) => { + quote! { + use test_context::futures::FutureExt; + let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await; + let #context_arg_name = &mut __context; + let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; } } - } + (true, false) => { + quote! { + use test_context::futures::FutureExt; + let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await; + let #context_arg_name = &mut __context; + let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; + <#context_type as test_context::AsyncTestContext>::teardown(__context).await; + } + } + (false, true) => { + quote! { + let mut __context = <#context_type as test_context::TestContext>::setup(); + let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let #context_arg_name = &mut __context; + #body + })); + } + } + (false, false) => { + quote! { + let mut __context = <#context_type as test_context::TestContext>::setup(); + let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let #context_arg_name = &mut __context; + #body + })); + <#context_type as test_context::TestContext>::teardown(__context); + } + } + }; + + let body = quote! { + { + #body + match #result_name { + Ok(value) => value, + Err(err) => { + std::panic::resume_unwind(err); + } + } + } + }; + + input.block = Box::new(syn::parse2(body).unwrap()); + + input } -fn extract_and_remove_context_arg( +fn remove_context_arg( mut input: syn::ItemFn, expected_context_type: syn::Type, ) -> (syn::ItemFn, Option) { @@ -154,10 +125,12 @@ fn extract_and_remove_context_arg( } } } + new_args.push(arg.clone()); } input.sig.inputs = new_args; + (input, context_arg_name) } diff --git a/test-context/tests/test.rs b/test-context/tests/test.rs index b894025..86fcf42 100644 --- a/test-context/tests/test.rs +++ b/test-context/tests/test.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use rstest::rstest; -use test_context::{test_context, AsyncTestContext, TestContext}; +use test_context::{AsyncTestContext, TestContext, test_context}; struct Context { n: u32,