rust 对NdArray执行Debug导致减法下溢错误

zzwlnbp8  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(90)

假设我有一个一般的NdArray<T, N>,它保证是矩形的,形状为[usize; N]-例如。[2, 2]表示方阵,或[3]表示列向量。我想以与NumPy打印其n维数组相同的方式打印NdArray。这是我目前为止所做的(代码主要从tinynumpy移植而来):

use core::fmt::Debug;
use std::cmp::{min, max};
use std::fmt;

pub struct NdArray<T: Clone, const N: usize> {
    pub shape: [usize; N],
    pub data: Vec<T>,
}

impl<T: Clone, const N: usize> NdArray<T, N> {
    pub fn from(array: Vec<T>, shape: [usize; N]) -> Self {
        NdArray { shape, data: array }
    }
}

fn _display_inner<T: Clone + Debug, const N: usize>(f: &mut fmt::Formatter<'_>, array: &NdArray<T, N>, axis: usize, offset: usize) -> std::fmt::Result {
    let axisindent = min(2, max(0, array.shape.len() - axis -1));
    if axis < array.shape.len() {
        f.write_str("[")?;
        for (k_index, k) in (0..array.shape[axis]).into_iter().enumerate() {
            if k_index > 0 {
                for _ in 0..axisindent {
                    f.write_str("\n       ")?;
                    for _ in 0..axis {
                        f.write_str(" ")?;
                    }
                }
            }
            let offset_ = offset + k;
            _display_inner(f, array, axis + 1, offset_)?;
            if k_index < &array.shape[axis] - 1 {
                f.write_str(", ")?;
            }
        }
    } else {
        f.write_str(format!("{:?}", array.data[offset]).as_str())?;
    }
    Ok(())
}

impl<T: Clone, const N: usize> fmt::Debug for NdArray<T, N> 
where T: Debug
{ 
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self.shape.len() {
            0 => write!(f, "NdArray([])"),
            1 => write!(f, "NdArray({:?}, shape={:?})", self.data, self.shape),
            _ => _display_inner(f, &self, 0, 0)
        }
    }
}

fn main() {
    let a = NdArray::from(vec![1., 2., 3., 4.], [2, 2]);
    println!("{:?}", a);
}

字符串
作为参考,这是原始的未移植的Python代码:

def print_nested_list(flat_array, shape):
    def _repr_r(s, axis, offset):
            axisindent = min(2, max(0, (len(shape) - axis - 1)))
            if axis < len(shape):
                s += '['
                for k_index, k in enumerate(range(shape[axis])):
                    if k_index > 0:
                        s += ('\n       ' + ' ' * axis)  * axisindent
                    offset_ = offset + k
                    s = _repr_r(s, axis+1, offset_)
                    if k_index < shape[axis] - 1:
                        s += ', '
                s += ']'
            else:
                r = repr(flat_array)[offset])
                if '.' in r:
                    r = ' ' + r
                    if r.endswith('.0'):
                        r = r[:-1]
                s += r
            return s
    s = _repr_r('', 0, 0)
    print(f"array({s})")


然而,这导致:

thread 'main' panicked at 'attempt to subtract with overflow', src/main.rs:17:36


这可能是因为usize不支持负数,不像tinynumpy的原始Python代码,如果axis < 1,则会导致代码死机。
有没有更好的方法来实现Debug

qv7cva1a

qv7cva1a1#

一个minimal reproducible example解决您的问题:

fn main() {
    let a = sub1(0);
}
fn saturating_sub1(a: usize) -> usize {
    std::cmp::max(0, a - 1);
}

字符串
但是**u**size不能低于0,因为它是**u**nsigned,所以max的第二个参数总是>= 0,正如你注意到的那样,Rust通过panicing而不是在调试模式下静默地 Package 来避免你获得usize::MAX。相反,您可以使用saturating_sub来实现相同的饱和行为:

fn saturating_sub1(a: usize) -> usize {
    a.saturating_sub(1)
}

相关问题