rust 一个过程宏,可以计算集合中特定枚举变量的出现次数

kkih6yb8  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(121)

我正在尝试创建一个过程宏(CountOf),它统计集合中特定枚举变体的发生率,在本例中是向量。
这个src/macros.rs就是我用来测试宏的

#[derive(count_of::CountOf, Clone)]
pub enum SomeEnum {
    Variant1,
    Variant2,
    Variant3,
}

#[cfg(test)]
mod tests {
    use super::*;
#[test]
    fn count_of_works() {
        use SomeEnum::*;
        let enums = vec![
            Variant1,
            Variant1,
            Variant2,
            Variant3,
            Variant2,
            Variant2,
        ];

        assert_eq!(enums.variant1_count(), 2);
        assert_eq!(enums.variant2_count(), 3);
        assert_eq!(enums.variant3_count(), 1);
    }
}

这是我在count_of板条箱中的lib.rs

use inflector::Inflector;
use quote::quote;

#[proc_macro_derive(CountOf)]
pub fn count_of(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = syn::parse_macro_input!(input as syn::ItemEnum);
    let name = input.ident;

    let variants = input.variants.iter().map(|variant| {
        let variant_name = &variant.ident;
        let variant_count = variant_name.to_string().to_lowercase();

        quote! {
            pub fn #variant_count(&self) -> usize {
                self.iter().filter(|&&x| x == &#name::#variant_name).count()
            }
        }
    });

    let output = quote! {
        impl #name {
            #(#variants)*
        }
    };

    proc_macro::TokenStream::from(output)
}

这是我在Cargo.toml中的依赖项:

[dependencies]
count-of           = { path = "./count-of" }

这是:

[package]
name    = "count-of"
version = "0.1.0"

[lib]
proc-macro = true

[dependencies]
syn         = { version = "2.0.15", features = ["full"] }
quote       = "1.0.26"
Inflector   = "0.11.4"
proc-macro2 = "1.0.56"

获取此错误:

error: expected identifier, found `"Variant1"`
  --> src/macros.rs:92:10
   |
92 | #[derive(count_of::CountOf, Clone)]
   |          ^^^^^^^^^^^^^^^^^
   |          |
   |          expected identifier
   |          while parsing this item list starting here
   |          the item list ends here
   |
   = note: this error originates in the derive macro `count_of::CountOf` (in Nightly builds, run with -Z macro-backtrace for more info)

error: proc-macro derive produced unparseable tokens
  --> src/macros.rs:92:10
   |
92 | #[derive(count_of::CountOf, Clone)]
   |          ^^^^^^^^^^^^^^^^^
error[E0599]: no method named `variant1_count` found for struct `Vec<SomeEnum>` in the current scope
   --> src/macros.rs:135:23
    |
135 |         assert_eq!(enums.variant1_count(), 2);
    |                             ^^^^^^^^^^^^^^^^^^^^^^^^^^ method not found in `Vec<SomeEnum>`

error[E0599]: no method named `variant2_count` found for struct `Vec<SomeEnum>` in the current scope
   --> src/macros.rs:136:23
    |
136 |         assert_eq!(enums.variant2_count(), 3);
    |                             ^^^^^^^^^^^^^^^^^ method not found in `Vec<SomeEnum>`

error[E0599]: no method named `variant3_count` found for struct `Vec<SomeEnum>` in the current scope
   --> src/macros.rs:137:23
    |
137 |         assert_eq!(enums.variant3_count(), 1);
    |                             ^^^^^^^^ method not found in `Vec<SomeEnum>`

我认为我的程序宏是问题,但找不到解决方法。如果有人有一个解决方案,会非常感激。

ccrfmcuu

ccrfmcuu1#

你的代码中有几个错误。

  • 不要在quote!内部使用String。使用正确的类型,在本例中为Ident
  • 如果您想在Vec上调用类型本身,请不要实现它们。

第二个是困难的-因为你不能在外来类型上实现方法。
解决方案是创建一个新的trait,然后实现外来类型的trait,这里是Vec<SomeEnum>
下面是一些工作代码:

use proc_macro2::Ident;
use quote::{format_ident, quote};

#[proc_macro_derive(CountOf)]
pub fn count_of(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = syn::parse_macro_input!(input as syn::ItemEnum);
    let name = input.ident;

    let trait_name = format_ident!("{}VecExt", name);

    let variants = input.variants.iter().map(|variant| {
        let variant_name = &variant.ident;
        let variant_count = variant_name.to_string().to_lowercase() + "_count";
        let variant_count_ident = Ident::new(&variant_count, variant_name.span());

        quote! {
            fn #variant_count_ident(&self) -> usize {
                self.as_ref().iter().filter(|&x| x == &#name::#variant_name).count()
            }
        }
    });

    let output = quote! {
        pub trait #trait_name: AsRef<[#name]> {
            #(#variants)*
        }
        impl<T> #trait_name for T where T: AsRef<[#name]> {}
    };

    proc_macro::TokenStream::from(output)
}
use rust_playground::CountOf;

#[derive(CountOf, Clone, Eq, PartialEq)]
pub enum SomeEnum {
    Variant1,
    Variant2,
    Variant3,
}

fn main() {
    use SomeEnum::*;

    let enums = vec![Variant1, Variant1, Variant2, Variant3, Variant2, Variant2];

    println!("Variant 1: {}", enums.variant1_count());
    println!("Variant 2: {}", enums.variant2_count());
    println!("Variant 3: {}", enums.variant3_count());
}
Variant 1: 2
Variant 2: 3
Variant 3: 1

当代码通过cargo expand时,你可以看到它扩展到了什么:

pub enum SomeEnum {
    Variant1,
    Variant2,
    Variant3,
}
pub trait SomeEnumVecExt: AsRef<[SomeEnum]> {
    fn variant1_count(&self) -> usize {
        self.as_ref().iter().filter(|&x| x == &SomeEnum::Variant1).count()
    }
    fn variant2_count(&self) -> usize {
        self.as_ref().iter().filter(|&x| x == &SomeEnum::Variant2).count()
    }
    fn variant3_count(&self) -> usize {
        self.as_ref().iter().filter(|&x| x == &SomeEnum::Variant3).count()
    }
}
impl<T> SomeEnumVecExt for T
where
    T: AsRef<[SomeEnum]>,
{}

还要注意#[derive(Eq, PartialEq)],否则无法比较枚举。
也许是一个小小的解释:

  • 创建trait MyEnumVecExt
  • 我为所有实现AsRef([SomeEnum])的类型实现了trait,这是所有可以引用为&[SomeEnum]的类型,比如Vec<SomeEnum>[SomeEnum; 10]&mut [SomeEnum]等。

相关问题