python 从Pandas数据框的每行中提取连续非空值的最长块

9o685dep  于 2023-02-11  发布在  Python
关注(0)|答案(3)|浏览(127)

假设我有一个结构类似于下面的Pandas Dataframe :

data = {
    'A' : [5.0, np.nan, 1.0],
    'B' : [7.0, np.nan, np.nan],
    'C' : [9.0, 2.0, 6.0],
    'D' : [np.nan, 4.0, 9.0],
    'E' : [np.nan, 6.0, np.nan],
    'F' : [np.nan, np.nan, np.nan],
    'G' : [np.nan, np.nan, 8.0]
}

df = pd.DataFrame(
    data,
    index=['11','22','33']
)

从每一行中,我想提取最长的连续非空值块,并将它们附加到列表中。
因此,这些行中的以下值:

row11: [5,7,9]    
row22: [2,4,6]    
row33: [6,9]

给我一个值列表:

[5.0, 7.0, 9.0, 2.0, 4.0, 6.0, 6.0, 9.0]

我当前的方法使用iterrows()first_valid_index()last_valid_index()

mylist = []
for i, r in df.iterrows():
    start = r.first_valid_index()
    end = r.last_valid_index()
    mylist.extend(r[start: end].values)

当有效数字被分块在一起时,例如row11row22,这种方法很好用。但是,当数字中散布着空值时,例如row33,我的方法就失败了。在这种情况下,我的方法提取整行,因为第一个和最后一个索引包含非空值。我的解决方案(错误地)输出了一个最终列表:

[5.0, 7.0, 9.0, 2.0, 4.0, 6.0, 1.0, nan, 6.0, 9.0, nan, nan, 8.0]

我有以下问题:
1.)我该如何应对row33示例中遇到的错误呢?
2.)有没有比使用iterrows()更有效的方法?我的实际数据有数千行。虽然它不一定太慢,但我总是谨慎地使用Pandas时诉诸迭代。

3vpjnl9f

3vpjnl9f1#

一个选项是使用groupby获取非NA的拉伸,使用max过滤最长的拉伸:

def get_longest(s):
    m = s.isna()
    return max(s[~m].groupby(m.cumsum()),
               key=lambda x: len(x[1])
              )[1].dropna().tolist()

out = df.apply(get_longest, axis=1)

输出:

11    [5.0, 7.0, 9.0]
22    [2.0, 4.0, 6.0]
33         [6.0, 9.0]
dtype: object
xlpyo6sf

xlpyo6sf2#

使用numpy.ma.masked_invalidnumpy.ma.clump_unmasked函数将行拆分为 * non-nan * 值的连续切片,并选择长度最大的切片:

res = df.apply(lambda x: x[max(np.ma.clump_unmasked(np.ma.masked_invalid(x.values)),
                               key=lambda sl: sl.stop - sl.start)].tolist(), axis=1)
11    [5.0, 7.0, 9.0]
22    [2.0, 4.0, 6.0]
33         [6.0, 9.0]
nukf8bse

nukf8bse3#

使用切片和列表解析的另一种方法:

mapper = df.columns.get_loc
c = df.diff(axis=1).notna().cumsum(1)
min_ = c.shift(-1,axis=1).gt(0).idxmax(1).map(mapper)
max_ = c.idxmax(1).map(mapper)
out = pd.Series({df.index[e] : 
                 list(df.iloc[e,np.r_[a:b+1]]) 
                 for e, (a,b) in enumerate(zip(min_,max_))})
print(out)
11    [5.0, 7.0, 9.0]
22    [2.0, 4.0, 6.0]
33         [6.0, 9.0]
dtype: object

相关问题