NumPy array2string的奇怪性能

xytpbqjk  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(100)

我正在使用NumPy的array2string来编写ASCII文件。它在循环中或使用map执行Python字符串格式化:

aa = np.array2string(array.flatten(), precision=precision, separator=' ', max_line_width=(precision + 4) * ncolumns, prefix='         ', floatmode='fixed')
aa =  '         ' + aa[1:-1] + '\n'

我注意到当元素的数量少于几千个时,会出现奇怪的结果。使用mapjoin的性能比较符合我的预期(随着数组变大而变慢,由于NumPy函数的开销,对于小数组来说更快):

numpy.array2string峰值的原因是什么?(100,3)数组比(50000,3)数组慢。NumPy是我的数据大小(>1000)的最佳选择,但峰值似乎很奇怪。完整代码:

import numpy as np
import perfplot

precision = 16
ncolumns = 6

# numpy method
def numpystring(array, precision, ncolumns):
    indent = '          '
    aa = np.array2string(array.flatten(), precision=precision, separator=' ', max_line_width=(precision + 6) * ncolumns,
                     prefix='         ', floatmode='fixed')
    return indent + aa[1:-1] + '\n'

# native python string creation
def nativepython_string(array, precision, ncolumns):
    fmt = '{' + f":.{precision}f" + '}'
    data_str = ''

    # calculate number of full rows
    if array.size <= ncolumns:
        nrows = 1
    else:
        nrows = int(array.size / ncolumns)

    # write full rows
    for row in range(nrows):
        shift = row * ncolumns
        data_str += '          ' + ' '.join(
            map(lambda x: fmt.format(x), array.flatten()[0 + shift:ncolumns + shift])) + '\n'

    # write any remaining data in last non-full row
    if array.size > ncolumns and array.size % ncolumns != 0:
        data_str += '          ' + ' '.join(
            map(lambda x: fmt.format(x), array.flatten()[ncolumns + shift::])) + '\n'

    return data_str

# Benchmark methods
out = perfplot.bench(
    setup=lambda n: np.random.random([n,3]),  # setup random nx3 array
    kernels=[
        lambda a: nativepython_string(a, precision, ncolumns),
        lambda a: numpystring(a, precision, ncolumns)
    ],
    equality_check=None,
    labels=["Native", "NumPy"],
    n_range=[2**k for k in range(16)],
    xlabel="Number of vectors [Nr.]",
    title="String Conversion Performance"

)

out.show(
    time_unit="us",  # set to one of ("auto", "s", "ms", "us", or "ns") to force plot units
)
out.save("perf.png", transparent=True, bbox_inches="tight")
nxagd54h

nxagd54h1#

使用savetxt与小型2d数组的示例:

In [87]: np.savetxt('test.txt', np.arange(24).reshape(3,8), fmt='%5d')
In [88]: cat test.txt
    0     1     2     3     4     5     6     7
    8     9    10    11    12    13    14    15
   16    17    18    19    20    21    22    23

In [90]: np.savetxt('test.txt', np.arange(24).reshape(3,8), fmt='%5d', newline=' ')
In [91]: cat test.txt
    0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20    21    22    23

它根据参数和列数构造一个fmt字符串:

In [95]: fmt=' '.join(['%5d']*8)
In [96]: fmt
Out[96]: '%5d %5d %5d %5d %5d %5d %5d %5d'

然后将这一行写入文件:

In [97]: fmt%tuple(np.arange(8))
Out[97]: '    0     1     2     3     4     5     6     7'

相关问题