Tutorial 2: Inference

You can do marginal and MPE inferences using SPNs.

In [1]:
import libspn as spn
import tensorflow as tf

Build the SPN

In [2]:
indicator_x = spn.IndicatorLeaf(
    num_vars=2, num_vals=2, name="indicator_x")

# Build structure and attach weight
sum_11 = spn.Sum((indicator_x, [0,1]), name="sum_11")
sum_11.generate_weights(initializer=tf.initializers.constant([0.4, 0.6]))
sum_12 = spn.Sum((indicator_x, [0,1]), name="sum_12")
sum_12.generate_weights(initializer=tf.initializers.constant([0.1, 0.9]))
sum_21 = spn.Sum((indicator_x, [2,3]), name="sum_21")
sum_21.generate_weights(initializer=tf.initializers.constant([0.7, 0.3]))
sum_22 = spn.Sum((indicator_x, [2,3]), name="sum_22")
sum_22.generate_weights(initializer=tf.initializers.constant([0.8, 0.2]))
prod_1 = spn.Product(sum_11, sum_21, name="prod_1")
prod_2 = spn.Product(sum_11, sum_22, name="prod_2")
prod_3 = spn.Product(sum_12, sum_22, name="prod_3")
root = spn.Sum(prod_1, prod_2, prod_3, name="root")
root.generate_weights(initializer=tf.initializers.constant([0.5, 0.2, 0.3]))

# Connect a latent indicator
indicator_y = root.generate_latent_indicators(name="indicator_y") # Can be added manually
[WARNING] [tensorflow:__getattr__] From /home/jos/spn/libspn/libspn/graph/node.py:40: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

[WARNING] [tensorflow:__getattr__] From /home/jos/spn/libspn/libspn/graph/leaf/indicator.py:63: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

<tensorflow.python.ops.init_ops.Constant object at 0x7f7b50f429b0>

Inspect

In [3]:
# Inspect
print(root.get_num_nodes())
print(root.get_scope())
print(root.is_valid())
15
[Scope({indicator_x:1, indicator_x:0, indicator_y:0})]
True

Visualize the SPN Graph

The visualization below uses graphviz. Depending on your setup (e.g. jupyter lab vs. jupyter notebook) this might fail to show. At least Chrome + jupyter notebook seems to work.

In [4]:
# Visualize SPN graph
spn.display_spn_graph(root)

Initialize weights and build inference Ops

In [5]:
init_weights = spn.initialize_weights(root)
marginal_val = root.get_value(inference_type=spn.InferenceType.MARGINAL)
mpe_val = root.get_value(inference_type=spn.InferenceType.MPE)
[WARNING] [tensorflow:__getattr__] From /home/jos/spn/libspn/libspn/graph/leaf/indicator.py:91: The name tf.log is deprecated. Please use tf.math.log instead.

[WARNING] [tensorflow:new_func] From /home/jos/.local/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py:2403: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
In [6]:
indicator_x_data = [
    [0, 1],
    [0, -1],
    [-1,-1]
]

indicator_y_data = [[0], [-1], [-1]]

with tf.Session() as sess:
    init_weights.run()
    marginal_val_arr = sess.run(marginal_val, feed_dict={indicator_x: indicator_x_data, indicator_y: indicator_y_data})
    mpe_val_arr = sess.run(mpe_val, feed_dict={indicator_x: indicator_x_data, indicator_y: indicator_y_data})

print(marginal_val_arr)
print(mpe_val_arr)
[[0.06      ]
 [0.30999994]
 [1.        ]]
[[0.06 ]
 [0.14 ]
 [0.216]]