import numpy as np
from scipy.cluster.hierarchy import leaders, ClusterNode, to_tree
from typing import Optional, Tuple, List
def get_node(
linkage_matrix: np.ndarray,
clusters_array: np.ndarray,
cluster_num: int
) -> ClusterNode:
"""
Returns ClusterNode (the node of the cluster tree) corresponding to the given cluster number.
:param linkage_matrix: linkage matrix
:param clusters_array: array of cluster numbers for each point
:param cluster_num: id of cluster for which we want to get ClusterNode
:return: ClusterNode corresponding to the given cluster number
"""
L, M = leaders(linkage_matrix, clusters_array)
idx = L[M == cluster_num]
tree = to_tree(linkage_matrix)
result = search_for_node(tree, idx)
assert result
return result
def search_for_node(
cur_node: Optional[ClusterNode],
target: int
) -> Optional[ClusterNode]:
"""
Searches for the node with the given id of the cluster in the given subtree.
:param cur_node: root of the cluster subtree to search for target node
:param target: id of the target node (cluster)
:return: ClusterNode with the given id if it exists in the subtree, None otherwise
"""
if cur_node is None:
return False
if cur_node.get_id() == target:
return cur_node
left = search_for_node(cur_node.get_left(), target)
if left:
return left
return search_for_node(cur_node.get_right(), target)
def get_LCA(
node_1: ClusterNode,
node_2: ClusterNode,
root: ClusterNode
) -> ClusterNode:
"""
Returns the lowest common ancestor of the given ClusterNodes in the subtree of root.
:param node_1: ClusterNode
:param node_2: ClusterNode
:param root: ClusterNode - root of the subtree
:return: the lowest common ancestor of the given ClusterNodes in the subtree of root
"""
if not root:
return root
left = get_LCA(node_1, node_2, root.get_left())
right = get_LCA(node_1, node_2, root.get_right())
if root.get_id() == node_1.get_id() or root.get_id() == node_2.get_id():
return root
if left and right:
return root
if left:
return left
if right:
return right
def get_num_steps(
ancestor_node: ClusterNode,
descendant_node: ClusterNode
) -> Optional[int]:
"""
Returns number of steps from the ancestor node to the descendant node.
:param ancestor_node: ClusterNode - ancestor node
:param descendant_node: ClusterNode - descendant node
:return: number of steps from the ancestor node to the descendant node or None if the
descendant node is not a descendant of the ancestor node
"""
if ancestor_node is None or descendant_node is None:
return None
if ancestor_node.get_id() == descendant_node.get_id():
return 0
left = get_num_steps(ancestor_node.get_left(), descendant_node)
if left is not None:
return left + 1
right = get_num_steps(ancestor_node.get_right(), descendant_node)
if right is not None:
return right + 1
return None
def get_leaves_ids(node: ClusterNode) -> List[int]:
"""
Returns ids of all samples (leaf nodes) that belong to the given ClusterNode (belong to the node's subtree).
:param node: ClusterNode for which we want to get ids of samples
:return: list of ids of samples that belong to the given ClusterNode
"""
res = []
def dfs(cur: Optional[ClusterNode]):
if cur is None:
return
if cur.is_leaf():
res.append(cur.get_id())
return
dfs(cur.get_left())
dfs(cur.get_right())
dfs(node)
return res
下面是使用这些hepler函数来获取两个集群的祖先的示例:
from sklearn.datasets import make_blobs
import scipy.cluster.hierarchy as shc
import numpy as np
data = make_blobs(centers=10, cluster_std=0.9, n_samples=3000, random_state=0)
lkage = shc.linkage(data[0], method='ward')
# For cluster_ids we are using n_clusters=n_samples.
# However any number of cluster works here
cluster_ids = shc.fcluster(lkage, t=data[0].shape[0], criterion='maxclust')
# Change seed to take other random clusters
seed = 3
np.random.seed(3)
# Take two random cluster_ids
cluster_id_1, cluster_id_2 = np.random.randint(cluster_ids.min(), cluster_ids.max(), size=(2))
ancestor = get_LCA(
node_1=get_node(lkage, cluster_ids, cluster_id_1),
node_2=get_node(lkage, cluster_ids, cluster_id_2),
root=shc.to_tree(lkage)
)
1条答案
按热度按时间3z6pesqy1#
不幸的是,没有开箱即用的方法来做到这一点。你需要一些样板代码:
下面是使用这些hepler函数来获取两个集群的祖先的示例:
距离:
您可以使用以下代码可视化结果: