diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index af7185e..4efac74 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -40,7 +40,7 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { let context_arg = context_args.into_iter().next().unwrap(); - if !args.skip_teardown && context_arg.mode == ContextArgMode::Owned { + if !args.skip_teardown && context_arg.mode.is_owned() { panic!( "It is not possible to take ownership of the context if the teardown has to be ran." ); @@ -97,6 +97,7 @@ fn refactor_input_body( let context_binding = match context_arg.mode { ContextArgMode::Owned => quote! { let #context_arg_name = __context; }, + ContextArgMode::OwnedMut => quote! { let mut #context_arg_name = __context; }, ContextArgMode::Reference => quote! { let #context_arg_name = &__context; }, ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; }, }; diff --git a/test-context-macros/src/test_args.rs b/test-context-macros/src/test_args.rs index 94069e4..664ee62 100644 --- a/test-context-macros/src/test_args.rs +++ b/test-context-macros/src/test_args.rs @@ -13,12 +13,25 @@ pub struct ContextArg { pub enum ContextArgMode { /// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`. Owned, + /// The argument was passed as an owned value (mut `ContextType`). Only valid with `skip_teardown`. + OwnedMut, /// The argument was passed as an immutable reference (`&ContextType`). Reference, /// The argument was passed as a mutable reference (`&mut ContextType`). MutableReference, } +impl ContextArgMode { + pub fn is_owned(&self) -> bool { + match self { + ContextArgMode::Owned => true, + ContextArgMode::OwnedMut => true, + ContextArgMode::Reference => false, + ContextArgMode::MutableReference => false, + } + } +} + #[derive(Clone)] pub enum TestArg { Any(FnArg), @@ -47,9 +60,18 @@ impl TestArg { mode, }) } else if types_equal(arg_type, expected_context_type) { + // To determine mutability for an owned type, we check the identifier pattern. + let mode = if pat_ident.mutability.is_some() { + // This catches signatures like: `mut my_ctx: ContextType` + ContextArgMode::OwnedMut + } else { + // This catches signatures like: `my_ctx: ContextType` + ContextArgMode::Owned + }; + TestArg::Context(ContextArg { name: pat_ident.ident.clone(), - mode: ContextArgMode::Owned, + mode, }) } else { TestArg::Any(arg) diff --git a/test-context/tests/test.rs b/test-context/tests/test.rs index 891d56f..c8919d6 100644 --- a/test-context/tests/test.rs +++ b/test-context/tests/test.rs @@ -182,6 +182,12 @@ async fn test_async_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext #[tokio::test] async fn test_async_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} +#[test_context(TeardownPanicContext, skip_teardown)] +#[tokio::test] +async fn test_async_skip_teardown_with_full_mut_ownership(mut ctx: TeardownPanicContext) { + let _test = &mut ctx; +} + #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {} @@ -194,6 +200,12 @@ fn test_sync_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {} #[test] fn test_sync_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} +#[test_context(TeardownPanicContext, skip_teardown)] +#[test] +fn test_sync_skip_teardown_with_full_mut_ownership(mut ctx: TeardownPanicContext) { + let _test = &mut ctx; +} + struct GenericContext { contents: T, }