numpy 相似性查找器、优化

zte4gxcn  于 2023-06-06  发布在  其他
关注(0)|答案(1)|浏览(142)
root_words = [set(f'word{i}')for i in range(10000)]
num_root_words = len(root_words)

def findsim(root_words):
    
    pair_dict = collections.defaultdict(list)
    for i, word_i in enumerate(root_words):
        for j in range(i + 1, num_root_words):
            word_j = root_words[j]
            if len(word_i & word_j) > 3:
                pair_dict[i].append(j)
    return pair_dict

pairdict =findsim(root_words)有什么方法可以让这个函数运行得更快吗?
我尝试使用NumPy相关的优化,但仍然很慢,需要改善执行时间

lnvxswe2

lnvxswe21#

这应该有助于你开始:您可以将每个不同的值分配给一个位,然后使用按位&,这比使用集合比较快得多。然后,您可以对每个元素进行向量化比较,并生成具有>3个共同元素的索引列表,而无需python循环。
您可能希望使用与3不同的值来检查结果-所有项目都有共同的word

import collections
import timeit
import numpy as np
root_words = [set(f'word{i}') for i in range(10000)]

def orig():
    num_root_words = len(root_words)
    def findsim(root_words):
        pair_dict = collections.defaultdict(list)
        for i, word_i in enumerate(root_words):
            for j in range(i + 1, num_root_words):
                word_j = root_words[j]
                if len(word_i & word_j) > 3:
                    pair_dict[i].append(j)
        return pair_dict
    return findsim(root_words)

def bit_np():
    def enough_common_elems(spot, item, rest, gt_common_elems=3):
        common = item & rest
        common_elems = ((common >> 0 & 1) +
                        (common >> 1 & 1) +
                        (common >> 2 & 1) +
                        (common >> 3 & 1) +
                        (common >> 4 & 1) +
                        (common >> 5 & 1) +
                        (common >> 6 & 1) +
                        (common >> 7 & 1)).sum(axis=1)
        return (np.where(common_elems > gt_common_elems)[0] + spot + 1).tolist()

    def findsim(root_words_sets):
        valid_chrs = list(set.union(*root_words_sets))
        bit_reprs = np.array([sum([2 ** bit for bit in range(len(valid_chrs)) if valid_chrs[bit] in word]) for word in
                              root_words_sets]).transpose()
        n_bytes = (len(valid_chrs) + 7) // 8
        bit_np_vals = np.zeros((len(root_words_sets), n_bytes), dtype='u1')

        for byte in range(n_bytes):
            bit_np_vals[:, byte] = bit_reprs >> 8 * byte & 255
        ret = {}
        for i, word_i in enumerate(bit_np_vals):
            common_elems = enough_common_elems(i, word_i, bit_np_vals[i+1:, :])
            if common_elems:
                ret[i] = common_elems
        return ret
    return findsim(root_words)

bef = timeit.timeit('orig()', number=3, globals=globals())
aft = timeit.timeit('bit_np()', number=3, globals=globals())
print(bef, aft, bef/aft)
assert orig() == bit_np()

在我的PC上,我看到了4.5倍的加速:

34.5807251 7.576571700000002 4.564165227922279

相关问题