feat: Support generic types in test_context macro

- Allow using test_context with generic contexts like MyContext<T>
This commit is contained in:
Changju 2025-01-14 12:08:47 +09:00
parent 7cab7279b4
commit dcf2bd7b46
3 changed files with 80 additions and 7 deletions

View file

@ -31,6 +31,42 @@ fn test_works(ctx: &mut MyContext) {
}
```
with generic types, you can use same type with different values
```rust
use test_context::{test_context, TestContext};
use std::marker::PhantomData;
struct MyGenericContext<T> {
value: u32,
_marker: PhantomData<T>,
}
impl TestContext for MyGenericContext<String> {
fn setup() -> MyGenericContext<String> {
MyGenericContext { value: 1, _marker: PhantomData }
}
}
#[test_context(MyGenericContext<String>)]
#[test]
fn test_generic_type(ctx: &mut MyGenericContext<String>) {
assert_eq!(ctx.value, 1);
}
impl TestContext for MyGenericContext<u32> {
fn setup() -> MyGenericContext<u32> {
MyGenericContext { value: 2, _marker: PhantomData }
}
}
#[test_context(MyGenericContext<u32>)]
#[test]
fn test_generic_type_u32(ctx: &mut MyGenericContext<u32>) {
assert_eq!(ctx.value, 2);
}
```
Alternatively, you can use `async` functions in your test context by using the
`AsyncTestContext`.

View file

@ -1,14 +1,14 @@
use syn::{parse::Parse, Ident, Token};
use syn::{parse::Parse, Token, Type};
pub(crate) struct TestContextArgs {
pub(crate) context_type: Ident,
pub(crate) context_type: Type,
pub(crate) skip_teardown: bool,
}
impl Parse for TestContextArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut skip_teardown = false;
let mut context_type: Option<Ident> = None;
let mut context_type: Option<Type> = None;
while !input.is_empty() {
let lookahead = input.lookahead1();
@ -18,10 +18,8 @@ impl Parse for TestContextArgs {
}
let _ = input.parse::<kw::skip_teardown>()?;
skip_teardown = true;
} else if lookahead.peek(Ident) {
if context_type.is_some() {
return Err(input.error("expected only a single type identifier"));
}
} else if context_type.is_none() {
// Parse any valid Rust type, including generic types
context_type = Some(input.parse()?);
} else if lookahead.peek(Token![,]) {
let _ = input.parse::<Token![,]>()?;

View file

@ -1,3 +1,5 @@
use std::marker::PhantomData;
use test_context::{test_context, AsyncTestContext, TestContext};
struct Context {
@ -47,6 +49,43 @@ fn includes_return_value() {
assert_eq!(return_value_func(), 1);
}
struct ContextGeneric<T> {
n: u32,
_marker: PhantomData<T>,
}
struct ContextGenericType1;
impl TestContext for ContextGeneric<ContextGenericType1> {
fn setup() -> Self {
Self {
n: 1,
_marker: PhantomData,
}
}
}
#[test_context(ContextGeneric<ContextGenericType1>)]
#[test]
fn test_generic_type(ctx: &mut ContextGeneric<ContextGenericType1>) {
assert_eq!(ctx.n, 1);
}
struct ContextGenericType2;
impl TestContext for ContextGeneric<ContextGenericType2> {
fn setup() -> Self {
Self {
n: 2,
_marker: PhantomData,
}
}
}
#[test_context(ContextGeneric<ContextGenericType2>)]
#[test]
fn test_generic_type_other(ctx: &mut ContextGeneric<ContextGenericType2>) {
assert_eq!(ctx.n, 2);
}
struct AsyncContext {
n: u32,
}