我正在尝试写一个算法的循环信念传播。我使用Numpy和pGMpy。目标是首先初始化从节点到因子的消息。然后在每次迭代中,您将计算因子到节点的消息,然后更新从节点到因子的消息。
对于从节点到因子的消息(M_v_to_f)和从因子到节点的消息(M_f_to_v),我使用元组作为键。M_v_to_f将具有M_v_to_f[('x2 ',〈表示0x7ff6debe3490处的phi(x2:3,x3:2,x4:2)的离散因子〉)]。在一次迭代之后,更新M_v_to_f。
然而,在第二次迭代中,我遇到了一个关键错误问题。因此,我打印出了可能引发密钥错误的密钥,并将密钥打印在M_v_to_f中。问题是我看到了一个匹配,但我不知道为什么Python没有响应它。这表明我真的可以看到一把钥匙。
下面是代码,以防它有帮助:
import numpy as np
import copy
from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.factors import factor_product
from pgmpy.readwrite import BIFReader
def make_debug_graph():
G = FactorGraph()
G.add_nodes_from(['x1', 'x2', 'x3', 'x4'])
# add factors
phi1 = DiscreteFactor(['x1', 'x2'], [2, 3], np.array([0.5, 0.7, 0.2,
0.5, 0.3, 0.8]))
phi2 = DiscreteFactor(['x2', 'x3', 'x4'], [3, 2, 2], np.array([0.2, 0.25, 0.70, 0.30,
0.4, 0.25, 0.15, 0.65,
0.4, 0.50, 0.15, 0.05]))
phi3 = DiscreteFactor(['x3'], [2], np.array([0.5,
0.5]))
phi4 = DiscreteFactor(['x4'], [2], np.array([0.4,
0.6]))
G.add_factors(phi1, phi2, phi3, phi4)
G.add_nodes_from([phi1, phi2, phi3, phi4])
G.add_edges_from([('x1', phi1), ('x2', phi1), ('x2', phi2), ('x3', phi2), ('x4', phi2), ('x3', phi3), ('x4', phi4)])
return G
G = make_debug_graph()
def _custom_reshape(arr, shape_len, axis):
shape = tuple([1 if i != axis else arr.shape[0] for i in range(shape_len)])
return np.reshape(arr, shape)
# initialize M_v_to_f
M_v_to_f = {}
for var in G.get_variable_nodes():
for factor in G.neighbors(var):
key = (var, factor)
print(key)
print(M_v_to_f)
M_v_to_f[key] = np.ones(G.get_cardinality(var))
for epoch in range(10):
print(epoch)
M_f_to_v = {}
for factor in G.get_factor_nodes():
num_axis = len(factor.values.shape)
for j, to_node in enumerate(factor.scope()):
incoming_msg = []
for k, in_node in enumerate(factor.scope()):
if j==k: continue
key = (in_node, factor)
# Error on here on the second iteration.
incoming_msg.append(_custom_reshape(M_v_to_f[key], num_axis, k))
outgoing = factor.values
for msg in incoming_msg:
print(msg.shape)
outgoing *= msg
sum_axis = list(range(num_axis))
sum_axis.remove(j)
outgoing = np.sum(outgoing, axis = tuple(sum_axis))
outgoing /= np.sum(outgoing)
key = (factor, to_node)
M_f_to_v[key] = outgoing
# update the M_v_to_f
for var in G.get_variable_nodes():
for j, factor in enumerate(G.neighbors(var)):
incoming_msg = []
for k, in_fact in enumerate(G.neighbors(var)):
if j == k: continue
key = (in_fact, var)
incoming_msg.append(M_f_to_v[key])
if incoming_msg:
outgoing = incoming_msg[0]
for msg in incoming_msg[1:]:
outgoing *= msg
outgoing /= np.sum(outgoing)
key = (var,factor)
M_v_to_f[key] = outgoing
enter image description here
我已经尝试了不同的方法来使用这些键(先定义元组)。我真的不知道该怎么办。
至于print语句,可以看到关键是:
('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)
并且M_v_to_f是:
{('x2', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([0.3625, 0.3625, 0.275 ]), **('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)**: array([0.33333333, 0.33333333, 0.33333333]), ('x3', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.5, 0.5]), ('x3', <DiscreteFactor representing phi(x3:2) at 0x7f94f90db1f0>): array([0.5, 0.5]), ('x1', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([1., 1.]), ('x4', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.4, 0.6]), ('x4', <DiscreteFactor representing phi(x4:2) at 0x7f94f90db130>): array([0.5, 0.5])}
1条答案
按热度按时间qybjjes11#
你正在改变你的dict键:
这会中断dict查找。