numpy 对称Tensor的置换唯一元的选取

34gzjxbg  于 11个月前  发布在  其他
关注(0)|答案(2)|浏览(109)

我有一个代码,它创建了一个形状为(m, m, ..., m)的numpy数组A,其中有mn副本。通过构造,这个数组是一个对称Tensor(在数学意义上),这意味着A[i, j, ..., k] == A[i', j', ..., k'],其中(i', j', ..., k')(i, j, ..., k)的置换。
我们可以将A的置换唯一元素定义为A中对应的索引通过置换彼此不等价的条目的集合。例如,在形状为(3, 3)的矩阵中,置换唯一索引为(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)
对于一个一般的对称TensorA,我如何提取它的所有置换唯一元素,以及它们对应的索引?(np.unique不起作用,在两个元素A[i1, j1, ..., k1]A[i2, j2, ..., k2]重合相等的情况下,但(i1, j1, ..., k1)不是(i2, j2, ..., k2)的置换。

q9yhzks0

q9yhzks01#

考虑到你的标签包含numpy,在大型数组上有一个更快的解决方案:

def combinations_with_replacement_(m, n, dtype=int):
    if n < 0:
        raise ValueError('n must be non-negative')
    if m < 0:
        raise ValueError('m must be non-negative')

    if n == 0:
        return np.empty((1, 0), dtype)

    shape = (math.comb(m + n - 1, n), n)
    out = np.zeros(shape, dtype)
    out[:m, -1] = np.arange(m, dtype=dtype)

    lengths = np.arange(1, m, dtype=np.intp)
    start = m
    for col in reversed(range(1, n)):
        block = out[:start, col:]
        for i, length in enumerate(reversed(lengths), 1):
            stop = start + length
            out[start:stop, col:] = block[-length:]
            out[start:stop, col - 1] = i
            start = stop
        lengths = lengths.cumsum()

    return out

字符串
对于小型数组,有一些棘手的加速方法,但它们会损失一些可读性:

from itertools import combinations_with_replacement, chain

def comb_with_rep_faster_on_small(m, n, dtype=int):
    it = combinations_with_replacement(range(m), n)
    flatten = chain.from_iterable(it)
    buffer = bytes(flatten)
    return np.frombuffer(buffer, np.uint8).astype(dtype).reshape(-1, n)


一个简单的比较方法:

from itertools import combinations_with_replacement

def comb_with_rep(m, n, dtype=int):
    return np.array(list(combinations_with_replacement(range(m), n)), dtype)


基准:

import perfplot

funcs = [
    comb_with_rep,
    comb_with_rep_faster_on_small,
    combinations_with_replacement_
]

perfplot.bench(
    kernels=funcs,
    n_range=list(range(1, 20)),
    setup=lambda n: (7, n),
    equality_check=np.array_equal
).show()


的数据

eoxn13cs

eoxn13cs2#

为了回答我自己的问题:一种方法是选择所有索引(i_1, i_2, ..., i_n)i_1 <= i_2 <= ... <= i_n。伪代码是:

def perm_unique(m):

  index_list = []
  
  for i_1 in range(m):
    for i_2 in range(i_1, m):
      ...
        for i_n in range(i_{n-1}, m):
  
          index_list.append([i_1, i_2, ..., i_n])

  return index_list

字符串
注意事项:上面的伪代码需要预先知道n。一个更好的解决方案是将n作为perm_unique()的参数,但这需要递归,我不确定什么是最好的方法。
事实证明,解决方案非常简单。下面的代码就是这样做的:

index_list = list(itertools.combinations_with_replacement(range(m), n))

相关问题