Added the ability to work with rstest (#51)

* [FIX-CONFLICT-WITH-RSTEST-V2]: semi-stable version, fix test_context function, remove unnecessary field, remove wrapper function, add  ability to work with rstest and same

* [FIX-CONFLICT-WITH-RSTEST-V2]: stable version, fix bugs with AssertUnwindSafe in sync_wrapper_body, remove redundant code, fix two tests, add notes for this tests

* [FIX-CONFLICT-WITH-RSTEST-V2]: stable version, fix context name extraction function, fix readme, fix docs, increase version of crate

---------

Co-authored-by: Vyacheslav Volkov <v.volkov@st-falcon.ru>
This commit is contained in:
Vyacheslav 2025-10-29 12:19:40 +03:00 committed by GitHub
parent 5ce361bc24
commit 9227f3e693
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 208 additions and 56 deletions

View file

@ -4,7 +4,7 @@ members = ["test-context", "test-context-macros"]
[workspace.package]
edition = "2021"
version = "0.4.0"
version = "0.5.0"
rust-version = "1.75.0"
homepage = "https://github.com/JasterV/test-context"
repository = "https://github.com/JasterV/test-context"

View file

@ -118,7 +118,7 @@ enable the optional `tokio-runtime` feature so those steps run inside a Tokio ru
```toml
[dependencies]
test-context = { version = "0.4", features = ["tokio-runtime"] }
test-context = { version = "0.5", features = ["tokio-runtime"] }
```
With this feature, the crate tries to reuse an existing runtime; if none is present, it creates
@ -127,7 +127,7 @@ tests annotated with `#[tokio::test]` continue to work as usual without the feat
## Skipping the teardown execution
If what you need is to take full **ownership** of the context and don't care about the
Also, if you don't care about the
teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
like this:
@ -144,9 +144,73 @@ like this:
#[test_context(MyContext, skip_teardown)]
#[test]
fn test_without_teardown(ctx: MyContext) {
fn test_without_teardown(ctx: &mut MyContext) {
// Perform any operations that require full ownership of your context
}
```
## ⚠️ Ensure that the context type specified in the macro matches the test function argument type exactly
The error occurs when a context type with an absolute path is mixed with an it's alias.
For example:
```
mod database {
use test_context::TestContext;
pub struct Connection;
impl TestContext for :Connection {
fn setup() -> Self {Connection}
fn teardown(self) {...}
}
}
```
✅The following code will work:
```
use database::Connection as DbConn;
#[test_context(DbConn)]
#[test]
fn test1(ctx: &mut DbConn) {
//some test logic
}
// or
use database::Connection
#[test_context(database::Connection)]
#[test]
fn test1(ctx: &mut database::Connection) {
//some test logic
}
```
❌The following code will not work:
```
use database::Connection as DbConn;
#[test_context(database::Connection)]
#[test]
fn test1(ctx: &mut DbConn) {
//some test logic
}
// or
use database::Connection as DbConn;
#[test_context(DbConn)]
#[test]
fn test1(ctx: &mut database::Connection) {
//some test logic
}
```
Type mismatches will cause context parsing to fail during either static analysis or compilation.
License: MIT

View file

@ -3,7 +3,7 @@ mod args;
use args::TestContextArgs;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
use syn::{Block, Ident};
/// Macro to use on tests to add the setup/teardown functionality of your context.
///
@ -30,55 +30,44 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(attr as TestContextArgs);
let input = syn::parse_macro_input!(item as syn::ItemFn);
let ret = &input.sig.output;
let name = &input.sig.ident;
let arguments = &input.sig.inputs;
let body = &input.block;
let attrs = &input.attrs;
let is_async = input.sig.asyncness.is_some();
let wrapped_name = format_ident!("__test_context_wrapped_{}", name);
let (new_input, context_arg_name) =
extract_and_remove_context_arg(input.clone(), args.context_type.clone());
let wrapper_body = if is_async {
async_wrapper_body(args, &wrapped_name)
async_wrapper_body(args, &context_arg_name, &input.block)
} else {
sync_wrapper_body(args, &wrapped_name)
sync_wrapper_body(args, &context_arg_name, &input.block)
};
let async_tag = if is_async {
quote! { async }
} else {
quote! {}
};
let mut result_input = new_input;
result_input.block = Box::new(syn::parse2(wrapper_body).unwrap());
quote! {
#(#attrs)*
#async_tag fn #name() #ret #wrapper_body
#async_tag fn #wrapped_name(#arguments) #ret #body
}
.into()
quote! { #result_input }.into()
}
fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
fn async_wrapper_body(
args: TestContextArgs,
context_arg_name: &Option<syn::Ident>,
body: &Block,
) -> proc_macro2::TokenStream {
let context_type = args.context_type;
let result_name = format_ident!("wrapped_result");
let binding = format_ident!("test_ctx");
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
let body = if args.skip_teardown {
quote! {
let ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
let #result_name = std::panic::AssertUnwindSafe(
#wrapped_name(ctx)
).catch_unwind().await;
let #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
}
} else {
quote! {
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
let ctx_reference = &mut ctx;
let #result_name = std::panic::AssertUnwindSafe(
#wrapped_name(ctx_reference)
).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
let mut #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(#context_name).await;
}
};
@ -93,25 +82,32 @@ fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro
}
}
fn sync_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
fn sync_wrapper_body(
args: TestContextArgs,
context_arg_name: &Option<syn::Ident>,
body: &Block,
) -> proc_macro2::TokenStream {
let context_type = args.context_type;
let result_name = format_ident!("wrapped_result");
let binding = format_ident!("test_ctx");
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
let body = if args.skip_teardown {
quote! {
let ctx = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(move || {
#wrapped_name(ctx)
});
let mut #context_name = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let #context_name = &mut #context_name;
#body
}));
}
} else {
quote! {
let mut ctx = <#context_type as test_context::TestContext>::setup();
let mut pointer = std::panic::AssertUnwindSafe(&mut ctx);
let #result_name = std::panic::catch_unwind(move || {
#wrapped_name(*pointer)
});
<#context_type as test_context::TestContext>::teardown(ctx);
let mut #context_name = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#body
}));
<#context_type as test_context::TestContext>::teardown(#context_name);
}
};
@ -135,3 +131,40 @@ fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
}
}
}
fn extract_and_remove_context_arg(
mut input: syn::ItemFn,
expected_context_type: syn::Type,
) -> (syn::ItemFn, Option<syn::Ident>) {
let mut context_arg_name = None;
let mut new_args = syn::punctuated::Punctuated::new();
for arg in &input.sig.inputs {
// Extract function arg:
if let syn::FnArg::Typed(pat_type) = arg {
// Extract arg identifier:
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
// Check that context arg is only ref or mutable ref:
if let syn::Type::Reference(type_ref) = &*pat_type.ty {
// Check that context has expected type:
if types_equal(&type_ref.elem, &expected_context_type) {
context_arg_name = Some(pat_ident.ident.clone());
continue;
}
}
}
}
new_args.push(arg.clone());
}
input.sig.inputs = new_args;
(input, context_arg_name)
}
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
return a_path.path.segments.last().unwrap().ident
== b_path.path.segments.last().unwrap().ident;
}
quote!(#a).to_string() == quote!(#b).to_string()
}

View file

@ -13,8 +13,9 @@ authors.workspace = true
license.workspace = true
[dependencies]
test-context-macros = { version = "0.4.0", path = "../test-context-macros/" }
test-context-macros = { version = "0.5.0", path = "../test-context-macros/" }
futures = "0.3"
[dev-dependencies]
rstest = "0.26.1"
tokio = { version = "1.0", features = ["macros", "rt"] }

View file

@ -103,9 +103,8 @@
//!
//! # Skipping the teardown execution
//!
//! If what you need is to take full __ownership__ of the context and don't care about the
//! teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
//! like this:
//! Also, if you don't care about the teardown execution for a specific test,
//! you can use the `skip_teardown` keyword on the macro like this:
//!
//! ```no_run
//! use test_context::{test_context, TestContext};
@ -120,7 +119,7 @@
//!
//! #[test_context(MyContext, skip_teardown)]
//! #[test]
//! fn test_without_teardown(ctx: MyContext) {
//! fn test_without_teardown(ctx: &mut MyContext) {
//! // Perform any operations that require full ownership of your context
//! }
//! ```

View file

@ -1,5 +1,6 @@
use std::marker::PhantomData;
use rstest::rstest;
use test_context::{test_context, AsyncTestContext, TestContext};
struct Context {
@ -165,11 +166,11 @@ impl AsyncTestContext for TeardownPanicContext {
#[test_context(TeardownPanicContext, skip_teardown)]
#[tokio::test]
async fn test_async_skip_teardown(mut _ctx: TeardownPanicContext) {}
async fn test_async_skip_teardown(_ctx: &mut TeardownPanicContext) {}
#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown(mut _ctx: TeardownPanicContext) {}
fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {}
struct GenericContext<T> {
contents: T,
@ -209,6 +210,60 @@ fn test_generic_with_string(ctx: &mut GenericContext<String>) {
#[test_context(GenericContext<u64>)]
#[tokio::test]
async fn test_async_generic(ctx: &mut GenericContext<u64>) {
assert_eq!(ctx.contents, 1);
async fn test_async_generic(test_ctx: &mut GenericContext<u64>) {
assert_eq!(test_ctx.contents, 1);
}
struct MyAsyncContext {
what_the_of_life: u32,
}
impl AsyncTestContext for MyAsyncContext {
async fn setup() -> Self {
println!("I guess...");
MyAsyncContext {
what_the_of_life: 42,
}
}
async fn teardown(self) {
println!("Answer is {}", self.what_the_of_life);
drop(self);
}
}
#[test_context(MyAsyncContext)]
#[rstest]
#[case("Hello, World!")]
#[tokio::test]
async fn test_async_generic_with_sync(#[case] value: String, test_ctx: &mut MyAsyncContext) {
println!("Something happens sync... {}", value);
assert_eq!(test_ctx.what_the_of_life, 42);
}
struct MyContext {
what_the_of_life: u32,
}
impl TestContext for MyContext {
fn setup() -> Self {
println!("I guess...");
MyContext {
what_the_of_life: 42,
}
}
fn teardown(self) {
println!("Answer is {}", self.what_the_of_life);
drop(self);
}
}
#[test_context(MyContext)]
#[rstest]
#[case("Hello, World!")]
#[test]
fn test_async_generic_with_async(test_ctx: &mut MyContext, #[case] value: String) {
println!("Something happens async... {}", value);
assert_eq!(test_ctx.what_the_of_life, 42);
}