refactor: internal code to make it more understandable

This commit is contained in:
JasterV 2025-11-06 13:37:44 +01:00
parent 059eb65276
commit 36a23a499c
3 changed files with 152 additions and 143 deletions

View file

@ -1,26 +1,11 @@
mod args;
mod macro_args;
mod test_args;
use args::TestContextArgs;
use crate::test_args::{ContextArg, ContextArgMode, TestArg};
use macro_args::TestContextArgs;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
#[derive(PartialEq, Eq, Debug)]
enum ContextArgMode {
/// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`.
Owned,
/// The argument was passed as an immutable reference (`&ContextType`).
Reference,
/// The argument was passed as a mutable reference (`&mut ContextType`).
MutableReference,
}
struct ContextArgInfo {
/// The identifier name used for the context argument.
pub name: syn::Ident,
/// The mode in which the context was passed to the test function.
pub mode: ContextArgMode,
}
use syn::ItemFn;
/// Macro to use on tests to add the setup/teardown functionality of your context.
///
@ -47,145 +32,108 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(attr as TestContextArgs);
let input = syn::parse_macro_input!(item as syn::ItemFn);
let (input, context_arg_info) =
remove_context_arg(input, args.context_type.clone(), args.skip_teardown);
let input = refactor_input_body(input, &args, context_arg_info);
quote! { #input }.into()
}
fn remove_context_arg(
mut input: syn::ItemFn,
expected_context_type: syn::Type,
skip_teardown: bool,
) -> (syn::ItemFn, ContextArgInfo) {
// 1. Partition the function arguments into two groups:
// (Context arguments, Other arguments)
let (context_args, new_args) = input
.sig
.inputs
.into_iter()
.partition::<Punctuated<_, _>, _>(|arg| {
// Check if the argument is the context argument
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(_) = &*pat_type.pat {
let arg_type = &*pat_type.ty;
// Check for mutable/immutable reference
if let syn::Type::Reference(type_ref) = arg_type {
return types_equal(&type_ref.elem, &expected_context_type);
}
// If skip_teardown is true, we also consider the fact
// that the context type could be fully owned and not just a reference
if skip_teardown && types_equal(arg_type, &expected_context_type) {
return true;
} else if types_equal(arg_type, &expected_context_type) {
panic!("If skip_teardown is false, we can't use an owned type")
} else {
return false;
}
}
}
false
});
let (input, context_args) = remove_context_args(input, args.context_type.clone());
if context_args.len() != 1 {
panic!("Exactly one Context argument needs to be provided to the test");
panic!("Exactly one Context argument must be defined");
}
let context_arg = context_args.into_iter().next().unwrap();
if !args.skip_teardown && context_arg.mode == ContextArgMode::Owned {
panic!(
"It is not possible to take ownership of the context if the teardown has to be ran."
);
}
let input = refactor_input_body(input, &args, context_arg);
quote! { #input }.into()
}
fn remove_context_args(
mut input: syn::ItemFn,
expected_context_type: syn::Type,
) -> (syn::ItemFn, Vec<ContextArg>) {
let test_args: Vec<TestArg> = input
.sig
.inputs
.into_iter()
.map(|arg| TestArg::parse_arg_with_expected_context(arg, &expected_context_type))
.collect();
let context_args: Vec<ContextArg> = test_args
.iter()
.cloned()
.filter_map(|arg| match arg {
TestArg::Any(_) => None,
TestArg::Context(context_arg_info) => Some(context_arg_info),
})
.collect();
let new_args: syn::punctuated::Punctuated<_, _> = test_args
.into_iter()
.filter_map(|arg| match arg {
TestArg::Any(fn_arg) => Some(fn_arg),
TestArg::Context(_) => None,
})
.collect();
input.sig.inputs = new_args;
// 2. Extract the identifier and mode from the single context argument found (if any).
let context_arg_info = if let syn::FnArg::Typed(pat_type) = context_arg
&& let syn::Pat::Ident(pat_ident) = *pat_type.pat
{
let arg_type = &*pat_type.ty;
let mode = if let syn::Type::Reference(type_ref) = arg_type {
if type_ref.mutability.is_some() {
ContextArgMode::MutableReference
} else {
ContextArgMode::Reference
}
} else {
ContextArgMode::Owned
};
ContextArgInfo {
name: pat_ident.ident,
mode,
}
} else {
panic!("Invalid context argument provided, it must be a reference or an owned type");
};
(input, context_arg_info)
(input, context_args)
}
fn refactor_input_body(
mut input: syn::ItemFn,
input: syn::ItemFn,
args: &TestContextArgs,
context_arg_info: ContextArgInfo,
context_arg: ContextArg,
) -> syn::ItemFn {
let context_type = &args.context_type;
let result_name = format_ident!("wrapped_result");
let body = &input.block;
let is_async = input.sig.asyncness.is_some();
let context_arg_name = context_arg.name;
// Determine the identifier and its mode. Default to "test_ctx" and MutableReference.
let (context_arg_name, context_mode) = (context_arg_info.name, context_arg_info.mode);
let context_binding = match context_mode {
let context_binding = match context_arg.mode {
ContextArgMode::Owned => quote! { let #context_arg_name = __context; },
ContextArgMode::Reference => quote! { let #context_arg_name = &__context; },
ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; },
};
let body = match (is_async, args.skip_teardown) {
// ASYNC and SKIP_TEARDOWN
(true, true) => {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
let body = if args.skip_teardown && is_async {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
#context_binding
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
}
} else if args.skip_teardown && !is_async {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#context_binding
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
}
#body
}));
}
// SYNC and SKIP_TEARDOWN
(false, true) => {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#context_binding
#body
}));
}
} else if !args.skip_teardown && is_async {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
#context_binding
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
}
// ASYNC and TEARDOWN (Teardown requires context ownership, so the test body must use &mut)
(true, false) => {
quote! {
use test_context::futures::FutureExt;
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
// MUST bind as &mut regardless of user's original signature to allow teardown
}
// !args.skip_teardown && !is_async
else {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#context_binding
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
}
}
// SYNC and TEARDOWN (Teardown requires context ownership, so the test body must use &mut)
(false, false) => {
quote! {
let mut __context = <#context_type as test_context::TestContext>::setup();
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#context_binding
#body
}));
<#context_type as test_context::TestContext>::teardown(__context);
}
#body
}));
<#context_type as test_context::TestContext>::teardown(__context);
}
};
@ -201,16 +149,8 @@ fn refactor_input_body(
}
};
input.block = Box::new(syn::parse2(body).unwrap());
input
}
// Note: The rest of the functions (test_context, refactor_input_body, types_equal) remain unchanged.
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
return a_path.path.segments.last().unwrap().ident
== b_path.path.segments.last().unwrap().ident;
ItemFn {
block: Box::new(syn::parse2(body).unwrap()),
..input
}
quote!(#a).to_string() == quote!(#b).to_string()
}

View file

@ -0,0 +1,69 @@
use quote::quote;
use syn::FnArg;
#[derive(Clone)]
pub struct ContextArg {
/// The identifier name used for the context argument.
pub name: syn::Ident,
/// The mode in which the context was passed to the test function.
pub mode: ContextArgMode,
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum ContextArgMode {
/// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`.
Owned,
/// The argument was passed as an immutable reference (`&ContextType`).
Reference,
/// The argument was passed as a mutable reference (`&mut ContextType`).
MutableReference,
}
#[derive(Clone)]
pub enum TestArg {
Any(FnArg),
Context(ContextArg),
}
impl TestArg {
pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self {
// Check if the argument is the context argument
if let syn::FnArg::Typed(pat_type) = &arg
&& let syn::Pat::Ident(pat_ident) = &*pat_type.pat
{
let arg_type = &*pat_type.ty;
// Check for mutable/immutable reference
if let syn::Type::Reference(type_ref) = arg_type
&& types_equal(&type_ref.elem, expected_context_type)
{
let mode = if type_ref.mutability.is_some() {
ContextArgMode::MutableReference
} else {
ContextArgMode::Reference
};
TestArg::Context(ContextArg {
name: pat_ident.ident.clone(),
mode,
})
} else if types_equal(arg_type, expected_context_type) {
TestArg::Context(ContextArg {
name: pat_ident.ident.clone(),
mode: ContextArgMode::Owned,
})
} else {
TestArg::Any(arg)
}
} else {
TestArg::Any(arg)
}
}
}
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
return a_path.path.segments.last().unwrap().ident
== b_path.path.segments.last().unwrap().ident;
}
quote!(#a).to_string() == quote!(#b).to_string()
}