假设我有一个一般的NdArray<T, N>
,它保证是矩形的,形状为[usize; N]
-例如。[2, 2]
表示方阵,或[3]
表示列向量。我想以与NumPy打印其n维数组相同的方式打印NdArray。这是我目前为止所做的(代码主要从tinynumpy移植而来):
use core::fmt::Debug;
use std::cmp::{min, max};
use std::fmt;
pub struct NdArray<T: Clone, const N: usize> {
pub shape: [usize; N],
pub data: Vec<T>,
}
impl<T: Clone, const N: usize> NdArray<T, N> {
pub fn from(array: Vec<T>, shape: [usize; N]) -> Self {
NdArray { shape, data: array }
}
}
fn _display_inner<T: Clone + Debug, const N: usize>(f: &mut fmt::Formatter<'_>, array: &NdArray<T, N>, axis: usize, offset: usize) -> std::fmt::Result {
let axisindent = min(2, max(0, array.shape.len() - axis -1));
if axis < array.shape.len() {
f.write_str("[")?;
for (k_index, k) in (0..array.shape[axis]).into_iter().enumerate() {
if k_index > 0 {
for _ in 0..axisindent {
f.write_str("\n ")?;
for _ in 0..axis {
f.write_str(" ")?;
}
}
}
let offset_ = offset + k;
_display_inner(f, array, axis + 1, offset_)?;
if k_index < &array.shape[axis] - 1 {
f.write_str(", ")?;
}
}
} else {
f.write_str(format!("{:?}", array.data[offset]).as_str())?;
}
Ok(())
}
impl<T: Clone, const N: usize> fmt::Debug for NdArray<T, N>
where T: Debug
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.shape.len() {
0 => write!(f, "NdArray([])"),
1 => write!(f, "NdArray({:?}, shape={:?})", self.data, self.shape),
_ => _display_inner(f, &self, 0, 0)
}
}
}
fn main() {
let a = NdArray::from(vec![1., 2., 3., 4.], [2, 2]);
println!("{:?}", a);
}
字符串
作为参考,这是原始的未移植的Python代码:
def print_nested_list(flat_array, shape):
def _repr_r(s, axis, offset):
axisindent = min(2, max(0, (len(shape) - axis - 1)))
if axis < len(shape):
s += '['
for k_index, k in enumerate(range(shape[axis])):
if k_index > 0:
s += ('\n ' + ' ' * axis) * axisindent
offset_ = offset + k
s = _repr_r(s, axis+1, offset_)
if k_index < shape[axis] - 1:
s += ', '
s += ']'
else:
r = repr(flat_array)[offset])
if '.' in r:
r = ' ' + r
if r.endswith('.0'):
r = r[:-1]
s += r
return s
s = _repr_r('', 0, 0)
print(f"array({s})")
型
然而,这导致:
thread 'main' panicked at 'attempt to subtract with overflow', src/main.rs:17:36
型
这可能是因为usize不支持负数,不像tinynumpy的原始Python代码,如果axis < 1,则会导致代码死机。
有没有更好的方法来实现Debug
?
1条答案
按热度按时间qv7cva1a1#
一个minimal reproducible example解决您的问题:
字符串
但是
**
u**size
不能低于0,因为它是**u
**nsigned,所以max
的第二个参数总是>= 0
,正如你注意到的那样,Rust通过panicing而不是在调试模式下静默地 Package 来避免你获得usize::MAX
。相反,您可以使用saturating_sub
来实现相同的饱和行为:型