0
votes

I am trying to write a rust derive macro for retrieving data from a nested struct. The struct only contains primitive types u8, i8, u16, i16, u32, i32, u64, i64, or other structs thereof. I have an Enum which encapsulates the leaf field data in a common type which I call an Item(). I want the macro to create a .get() implementation which returns an item based on a u16 index.

Here is the desired behavior.

pub enum Item {
    U8(u8),
    I8(i8),
    U16(u16),
    I16(i16),
    U32(u32),
    I32(i32),
    U64(u64),
    I64(i64),
}

struct NestedData {
    a: u16,
    b: i32,
}

#[derive(GetItem)]
struct Data {
    a: i32,
    b: u64,
    c: NestedData,
}

let data = Data {
        a: 42,
        b: 1000,
        c: NestedData { a: 500, b: -2 },
};

assert_eq!(data.get(0).unwrap(), Item::I32(42));
assert_eq!(data.get(1).unwrap(), Item::U64(1000));
assert_eq!(data.get(2).unwrap(), Item::U16(500));
assert_eq!(data.get(3).unwrap(), Item::I32(-2));

I have a working macro for a single layer struct, but I am not sure about how to modify it to support nested structs. Here is where I am at...

use proc_macro2::TokenStream;
use quote::quote;

use syn::{Data, DataStruct, DeriveInput, Fields, Type, TypePath};

pub fn impl_get_item(input: DeriveInput) -> syn::Result<TokenStream> {
    let model_name = input.ident;

    let fields = match input.data {
        Data::Struct(DataStruct {
            fields: Fields::Named(fields),
            ..
        }) => fields.named,
        _ => panic!("The GetItem derive can only be applied to structs"),
    };

    let mut matches = TokenStream::new();
    let mut item_index: u16 = 0;
    for field in fields {
        let item_name = field.ident;
        let item_type = field.ty;
        let ts = match item_type {
            Type::Path(TypePath { path, .. }) if path.is_ident("u8") => {
                quote! {#item_index => Ok(Item::U8(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i8") => {
                quote! {#item_index => Ok(Item::I8(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("u16") => {
                quote! {#item_index => Ok(Item::U16(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i16") => {
                quote! {#item_index => Ok(Item::I16(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("u32") => {
                quote! {#item_index => Ok(Item::U32(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i32") => {
                quote! {#item_index => Ok(Item::I32(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("u64") => {
                quote! {#item_index => Ok(Item::U64(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i64") => {
                quote! {#item_index => Ok(Item::I64(self.#item_name)),}
            }
            _ => panic!("{:?} uses unsupported type {:?}", item_name, item_type),
        };
        matches.extend(ts);
        item_index += 1;
    }

    let output = quote! {
        #[automatically_derived]
        impl #model_name {
            pub fn get(&self, index: u16) -> Result<Item, Error> {
                match index {
                    #matches
                    _ => Err(Error::BadIndex),
                }
            }
        }
    };

    Ok(output)
}