rust 从N维数组中获取数据的所有1维切片

m0rkklqb  于 11个月前  发布在  其他
关注(0)|答案(1)|浏览(156)

我的目标是做一个函数,它接受一个 n 维的&ndarray::ArrayD,并在每个维度的每个一维切片上做一个操作。
下面是一个3D数组的例子:

let values = array![
    [
        [0.0, 0.1, 0.2], // (x0, y0, z0), (x0, y0, z1), (x0, y0, z2)
        [0.3, 0.4, 0.5], // (x0, y1, z0), (x0, y1, z1), (x0, y1, z2)
        [0.6, 0.7, 0.8], // (x0, y2, z0), (x0, y0, z1), (x0, y2, z2)
    ],
    [
        [0.9, 1.0, 1.1], // (x1, y0, z0), (x1, y0, z1), (x1, y0, z2)
        [1.2, 1.3, 1.4], // (x1, y1, z0), (x1, y1, z1), (x1, y1, z2)
        [1.5, 1.6, 1.7], // (x1, y2, z0), (x1, y2, z1), (x1, y2, z2)
    ],
    [
        [1.8, 1.9, 2.0], // (x2, y0, z0), (x2, y0, z1), (x2, y0, z2)
        [2.1, 2.2, 2.3], // (x2, y1, z0), (x2, y1, z1), (x2, y1, z2)
        [2.4, 2.5, 2.6], // (x2, y2, z0), (x2, y2, z1), (x2, y2, z2)
    ],
].into_dyn();

字符串
因此,我需要一个迭代器或类似的东西来遍历这些项(最好是未复制的,顺序无关紧要,我只是以最容易阅读的顺序输入它们):

// 1-D arrays in first direction
[0.0, 0.1, 0.2],
[0.3, 0.4, 0.5],
...
[2.1, 2.2, 2.3],
[2.4, 2.5, 2.6],
// 1-D arrays in second direction
[0.0, 0.3, 0.6],
[0.1, 0.4, 0.7],
...
[1.9, 2.2, 2.5]
[2.0, 2.3, 2.6],
// 1-D arrays in third direction
[0.0, 0.9, 1.8],
[0.1, 1.0, 1.9],
...
[0.7, 1.6, 2.5]
[0.8, 1.7, 2.6],


我尝试用长度为 n - 1的切片进行索引(很容易由一个函数生成,该函数返回给定形状的所有可能的n-1维索引,不像插入..的切片),但这会引起恐慌;我认为,因为切片的维度必须与数组的维度相匹配。似乎从1到N的任何长度的索引都是有用的,所以这可能是ndarray特性请求的主题。
这是我想对所有这些1-D切片进行的实际操作,以检查任何非NaN数据在切片内是否连续。

// assuming arr is one of the 1-D arrays
assert!(
    arr.windows(2)
        .map(|w| w[0].is_nan() != w[1].is_nan())
        .filter(|&b| b)
        .count()
        <= 2
);
// e.g. the below would fail this assert!
// [f64::NAN, 0.0, 0.0, f64::NAN, 0.0, 0.0, f64::NAN]
// and this would be ok
// [f64::NAN, 0.0, 0.0, 0.0, 0.0, 0.0, f64::NAN]


有没有什么好的方法来解决这个问题?谢谢!

wqsoz72f

wqsoz72f1#

当然,你正在寻找lanes功能。

use ndarray::prelude::*;

fn main() {
    let values = array![
        [
            [0.0, 0.1, 0.2], // (x0, y0, z0), (x0, y0, z1), (x0, y0, z2)
            [0.3, 0.4, 0.5], // (x0, y1, z0), (x0, y1, z1), (x0, y1, z2)
            [0.6, 0.7, 0.8], // (x0, y2, z0), (x0, y0, z1), (x0, y2, z2)
        ],
        [
            [0.9, 1.0, 1.1], // (x1, y0, z0), (x1, y0, z1), (x1, y0, z2)
            [1.2, 1.3, 1.4], // (x1, y1, z0), (x1, y1, z1), (x1, y1, z2)
            [1.5, 1.6, 1.7], // (x1, y2, z0), (x1, y2, z1), (x1, y2, z2)
        ],
        [
            [1.8, 1.9, 2.0], // (x2, y0, z0), (x2, y0, z1), (x2, y0, z2)
            [2.1, 2.2, 2.3], // (x2, y1, z0), (x2, y1, z1), (x2, y1, z2)
            [2.4, 2.5, 2.6], // (x2, y2, z0), (x2, y2, z1), (x2, y2, z2)
        ],
    ]
    .into_dyn();

    // Axis 0 is the outermost, n-1 is the innermost
    // So we use rev() to iterate from inner to outer dimension
    for axis in (0..values.ndim()).rev() {
        for lane in values.lanes(Axis(axis)) {
            println!("{:?}", lane);
        }
    }
}

个字符

相关问题