You can do marginal and MPE inferences using SPNs.
import libspn as spn
import tensorflow as tf
indicator_x = spn.IndicatorLeaf(
num_vars=2, num_vals=2, name="indicator_x")
# Build structure and attach weight
sum_11 = spn.Sum((indicator_x, [0,1]), name="sum_11")
sum_11.generate_weights(initializer=tf.initializers.constant([0.4, 0.6]))
sum_12 = spn.Sum((indicator_x, [0,1]), name="sum_12")
sum_12.generate_weights(initializer=tf.initializers.constant([0.1, 0.9]))
sum_21 = spn.Sum((indicator_x, [2,3]), name="sum_21")
sum_21.generate_weights(initializer=tf.initializers.constant([0.7, 0.3]))
sum_22 = spn.Sum((indicator_x, [2,3]), name="sum_22")
sum_22.generate_weights(initializer=tf.initializers.constant([0.8, 0.2]))
prod_1 = spn.Product(sum_11, sum_21, name="prod_1")
prod_2 = spn.Product(sum_11, sum_22, name="prod_2")
prod_3 = spn.Product(sum_12, sum_22, name="prod_3")
root = spn.Sum(prod_1, prod_2, prod_3, name="root")
root.generate_weights(initializer=tf.initializers.constant([0.5, 0.2, 0.3]))
# Connect a latent indicator
indicator_y = root.generate_latent_indicators(name="indicator_y") # Can be added manually
# Inspect
print(root.get_num_nodes())
print(root.get_scope())
print(root.is_valid())
The visualization below uses graphviz
. Depending on your setup (e.g. jupyter lab
vs. jupyter notebook
) this might fail to show. At least Chrome
+ jupyter notebook
seems to work.
# Visualize SPN graph
spn.display_spn_graph(root)
init_weights = spn.initialize_weights(root)
marginal_val = root.get_value(inference_type=spn.InferenceType.MARGINAL)
mpe_val = root.get_value(inference_type=spn.InferenceType.MPE)
indicator_x_data = [
[0, 1],
[0, -1],
[-1,-1]
]
indicator_y_data = [[0], [-1], [-1]]
with tf.Session() as sess:
init_weights.run()
marginal_val_arr = sess.run(marginal_val, feed_dict={indicator_x: indicator_x_data, indicator_y: indicator_y_data})
mpe_val_arr = sess.run(mpe_val, feed_dict={indicator_x: indicator_x_data, indicator_y: indicator_y_data})
print(marginal_val_arr)
print(mpe_val_arr)