From 39756352a86c52d2e93efa24c6f534c317630f9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Mart=C3=ADnez?= <49537445+JasterV@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:52:20 +0100 Subject: [PATCH] feat: make it so immutable references & full ownership can be taken depending on context (#58) This pull requests introduces changes that make the use of the `test_context` macro more flexible. Now, if the teardown is not skipped (default behavior), either an `immutable` or a `mutable` reference can be used for the context. If the teardown is skipped with the `skip_teardown` option, an `immutable`, a `mutable` reference or full ownership can be taken. So now the following is possible: ```rust #[test_context(TeardownPanicContext, skip_teardown)] #[tokio::test] async fn test_async_skip_teardown(_ctx: &mut TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[tokio::test] async fn test_async_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[tokio::test] async fn test_async_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} ``` --- README.md | 50 ++++- release-plz.toml | 1 + test-context-macros/src/lib.rs | 177 ++++++++++-------- .../src/{args.rs => macro_args.rs} | 0 test-context-macros/src/test_args.rs | 69 +++++++ test-context/tests/test.rs | 34 +++- 6 files changed, 237 insertions(+), 94 deletions(-) rename test-context-macros/src/{args.rs => macro_args.rs} (100%) create mode 100644 test-context-macros/src/test_args.rs diff --git a/README.md b/README.md index 2635e00..2f6334e 100644 --- a/README.md +++ b/README.md @@ -127,9 +127,8 @@ tests annotated with `#[tokio::test]` continue to work as usual without the feat ## Skipping the teardown execution -Also, if you don't care about the -teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro -like this: +Also, if you don't care about the teardown execution for a specific test, +you can use the `skip_teardown` keyword on the macro like this: ```rust use test_context::{test_context, TestContext}; @@ -144,11 +143,50 @@ like this: #[test_context(MyContext, skip_teardown)] #[test] -fn test_without_teardown(ctx: &mut MyContext) { +fn test_without_teardown(ctx: &MyContext) {} +``` + +## Taking ownership of the context vs taking a reference + +If the teardown is ON (default behavior), you can only take a reference to the context, either mutable or immutable, as follows: + +```rust +#[test_context(MyContext)] +#[test] +fn test_with_teardown_using_immutable_ref(ctx: &MyContext) {} + +#[test_context(MyContext)] +#[test] +fn test_with_teardown_using_mutable_ref(ctx: &mut MyContext) {} +``` + +❌The following is invalid: + +```rust +#[test_context(MyContext)] +#[test] +fn test_with_teardown_taking_ownership(ctx: MyContext) {} +``` + +If the teardown is skipped (as specified in the section above), you can take an immutable ref, mutable ref or full ownership of the context: + +```rust +#[test_context(MyContext, skip_teardown)] +#[test] +fn test_without_teardown(ctx: MyContext) { // Perform any operations that require full ownership of your context } + +#[test_context(MyContext, skip_teardown)] +#[test] +fn test_without_teardown_taking_a_ref(ctx: &MyContext) {} + +#[test_context(MyContext, skip_teardown)] +#[test] +fn test_without_teardown_taking_a_mut_ref(ctx: &mut MyContext) {} ``` + ## ⚠️ Ensure that the context type specified in the macro matches the test function argument type exactly The error occurs when a context type with an absolute path is mixed with an it's alias. @@ -161,8 +199,8 @@ mod database { pub struct Connection; - impl TestContext for :Connection { - fn setup() -> Self {Connection} + impl TestContext for Connection { + fn setup() -> Self { Connection } fn teardown(self) {...} } } diff --git a/release-plz.toml b/release-plz.toml index 0b7b3ac..23cda84 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -3,4 +3,5 @@ pr_branch_prefix = "release-" pr_labels = ["release"] git_tag_enable = true git_tag_name = "v{{ version }}" +git_release_name = "v{{ version }}" pr_draft = true diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index 7713d5e..af7185e 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -1,9 +1,11 @@ -mod args; +mod macro_args; +mod test_args; -use args::TestContextArgs; +use crate::test_args::{ContextArg, ContextArgMode, TestArg}; +use macro_args::TestContextArgs; use proc_macro::TokenStream; use quote::{format_ident, quote}; -use syn::Ident; +use syn::ItemFn; /// Macro to use on tests to add the setup/teardown functionality of your context. /// @@ -30,59 +32,108 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { let args = syn::parse_macro_input!(attr as TestContextArgs); let input = syn::parse_macro_input!(item as syn::ItemFn); - let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone()); - let input = refactor_input_body(input, &args, context_arg_name); + let (input, context_args) = remove_context_args(input, args.context_type.clone()); + + if context_args.len() != 1 { + panic!("Exactly one Context argument must be defined"); + } + + let context_arg = context_args.into_iter().next().unwrap(); + + if !args.skip_teardown && context_arg.mode == ContextArgMode::Owned { + panic!( + "It is not possible to take ownership of the context if the teardown has to be ran." + ); + } + + let input = refactor_input_body(input, &args, context_arg); quote! { #input }.into() } -fn refactor_input_body( +fn remove_context_args( mut input: syn::ItemFn, + expected_context_type: syn::Type, +) -> (syn::ItemFn, Vec) { + let test_args: Vec = input + .sig + .inputs + .into_iter() + .map(|arg| TestArg::parse_arg_with_expected_context(arg, &expected_context_type)) + .collect(); + + let context_args: Vec = test_args + .iter() + .cloned() + .filter_map(|arg| match arg { + TestArg::Any(_) => None, + TestArg::Context(context_arg_info) => Some(context_arg_info), + }) + .collect(); + + let new_args: syn::punctuated::Punctuated<_, _> = test_args + .into_iter() + .filter_map(|arg| match arg { + TestArg::Any(fn_arg) => Some(fn_arg), + TestArg::Context(_) => None, + }) + .collect(); + + input.sig.inputs = new_args; + + (input, context_args) +} + +fn refactor_input_body( + input: syn::ItemFn, args: &TestContextArgs, - context_arg_name: Option, + context_arg: ContextArg, ) -> syn::ItemFn { let context_type = &args.context_type; - let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx")); let result_name = format_ident!("wrapped_result"); let body = &input.block; let is_async = input.sig.asyncness.is_some(); + let context_arg_name = context_arg.name; - let body = match (is_async, args.skip_teardown) { - (true, true) => { - quote! { - use test_context::futures::FutureExt; - let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await; - let #context_arg_name = &mut __context; - let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; - } + let context_binding = match context_arg.mode { + ContextArgMode::Owned => quote! { let #context_arg_name = __context; }, + ContextArgMode::Reference => quote! { let #context_arg_name = &__context; }, + ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; }, + }; + + let body = if args.skip_teardown && is_async { + quote! { + use test_context::futures::FutureExt; + let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await; + #context_binding + let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; } - (true, false) => { - quote! { - use test_context::futures::FutureExt; - let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await; - let #context_arg_name = &mut __context; - let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; - <#context_type as test_context::AsyncTestContext>::teardown(__context).await; - } + } else if args.skip_teardown && !is_async { + quote! { + let mut __context = <#context_type as test_context::TestContext>::setup(); + let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + #context_binding + #body + })); } - (false, true) => { - quote! { - let mut __context = <#context_type as test_context::TestContext>::setup(); - let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - let #context_arg_name = &mut __context; - #body - })); - } + } else if !args.skip_teardown && is_async { + quote! { + use test_context::futures::FutureExt; + let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await; + #context_binding + let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await; + <#context_type as test_context::AsyncTestContext>::teardown(__context).await; } - (false, false) => { - quote! { - let mut __context = <#context_type as test_context::TestContext>::setup(); - let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - let #context_arg_name = &mut __context; - #body - })); - <#context_type as test_context::TestContext>::teardown(__context); - } + } + // !args.skip_teardown && !is_async + else { + quote! { + let mut __context = <#context_type as test_context::TestContext>::setup(); + let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + #context_binding + #body + })); + <#context_type as test_context::TestContext>::teardown(__context); } }; @@ -98,46 +149,8 @@ fn refactor_input_body( } }; - input.block = Box::new(syn::parse2(body).unwrap()); - - input -} - -fn remove_context_arg( - mut input: syn::ItemFn, - expected_context_type: syn::Type, -) -> (syn::ItemFn, Option) { - let mut context_arg_name = None; - let mut new_args = syn::punctuated::Punctuated::new(); - - for arg in &input.sig.inputs { - // Extract function arg: - if let syn::FnArg::Typed(pat_type) = arg { - // Extract arg identifier: - if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { - // Check that context arg is only ref or mutable ref: - if let syn::Type::Reference(type_ref) = &*pat_type.ty { - // Check that context has expected type: - if types_equal(&type_ref.elem, &expected_context_type) { - context_arg_name = Some(pat_ident.ident.clone()); - continue; - } - } - } - } - - new_args.push(arg.clone()); + ItemFn { + block: Box::new(syn::parse2(body).unwrap()), + ..input } - - input.sig.inputs = new_args; - - (input, context_arg_name) -} - -fn types_equal(a: &syn::Type, b: &syn::Type) -> bool { - if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) { - return a_path.path.segments.last().unwrap().ident - == b_path.path.segments.last().unwrap().ident; - } - quote!(#a).to_string() == quote!(#b).to_string() } diff --git a/test-context-macros/src/args.rs b/test-context-macros/src/macro_args.rs similarity index 100% rename from test-context-macros/src/args.rs rename to test-context-macros/src/macro_args.rs diff --git a/test-context-macros/src/test_args.rs b/test-context-macros/src/test_args.rs new file mode 100644 index 0000000..94069e4 --- /dev/null +++ b/test-context-macros/src/test_args.rs @@ -0,0 +1,69 @@ +use quote::quote; +use syn::FnArg; + +#[derive(Clone)] +pub struct ContextArg { + /// The identifier name used for the context argument. + pub name: syn::Ident, + /// The mode in which the context was passed to the test function. + pub mode: ContextArgMode, +} + +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub enum ContextArgMode { + /// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`. + Owned, + /// The argument was passed as an immutable reference (`&ContextType`). + Reference, + /// The argument was passed as a mutable reference (`&mut ContextType`). + MutableReference, +} + +#[derive(Clone)] +pub enum TestArg { + Any(FnArg), + Context(ContextArg), +} + +impl TestArg { + pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self { + // Check if the argument is the context argument + if let syn::FnArg::Typed(pat_type) = &arg + && let syn::Pat::Ident(pat_ident) = &*pat_type.pat + { + let arg_type = &*pat_type.ty; + // Check for mutable/immutable reference + if let syn::Type::Reference(type_ref) = arg_type + && types_equal(&type_ref.elem, expected_context_type) + { + let mode = if type_ref.mutability.is_some() { + ContextArgMode::MutableReference + } else { + ContextArgMode::Reference + }; + + TestArg::Context(ContextArg { + name: pat_ident.ident.clone(), + mode, + }) + } else if types_equal(arg_type, expected_context_type) { + TestArg::Context(ContextArg { + name: pat_ident.ident.clone(), + mode: ContextArgMode::Owned, + }) + } else { + TestArg::Any(arg) + } + } else { + TestArg::Any(arg) + } + } +} + +fn types_equal(a: &syn::Type, b: &syn::Type) -> bool { + if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) { + return a_path.path.segments.last().unwrap().ident + == b_path.path.segments.last().unwrap().ident; + } + quote!(#a).to_string() == quote!(#b).to_string() +} diff --git a/test-context/tests/test.rs b/test-context/tests/test.rs index 86fcf42..891d56f 100644 --- a/test-context/tests/test.rs +++ b/test-context/tests/test.rs @@ -50,6 +50,18 @@ fn includes_return_value() { assert_eq!(return_value_func(), 1); } +#[test_context(Context)] +#[test] +fn use_different_name(test_data: &mut Context) { + assert_eq!(test_data.n, 1); +} + +#[test_context(Context)] +#[test] +fn use_immutable_ref(test_data: &Context) { + assert_eq!(test_data.n, 1); +} + struct ContextGeneric { n: u32, _marker: PhantomData, @@ -140,12 +152,6 @@ fn async_auto_impls_sync(ctx: &mut AsyncContext) { assert_eq!(ctx.n, 1); } -#[test_context(Context)] -#[test] -fn use_different_name(test_data: &mut Context) { - assert_eq!(test_data.n, 1); -} - #[test_context(AsyncContext)] #[tokio::test] async fn use_different_name_async(test_data: &mut AsyncContext) { @@ -168,10 +174,26 @@ impl AsyncTestContext for TeardownPanicContext { #[tokio::test] async fn test_async_skip_teardown(_ctx: &mut TeardownPanicContext) {} +#[test_context(TeardownPanicContext, skip_teardown)] +#[tokio::test] +async fn test_async_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {} + +#[test_context(TeardownPanicContext, skip_teardown)] +#[tokio::test] +async fn test_async_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} + #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {} +#[test_context(TeardownPanicContext, skip_teardown)] +#[test] +fn test_sync_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {} + +#[test_context(TeardownPanicContext, skip_teardown)] +#[test] +fn test_sync_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} + struct GenericContext { contents: T, }