# Tutorial 3: Learning¶

In [1]:
import libspn as spn
import tensorflow as tf


### Building a Test Graph with Random Weights¶

In [2]:
iv_x = spn.IVs(num_vars=2, num_vals=2, name="iv_x")
sum_11 = spn.Sum((iv_x, [0,1]), name="sum_11")
sum_12 = spn.Sum((iv_x, [0,1]), name="sum_12")
sum_21 = spn.Sum((iv_x, [2,3]), name="sum_21")
sum_22 = spn.Sum((iv_x, [2,3]), name="sum_22")
prod_1 = spn.Product(sum_11, sum_21, name="prod_1")
prod_2 = spn.Product(sum_11, sum_22, name="prod_2")
prod_3 = spn.Product(sum_12, sum_22, name="prod_3")
root = spn.Sum(prod_1, prod_2, prod_3, name="root")
iv_y = root.generate_ivs(name="iv_y")
spn.generate_weights(root, init_value=spn.ValueType.RANDOM_UNIFORM(0, 1))


### Visualizing the SPN Graph¶

In [3]:
spn.display_spn_graph(root)


### Specify Training Data¶

In [4]:
iv_x_arr=[[0,0],[0,0],[1,1],[1,1],[1,1],[0,1],[0,1],[0,1]]
iv_y_arr=[[-1]] * len(iv_x_arr)


In [5]:
learner = spn.EMLearner(root, value_inference_type = spn.InferenceType.MARGINAL, initial_accum_value=10)
init_weights = spn.initialize_weights(root)
init_learning = learner.initialize()
learn = learner.learn()
likelihood = tf.reduce_mean(learner.value.values[root])


### Run Learning¶

In [6]:
num_epochs=20
with spn.session() as (sess, _):
sess.run(init_weights)
sess.run(init_learning)
likelihoods=[]
for epoch in range(num_epochs):
likelihood_, _ = sess.run([likelihood, learn],
feed_dict={iv_x:iv_x_arr, iv_y:iv_y_arr})
likelihoods+=[likelihood_]
print("Avg. Likelihood: %s" % (likelihood_))

Avg. Likelihood: -1.4428
Avg. Likelihood: -1.30811
Avg. Likelihood: -1.26584
Avg. Likelihood: -1.23868
Avg. Likelihood: -1.21934
Avg. Likelihood: -1.20463
Avg. Likelihood: -1.19297
Avg. Likelihood: -1.18342
Avg. Likelihood: -1.17542
Avg. Likelihood: -1.16862
Avg. Likelihood: -1.16274
Avg. Likelihood: -1.15761
Avg. Likelihood: -1.1531
Avg. Likelihood: -1.14908
Avg. Likelihood: -1.1455
Avg. Likelihood: -1.14227
Avg. Likelihood: -1.13935
Avg. Likelihood: -1.13669
Avg. Likelihood: -1.13427
Avg. Likelihood: -1.13205

In [7]:
# Plot training likelihoods
plt.plot(likelihoods)
plt.show()