
ru9i0ody  于 2023-04-30  发布在  其他





import numpy as np
from scipy.cluster.hierarchy import leaders, ClusterNode, to_tree
from typing import Optional, Tuple, List

def get_node(
    linkage_matrix: np.ndarray,
    clusters_array: np.ndarray,
    cluster_num: int
) -> ClusterNode:
    Returns ClusterNode (the node of the cluster tree) corresponding to the given cluster number.
    :param linkage_matrix: linkage matrix
    :param clusters_array: array of cluster numbers for each point
    :param cluster_num: id of cluster for which we want to get ClusterNode
    :return: ClusterNode corresponding to the given cluster number
    L, M = leaders(linkage_matrix, clusters_array)
    idx = L[M == cluster_num]
    tree = to_tree(linkage_matrix)
    result = search_for_node(tree, idx)
    assert result
    return result

def search_for_node(
    cur_node: Optional[ClusterNode],
    target: int
) -> Optional[ClusterNode]:
    Searches for the node with the given id of the cluster in the given subtree.
    :param cur_node: root of the cluster subtree to search for target node
    :param target: id of the target node (cluster)
    :return: ClusterNode with the given id if it exists in the subtree, None otherwise
    if cur_node is None:
        return False
    if cur_node.get_id() == target:
        return cur_node
    left = search_for_node(cur_node.get_left(), target)
    if left:
        return left
    return search_for_node(cur_node.get_right(), target)

def get_LCA(
        node_1: ClusterNode,
        node_2: ClusterNode,
        root: ClusterNode
) -> ClusterNode:
    Returns the lowest common ancestor of the given ClusterNodes in the subtree of root.
    :param node_1: ClusterNode
    :param node_2: ClusterNode
    :param root: ClusterNode - root of the subtree
    :return: the lowest common ancestor of the given ClusterNodes in the subtree of root
    if not root:
        return root
    left = get_LCA(node_1, node_2, root.get_left())
    right = get_LCA(node_1, node_2, root.get_right())
    if root.get_id() == node_1.get_id() or root.get_id() == node_2.get_id():
        return root
    if left and right:
        return root
    if left:
        return left
    if right:
        return right

def get_num_steps(
    ancestor_node: ClusterNode,
    descendant_node: ClusterNode
) -> Optional[int]:
    Returns number of steps from the ancestor node to the descendant node.
    :param ancestor_node: ClusterNode - ancestor node
    :param descendant_node: ClusterNode - descendant node
    :return: number of steps from the ancestor node to the descendant node or None if the 
    descendant node is not a descendant of the ancestor node
    if ancestor_node is None or descendant_node is None:
        return None
    if ancestor_node.get_id() == descendant_node.get_id():
        return 0
    left = get_num_steps(ancestor_node.get_left(), descendant_node)
    if left is not None:
        return left + 1
    right = get_num_steps(ancestor_node.get_right(), descendant_node)
    if right is not None:
        return right + 1
    return None

def get_leaves_ids(node: ClusterNode) -> List[int]:
    Returns ids of all samples (leaf nodes) that belong to the given ClusterNode (belong to the node's subtree).
    :param node: ClusterNode for which we want to get ids of samples
    :return: list of ids of samples that belong to the given ClusterNode
    res = []

    def dfs(cur: Optional[ClusterNode]):
        if cur is None:
        if cur.is_leaf():
    return res


from sklearn.datasets import make_blobs
import scipy.cluster.hierarchy as shc
import numpy as np
data = make_blobs(centers=10, cluster_std=0.9, n_samples=3000, random_state=0)
lkage = shc.linkage(data[0], method='ward')
# For cluster_ids we are using n_clusters=n_samples.
# However any number of cluster works here
cluster_ids = shc.fcluster(lkage, t=data[0].shape[0], criterion='maxclust')

# Change seed to take other random clusters
seed = 3
# Take two random cluster_ids
cluster_id_1, cluster_id_2 = np.random.randint(cluster_ids.min(), cluster_ids.max(), size=(2))
ancestor = get_LCA(
    node_1=get_node(lkage, cluster_ids, cluster_id_1),
    node_2=get_node(lkage, cluster_ids, cluster_id_2),


get_num_steps(ancestor, get_node(lkage, cluster_ids, cluster_id_1))


from matplotlib import pyplot as plt
# Set initial_clusters_dot_size to 1 if initial clusters contain many dots
initial_clusters_dot_size = 50
color = np.zeros(shape=(data[0].shape[0]))
color[get_leaves_ids(ancestor)] = 3
color[cluster_ids==cluster_id_1] = 1
color[cluster_ids==cluster_id_2] = 2
sizes = np.ones(shape=(data[0].shape[0]))
sizes[cluster_ids==cluster_id_1] = initial_clusters_dot_size
sizes[cluster_ids==cluster_id_2] = initial_clusters_dot_size
plt.scatter(data[0][:,0], data[0][:,1], s=sizes, c=color)
