python 从panda Dataframe 中获取序列的训练/有效/测试索引

4sup72z8  于 2023-03-16  发布在  Python
关注(0)|答案(1)|浏览(85)

我希望将train/valid/test列、最后一个序列设置为test,将前一个序列设置为valid

+----+-------+----+------+
|user|cnt_seq|item| mode |
+----+-------+----+------+
|   1|      1|   4| train|
|   1|      1|   7| train|
|   1|      2|   2| train|
|   1|      2|   9| train|
|   1|      3|   8| valid|
|   1|      4|   3|  test|
|   1|      4|  10|  test|
|   2|      1|   6| train|
|   2|      2|   7| valid|
|   2|      3|   1|  test|
+----+-------+----+------+

每个用户都有不同的cnt_seq长度和cnt_seq个数,所以我的代码是...

test_users = [1, 2]
mdict = df.groupby('user')['cnt_seq'].max().to_dict()
test_idx = [(k, v) for k, v in mdict.items() if k in test_users]
valid_idx = [(k, v-1) for k, v in mdict.items() if k in test_users] 

df['mode'] = 'train'

for i, j in valid_idx:
    df.loc[(df.user== i) & (df.cnt_seq == j), 'mode'] = 'valid'
for i, j in test_idx:
    df.loc[(df.user== i) & (df.cnt_seq == j), 'mode'] = 'test'

但我认为它不太好,因为它需要两个for循环来进行valid/test。我能得到更简单的代码吗?

mrfwxfqh

mrfwxfqh1#

试试这个:

import pandas as pd
import numpy as np

# Your original df
data = [{'user': 1, 'cnt_seq': 1, 'item': 4},
        {'user': 1, 'cnt_seq': 1, 'item': 7},
        {'user': 1, 'cnt_seq': 2, 'item': 2},
        {'user': 1, 'cnt_seq': 2, 'item': 9},
        {'user': 1, 'cnt_seq': 3, 'item': 8},
        {'user': 1, 'cnt_seq': 4, 'item': 3},
        {'user': 1, 'cnt_seq': 4, 'item': 10},
        {'user': 2, 'cnt_seq': 1, 'item': 6},
        {'user': 2, 'cnt_seq': 2, 'item': 7},
        {'user': 2, 'cnt_seq': 3, 'item': 1}]
df = pd.DataFrame(data)

# Calculate the maximum sequence number for each user
group_max = df.groupby(['user'])['cnt_seq'].transform('max')

# Assign modes to each sequence based on the maximum sequence number
df['mode'] = np.select(
    [
        df['cnt_seq'] == group_max,          # test set corresponds to the last sequence
        df['cnt_seq'] == group_max-1         # validation set corresponds to the previous sequence
    ],
    ['test', 'valid'],                       # corresponding modes
    'train'                                   # all other sequences are assigned train set
)

# Print the results
print(df)

相关问题