我有以下代码:
import sentence_transformers
import multiprocessing
from tqdm import tqdm
from multiprocessing import Pool
embedding_model = sentence_transformers.SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
data = [[100227, 7382501.0, 'view', 30065006, False, ''],
[100227, 7382501.0, 'view', 57072062, True, ''],
[100227, 7382501.0, 'view', 66405922, True, ''],
[100227, 7382501.0, 'view', 5221475, False, ''],
[100227, 7382501.0, 'view', 63283995, True, '']]
# Define the function to be executed in parallel
def process_data(chunk):
results = []
for row in chunk:
print(row[0])
work_id = row[1]
mentioning_work_id = row[3]
print(work_id)
if work_id in df_text and mentioning_work_id in df_text:
title1 = df_text[work_id]['title']
title2 = df_text[mentioning_work_id]['title']
embeddings_title1 = embedding_model.encode(title1,convert_to_numpy=True)
embeddings_title2 = embedding_model.encode(title2,convert_to_numpy=True)
similarity = np.matmul(embeddings_title1, embeddings_title2.T)
results.append([row[0],row[1],row[2],row[3],row[4],similarity])
else:
continue
return results
# Define the number of CPU cores to use
num_cores = multiprocessing.cpu_count()
# Split the data into chunks
chunk_size = len(data) // num_cores
# chunks = [data[i:i+chunk_size] for i in range(0, len(data), chunk_size)]
# Create a pool of worker processest
pool = multiprocessing.Pool(processes=num_cores)
results = []
with tqdm(total=len(data)) as pbar:
for i, result_chunk in enumerate(pool.map(process_data, data)):
# Update the progress bar
pbar.update()
# Add the results to the list
results += result_chunk
# Concatenate the results
final_result = results
当我执行这段代码时,我得到以下错误:
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/multiprocessing/pool.py", line 121, in worker
result = (True, func(*args, **kwds))
File "/opt/conda/lib/python3.7/multiprocessing/pool.py", line 44, in mapstar
return list(map(*args))
File "<ipython-input-4-3aab73406a3b>", line 18, in process_data
print(row[0])
TypeError: 'int' object is not subscriptable
"""
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<ipython-input-4-3aab73406a3b> in <module>
46 results = []
47 with tqdm(total=len(data)) as pbar:
---> 48 for i, result_chunk in enumerate(pool.map(process_data, data)):
49 # Update the progress bar
50 pbar.update()
/opt/conda/lib/python3.7/multiprocessing/pool.py in map(self, func, iterable, chunksize)
266 in a list that is returned.
267 '''
--> 268 return self._map_async(func, iterable, mapstar, chunksize).get()
269
270 def starmap(self, func, iterable, chunksize=None):
/opt/conda/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
655 return self._value
656 else:
--> 657 raise self._value
658
659 def _set(self, i, obj):
TypeError: 'int' object is not subscriptable
如何在pool.map()
中传递一个列表列表并使其并行处理?
1条答案
按热度按时间2izufjch1#
发生错误的原因是,当你将
data
作为参数传入pool.map
时,在每次Map迭代中传入函数process_data
的,并与参数chink
相关联的,将不是data
本身,而是data
的每个元素,它们只是列表(不是列表的列表)。因此,当用
for row in chunk
迭代chunk
时,row
将不是一个列表,而是每个内部元素,如100227
,7382501.0
等。最后,如果row实际上是
100277
,并且您尝试执行print(row[0])
,则编译器会抛出上述错误:“int”对象不可下标
要解决这个问题,您只需要更改
process_data
的实现,使其接受一个列表而不是列表的列表。可能是这几行的内容:
在你的main函数中,修改代码以期望单个列表作为迭代的元素: