diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index 8644f36..5e6d30f 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -93,13 +93,12 @@ fn refactor_input_body( let result_name = format_ident!("wrapped_result"); let body = &input.block; let is_async = input.sig.asyncness.is_some(); - let context_arg_name = context_arg.name; + let context_pattern = context_arg.pattern; let context_binding = match context_arg.mode { - ContextArgMode::Owned => quote! { let #context_arg_name = __context; }, - ContextArgMode::OwnedMut => quote! { let mut #context_arg_name = __context; }, - ContextArgMode::Reference => quote! { let #context_arg_name = &__context; }, - ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; }, + ContextArgMode::Owned => quote! { let #context_pattern = __context; }, + ContextArgMode::Reference => quote! { let #context_pattern = &__context; }, + ContextArgMode::MutableReference => quote! { let #context_pattern = &mut __context; }, }; let body = if args.skip_teardown && is_async { diff --git a/test-context-macros/src/test_args.rs b/test-context-macros/src/test_args.rs index b5359a1..9d07289 100644 --- a/test-context-macros/src/test_args.rs +++ b/test-context-macros/src/test_args.rs @@ -3,8 +3,8 @@ use syn::FnArg; #[derive(Clone)] pub struct ContextArg { - /// The identifier name used for the context argument. - pub name: syn::Ident, + /// The original pattern (left side of a `pattern: type` expression). + pub pattern: syn::Pat, /// The mode in which the context was passed to the test function. pub mode: ContextArgMode, } @@ -13,8 +13,6 @@ pub struct ContextArg { pub enum ContextArgMode { /// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`. Owned, - /// The argument was passed as an owned value (mut `ContextType`). Only valid with `skip_teardown`. - OwnedMut, /// The argument was passed as an immutable reference (`&ContextType`). Reference, /// The argument was passed as a mutable reference (`&mut ContextType`). @@ -25,7 +23,6 @@ impl ContextArgMode { pub fn is_owned(&self) -> bool { match self { ContextArgMode::Owned => true, - ContextArgMode::OwnedMut => true, ContextArgMode::Reference => false, ContextArgMode::MutableReference => false, } @@ -40,15 +37,14 @@ pub enum TestArg { impl TestArg { pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self { - let syn::FnArg::Typed(pat_type) = &arg else { - return Self::Any(arg); - }; - - let syn::Pat::Ident(pat_ident) = &*pat_type.pat else { - return Self::Any(arg); + let pat_type = match arg { + FnArg::Typed(pat_type) => pat_type, + FnArg::Receiver(_) => return TestArg::Any(arg), }; + // fn example(pattern: arg_type) let arg_type = &*pat_type.ty; + let pattern = &*pat_type.pat; // Check for mutable/immutable reference if let syn::Type::Reference(type_ref) = arg_type @@ -61,28 +57,19 @@ impl TestArg { }; return TestArg::Context(ContextArg { - name: pat_ident.ident.clone(), + pattern: pattern.to_owned(), mode, }); } - if !types_equal(arg_type, expected_context_type) { - return TestArg::Any(arg); + if types_equal(arg_type, expected_context_type) { + return TestArg::Context(ContextArg { + pattern: pattern.to_owned(), + mode: ContextArgMode::Owned, + }); } - // To determine mutability for an owned type, we check the identifier pattern. - let mode = if pat_ident.mutability.is_some() { - // This catches signatures like: `mut my_ctx: ContextType` - ContextArgMode::OwnedMut - } else { - // This catches signatures like: `my_ctx: ContextType` - ContextArgMode::Owned - }; - - TestArg::Context(ContextArg { - name: pat_ident.ident.clone(), - mode, - }) + TestArg::Any(FnArg::Typed(pat_type)) } } diff --git a/test-context/tests/test.rs b/test-context/tests/test.rs index a82e864..3065191 100644 --- a/test-context/tests/test.rs +++ b/test-context/tests/test.rs @@ -25,6 +25,12 @@ fn test_sync_setup(ctx: &mut Context) { assert_eq!(ctx.n, 1); } +#[test_context(Context)] +#[test] +fn test_pattern_match_setup(Context { n }: &mut Context) { + assert_eq!(*n, 1); +} + #[test_context(Context)] #[test] #[should_panic(expected = "Number changed")]