为具有泛型Rust类型的Rust函数实现Python接口

cotxawn7  于 12个月前  发布在  Python
关注(0)|答案(1)|浏览(94)

这个函数在Rust上运行得很好:

fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
where
    T: Hash + Eq + Clone,
{
    let s1 = vec_to_set(&s1);
    let s2 = vec_to_set(&s2);
    let i = s1.intersection(&s2).count() as f32;
    let u = s1.union(&s2).count() as f32;
    return i / u;
}

fn vec_to_set<T>(vec: &Vec<T>) -> HashSet<T>
where
    T: Hash + Eq + Clone,{
    HashSet::from_iter(vec.iter().cloned())
}

字符串
在以下测试案例中:

#[test]
fn test_jaccard_similarity() {
    let left = vec!["kitten", "sitting", "saturday", "sunday"];
    let right = vec!["kitten", "sitting", "saturday", "sunday"];
    assert_eq!(jaccard_similarity(left, right), 1.0);
    let left = vec![1,2,3,4];
    let right = vec![1,2,3,4];
    assert_eq!(jaccard_similarity(left, right), 1.0);
    let left = vec![1,2,3,4];
    let right = vec![2,2,3,4];
    assert_eq!(jaccard_similarity(left, right), 0.75);

}


但是,一旦我将它 Package 为pyo3 crate [version:0.13.2]的#[pyfunction](并且我还更新了lib.rs和mod.rs文件)。对于上下文,我使用马图林库。

#[pyfunction]
fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
where
    T: Hash + Eq + Clone,
{
    let s1 = vec_to_set(&s1);
    let s2 = vec_to_set(&s2);
    let i = s1.intersection(&s2).count() as f32;
    let u = s1.union(&s2).count() as f32;
    return i / u;
}


我得到以下错误:

--> src\distance_functions\jaccard_similarity.rs:6:4
  |
6 | fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
  |    ^^^^^^^^^^^^^^^^^^ cannot infer type of the type parameter `T` declared on the function `jaccard_similarity`
  |
  = note: cannot satisfy `_: Hash`
note: required by a bound in `jaccard_similarity`
 --> src\distance_functions\jaccard_similarity.rs:8:8
  |
6 | fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
  |    ------------------ required by a bound in this function
7 | where
8 |     T: Hash + Eq + Clone,
  |        ^^^^ required by this bound in `jaccard_similarity`
help: consider specifying the generic argument
  |
6 | fn jaccard_similarity::<T><T>(s1: Vec<T>, s2: Vec<T>) -> f32
  |                      +++++


泛型参数已经在函数上声明了。我不明白编译器要求我做什么。
在Rust中工作的东西也应该在Rust代码 Package 在Python接口中时使用。
编辑:我更新了我的pyo3版本到0.20.0,现在我得到了一个更有意义的错误消息:

error: Python functions cannot have generic type parameters
 --> src\distance_functions\jaccard_similarity.rs:6:23
  |
6 | fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32


有没有一种方法可以为Python函数使用泛型类型参数?

ttcibm8c

ttcibm8c1#

在Rust中工作的东西也应该在Rust代码 Package 在Python接口中时使用。
Python不是静态类型的,python接口不支持泛型,因此pyo3无法创建桥接函数和Rust实现的绑定。
实际上,这与Rust本身的行为相匹配:jaccard_similarity本身并不生成任何代码。相反,编译器会查看 call sites,并为每个调用函数的T生成一个示例,这些示例是最终在二进制文件中的代码。示例化步骤是pyo3无法实现的步骤,因此它无法工作。
我还要说的是,这段代码基本上没有用,python相当于5次调用C语言实现的东西(创建两个集合,相交,合并,除法)。将数据复制到向量,然后设置这些向量的开销可能会和让python做这些事情一样大。特别是使用Rust的默认哈希函数。
为了获得真实的收益,我认为您可能需要完全避免这两个vec(从PyList动态执行转换),并避免具体化其中一个集合(可能是较大的那个)。

相关问题