feat: make it so immutable references & full ownership can be taken depending on context (#58)

This pull requests introduces changes that make the use of the `test_context` macro more flexible.

Now, if the teardown is not skipped (default behavior), either an `immutable` or a `mutable` reference can be used for the context.

If the teardown is skipped with the `skip_teardown` option, an `immutable`, a `mutable` reference or full ownership can be taken.

So now the following is possible:

```rust
#[test_context(TeardownPanicContext, skip_teardown)]
#[tokio::test]
async fn test_async_skip_teardown(_ctx: &mut TeardownPanicContext) {}

#[test_context(TeardownPanicContext, skip_teardown)]
#[tokio::test]
async fn test_async_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {}

#[test_context(TeardownPanicContext, skip_teardown)]
#[tokio::test]
async fn test_async_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {}

#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {}

#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {}

#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {}
```
This commit is contained in:
Víctor Martínez 2025-11-06 13:52:20 +01:00 committed by GitHub
parent 5e407cdfa2
commit 39756352a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 237 additions and 94 deletions

View file

@ -127,9 +127,8 @@ tests annotated with `#[tokio::test]` continue to work as usual without the feat
## Skipping the teardown execution
Also, if you don't care about the
teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
like this:
Also, if you don't care about the teardown execution for a specific test,
you can use the `skip_teardown` keyword on the macro like this:
```rust
use test_context::{test_context, TestContext};
@ -144,11 +143,50 @@ like this:
#[test_context(MyContext, skip_teardown)]
#[test]
fn test_without_teardown(ctx: &mut MyContext) {
fn test_without_teardown(ctx: &MyContext) {}
```
## Taking ownership of the context vs taking a reference
If the teardown is ON (default behavior), you can only take a reference to the context, either mutable or immutable, as follows:
```rust
#[test_context(MyContext)]
#[test]
fn test_with_teardown_using_immutable_ref(ctx: &MyContext) {}
#[test_context(MyContext)]
#[test]
fn test_with_teardown_using_mutable_ref(ctx: &mut MyContext) {}
```
❌The following is invalid:
```rust
#[test_context(MyContext)]
#[test]
fn test_with_teardown_taking_ownership(ctx: MyContext) {}
```
If the teardown is skipped (as specified in the section above), you can take an immutable ref, mutable ref or full ownership of the context:
```rust
#[test_context(MyContext, skip_teardown)]
#[test]
fn test_without_teardown(ctx: MyContext) {
// Perform any operations that require full ownership of your context
}
#[test_context(MyContext, skip_teardown)]
#[test]
fn test_without_teardown_taking_a_ref(ctx: &MyContext) {}
#[test_context(MyContext, skip_teardown)]
#[test]
fn test_without_teardown_taking_a_mut_ref(ctx: &mut MyContext) {}
```
## ⚠️ Ensure that the context type specified in the macro matches the test function argument type exactly
The error occurs when a context type with an absolute path is mixed with an it's alias.
@ -161,8 +199,8 @@ mod database {
pub struct Connection;
impl TestContext for :Connection {
fn setup() -> Self {Connection}
impl TestContext for Connection {
fn setup() -> Self { Connection }
fn teardown(self) {...}
}
}

View file

@ -3,4 +3,5 @@ pr_branch_prefix = "release-"
pr_labels = ["release"]
git_tag_enable = true
git_tag_name = "v{{ version }}"
git_release_name = "v{{ version }}"
pr_draft = true

View file

@ -1,9 +1,11 @@
mod args;
mod macro_args;
mod test_args;
use args::TestContextArgs;
use crate::test_args::{ContextArg, ContextArgMode, TestArg};
use macro_args::TestContextArgs;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
use syn::ItemFn;
/// Macro to use on tests to add the setup/teardown functionality of your context.
///
@ -30,59 +32,108 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(attr as TestContextArgs);
let input = syn::parse_macro_input!(item as syn::ItemFn);
let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone());
let input = refactor_input_body(input, &args, context_arg_name);
let (input, context_args) = remove_context_args(input, args.context_type.clone());
if context_args.len() != 1 {
panic!("Exactly one Context argument must be defined");
}
let context_arg = context_args.into_iter().next().unwrap();
if !args.skip_teardown && context_arg.mode == ContextArgMode::Owned {
panic!(
"It is not possible to take ownership of the context if the teardown has to be ran."
);
}
let input = refactor_input_body(input, &args, context_arg);
quote! { #input }.into()
}
fn refactor_input_body(
fn remove_context_args(
mut input: syn::ItemFn,
expected_context_type: syn::Type,
) -> (syn::ItemFn, Vec<ContextArg>) {
let test_args: Vec<TestArg> = input
.sig
.inputs
.into_iter()
.map(|arg| TestArg::parse_arg_with_expected_context(arg, &expected_context_type))
.collect();
let context_args: Vec<ContextArg> = test_args
.iter()
.cloned()
.filter_map(|arg| match arg {
TestArg::Any(_) => None,
TestArg::Context(context_arg_info) => Some(context_arg_info),
})
.collect();
let new_args: syn::punctuated::Punctuated<_, _> = test_args
.into_iter()
.filter_map(|arg| match arg {
TestArg::Any(fn_arg) => Some(fn_arg),
TestArg::Context(_) => None,
})
.collect();
input.sig.inputs = new_args;
(input, context_args)
}
fn refactor_input_body(
input: syn::ItemFn,
args: &TestContextArgs,
context_arg_name: Option<Ident>,
context_arg: ContextArg,
) -> syn::ItemFn {
let context_type = &args.context_type;
let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx"));
let result_name = format_ident!("wrapped_result");
let body = &input.block;
let is_async = input.sig.asyncness.is_some();
let context_arg_name = context_arg.name;
let body = match (is_async, args.skip_teardown) {
(true, true) => {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
let #context_arg_name = &mut __context;
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
}
let context_binding = match context_arg.mode {
ContextArgMode::Owned => quote! { let #context_arg_name = __context; },
ContextArgMode::Reference => quote! { let #context_arg_name = &__context; },
ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; },
};
let body = if args.skip_teardown && is_async {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
#context_binding
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
}
(true, false) => {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
let #context_arg_name = &mut __context;
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
}
} else if args.skip_teardown && !is_async {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#context_binding
#body
}));
}
(false, true) => {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let #context_arg_name = &mut __context;
#body
}));
}
} else if !args.skip_teardown && is_async {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
#context_binding
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
}
(false, false) => {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let #context_arg_name = &mut __context;
#body
}));
<#context_type as test_context::TestContext>::teardown(__context);
}
}
// !args.skip_teardown && !is_async
else {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#context_binding
#body
}));
<#context_type as test_context::TestContext>::teardown(__context);
}
};
@ -98,46 +149,8 @@ fn refactor_input_body(
}
};
input.block = Box::new(syn::parse2(body).unwrap());
input
}
fn remove_context_arg(
mut input: syn::ItemFn,
expected_context_type: syn::Type,
) -> (syn::ItemFn, Option<syn::Ident>) {
let mut context_arg_name = None;
let mut new_args = syn::punctuated::Punctuated::new();
for arg in &input.sig.inputs {
// Extract function arg:
if let syn::FnArg::Typed(pat_type) = arg {
// Extract arg identifier:
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
// Check that context arg is only ref or mutable ref:
if let syn::Type::Reference(type_ref) = &*pat_type.ty {
// Check that context has expected type:
if types_equal(&type_ref.elem, &expected_context_type) {
context_arg_name = Some(pat_ident.ident.clone());
continue;
}
}
}
}
new_args.push(arg.clone());
ItemFn {
block: Box::new(syn::parse2(body).unwrap()),
..input
}
input.sig.inputs = new_args;
(input, context_arg_name)
}
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
return a_path.path.segments.last().unwrap().ident
== b_path.path.segments.last().unwrap().ident;
}
quote!(#a).to_string() == quote!(#b).to_string()
}

View file

@ -0,0 +1,69 @@
use quote::quote;
use syn::FnArg;
#[derive(Clone)]
pub struct ContextArg {
/// The identifier name used for the context argument.
pub name: syn::Ident,
/// The mode in which the context was passed to the test function.
pub mode: ContextArgMode,
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum ContextArgMode {
/// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`.
Owned,
/// The argument was passed as an immutable reference (`&ContextType`).
Reference,
/// The argument was passed as a mutable reference (`&mut ContextType`).
MutableReference,
}
#[derive(Clone)]
pub enum TestArg {
Any(FnArg),
Context(ContextArg),
}
impl TestArg {
pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self {
// Check if the argument is the context argument
if let syn::FnArg::Typed(pat_type) = &arg
&& let syn::Pat::Ident(pat_ident) = &*pat_type.pat
{
let arg_type = &*pat_type.ty;
// Check for mutable/immutable reference
if let syn::Type::Reference(type_ref) = arg_type
&& types_equal(&type_ref.elem, expected_context_type)
{
let mode = if type_ref.mutability.is_some() {
ContextArgMode::MutableReference
} else {
ContextArgMode::Reference
};
TestArg::Context(ContextArg {
name: pat_ident.ident.clone(),
mode,
})
} else if types_equal(arg_type, expected_context_type) {
TestArg::Context(ContextArg {
name: pat_ident.ident.clone(),
mode: ContextArgMode::Owned,
})
} else {
TestArg::Any(arg)
}
} else {
TestArg::Any(arg)
}
}
}
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
return a_path.path.segments.last().unwrap().ident
== b_path.path.segments.last().unwrap().ident;
}
quote!(#a).to_string() == quote!(#b).to_string()
}

View file

@ -50,6 +50,18 @@ fn includes_return_value() {
assert_eq!(return_value_func(), 1);
}
#[test_context(Context)]
#[test]
fn use_different_name(test_data: &mut Context) {
assert_eq!(test_data.n, 1);
}
#[test_context(Context)]
#[test]
fn use_immutable_ref(test_data: &Context) {
assert_eq!(test_data.n, 1);
}
struct ContextGeneric<T> {
n: u32,
_marker: PhantomData<T>,
@ -140,12 +152,6 @@ fn async_auto_impls_sync(ctx: &mut AsyncContext) {
assert_eq!(ctx.n, 1);
}
#[test_context(Context)]
#[test]
fn use_different_name(test_data: &mut Context) {
assert_eq!(test_data.n, 1);
}
#[test_context(AsyncContext)]
#[tokio::test]
async fn use_different_name_async(test_data: &mut AsyncContext) {
@ -168,10 +174,26 @@ impl AsyncTestContext for TeardownPanicContext {
#[tokio::test]
async fn test_async_skip_teardown(_ctx: &mut TeardownPanicContext) {}
#[test_context(TeardownPanicContext, skip_teardown)]
#[tokio::test]
async fn test_async_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {}
#[test_context(TeardownPanicContext, skip_teardown)]
#[tokio::test]
async fn test_async_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {}
#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {}
#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {}
#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {}
struct GenericContext<T> {
contents: T,
}