Rust泛型linspace函数

elcex8rz  于 2022-11-12  发布在  其他
关注(0)|答案(3)|浏览(168)

我正在尝试实现一个泛型函数linspace:

pub fn linspace<T> (x0: T, xend: T, n: usize) -> Vec<T>
    where
        T: Sub<Output = T>
        + Add<Output = T>
        + Div<Output = T>
        + Clone
{

    let dx = (xend - x0) / ((n - 1) as T);

    let mut x = vec![x0; n];

    for i in 1..n {
        x[i] = x[i - 1] + dx;
    }

    x
}

到目前为止,我已经了解到T必须实现SubAddDivClone,但现在我遇到了n as T语句的问题。

non-primitive cast: `usize` as `T`
let dx = (xend - x0) / ((n - 1) as T);
   |                   ^^^^^^^^^^^^^ an `as` expression can only be used to convert between primitive types or to coerce to a specific trait object

我知道num板条箱,但我试图实现这一点没有外部板条箱。有一个变通办法吗?谢谢!

k2fxgqgv

k2fxgqgv1#

除非您将此作为一个学习练习,否则我建议您使用num_traits crate中的边界,它具有Float这样的特性,在这里会很有用:

use num_traits::Float;

pub fn linspace<T: Float + TryFrom<usize>>(x0: T, xend: T, n: usize) -> Vec<T> {
    let dx = (xend - x0) / (n - 1).try_into().unwrap_or_else(|_| panic!());
    let mut x = vec![x0; n];
    for i in 1..n {
        x[i] = x[i - 1] + dx;
    }
    x
}

然而现在我得到的错误:没有为“f64”实现特征“From”。
它未实现,因为存在无法精确表示为f64usize值。该错误会让您决定如何处理这些值。如果遇到这样的值,我的代码会出现混乱。
此外,我相信浮点加法会累积错误,因此基于乘法的计算可能是更好的主意:

pub fn linspace<T: Float + TryFrom<usize>>(x0: T, xend: T, n: usize) -> Vec<T> {
    let to_float = |i: usize| i.try_into().unwrap_or_else(|_| panic!());
    let dx = (xend - x0) / to_float(n - 1);
    (0..n).map(|i| x0 + to_float(i) * dx).collect()
}

Playground

7z5jn7bk

7z5jn7bk2#

如果您想坚持使用标准库特性,您将需要使用TryInto,并处理这样一个事实,即由于数字类型可能小于输入,所请求的转换可能会失败。因此,我们必须使项的数目为u16,它可以转换为任何浮点类型。

use core::ops::Add;
use core::ops::Div;
use core::ops::Sub;
use std::fmt::Debug;

pub fn linspace<T>(x0: T, xend: T, n: u16) -> Vec<T>
where
    T: Sub<Output = T> + Add<Output = T> + Div<Output = T> + Clone + Debug,
    u16: TryInto<T> + TryInto<usize>,
    <u16 as TryInto<T>>::Error: Debug,
{
    let segments: T = (n - 1)
        .try_into()
        .expect("requested number of elements did not fit into T");
    let n_size: usize = n.try_into()
        .expect("requested number of elements exceeds usize");

    let dx = (xend - x0.clone()) / segments;

    let mut x = vec![x0; n_size];

    for i in 1..n_size {
        x[i] = x[i - 1].clone() + dx.clone();
    }

    x
}

如果给定的n太大,例如Tu8n1000,则会出现混乱。
(By重复添加dx的方式通常不是最好的方式,因为如果T是浮点类型,则它会累积错误;最后一个元素将不必等于xend

t3psigkw

t3psigkw3#

这样更简洁:

use num_traits::Float;

fn linspace<T: Float + std::convert::From<u16>>(l: T, h: T, n: usize) -> 
    Vec<T> {
    let size: T = (n as u16 - 1).try_into()
        .expect("too many elements: max is 2^16");
    let dx = (h - l) / size;

    (1..=n).scan(-dx, |a, _| { *a = *a + dx; Some(*a) }).collect()
}

相关问题