fix: Regression in version 0.5.0 (#55)

An attempt to fix the issue defined in https://github.com/JasterV/test-context/issues/53 & some refactors
This commit is contained in:
Víctor Martínez 2025-11-04 17:27:53 +01:00 committed by GitHub
parent 90f0f92b66
commit ef4bdb28a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 77 additions and 104 deletions

View file

@ -3,9 +3,9 @@ resolver = "2"
members = ["test-context", "test-context-macros"] members = ["test-context", "test-context-macros"]
[workspace.package] [workspace.package]
edition = "2021" edition = "2024"
version = "0.5.0" version = "0.5.0"
rust-version = "1.75.0" rust-version = "1.91.0"
homepage = "https://github.com/JasterV/test-context" homepage = "https://github.com/JasterV/test-context"
repository = "https://github.com/JasterV/test-context" repository = "https://github.com/JasterV/test-context"
authors = [ authors = [

View file

@ -1,3 +1,3 @@
[toolchain] [toolchain]
channel = "1.90" channel = "1.91"
components = ["rustfmt", "clippy", "rust-src", "rust-analyzer"] components = ["rustfmt", "clippy", "rust-src", "rust-analyzer"]

View file

@ -1,4 +1,4 @@
use syn::{parse::Parse, Token, Type}; use syn::{Token, Type, parse::Parse};
pub(crate) struct TestContextArgs { pub(crate) struct TestContextArgs {
pub(crate) context_type: Type, pub(crate) context_type: Type,

View file

@ -3,7 +3,7 @@ mod args;
use args::TestContextArgs; use args::TestContextArgs;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{Block, Ident}; use syn::Ident;
/// Macro to use on tests to add the setup/teardown functionality of your context. /// Macro to use on tests to add the setup/teardown functionality of your context.
/// ///
@ -28,111 +28,82 @@ use syn::{Block, Ident};
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(attr as TestContextArgs); let args = syn::parse_macro_input!(attr as TestContextArgs);
let input = syn::parse_macro_input!(item as syn::ItemFn); let input = syn::parse_macro_input!(item as syn::ItemFn);
let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone());
let input = refactor_input_body(input, &args, context_arg_name);
quote! { #input }.into()
}
fn refactor_input_body(
mut input: syn::ItemFn,
args: &TestContextArgs,
context_arg_name: Option<Ident>,
) -> syn::ItemFn {
let context_type = &args.context_type;
let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx"));
let result_name = format_ident!("wrapped_result");
let body = &input.block;
let is_async = input.sig.asyncness.is_some(); let is_async = input.sig.asyncness.is_some();
let (new_input, context_arg_name) = let body = match (is_async, args.skip_teardown) {
extract_and_remove_context_arg(input.clone(), args.context_type.clone()); (true, true) => {
quote! {
let wrapper_body = if is_async { use test_context::futures::FutureExt;
async_wrapper_body(args, &context_arg_name, &input.block) let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
} else { let #context_arg_name = &mut __context;
sync_wrapper_body(args, &context_arg_name, &input.block) let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
};
let mut result_input = new_input;
result_input.block = Box::new(syn::parse2(wrapper_body).unwrap());
quote! { #result_input }.into()
}
fn async_wrapper_body(
args: TestContextArgs,
context_arg_name: &Option<syn::Ident>,
body: &Block,
) -> proc_macro2::TokenStream {
let context_type = args.context_type;
let result_name = format_ident!("wrapped_result");
let binding = format_ident!("test_ctx");
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
let body = if args.skip_teardown {
quote! {
let #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
}
} else {
quote! {
let mut #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(#context_name).await;
}
};
let handle_wrapped_result = handle_result(result_name);
quote! {
{
use test_context::futures::FutureExt;
#body
#handle_wrapped_result
}
}
}
fn sync_wrapper_body(
args: TestContextArgs,
context_arg_name: &Option<syn::Ident>,
body: &Block,
) -> proc_macro2::TokenStream {
let context_type = args.context_type;
let result_name = format_ident!("wrapped_result");
let binding = format_ident!("test_ctx");
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
let body = if args.skip_teardown {
quote! {
let mut #context_name = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let #context_name = &mut #context_name;
#body
}));
}
} else {
quote! {
let mut #context_name = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#body
}));
<#context_type as test_context::TestContext>::teardown(#context_name);
}
};
let handle_wrapped_result = handle_result(result_name);
quote! {
{
#body
#handle_wrapped_result
}
}
}
fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
quote! {
match #result_name {
Ok(value) => value,
Err(err) => {
std::panic::resume_unwind(err);
} }
} }
} (true, false) => {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
let #context_arg_name = &mut __context;
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
}
}
(false, true) => {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let #context_arg_name = &mut __context;
#body
}));
}
}
(false, false) => {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let #context_arg_name = &mut __context;
#body
}));
<#context_type as test_context::TestContext>::teardown(__context);
}
}
};
let body = quote! {
{
#body
match #result_name {
Ok(value) => value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
}
};
input.block = Box::new(syn::parse2(body).unwrap());
input
} }
fn extract_and_remove_context_arg( fn remove_context_arg(
mut input: syn::ItemFn, mut input: syn::ItemFn,
expected_context_type: syn::Type, expected_context_type: syn::Type,
) -> (syn::ItemFn, Option<syn::Ident>) { ) -> (syn::ItemFn, Option<syn::Ident>) {
@ -154,10 +125,12 @@ fn extract_and_remove_context_arg(
} }
} }
} }
new_args.push(arg.clone()); new_args.push(arg.clone());
} }
input.sig.inputs = new_args; input.sig.inputs = new_args;
(input, context_arg_name) (input, context_arg_name)
} }

View file

@ -1,7 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use rstest::rstest; use rstest::rstest;
use test_context::{test_context, AsyncTestContext, TestContext}; use test_context::{AsyncTestContext, TestContext, test_context};
struct Context { struct Context {
n: u32, n: u32,