Pandas:基于不同列中的公共组为记录分配相同的集群ID

rn0zuynd  于 2023-09-29  发布在  其他
关注(0)|答案(1)|浏览(75)

我有一个这样的dataframe:
| 指数|一|B| C|
| --|--|--|--|
| 1 | 111 | 222 | 111 |
| 2 | 111 | 222 | 222 |
| 3 | 111 | 111 | 555 |
| 4 | 222 | 222 | 444 |
| 5 | 222 | 333 | 111 |
| 6 | 222 | 444 | 333 |
| 7 | 333 | 555 | 777 |
| 8 | 444 | 666 | 777 |

df = pd.DataFrame({
  'A': [111,111,111,222,222,222,333,444], 
  'B': [222,222,111,222,333,444,555,666],
  'C': [111,222,555,444,111,333,777,777]
})

我想创建新的列'集群'和分配相同的ID记录直接连接或通过共同组中的一列。
意思是,例如,我们看到前3个元素在'A'中由相同的组连接,但它们也连接到列'B'中具有相同组'222'、'111'的其他记录。以及所有在'C'列中有'111'、'222'、'555'的记录。
因此,基本上,所有前6个元素应该具有相同的集群ID。
| 指数|一|B| C|集群|
| --|--|--|--|--|
| 1 | 111 | 222 | 111 | 1 |
| 2 | 111 | 222 | 222 | 1 |
| 3 | 111 | 111 | 555 | 1 |
| 4 | 222 | 222 | 444 | 1 |
| 5 | 222 | 333 | 111 | 1 |
| 6 | 222 | 444 | 333 | 1 |
| 7 | 333 | 555 | 777 | 2 |
| 8 | 444 | 666 | 777 | 2 |
记录4-6连接到1-3,因为它们在列A中形成一个组,并且它们通过列B和C连接到先前的记录。
我在成对的列上使用多个结果应用函数,但现在想在这里应用连接的组件,但不知道如何做到这一点。
此外,主要的问题是,这个数据集是巨大的,> 30 000 000记录。
感谢任何帮助。

k97glaaz

k97glaaz1#

您确实可以使用networkx.connected_components来实现这一点,在重塑DataFrame以构建连续的边对之后。不过,对于3300万行来说,这不会那么快:

import networkx as nx

tmp = (df
   .melt(ignore_index=False, value_name='source')
   .assign(target=lambda d: d.groupby(level=0)['source'].shift())
   .dropna(subset=['source', 'target'])
   .drop_duplicates(subset=['source', 'target'])
)

G = nx.from_pandas_edgelist(tmp)

groups = {n: next(iter(g)) for g in nx.connected_components(G) for n in g}

df['cluster'] = df['A'].map(groups)

输出量:

A   B    C cluster
0  111  02  001     002
1  111  02  002     002
2  111  01  005     002
3  222  02  004     002
4  222  03  001     002
5  222  04  003     002

图表:

将不同列中的相同值视为不同节点:

有几个选项,一个简单的是使用字符串并在节点中为列名添加前缀

tmp1 = df.astype(str).radd([f'{c}|' for c in df])

tmp = (tmp1
   .melt(ignore_index=False, value_name='source')
   .assign(target=lambda d: d.groupby(level=0)['source'].shift())
   .dropna(subset=['source', 'target'])
   .drop_duplicates(subset=['source', 'target'])
)

G = nx.from_pandas_edgelist(tmp)

groups = {n: i for i, g in enumerate(nx.connected_components(G), start=1)
          for n in g}

df['cluster'] = tmp1['A'].map(groups)

输出量:

A    B    C  cluster
0  111  222  111        1
1  111  222  222        1
2  111  111  555        1
3  222  222  444        1
4  222  333  111        1
5  222  444  333        1
6  333  555  777        2
7  444  666  777        2

图表:

相关问题