Tutorial 2: Inference

In [1]:
import libspn as spn

Building a Test Graph with Initialized Weights

In [2]:
iv_x = spn.IVs(num_vars=2, num_vals=2, name="iv_x")
sum_11 = spn.Sum((iv_x, [0,1]), name="sum_11")
sum_11.generate_weights([0.4, 0.6])
sum_12 = spn.Sum((iv_x, [0,1]), name="sum_12")
sum_12.generate_weights([0.1, 0.9])
sum_21 = spn.Sum((iv_x, [2,3]), name="sum_21")
sum_21.generate_weights([0.7, 0.3])
sum_22 = spn.Sum((iv_x, [2,3]), name="sum_22")
sum_22.generate_weights([0.8, 0.2])
prod_1 = spn.Product(sum_11, sum_21, name="prod_1")
prod_2 = spn.Product(sum_11, sum_22, name="prod_2")
prod_3 = spn.Product(sum_12, sum_22, name="prod_3")
root = spn.Sum(prod_1, prod_2, prod_3, name="root")
root.generate_weights([0.5, 0.2, 0.3])
iv_y = root.generate_ivs(name="iv_y")

Visualizing the SPN Graph

In [3]:
spn.display_spn_graph(root)

Add Value Ops

In [4]:
init_weights = spn.initialize_weights(root)
marginal_val = root.get_value(inference_type=spn.InferenceType.MARGINAL)
mpe_val = root.get_value(inference_type=spn.InferenceType.MPE)

Calculate Values

In [5]:
iv_x_arr = [[0, 1],
           [0, -1],
           [-1,-1]]

iv_y_arr = [[0], [-1], [-1]]

with spn.session() as (sess, _):
    init_weights.run()
    marginal_val_arr = sess.run(marginal_val, feed_dict={iv_x: iv_x_arr, iv_y: iv_y_arr})
    mpe_val_arr = sess.run(mpe_val, feed_dict={iv_x: iv_x_arr, iv_y: iv_y_arr})

print(marginal_val_arr)
print(mpe_val_arr)
[[ 0.06]
 [ 0.31]
 [ 1.  ]]
[[ 0.06      ]
 [ 0.14      ]
 [ 0.21600001]]

Add MPE State Ops

In [6]:
mpe_state = spn.MPEState(value_inference_type=spn.InferenceType.MPE)
iv_x_mpe, iv_y_mpe = mpe_state.get_state(root, iv_x, iv_y)
In [7]:
with spn.session() as (sess, _):
    init_weights.run()
    iv_x_mpe_arr, iv_y_mpe_arr = sess.run([iv_x_mpe, iv_y_mpe], 
                                              feed_dict={iv_x: iv_x_arr, iv_y: iv_y_arr})
    
print(iv_x_mpe_arr)
print(iv_y_mpe_arr)
[[0 1]
 [0 0]
 [1 0]]
[[0]
 [0]
 [2]]