作为制作基于Rust的Tensor库的一部分,我实现了一个2D Tensor
类型,如下所示:
use ndarray::prelude::*;
pub struct Tensor(Rc<RefCell<TensorData>>);
pub struct TensorData {
pub data: Array2<f64>,
pub grad: Array2<f64>,
// other fields...
}
impl TensorData {
fn new(data: Array2<f64>) -> TensorData {
let shape = data.raw_dim();
TensorData {
data,
grad: Array2::zeros(shape),
// other fields...
}
}
}
impl Tensor {
pub fn new(array: Array2<f64>) -> Tensor {
Tensor(Rc::new(RefCell::new(TensorData::new(array))))
}
pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
Ref::map((*self.0).borrow(), |mi| &mi.data)
}
}
字符串
现在,我希望能够迭代Tensor的行(例如,用于实现随机梯度下降)。这是我到目前为止所拥有的:
impl Tensor {
// other methods...
pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
self.data().outer_iter().map(|el| {
let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
reshaped_and_cloned_el
}).map(|el| Tensor::new(el))
}
}
型
这样做的问题是:(1)视觉上不舒服;(2)无法编译,因为迭代器是一个临时值,一旦超出范围就会被删除:
error[E0515]: cannot return value referencing temporary value
--> src/tensor/mod.rs:348:9
|
348 | self.data().outer_iter().map(|el| {
| ^----------
| |
| _________temporary value created here
| |
349 | | let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
350 | | reshaped_and_cloned_el
351 | | }).map(|el| Tensor::new(el))
| |____________________________________^ returns a value referencing data owned by the current function
|
= help: use `.collect()` to allocate the iterator
型iter()
的另一种实现方式是什么,它将不会有这些问题?
1条答案
按热度按时间j13ufse21#
不幸的是,您不能使用轴迭代器,因为这将产生一个自引用结构(数据保护和从它借用的迭代器)。但是你可以索引来访问轴:
字符串