rust 如何为 Package NdArray的类型实现iter()?

rdlzhqv9  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(78)

作为制作基于Rust的Tensor库的一部分,我实现了一个2D Tensor类型,如下所示:

use ndarray::prelude::*;

pub struct Tensor(Rc<RefCell<TensorData>>);

pub struct TensorData {
    pub data: Array2<f64>,
    pub grad: Array2<f64>,
    // other fields...
}

impl TensorData {
    fn new(data: Array2<f64>) -> TensorData {
        let shape = data.raw_dim();
        TensorData {
            data,
            grad: Array2::zeros(shape),
            // other fields...
        }
    }
}

impl Tensor {
    pub fn new(array: Array2<f64>) -> Tensor {
        Tensor(Rc::new(RefCell::new(TensorData::new(array))))
    }

    pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
        Ref::map((*self.0).borrow(), |mi| &mi.data)
    }
}

字符串
现在,我希望能够迭代Tensor的行(例如,用于实现随机梯度下降)。这是我到目前为止所拥有的:

impl Tensor {
    // other methods...
    
    pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
        self.data().outer_iter().map(|el| {
            let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
            reshaped_and_cloned_el
        }).map(|el| Tensor::new(el))
    }
}


这样做的问题是:(1)视觉上不舒服;(2)无法编译,因为迭代器是一个临时值,一旦超出范围就会被删除:

error[E0515]: cannot return value referencing temporary value
   --> src/tensor/mod.rs:348:9
    |
348 |           self.data().outer_iter().map(|el| {
    |           ^----------
    |           |
    |  _________temporary value created here
    | |
349 | |             let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
350 | |             reshaped_and_cloned_el
351 | |         }).map(|el| Tensor::new(el))
    | |____________________________________^ returns a value referencing data owned by the current function
    |
    = help: use `.collect()` to allocate the iterator


iter()的另一种实现方式是什么,它将不会有这些问题?

j13ufse2

j13ufse21#

不幸的是,您不能使用轴迭代器,因为这将产生一个自引用结构(数据保护和从它借用的迭代器)。但是你可以索引来访问轴:

pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
    let data = self.data();
    (0..data.shape()[0]).map(move |i| {
        let el = data.index_axis(Axis(0), i);
        let reshaped_and_cloned_el = el
            .into_shape((el.shape()[0], 1))
            .unwrap()
            .mapv(|el| el.clone());
        Tensor::new(reshaped_and_cloned_el)
    })
}

字符串

相关问题