python Pandas:查找从N个连续到M个连续的间隔距离

k97glaaz  于 2023-02-02  发布在  Python
关注(0)|答案(1)|浏览(108)

TLDR版本:

我有一个专栏,如下所示,

[2, 2, 0, 0, 0, 2, 2, 0, 3, 3, 3, 0, 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 0, 3, 3, 3]
# There is the probability that has more sequences, like 4, 5, 6, 7, 8...

我需要一个参数为n,m的函数,如果我使用n=2,m=3,我将得到2和3之间的距离,然后分组后的最终结果可能是:

[6, 9]

详细版本

下面是测试用例。我正在编写一个函数,它将给予n,m,然后生成每个连续序列之间的距离列表。目前,这个函数只能使用一个参数N(即从N个连续序列到另一个N个连续序列的距离)。我想对这个函数做一些修改,使其接受M

dummy = [1,1,0,0,0,1,1,0,1,1,1,0,0,1,1,0,0,0,0,1,1,0,0,0,1,1,1]

df = pd.DataFrame({'a': dummy})

我现在写的东西

def get_N_seq_stat(df, N=2, M=3):
    df["c1"] = (
        df.groupby(df.a.ne(df.a.shift()).cumsum())["a"]
        .transform("size")
        .where(df.a.eq(1), 0)
    )
    df["c2"] = np.where(df.c1.ne(N) , 1, 0)
    df["c3"] = df["c2"].ne(df["c2"].shift()).cumsum()

    result = df.loc[df["c2"] == 1].groupby("c3")["c2"].count().tolist()

    # if last N rows are not consequence shouldn't add last.
    if not (df["c1"].tail(N) == N).all():
        del result[-1]
    if not (df["c1"].head(N) == N).all():
        del result[0]
    return result

如果我设置N=2,M=3(从2连续到3连续),那么理想值返回从这将是[6,9],因为下面.

dummy = [1,1,**0,0,0,1,1,0,**1,1,1,0,0,1,1,**0,0,0,0,1,1,0,0,0,**1,1,1]

目前,如果我设置N =2,返回列表将是[3,6,4],因为

dummy = [1,1,**0,0,0,**1,1,**0,1,1,1,0,0,**1,1,**0,0,0,0,**1,1,0,0,0,1,1,1]
fruv7luv

fruv7luv1#

我会这样修改你的代码:

def get_N_seq_stat(df, N=2, M=3, debug=False):
    # get number of consecutive 1s
    c1 = (
        df.groupby(df.a.ne(df.a.shift()).cumsum())["a"]
        .transform("size")
        .where(df.a.eq(1), 0)
    )

    # find stretches between N and M
    m1 = c1.eq(N)
    m2 = c1.eq(M)
    c2 = pd.Series(np.select([m1.shift()&~m1, m2], [True, False], np.nan),
                   index=df.index).ffill().eq(1)

    # debug mode to understand how this works
    if debug:
        return df.assign(c1=c1, c2=c2,
                          length=c2[c2].groupby(c2.ne(c2.shift()).cumsum())
                                       .transform('size')
                        )

    # get the length of the stretches
    return c2[c2].groupby(c2.ne(c2.shift()).cumsum()).size().to_list()

get_N_seq_stat(df, N=2, M=3)

输出:[6, 9]
中间体c1c2和长度:

get_N_seq_stat(df, N=2, M=3, debug=True)

    a  c1     c2  length
0   1   2  False     NaN
1   1   2  False     NaN
2   0   0   True     6.0
3   0   0   True     6.0
4   0   0   True     6.0
5   1   2   True     6.0
6   1   2   True     6.0
7   0   0   True     6.0
8   1   3  False     NaN
9   1   3  False     NaN
10  1   3  False     NaN
11  0   0  False     NaN
12  0   0  False     NaN
13  1   2  False     NaN
14  1   2  False     NaN
15  0   0   True     9.0
16  0   0   True     9.0
17  0   0   True     9.0
18  0   0   True     9.0
19  1   2   True     9.0
20  1   2   True     9.0
21  0   0   True     9.0
22  0   0   True     9.0
23  0   0   True     9.0
24  1   3  False     NaN
25  1   3  False     NaN
26  1   3  False     NaN

相关问题