refactor: clean up the macro implementation

This commit is contained in:
Victor Martinez 2024-02-26 15:56:59 +01:00
parent 890a3232e9
commit 0fed446341
2 changed files with 52 additions and 43 deletions

View file

@ -16,3 +16,4 @@ proc-macro = true
[dependencies] [dependencies]
quote = "1.0.3" quote = "1.0.3"
syn = { version = "^2", features = ["full"] } syn = { version = "^2", features = ["full"] }
proc-macro2 = "1.0.78"

View file

@ -1,5 +1,6 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::Ident;
/// Macro to use on tests to add the setup/teardown functionality of your context. /// Macro to use on tests to add the setup/teardown functionality of your context.
/// ///
@ -25,53 +26,19 @@ use quote::{format_ident, quote};
pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
let context_type = syn::parse_macro_input!(attr as syn::Ident); let context_type = syn::parse_macro_input!(attr as syn::Ident);
let input = syn::parse_macro_input!(item as syn::ItemFn); let input = syn::parse_macro_input!(item as syn::ItemFn);
let ret = &input.sig.output; let ret = &input.sig.output;
let name = &input.sig.ident; let name = &input.sig.ident;
let arguments = &input.sig.inputs; let arguments = &input.sig.inputs;
let inner_body = &input.block; let body = &input.block;
let attrs = &input.attrs; let attrs = &input.attrs;
let is_async = input.sig.asyncness.is_some(); let is_async = input.sig.asyncness.is_some();
let wrapped_name = format_ident!("__test_context_wrapped_{}", name); let wrapped_name = format_ident!("__test_context_wrapped_{}", name);
let outer_body = if is_async { let wrapper_body = if is_async {
quote! { async_wrapper_body(context_type, &wrapped_name)
{
use test_context::futures::FutureExt;
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
let wrapped_ctx = &mut ctx;
let result = async move {
std::panic::AssertUnwindSafe(
#wrapped_name(wrapped_ctx)
).catch_unwind().await
}.await;
<#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
match result {
Ok(returned_value) => returned_value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
}
}
} else { } else {
quote! { sync_wrapper_body(context_type, &wrapped_name)
{
let mut ctx = <#context_type as test_context::TestContext>::setup();
let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx);
let result = std::panic::catch_unwind(move || {
#wrapped_name(*wrapper)
});
<#context_type as test_context::TestContext>::teardown(ctx);
match result {
Ok(returned_value) => returned_value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
}
}
}; };
let async_tag = if is_async { let async_tag = if is_async {
@ -80,11 +47,52 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
quote! {} quote! {}
}; };
let result = quote! { quote! {
#(#attrs)* #(#attrs)*
#async_tag fn #name() #ret #outer_body #async_tag fn #name() #ret #wrapper_body
#async_tag fn #wrapped_name(#arguments) #ret #inner_body #async_tag fn #wrapped_name(#arguments) #ret #body
}; }
result.into() .into()
}
fn async_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream {
quote! {
{
use test_context::futures::FutureExt;
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
let wrapped_ctx = &mut ctx;
let result = async move {
std::panic::AssertUnwindSafe(
#wrapped_name(wrapped_ctx)
).catch_unwind().await
}.await;
<#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
match result {
Ok(returned_value) => returned_value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
}
}
}
fn sync_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream {
quote! {
{
let mut ctx = <#context_type as test_context::TestContext>::setup();
let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx);
let result = std::panic::catch_unwind(move || {
#wrapped_name(*wrapper)
});
<#context_type as test_context::TestContext>::teardown(ctx);
match result {
Ok(returned_value) => returned_value,
Err(err) => {
std::panic::resume_unwind(err);
}
}
}
}
} }