From dc77d78608b537d21159bdaf522fecd713e2c9ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20Martinez=20Montan=C3=A9?= Date: Wed, 22 Apr 2026 17:26:08 +0200 Subject: [PATCH] fix/pattern-matching-breaks-compilation (#4) This PR solves the issue identified at https://codeberg.org/JasterV/test-context/issues/2. Previously, when defining the function argument that implements the `TestContext` trait, only "Ident" patterns were accepted (That is, only names). So, the following was valid: ```rust #[test] fn my_test(context: MyContext) {} ``` But the following would throw an error: ```rust #[test] fn my_test(MyContext { n }: MyContext {} ``` With this PR, we are now able to accept any kind of "pattern" such as struct pattern matching, enum pattern matching... etc. We only really care about the "type", and not the "pattern", so this PR makes sure that the "pattern" part of the binding is left untouched. Tests have been added to ensure that destructuring a struct compiles Co-authored-by: JasterV <49537445+JasterV@users.noreply.github.com> Reviewed-on: https://codeberg.org/JasterV/test-context/pulls/4 --- test-context-macros/src/lib.rs | 9 +++--- test-context-macros/src/test_args.rs | 41 ++++++++++------------------ test-context/tests/test.rs | 6 ++++ 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index 8644f36..5e6d30f 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -93,13 +93,12 @@ fn refactor_input_body( let result_name = format_ident!("wrapped_result"); let body = &input.block; let is_async = input.sig.asyncness.is_some(); - let context_arg_name = context_arg.name; + let context_pattern = context_arg.pattern; let context_binding = match context_arg.mode { - ContextArgMode::Owned => quote! { let #context_arg_name = __context; }, - ContextArgMode::OwnedMut => quote! { let mut #context_arg_name = __context; }, - ContextArgMode::Reference => quote! { let #context_arg_name = &__context; }, - ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; }, + ContextArgMode::Owned => quote! { let #context_pattern = __context; }, + ContextArgMode::Reference => quote! { let #context_pattern = &__context; }, + ContextArgMode::MutableReference => quote! { let #context_pattern = &mut __context; }, }; let body = if args.skip_teardown && is_async { diff --git a/test-context-macros/src/test_args.rs b/test-context-macros/src/test_args.rs index b5359a1..9d07289 100644 --- a/test-context-macros/src/test_args.rs +++ b/test-context-macros/src/test_args.rs @@ -3,8 +3,8 @@ use syn::FnArg; #[derive(Clone)] pub struct ContextArg { - /// The identifier name used for the context argument. - pub name: syn::Ident, + /// The original pattern (left side of a `pattern: type` expression). + pub pattern: syn::Pat, /// The mode in which the context was passed to the test function. pub mode: ContextArgMode, } @@ -13,8 +13,6 @@ pub struct ContextArg { pub enum ContextArgMode { /// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`. Owned, - /// The argument was passed as an owned value (mut `ContextType`). Only valid with `skip_teardown`. - OwnedMut, /// The argument was passed as an immutable reference (`&ContextType`). Reference, /// The argument was passed as a mutable reference (`&mut ContextType`). @@ -25,7 +23,6 @@ impl ContextArgMode { pub fn is_owned(&self) -> bool { match self { ContextArgMode::Owned => true, - ContextArgMode::OwnedMut => true, ContextArgMode::Reference => false, ContextArgMode::MutableReference => false, } @@ -40,15 +37,14 @@ pub enum TestArg { impl TestArg { pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self { - let syn::FnArg::Typed(pat_type) = &arg else { - return Self::Any(arg); - }; - - let syn::Pat::Ident(pat_ident) = &*pat_type.pat else { - return Self::Any(arg); + let pat_type = match arg { + FnArg::Typed(pat_type) => pat_type, + FnArg::Receiver(_) => return TestArg::Any(arg), }; + // fn example(pattern: arg_type) let arg_type = &*pat_type.ty; + let pattern = &*pat_type.pat; // Check for mutable/immutable reference if let syn::Type::Reference(type_ref) = arg_type @@ -61,28 +57,19 @@ impl TestArg { }; return TestArg::Context(ContextArg { - name: pat_ident.ident.clone(), + pattern: pattern.to_owned(), mode, }); } - if !types_equal(arg_type, expected_context_type) { - return TestArg::Any(arg); + if types_equal(arg_type, expected_context_type) { + return TestArg::Context(ContextArg { + pattern: pattern.to_owned(), + mode: ContextArgMode::Owned, + }); } - // To determine mutability for an owned type, we check the identifier pattern. - let mode = if pat_ident.mutability.is_some() { - // This catches signatures like: `mut my_ctx: ContextType` - ContextArgMode::OwnedMut - } else { - // This catches signatures like: `my_ctx: ContextType` - ContextArgMode::Owned - }; - - TestArg::Context(ContextArg { - name: pat_ident.ident.clone(), - mode, - }) + TestArg::Any(FnArg::Typed(pat_type)) } } diff --git a/test-context/tests/test.rs b/test-context/tests/test.rs index a82e864..3065191 100644 --- a/test-context/tests/test.rs +++ b/test-context/tests/test.rs @@ -25,6 +25,12 @@ fn test_sync_setup(ctx: &mut Context) { assert_eq!(ctx.n, 1); } +#[test_context(Context)] +#[test] +fn test_pattern_match_setup(Context { n }: &mut Context) { + assert_eq!(*n, 1); +} + #[test_context(Context)] #[test] #[should_panic(expected = "Number changed")]