numpy 使用元组时Python中的键错误,但当我打印的时候,我在琴键上看到它

sycxhyv7  于 2023-04-30  发布在  Python
关注(0)|答案(1)|浏览(107)

我正在尝试写一个算法的循环信念传播。我使用Numpy和pGMpy。目标是首先初始化从节点到因子的消息。然后在每次迭代中,您将计算因子到节点的消息,然后更新从节点到因子的消息。
对于从节点到因子的消息(M_v_to_f)和从因子到节点的消息(M_f_to_v),我使用元组作为键。M_v_to_f将具有M_v_to_f[('x2 ',〈表示0x7ff6debe3490处的phi(x2:3,x3:2,x4:2)的离散因子〉)]。在一次迭代之后,更新M_v_to_f。
然而,在第二次迭代中,我遇到了一个关键错误问题。因此,我打印出了可能引发密钥错误的密钥,并将密钥打印在M_v_to_f中。问题是我看到了一个匹配,但我不知道为什么Python没有响应它。这表明我真的可以看到一把钥匙。
下面是代码,以防它有帮助:

import numpy as np
import copy
from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.factors import factor_product
from pgmpy.readwrite import BIFReader

def make_debug_graph():
    
    G = FactorGraph()
    G.add_nodes_from(['x1', 'x2', 'x3', 'x4'])
    
    # add factors 
    phi1 = DiscreteFactor(['x1', 'x2'], [2, 3], np.array([0.5, 0.7, 0.2,
                                                          0.5, 0.3, 0.8]))
    phi2 = DiscreteFactor(['x2', 'x3', 'x4'], [3, 2, 2], np.array([0.2, 0.25, 0.70, 0.30,
                                                                   0.4, 0.25, 0.15, 0.65,
                                                                   0.4, 0.50, 0.15, 0.05]))
    phi3 = DiscreteFactor(['x3'], [2], np.array([0.5, 
                                                 0.5]))
    phi4 = DiscreteFactor(['x4'], [2], np.array([0.4, 
                                                 0.6]))
    G.add_factors(phi1, phi2, phi3, phi4)
    
    G.add_nodes_from([phi1, phi2, phi3, phi4])
    G.add_edges_from([('x1', phi1), ('x2', phi1), ('x2', phi2), ('x3', phi2), ('x4', phi2), ('x3', phi3), ('x4', phi4)])
    
    return G
G = make_debug_graph()
def _custom_reshape(arr, shape_len, axis):
    shape = tuple([1 if i != axis else arr.shape[0] for i in range(shape_len)])
    return np.reshape(arr, shape)
# initialize M_v_to_f
M_v_to_f = {}
for var in G.get_variable_nodes():
    for factor in G.neighbors(var):
        key = (var, factor)
        print(key)
        print(M_v_to_f)
        M_v_to_f[key] = np.ones(G.get_cardinality(var))

for epoch in range(10):
    print(epoch)
    M_f_to_v = {}
    for factor in G.get_factor_nodes():
        num_axis = len(factor.values.shape)
        for j, to_node in enumerate(factor.scope()):
            incoming_msg = []
            for k, in_node in enumerate(factor.scope()):
                if j==k: continue
                key = (in_node, factor) 
# Error on here on the second iteration.
                incoming_msg.append(_custom_reshape(M_v_to_f[key], num_axis, k))
            outgoing = factor.values
            for msg in incoming_msg:
                print(msg.shape)
                outgoing *= msg
            sum_axis = list(range(num_axis))
            sum_axis.remove(j)
            outgoing = np.sum(outgoing, axis = tuple(sum_axis))
            outgoing /= np.sum(outgoing)
            key = (factor, to_node)
            M_f_to_v[key] = outgoing
    # update the M_v_to_f
    for var in G.get_variable_nodes():
        for j, factor in enumerate(G.neighbors(var)):
            incoming_msg = []
            for k, in_fact in enumerate(G.neighbors(var)):
                if j == k: continue
                key = (in_fact, var)
                incoming_msg.append(M_f_to_v[key])
            
            if incoming_msg:
                outgoing = incoming_msg[0]
                for msg in incoming_msg[1:]:
                    outgoing *= msg
                outgoing /= np.sum(outgoing)
                key = (var,factor)
                M_v_to_f[key] = outgoing

enter image description here
我已经尝试了不同的方法来使用这些键(先定义元组)。我真的不知道该怎么办。
至于print语句,可以看到关键是:

('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)

并且M_v_to_f是:

{('x2', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([0.3625, 0.3625, 0.275 ]), **('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)**: array([0.33333333, 0.33333333, 0.33333333]), ('x3', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.5, 0.5]), ('x3', <DiscreteFactor representing phi(x3:2) at 0x7f94f90db1f0>): array([0.5, 0.5]), ('x1', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([1., 1.]), ('x4', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.4, 0.6]), ('x4', <DiscreteFactor representing phi(x4:2) at 0x7f94f90db130>): array([0.5, 0.5])}
qybjjes1

qybjjes11#

你正在改变你的dict键:

outgoing = factor.values
for msg in incoming_msg:
    print(msg.shape)
    outgoing *= msg

这会中断dict查找。

相关问题