fix: use any kind of pattern when defining context fn argument

This commit is contained in:
JasterV 2026-04-22 16:28:51 +02:00
parent bfe19a44c6
commit 80eeaf756d
3 changed files with 24 additions and 32 deletions

View file

@ -93,13 +93,12 @@ fn refactor_input_body(
let result_name = format_ident!("wrapped_result");
let body = &input.block;
let is_async = input.sig.asyncness.is_some();
let context_arg_name = context_arg.name;
let context_pattern = context_arg.pattern;
let context_binding = match context_arg.mode {
ContextArgMode::Owned => quote! { let #context_arg_name = __context; },
ContextArgMode::OwnedMut => quote! { let mut #context_arg_name = __context; },
ContextArgMode::Reference => quote! { let #context_arg_name = &__context; },
ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; },
ContextArgMode::Owned => quote! { let #context_pattern = __context; },
ContextArgMode::Reference => quote! { let #context_pattern = &__context; },
ContextArgMode::MutableReference => quote! { let #context_pattern = &mut __context; },
};
let body = if args.skip_teardown && is_async {

View file

@ -3,8 +3,8 @@ use syn::FnArg;
#[derive(Clone)]
pub struct ContextArg {
/// The identifier name used for the context argument.
pub name: syn::Ident,
/// The original pattern (left side of a `pattern: type` expression).
pub pattern: syn::Pat,
/// The mode in which the context was passed to the test function.
pub mode: ContextArgMode,
}
@ -13,8 +13,6 @@ pub struct ContextArg {
pub enum ContextArgMode {
/// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`.
Owned,
/// The argument was passed as an owned value (mut `ContextType`). Only valid with `skip_teardown`.
OwnedMut,
/// The argument was passed as an immutable reference (`&ContextType`).
Reference,
/// The argument was passed as a mutable reference (`&mut ContextType`).
@ -25,7 +23,6 @@ impl ContextArgMode {
pub fn is_owned(&self) -> bool {
match self {
ContextArgMode::Owned => true,
ContextArgMode::OwnedMut => true,
ContextArgMode::Reference => false,
ContextArgMode::MutableReference => false,
}
@ -40,15 +37,14 @@ pub enum TestArg {
impl TestArg {
pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self {
let syn::FnArg::Typed(pat_type) = &arg else {
return Self::Any(arg);
};
let syn::Pat::Ident(pat_ident) = &*pat_type.pat else {
return Self::Any(arg);
let pat_type = match arg {
FnArg::Typed(pat_type) => pat_type,
FnArg::Receiver(_) => return TestArg::Any(arg),
};
// fn example(pattern: arg_type)
let arg_type = &*pat_type.ty;
let pattern = &*pat_type.pat;
// Check for mutable/immutable reference
if let syn::Type::Reference(type_ref) = arg_type
@ -61,28 +57,19 @@ impl TestArg {
};
return TestArg::Context(ContextArg {
name: pat_ident.ident.clone(),
pattern: pattern.to_owned(),
mode,
});
}
if !types_equal(arg_type, expected_context_type) {
return TestArg::Any(arg);
if types_equal(arg_type, expected_context_type) {
return TestArg::Context(ContextArg {
pattern: pattern.to_owned(),
mode: ContextArgMode::Owned,
});
}
// To determine mutability for an owned type, we check the identifier pattern.
let mode = if pat_ident.mutability.is_some() {
// This catches signatures like: `mut my_ctx: ContextType`
ContextArgMode::OwnedMut
} else {
// This catches signatures like: `my_ctx: ContextType`
ContextArgMode::Owned
};
TestArg::Context(ContextArg {
name: pat_ident.ident.clone(),
mode,
})
TestArg::Any(FnArg::Typed(pat_type))
}
}

View file

@ -25,6 +25,12 @@ fn test_sync_setup(ctx: &mut Context) {
assert_eq!(ctx.n, 1);
}
#[test_context(Context)]
#[test]
fn test_pattern_match_setup(Context { n }: &mut Context) {
assert_eq!(*n, 1);
}
#[test_context(Context)]
#[test]
#[should_panic(expected = "Number changed")]