4
votes

I've found the following solution to create a macro that defines a function which returns true if an enum matches a variant:

macro_rules! is_variant {
    ($name: ident, $enum_type: ty, $enum_pattern: pat) => {
        fn $name(value: &$enum_type) -> bool {
            matches!(value, $enum_pattern)
        }
    }
}

Usage:

enum TestEnum {
    A,
    B(),
    C(i32, i32),
}

is_variant!(is_a, TestEnum, TestEnum::A);
is_variant!(is_b, TestEnum, TestEnum::B());
is_variant!(is_c, TestEnum, TestEnum::C(_, _));

assert_eq!(is_a(&TestEnum::A), true);
assert_eq!(is_a(&TestEnum::B()), false);
assert_eq!(is_a(&TestEnum::C(1, 1)), false);

Is there a way to define this macro so that providing placeholders for the variant data can be avoided?

In other words, change the macro to be able to use it like so:

is_variant!(is_a, TestEnum, TestEnum::A);
is_variant!(is_a, TestEnum, TestEnum::B);
is_variant!(is_a, TestEnum, TestEnum::C);

Using std::mem::discriminant, as described in Compare enums only by variant, not value, doesn't help since it can only be used to compare two enum instances. In this case there is only one single object and the variant identifier. It also mentions matching on TestEnum::A(..) but that doesn't work if the variant has no data.

1
If you move it within impl TestEnum { ... } then change $enum_type from ty to ident adding use $enum_type::*; prior to matches!, that will allow you to be able to remove the TestEnum:: prefix. However, if you want to "simplify" it more, it will require a proc macro as Mihir mentioned in the now deleted answer.vallentin
@vallentin, thanks for the edit. Although, the functions generated by the proc-macro, will be named as is_A, is_B, etc. It is using variant's name to give name to the functions. Sadly I don't know any method to convert it into lower case.Mihir Luthra
@Mihir I just updated your answer, so now they are converted to lowercase :)vallentin
@vallentin, ah, didn't notice, I will revert back my edit. Thanks :)Mihir Luthra

1 Answers

5
votes

You can do that using proc macros. There is a chapter in rust book that may help.

Then you can use it like:

use is_variant_derive::IsVariant;

#[derive(IsVariant)]
enum TestEnum {
    A,
    B(),
    C(i32, i32),
    D { _name: String, _age: i32 },
}

fn main() {
    let x = TestEnum::C(1, 2);
    assert!(x.is_c());

    let x = TestEnum::A;
    assert!(x.is_a());

    let x = TestEnum::B();
    assert!(x.is_b());

    let x = TestEnum::D {_name: "Jane Doe".into(), _age: 30 };
    assert!(x.is_d());
}

For the above effect, proc macro crate will look like:

is_variant_derive/src/lib.rs:

extern crate proc_macro;

use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};

use quote::{format_ident, quote, quote_spanned};
use syn::spanned::Spanned;
use syn::{parse_macro_input, Data, DeriveInput, Error, Fields};

// https://crates.io/crates/convert_case
use convert_case::{Case, Casing};

macro_rules! derive_error {
    ($string: tt) => {
        Error::new(Span::call_site(), $string)
            .to_compile_error()
            .into();
    };
}

#[proc_macro_derive(IsVariant)]
pub fn derive_is_variant(input: TokenStream) -> TokenStream {
    // See https://doc.servo.org/syn/derive/struct.DeriveInput.html
    let input: DeriveInput = parse_macro_input!(input as DeriveInput);

    // get enum name
    let ref name = input.ident;
    let ref data = input.data;

    let mut variant_checker_functions;

    // data is of type syn::Data
    // See https://doc.servo.org/syn/enum.Data.html
    match data {
        // Only if data is an enum, we do parsing
        Data::Enum(data_enum) => {

            // data_enum is of type syn::DataEnum
            // https://doc.servo.org/syn/struct.DataEnum.html

            variant_checker_functions = TokenStream2::new();

            // Iterate over enum variants
            // `variants` if of type `Punctuated` which implements IntoIterator
            //
            // https://doc.servo.org/syn/punctuated/struct.Punctuated.html
            // https://doc.servo.org/syn/struct.Variant.html
            for variant in &data_enum.variants {

                // Variant's name
                let ref variant_name = variant.ident;

                // Variant can have unnamed fields like `Variant(i32, i64)`
                // Variant can have named fields like `Variant {x: i32, y: i32}`
                // Variant can be named Unit like `Variant`
                let fields_in_variant = match &variant.fields {
                    Fields::Unnamed(_) => quote_spanned! {variant.span()=> (..) },
                    Fields::Unit => quote_spanned! { variant.span()=> },
                    Fields::Named(_) => quote_spanned! {variant.span()=> {..} },
                };

                // construct an identifier named is_<variant_name> for function name
                // We convert it to snake case using `to_case(Case::Snake)`
                // For example, if variant is `HelloWorld`, it will generate `is_hello_world`
                let mut is_variant_func_name =
                    format_ident!("is_{}", variant_name.to_string().to_case(Case::Snake));
                is_variant_func_name.set_span(variant_name.span());

                // Here we construct the function for the current variant
                variant_checker_functions.extend(quote_spanned! {variant.span()=>
                    fn #is_variant_func_name(&self) -> bool {
                        match self {
                            #name::#variant_name #fields_in_variant => true,
                            _ => false,
                        }
                    }
                });

                // Above we are making a TokenStream using extend()
                // This is because TokenStream is an Iterator,
                // so we can keep extending it.
                //
                // proc_macro2::TokenStream:- https://docs.rs/proc-macro2/1.0.24/proc_macro2/struct.TokenStream.html

                // Read about
                // quote:- https://docs.rs/quote/1.0.7/quote/
                // quote_spanned:- https://docs.rs/quote/1.0.7/quote/macro.quote_spanned.html
                // spans:- https://docs.rs/syn/1.0.54/syn/spanned/index.html
            }
        }
        _ => return derive_error!("IsVariant is only implemented for enums"),
    };

    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

    let expanded = quote! {
        impl #impl_generics #name #ty_generics #where_clause {
            // variant_checker_functions gets replaced by all the functions
            // that were constructed above
            #variant_checker_functions
        }
    };

    TokenStream::from(expanded)
}

Cargo.toml for the library named is_variant_derive:

[lib]
proc-macro = true

[dependencies]
syn = "1.0"
quote = "1.0"
proc-macro2 = "1.0"
convert_case = "0.4.0"

Cargo.toml for the binary:

[dependencies]
is_variant_derive = { path = "../is_variant_derive" }

Then have both crates in the same directory (workspace) and then have this Cargo.toml:

[workspace]
members = [
    "bin",
    "is_variant_derive",
]

Playground

Also note that proc-macro needs to exist in its own separate crate.


Or you can directly use is_variant crate.