Wicker Convolutional SPNs for continuous MNIST data

This notebook shows how to build Wicker Convolutional SPNs (WCSPNs) and use them to classifiy digits with the MNIST dataset.

Setting up the imports and preparing the data

We load the data from tf.keras.datasets. Preprocessing consists of flattening and binarization of the data.

In [ ]:
%matplotlib inline
import libspn as spn
import tensorflow as tf
import numpy as np
from libspn.examples.utils.dataiterator import DataIterator

# Load
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()

def binarize(x):
    return x / 255.

def flatten(x):
    return x.reshape(-1, np.prod(x.shape[1:]))

def preprocess(x, y):
    return binarize(flatten(x)), np.expand_dims(y, axis=1)

# Preprocess
train_x, train_y = preprocess(train_x, train_y)
test_x, test_y = preprocess(test_x, test_y)

Defining the hyperparameters

Some hyperparameters for the SPN.

  • num_vars corresponds to the number of input variables (the number of pixels in the case of MNIST).
  • num_leaf_components is the number of distribution components in the normal leafs
  • inference_type determines the kind of forward inference where spn.InferenceType.MARGINAL corresponds to sum nodes marginalizing their inputs. spn.InferenceType.MPE would correspond to having max nodes instead.
  • learning_rate is the learning rate for the Adam optimizer
  • scale_init, initial scale value for the NormalLeaf node. This parameter greatly determines the stability of the training process
  • num_classes, batch_size and num_epochs should be obvious:)
In [ ]:
# Number of variables
num_vars = train_x.shape[1]
# Number of different values at leaf (binary here, so 2)
num_leaf_components = 4
# Inference type (can also be spn.InferenceType.MPE) where 
# sum nodes are turned into max nodes
inference_type = spn.InferenceType.MARGINAL
# Adam optimizer parameters
learning_rate = 1e-2
# Scale init
scale_init = 0.1
# Other params
num_classes = 10
batch_size = 32
num_epochs = 50

Building the SPN

Our SPN consists of a leaf node with normal distributions followed by spatial products and sums. A ConvProducts node will generate all possible permutations of the child channels (if possible). A ConvProductsDepthwise will use the subset of permutations that corresponds to depthwise convolutions. Products are in fact implemented as convolutions, since multiplications become sums in the log-space. LocalSums consist of sums that are applied 'locally', without weight sharing, so they are in a sense comparable to LocallyConnected layers in Keras.

Note that after two non-overlapping products (with kernel sizes of $2\times 2$ and strides of $2\times 2$), we have a 'wicker' stack where we use 'full' padding and exponentially increasing dilation rates.

Finally, we apply a ConvProductDepthwise layer with 'wicker_top' padding to get scopes which include all variables at the final layer. This layer can then be connected to class roots, which are in turn connected to a single root node.

In [ ]:
tf.reset_default_graph()
# Leaf nodes

normal_leafs = spn.NormalLeaf(
    num_components=num_leaf_components, num_vars=num_vars, 
    trainable_scale=False, trainable_loc=True, scale_init=scale_init)

# Twice non-overlapping convolutions
x = spn.ConvProducts(normal_leafs, num_channels=32, padding='valid', kernel_size=2, strides=2, spatial_dim_sizes=[28, 28])
x = spn.LocalSums(x, num_channels=32)
x = spn.ConvProductsDepthwise(x, padding='valid', kernel_size=2, strides=2)
x = spn.LocalSums(x, num_channels=32)

# Make a wicker stack
stack_size = int(np.ceil(np.log2(28 // 4)))
for i in range(stack_size):
    dilation_rate = 2 ** i
    x = spn.ConvProductsDepthwise(
        x, padding='full', kernel_size=2, strides=1, dilation_rate=dilation_rate)
    x = spn.LocalSums(x, num_channels=64)
# Create final layer of products
full_scope_prod = spn.ConvProductsDepthwise(
    x, padding='wicker_top', kernel_size=2, strides=1, dilation_rate=2 ** stack_size)
class_roots = spn.ParallelSums(full_scope_prod, num_sums=num_classes)
root = spn.Sum(class_roots)

# Add a IndicatorLeaf node to the root as a latent class variable
class_indicators = root.generate_latent_indicators()

# Generate the weights for the SPN rooted at `root`
spn.generate_weights(root, log=True, initializer=tf.initializers.random_uniform())

print("SPN depth: {}".format(root.get_depth()))
print("Number of products layers: {}".format(root.get_num_nodes(node_type=spn.ConvProducts)))
print("Number of sums layers: {}".format(root.get_num_nodes(node_type=spn.LocalSums)))

Defining the TensorFlow graph

Now that we have defined the SPN graph we can declare the TensorFlow operations needed for training and evaluation. The MPEState class can be used to find the MPE state of any node in the graph. In this case we might be interested in finding the most likely class based on the evidence elsewhere. This corresponds to the MPE state of class_indicators.

Note that for the gradient optimizer we use AMSGrad, which usually yields reasonable results much faster than Adam. Admittedly, more time needs to be spent on the interdependencies of parameters (e.g. scale_init) affect training

In [ ]:
from libspn.examples.convspn.amsgrad import AMSGrad

# Op for initializing all weights
weight_init_op = spn.initialize_weights(root)
# Op for getting the log probability of the root
root_log_prob = root.get_log_value(inference_type=inference_type)

# Set up ops for discriminative GD learning
gd_learning = spn.GDLearning(
    root=root, learning_task_type=spn.LearningTaskType.SUPERVISED,
    learning_method=spn.LearningMethodType.DISCRIMINATIVE)
optimizer = AMSGrad(learning_rate=learning_rate)

# Use post_gradients_ops = True to also normalize weights (and clip Gaussian variance)
gd_update_op = gd_learning.learn(optimizer=optimizer, post_gradient_ops=True)

# Compute predictions and matches
mpe_state = spn.MPEState()
root_marginalized = spn.Sum(root.values[0], weights=root.weights)
marginalized_ivs = root_marginalized.generate_latent_indicators(
    feed=-tf.ones_like(class_indicators.feed)) 
predictions, = mpe_state.get_state(root_marginalized, marginalized_ivs)
with tf.name_scope("MatchPredictionsAndTarget"):
    match_op = tf.equal(tf.to_int64(predictions), tf.to_int64(class_indicators.feed))

Training the SPN

In [ ]:
# Set up some convenient iterators
train_iterator = DataIterator([train_x, train_y], batch_size=batch_size)
test_iterator = DataIterator([test_x, test_y], batch_size=batch_size)

def fd(x, y):
    return {normal_leafs: x, class_indicators: y}

with tf.Session() as sess:
    # Initialize things
    sess.run([tf.global_variables_initializer(), weight_init_op])
    
    # Do one run for test likelihoods
    matches = []
    for batch_x, batch_y in test_iterator.iter_epoch("Testing"):
        batch_matches = sess.run(match_op, fd(batch_x, batch_y))
        matches.extend(batch_matches.ravel())
        test_iterator.display_progress(Accuracy="{:.2f}".format(np.mean(batch_matches)))
    mean_test_accuracy = np.mean(matches)
    
    print("Before training test accuracy = {:.2f}".format(mean_test_accuracy))                              
    for epoch in range(num_epochs):
        
        # Train
        matches = []
        for batch_x, batch_y in train_iterator.iter_epoch("Training"):
            batch_matches, _ = sess.run(
                [match_op, gd_update_op], fd(batch_x, batch_y))
            matches.extend(batch_matches.ravel())
            train_iterator.display_progress(Accuracy="{:.2f}".format(np.mean(batch_matches)))
        mean_train_accuracy = np.mean(matches)
        
        # Test
        matches = []
        for batch_x, batch_y in test_iterator.iter_epoch("Testing"):
            batch_matches = sess.run(match_op, fd(batch_x, batch_y))
            matches.extend(batch_matches.ravel())
            test_iterator.display_progress(Accuracy="{:.2f}".format(np.mean(batch_matches)))
        mean_test_accuracy = np.mean(matches)
        
        # Report
        print("Epoch {}, train accuracy = {:.2f}, test accuracy = {:.2f}".format(
            epoch, mean_train_accuracy, mean_test_accuracy))