numpy np.argmax不返回整数

ymdaylpp  于 2023-11-18  发布在  其他
关注(0)|答案(1)|浏览(123)

我有一些onehot编码的数据,叫做testoutput,它的形状是(1000,14)。
我想解码它,所以遵循我在网上找到的一些建议,我使用了以下代码:

# go from onehot encoding to integer
def decode(datum):
    return np.argmax(datum)

predictedhits=np.empty((1000))
for i in range(testoutput.shape[0]):
    datum = testoutput[i]
    predictedhits[i] = decode(datum)
    print('Event',i,'predicted number of hits: %s' % predictedhits[i])

字符串
问题是,我希望并期望np.argmax输出一个整数,但它却输出了一个numpy.float64。请有人能告诉我为什么会发生这种情况,该怎么办?简单地做predictedhits[i] = int(decode(datum))并没有改变任何东西。
提前感谢!

envsm3lx

envsm3lx1#

您误诊了问题。numpy.argmax没有返回numpy.float64的示例。相反,predictedhits具有float64 dtype。将任何值存储到该数组中都会将其存储为64位浮点数,而从数组中检索predictedhits[i]会生成numpy.float64对象。
与每次遍历testoutput的行并将值逐个存储到空数组中不同,只需沿所需的轴沿着调用argmax

predictedhits = np.argmax(testoutput, axis=1)

字符串
这样可以节省代码、节省运行时间,并生成一个具有正确dtype的数组。

相关问题