From a58e13ecf5fb2f653f050565c25bb196766d5a79 Mon Sep 17 00:00:00 2001 From: Changju Lee Date: Mon, 27 Jan 2025 19:53:39 +0900 Subject: [PATCH] feat: Support generic types in test_context macro (#44) - Allow using test_context with generic contexts like MyContext --- README.md | 36 ++++++++++++++++++++++++++++++ test-context-macros/src/args.rs | 12 +++++----- test-context/tests/test.rs | 39 +++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index b31e05a..e8bf947 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,42 @@ fn test_works(ctx: &mut MyContext) { } ``` +with generic types, you can use same type with different values + +```rust +use test_context::{test_context, TestContext}; +use std::marker::PhantomData; + +struct MyGenericContext { + value: u32, + _marker: PhantomData, +} + +impl TestContext for MyGenericContext { + fn setup() -> MyGenericContext { + MyGenericContext { value: 1, _marker: PhantomData } + } +} + +#[test_context(MyGenericContext)] +#[test] +fn test_generic_type(ctx: &mut MyGenericContext) { + assert_eq!(ctx.value, 1); +} + +impl TestContext for MyGenericContext { + fn setup() -> MyGenericContext { + MyGenericContext { value: 2, _marker: PhantomData } + } +} + +#[test_context(MyGenericContext)] +#[test] +fn test_generic_type_u32(ctx: &mut MyGenericContext) { + assert_eq!(ctx.value, 2); +} +``` + Alternatively, you can use `async` functions in your test context by using the `AsyncTestContext`. diff --git a/test-context-macros/src/args.rs b/test-context-macros/src/args.rs index 7124bae..7a06b63 100644 --- a/test-context-macros/src/args.rs +++ b/test-context-macros/src/args.rs @@ -1,14 +1,14 @@ -use syn::{parse::Parse, Ident, Token}; +use syn::{parse::Parse, Token, Type}; pub(crate) struct TestContextArgs { - pub(crate) context_type: Ident, + pub(crate) context_type: Type, pub(crate) skip_teardown: bool, } impl Parse for TestContextArgs { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut skip_teardown = false; - let mut context_type: Option = None; + let mut context_type: Option = None; while !input.is_empty() { let lookahead = input.lookahead1(); @@ -18,10 +18,8 @@ impl Parse for TestContextArgs { } let _ = input.parse::()?; skip_teardown = true; - } else if lookahead.peek(Ident) { - if context_type.is_some() { - return Err(input.error("expected only a single type identifier")); - } + } else if context_type.is_none() { + // Parse any valid Rust type, including generic types context_type = Some(input.parse()?); } else if lookahead.peek(Token![,]) { let _ = input.parse::()?; diff --git a/test-context/tests/test.rs b/test-context/tests/test.rs index a0c2789..4cdcd88 100644 --- a/test-context/tests/test.rs +++ b/test-context/tests/test.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; + use test_context::{test_context, AsyncTestContext, TestContext}; struct Context { @@ -47,6 +49,43 @@ fn includes_return_value() { assert_eq!(return_value_func(), 1); } +struct ContextGeneric { + n: u32, + _marker: PhantomData, +} + +struct ContextGenericType1; +impl TestContext for ContextGeneric { + fn setup() -> Self { + Self { + n: 1, + _marker: PhantomData, + } + } +} + +#[test_context(ContextGeneric)] +#[test] +fn test_generic_type(ctx: &mut ContextGeneric) { + assert_eq!(ctx.n, 1); +} + +struct ContextGenericType2; +impl TestContext for ContextGeneric { + fn setup() -> Self { + Self { + n: 2, + _marker: PhantomData, + } + } +} + +#[test_context(ContextGeneric)] +#[test] +fn test_generic_type_other(ctx: &mut ContextGeneric) { + assert_eq!(ctx.n, 2); +} + struct AsyncContext { n: u32, }