From dc2992cbff7251b3348fa3602b6a41a0d26d4448 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Mart=C3=ADnez?= <49537445+JasterV@users.noreply.github.com> Date: Tue, 27 Feb 2024 23:26:40 +0100 Subject: [PATCH] feat: add support for the skip_teardown keyword (#40) * feat: add support for the skip_teardown keyword --- README.md | 24 ++++++++ test-context-macros/src/args.rs | 43 +++++++++++++++ test-context-macros/src/lib.rs | 97 +++++++++++++++++++++++---------- test-context/src/lib.rs | 24 ++++++++ test-context/tests/test.rs | 20 +++++++ 5 files changed, 179 insertions(+), 29 deletions(-) create mode 100644 test-context-macros/src/args.rs diff --git a/README.md b/README.md index 42bc36e..b31e05a 100644 --- a/README.md +++ b/README.md @@ -69,4 +69,28 @@ async fn test_works(ctx: &mut MyAsyncContext) { } ``` +## Skipping the teardown execution + +If what you need is to take full **ownership** of the context and don't care about the +teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro +like this: + +```rust + use test_context::{test_context, TestContext}; + + struct MyContext {} + + impl TestContext for MyContext { + fn setup() -> MyContext { + MyContext {} + } + } + +#[test_context(MyContext, skip_teardown)] +#[test] +fn test_without_teardown(ctx: MyContext) { + // Perform any operations that require full ownership of your context +} +``` + License: MIT diff --git a/test-context-macros/src/args.rs b/test-context-macros/src/args.rs new file mode 100644 index 0000000..7124bae --- /dev/null +++ b/test-context-macros/src/args.rs @@ -0,0 +1,43 @@ +use syn::{parse::Parse, Ident, Token}; + +pub(crate) struct TestContextArgs { + pub(crate) context_type: Ident, + pub(crate) skip_teardown: bool, +} + +impl Parse for TestContextArgs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut skip_teardown = false; + let mut context_type: Option = None; + + while !input.is_empty() { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::skip_teardown) { + if skip_teardown { + return Err(input.error("expected only a single `skip_teardown` argument")); + } + let _ = input.parse::()?; + skip_teardown = true; + } else if lookahead.peek(Ident) { + if context_type.is_some() { + return Err(input.error("expected only a single type identifier")); + } + context_type = Some(input.parse()?); + } else if lookahead.peek(Token![,]) { + let _ = input.parse::()?; + } else { + return Err(lookahead.error()); + } + } + + Ok(TestContextArgs { + context_type: context_type + .ok_or(input.error("expected at least one type identifier"))?, + skip_teardown, + }) + } +} + +mod kw { + syn::custom_keyword!(skip_teardown); +} diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index 28ca235..1908826 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -1,3 +1,6 @@ +mod args; + +use args::TestContextArgs; use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::Ident; @@ -24,7 +27,8 @@ use syn::Ident; /// ``` #[proc_macro_attribute] pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { - let context_type = syn::parse_macro_input!(attr as syn::Ident); + let args = syn::parse_macro_input!(attr as TestContextArgs); + let input = syn::parse_macro_input!(item as syn::ItemFn); let ret = &input.sig.output; let name = &input.sig.ident; @@ -36,9 +40,9 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { let wrapped_name = format_ident!("__test_context_wrapped_{}", name); let wrapper_body = if is_async { - async_wrapper_body(context_type, &wrapped_name) + async_wrapper_body(args, &wrapped_name) } else { - sync_wrapper_body(context_type, &wrapped_name) + sync_wrapper_body(args, &wrapped_name) }; let async_tag = if is_async { @@ -56,42 +60,77 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { .into() } -fn async_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream { +fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream { + let context_type = args.context_type; + let result_name = format_ident!("wrapped_result"); + + let body = if args.skip_teardown { + quote! { + let ctx = <#context_type as test_context::AsyncTestContext>::setup().await; + let #result_name = std::panic::AssertUnwindSafe( + #wrapped_name(ctx) + ).catch_unwind().await; + } + } else { + quote! { + let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await; + let ctx_reference = &mut ctx; + let #result_name = std::panic::AssertUnwindSafe( + #wrapped_name(ctx_reference) + ).catch_unwind().await; + <#context_type as test_context::AsyncTestContext>::teardown(ctx).await; + } + }; + + let handle_wrapped_result = handle_result(result_name); + quote! { { use test_context::futures::FutureExt; - let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await; - let wrapped_ctx = &mut ctx; - let result = async move { - std::panic::AssertUnwindSafe( - #wrapped_name(wrapped_ctx) - ).catch_unwind().await - }.await; - <#context_type as test_context::AsyncTestContext>::teardown(ctx).await; - match result { - Ok(returned_value) => returned_value, - Err(err) => { - std::panic::resume_unwind(err); - } - } + #body + #handle_wrapped_result } } } -fn sync_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream { - quote! { - { +fn sync_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream { + let context_type = args.context_type; + let result_name = format_ident!("wrapped_result"); + + let body = if args.skip_teardown { + quote! { + let ctx = <#context_type as test_context::TestContext>::setup(); + let #result_name = std::panic::catch_unwind(move || { + #wrapped_name(ctx) + }); + } + } else { + quote! { let mut ctx = <#context_type as test_context::TestContext>::setup(); - let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx); - let result = std::panic::catch_unwind(move || { - #wrapped_name(*wrapper) + let mut pointer = std::panic::AssertUnwindSafe(&mut ctx); + let #result_name = std::panic::catch_unwind(move || { + #wrapped_name(*pointer) }); <#context_type as test_context::TestContext>::teardown(ctx); - match result { - Ok(returned_value) => returned_value, - Err(err) => { - std::panic::resume_unwind(err); - } + } + }; + + let handle_wrapped_result = handle_result(result_name); + + quote! { + { + #body + #handle_wrapped_result + } + } +} + +fn handle_result(result_name: Ident) -> proc_macro2::TokenStream { + quote! { + match #result_name { + Ok(value) => value, + Err(err) => { + std::panic::resume_unwind(err); } } } diff --git a/test-context/src/lib.rs b/test-context/src/lib.rs index 3ea5692..08d01cd 100644 --- a/test-context/src/lib.rs +++ b/test-context/src/lib.rs @@ -77,6 +77,30 @@ //! assert_eq!(ctx.value, "Hello, World!"); //! } //! ``` +//! +//! # Skipping the teardown execution +//! +//! If what you need is to take full __ownership__ of the context and don't care about the +//! teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro +//! like this: +//! +//! ```no_run +//! use test_context::{test_context, TestContext}; +//! +//! struct MyContext {} +//! +//! impl TestContext for MyContext { +//! fn setup() -> MyContext { +//! MyContext {} +//! } +//! } +//! +//! #[test_context(MyContext, skip_teardown)] +//! #[test] +//! fn test_without_teardown(ctx: MyContext) { +//! // Perform any operations that require full ownership of your context +//! } +//! ``` // Reimported to allow for use in the macro. pub use futures; diff --git a/test-context/tests/test.rs b/test-context/tests/test.rs index 5319a2f..a0c2789 100644 --- a/test-context/tests/test.rs +++ b/test-context/tests/test.rs @@ -111,3 +111,23 @@ fn use_different_name(test_data: &mut Context) { async fn use_different_name_async(test_data: &mut AsyncContext) { assert_eq!(test_data.n, 1); } + +struct TeardownPanicContext {} + +impl AsyncTestContext for TeardownPanicContext { + async fn setup() -> Self { + Self {} + } + + async fn teardown(self) { + panic!("boom!"); + } +} + +#[test_context(TeardownPanicContext, skip_teardown)] +#[tokio::test] +async fn test_async_skip_teardown(mut _ctx: TeardownPanicContext) {} + +#[test_context(TeardownPanicContext, skip_teardown)] +#[test] +fn test_sync_skip_teardown(mut _ctx: TeardownPanicContext) {}