feat: add support for the skip_teardown keyword

This commit is contained in:
Victor Martinez 2024-02-26 18:22:18 +01:00
parent 8fd7841baa
commit cd3d9bf943
3 changed files with 134 additions and 28 deletions

View file

@ -0,0 +1,43 @@
use syn::{parse::Parse, Ident, Token};
pub(crate) struct TestContextArgs {
pub(crate) context_type: Ident,
pub(crate) skip_teardown: bool,
}
impl Parse for TestContextArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut skip_teardown = false;
let mut context_type: Option<Ident> = None;
while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(kw::skip_teardown) {
if skip_teardown {
return Err(input.error("expected only a single `skip_teardown` argument"));
}
let _ = input.parse::<kw::skip_teardown>()?;
skip_teardown = true;
} else if lookahead.peek(Ident) {
if context_type.is_some() {
return Err(input.error("expected only a single type identifier"));
}
context_type = Some(input.parse()?);
} else if lookahead.peek(Token![,]) {
let _ = input.parse::<Token![,]>()?;
} else {
return Err(lookahead.error());
}
}
Ok(TestContextArgs {
context_type: context_type
.ok_or(input.error("expected at least one type identifier"))?,
skip_teardown,
})
}
}
mod kw {
syn::custom_keyword!(skip_teardown);
}

View file

@ -1,3 +1,6 @@
mod args;
use args::TestContextArgs;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
@ -24,7 +27,8 @@ use syn::Ident;
/// ```
#[proc_macro_attribute]
pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
let context_type = syn::parse_macro_input!(attr as syn::Ident);
let args = syn::parse_macro_input!(attr as TestContextArgs);
let input = syn::parse_macro_input!(item as syn::ItemFn);
let ret = &input.sig.output;
let name = &input.sig.ident;
@ -36,9 +40,9 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
let wrapped_name = format_ident!("__test_context_wrapped_{}", name);
let wrapper_body = if is_async {
async_wrapper_body(context_type, &wrapped_name)
async_wrapper_body(args, &wrapped_name)
} else {
sync_wrapper_body(context_type, &wrapped_name)
sync_wrapper_body(args, &wrapped_name)
};
let async_tag = if is_async {
@ -56,42 +60,81 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
.into()
}
fn async_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream {
quote! {
{
use test_context::futures::FutureExt;
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
let wrapped_ctx = &mut ctx;
let result = async move {
fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
let context_type = args.context_type;
let result_name = format_ident!("wrapped_result");
let body = if args.skip_teardown {
quote! {
let ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
let #result_name = async move {
std::panic::AssertUnwindSafe(
#wrapped_name(wrapped_ctx)
#wrapped_name(ctx)
).catch_unwind().await
}.await;
}
} else {
quote! {
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
let ctx_reference = &mut ctx;
let #result_name = async move {
std::panic::AssertUnwindSafe(
#wrapped_name(ctx_reference)
).catch_unwind().await
}.await;
<#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
match result {
Ok(returned_value) => returned_value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
}
};
let handle_wrapped_result = handle_result(result_name);
quote! {
{
use test_context::futures::FutureExt;
#body
#handle_wrapped_result
}
}
}
fn sync_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream {
quote! {
{
fn sync_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
let context_type = args.context_type;
let result_name = format_ident!("wrapped_result");
let body = if args.skip_teardown {
quote! {
let ctx = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(move || {
#wrapped_name(ctx)
});
}
} else {
quote! {
let mut ctx = <#context_type as test_context::TestContext>::setup();
let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx);
let result = std::panic::catch_unwind(move || {
#wrapped_name(*wrapper)
let mut pointer = std::panic::AssertUnwindSafe(&mut ctx);
let #result_name = std::panic::catch_unwind(move || {
#wrapped_name(*pointer)
});
<#context_type as test_context::TestContext>::teardown(ctx);
match result {
Ok(returned_value) => returned_value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
};
let handle_wrapped_result = handle_result(result_name);
quote! {
{
#body
#handle_wrapped_result
}
}
}
fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
quote! {
match #result_name {
Ok(value) => value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
}

View file

@ -111,3 +111,23 @@ fn use_different_name(test_data: &mut Context) {
async fn use_different_name_async(test_data: &mut AsyncContext) {
assert_eq!(test_data.n, 1);
}
struct TeardownPanicContext {}
impl AsyncTestContext for TeardownPanicContext {
async fn setup() -> Self {
Self {}
}
async fn teardown(self) {
panic!("boom!");
}
}
#[test_context(TeardownPanicContext, skip_teardown)]
#[tokio::test]
async fn test_async_skip_teardown(mut _ctx: TeardownPanicContext) {}
#[test_context(TeardownPanicContext, skip_teardown)]
#[test]
fn test_sync_skip_teardown(mut _ctx: TeardownPanicContext) {}