Writing Numbers with a GAN

This notebook is an example of using Conditional Generative Adversarial Network (GANs) to generate ‘hand written’ digits.

One of the first things to try after training a neural network to classify the handwritten digits from the MNIST dataset is to ‘drive it backwards’ and use it to generate images of numbers. The idea is to start with an image of random noise then use the optimizer to tweak pixels so the classifier strongly predicts it as being the number you want to generate.

The implementation of this is in Tensorflow is straightforward: just fix the weights and biases, and make the image a training variable rather than a placeholder.

Unfortunately it doesn’t work. The classifier divides 784-dimensional space (one dimension for each of the 28x28 pixels) into 10 categories, one for each digit. This space is large, but the classifier has only seen training examples for the parts of it that look like a number. For the parts where it hasn’t seen training data, the output is basically a random choice of the 10 digits. When we ask the optimizer to find a ‘4’ it will find an image that the classifier strongly believes is a number ‘4’, but it will turn out to be a point somewhere in the backwaters of this 784-dimensional space that doesn’t look like a number ‘4’ to you or I.

One alternative is a thing called a ‘Generative Adversarial Network’ or GAN. Instead of driving a classifier backwards, this sets up a system of a two separate networks called a ‘Generator’ (G) and a ‘Discriminator’ (D).

The Generator is given a random seed and told to ‘generate a number 4’. It does this without ever seeing what a human-written ‘4’ looks like.

The Detector takes the output of the Generator and must determine if this input image was generated by a human or computer.

Two two are trained alternately, with the Generator’s learning to fool the Detector and the Detector learning to distinguish real from fake numbers. If there is any systematic difference between a human and the Generator, then the Detector will learn it. This might be anything from ‘the image is white noise’ to ‘the writing is too neat’. The Generator will then have to learn to generate things that more closely match (in a probability distribution sense) what people draw.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math
import re
import os
import os.path
%matplotlib inline
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

sess = tf.InteractiveSession()
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

The Detector Network

This network takes an image and a one-hot encoding of the digit is is meant to be and returns a probability that it is ‘real’. It is directly derrived from the Google Tensorflow ‘Deap MNIST’ example, and consists of two convolution layers followed by a fully connected (FC) layer then a final ‘decision’ layer that outputs a single fake/real probability.

The information about the which digit was drawn is injected in the first FC layer by concatinating with the 7x7x64 values from the 2nd convolution layer.

def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=2.0/np.sqrt(np.product(shape))) # He et al.
    return tf.Variable(initial, name=name)
def bias_variable(shape, name):
    initial = tf.constant(0.0, shape=shape)
    return tf.Variable(initial, name=name)
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1],
                          strides=[1,2,2,1], padding='SAME')

class Detect(object):
    def __init__(self):

        # First conv. layer
        self.W_conv1 = weight_variable([5,5,1,32], name='W_conv1')
        self.b_conv1 = bias_variable([32], name='b_conv1')

        # 2nd conv. layer
        self.W_conv2 = weight_variable([5,5,32,64], name='W_conv2')
        self.b_conv2 = bias_variable([64], name='b_conv2')

        # Wide FC layer (y is injected here)
        self.W_fc1 = weight_variable([7 * 7 * 64 + 10, 1024], name='W_fc1')
        self.b_fc1 = bias_variable([1024], name='b_fc1')

        # Final decision layer
        self.W_fc2 = weight_variable([1024, 1], name='W_fc2')
        self.b_fc2 = bias_variable([1], name='b_fc2')

        self.all_vars = [
            self.W_conv1, self.b_conv1,
            self.W_conv2, self.b_conv2,
            self.W_fc1, self.b_fc1,
            self.W_fc2, self.b_fc2
        ]
    def calc(self, x, y, keep_prob):
        # First conv. layer
        x_image = tf.reshape(x, [-1, 28, 28,1])
        h_conv1 = tf.nn.relu(conv2d(x_image, self.W_conv1) + self.b_conv1)
        h_pool1 = max_pool_2x2(h_conv1)
        # 2nd conv. layer
        h_conv2 = tf.nn.relu(conv2d(h_pool1, self.W_conv2) + self.b_conv2)
        h_pool2 = max_pool_2x2(h_conv2)

        # Wide FC layer (y is injected here)
        h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
        h_pool2_flat_plus_y = tf.concat(1, (h_pool2_flat, y))
        h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat_plus_y, self.W_fc1) + self.b_fc1)

        h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
        p_fake = tf.sigmoid(tf.matmul(h_fc1_drop, self.W_fc2) + self.b_fc2)
        return tf.clip_by_value(p_fake, 1e-9, 1-1e-9)

The Generator network

My initial implementation of this was an flipped copy of the Tensorflow Deap MNIST/Detector network. However I wasn’t able to train this successfully, and instead used a shallower network that has a single FC layer that takes the random seed and the one-hot encoded number to produce followed by 2 inverted convolution layers.

class GenerateSmaller(object):
    """A simpler network that I was able to train successfully"""
    def __init__(self, y, noise, keep_prob, batch_size):
        W_fc1 = weight_variable([10 + 2, 7*7*64], name='G_W_fc1')
        B_fc1 = bias_variable([7*7*64], name='G_B_fc1')
        h_fc1 = tf.nn.relu(tf.matmul(tf.concat(1, (y,noise)), W_fc1) + B_fc1)

        h_img = tf.reshape(h_fc1, [-1, 7, 7, 64])

        W_conv = weight_variable([5, 5, 32, 64], name="G_W_conv")
        B_conv = bias_variable([32], name="G_B_conv")
        h2_img = tf.nn.relu(
            tf.nn.conv2d_transpose(
                value=h_img,
                filter=W_conv,
                output_shape=[batch_size, 28, 28, 32],
                strides=[1, 4, 4, 1])
            + B_conv)

        W_conv2 = weight_variable([3, 3, 1, 32], name="G_W_conv2")
        B_conv2 = bias_variable([1], name="G_B_conv2")

        x_img = tf.nn.conv2d_transpose(
                value=h2_img,
                filter=W_conv2,
                output_shape=[batch_size, 28, 28, 1],
                strides=[1, 1, 1, 1]) + B_conv2

        self.x = tf.clip_by_value(tf.sigmoid(tf.reshape(x_img, [-1, 28*28,1])) * 1.2 - 0.1, 0.0, 1.0)
        self.all_vars = [
            W_fc1, B_fc1,
            W_conv, B_conv,
            W_conv2, B_conv2
        ]

Training the networks

For training, there are two copies of the detector network. One of these is always fed real images, and the other always gets its input from the Generator. They can’t ‘cheat’ and use this information because they share the same variables :)

The training process takes a few days on my PC, and so checkpointing is used to recover from crashes.

checkpoint_dir = 'digits-with-gan-checkpoints'
if not os.path.exists(checkpoint_dir):
    os.mkdir(checkpoint_dir)

x = tf.placeholder(tf.float32, [None, 784], name='x_real')
y = tf.placeholder(tf.float32, [None, 10], name='y')

keep_prob = tf.placeholder(tf.float32, name='keep_prob')

detect = Detect()
detect_on_real = detect.calc(x, y, keep_prob)

noise = tf.placeholder(tf.float32, [None, 2], name='noise')
batch_size = 10
generate = GenerateSmaller(y, noise, keep_prob, batch_size=batch_size)
detect_on_generate = detect.calc(generate.x, y, keep_prob)

detect_surprise = tf.reduce_mean(
    - tf.log(1-detect_on_real) # On real data it should predict 'not fake' i.e. close to zero
    - tf.log(detect_on_generate), # On fake data it should predict 'fake', i.e close to one
    reduction_indices=[1])

detect_fake_surprise = tf.reduce_mean(-tf.log(detect_on_generate), reduction_indices=[1])

detect_optimiser = tf.train.AdamOptimizer(1e-4).minimize(detect_surprise, var_list=detect.all_vars)

generate_optimiser = tf.train.AdamOptimizer(1e-4).minimize(-detect_fake_surprise, var_list=generate.all_vars)

generate_error_history = []
detect_error_history = []
saver = tf.train.Saver()
last_chpkt = tf.train.latest_checkpoint(checkpoint_dir)
if last_chpkt:
    m = re.compile('.+-([0-9]+)').match(last_chpkt)
    global_step = int(m.group(1))
    saver.restore(sess, last_chpkt)
    print "Restoring from %s global_step:%d" % (last_chpkt, global_step)
else:
    sess.run(tf.initialize_all_variables())
    global_step = 0
Restoring from digits-with-gan-checkpoints/ckpt-303000 global_step:303000

One of the problems I has training this model was that the training would get stuck in a local minima, with either the detector or the generator ‘winning’ every time. In this case the losing side never manages to improve itself.

In order to avoid that, when the rate is a long way from 50%, I stopped training the ‘winner’ to let the other side catch up. This was a ‘made up’ solution: I haven’t seen it in the literature before.

optimise_detector = True
optimise_generator = True
while global_step <= 1000*1000:
    batch = mnist.train.next_batch(batch_size)

    if optimise_generator:
        generate_optimiser.run(feed_dict={
            noise:np.random.randn(batch_size, 2),
            y: batch[1],
            keep_prob:1.0,
        })
    if optimise_detector:
        detect_optimiser.run(feed_dict={
            keep_prob: .5,
            x: batch[0],
            y: batch[1],
            noise:np.random.randn(batch_size, 2)
        })
    global_step += 1
    if global_step % 100 == 0:
        error1 = np.mean(sess.run(detect_surprise, feed_dict={
            keep_prob: 1.0,
            noise:np.random.randn(batch_size, 2),
            x: batch[0],
            y: batch[1],
        }))
        error2 = np.mean(sess.run(detect_fake_surprise, feed_dict={
            keep_prob: 1.0,
            noise:np.random.randn(batch_size, 2),
            x: batch[0],
            y: batch[1]
        }))
        for v in detect.all_vars + generate.all_vars:
            a = sess.run(tf.reduce_max(tf.abs(v)))
            if 1e9 < a:
                print "%s => %f" % (v.name, a)
        if math.isnan(error1) or math.isnan(error2):
            print "Aborting on NAN Detected"
        detect_error_history.append(error1)
        generate_error_history.append(error2)
        overall_error = (error1 + error2)/2
        optimise_generator = overall_error < 0.8
        optimise_detector = 0.3 < overall_error

        if global_step % 1000 == 0:
            print "%d Error is %f %f" % (global_step, error1, error2)
            save_path = saver.save(sess, os.path.join(checkpoint_dir, 'ckpt'), global_step=global_step)
            print 'Model saved in file: ', save_path
304000 Error is 0.429973 0.103284
Model saved in file:  digits-with-gan-checkpoints/ckpt-304000
...
999000 Error is 0.359853 0.121428
Model saved in file:  digits-with-gan-checkpoints/ckpt-999000
1000000 Error is 0.732500 0.212447
Model saved in file:  digits-with-gan-checkpoints/ckpt-1000000

Finally, let’s use the generator to create some numbers:

plt.figure().set_size_inches(8, 8)
for n in xrange(10):
    target = [0] * 10
    target[n] = 1.0
    aa = sess.run(generate.x, feed_dict= {
        y: [target]*batch_size,
        noise: np.random.randn(batch_size, 2),
        keep_prob: 1.0
    })
    for i in range(10):
        plt.subplot(10,10,i*10 + n + 1)
        img = np.reshape(aa[i], (28,28))
        a = plt.imshow(img)
        a.set_cmap('gray')
        plt.axis('off')
# plt.savefig("numbers.png", dpi = 72)
_images/digits-with-gan_12_0.png

Also, we can interpolate between two different numbers. For example what number is half way between 7 and 9?

noise_data = np.random.randn(batch_size, 2)/2
plt.figure().set_size_inches(8, 8)
for n in xrange(10):
    target = [0] * 10
    target[7] = 1.0 - n/9.
    target[9] = n/9.
    aa = sess.run(generate.x, feed_dict= {
        y: [target]*batch_size,
        noise: noise_data,
        keep_prob: 1.0
    })
    for i in range(10):
        plt.subplot(10,10,i*10 + n + 1)
        img = np.reshape(aa[i], (28,28))
        a = plt.imshow(img)
        a.set_cmap('gray')
        plt.axis('off')
_images/digits-with-gan_14_0.png

Since we can interpolate between numbers, let’s use that to give a new take on a digital clock where the numbers cross fade in the latent ‘numberness’ space that we’ve just trained:

noise_data = np.zeros((batch_size,2))
imgn = 0
ys = []
for n in xrange(60*10):
    target = [0] * 10
    ease = 0.5 - np.cos(np.pi * (n%60)/60.)/2
    target[n/60] = 1.0 - ease
    target[(n/60+1)%10] = ease
    ys.append(target)
    if (len(ys) == batch_size):
        aa = sess.run(generate.x, feed_dict= {
            y: ys,
            noise: noise_data,
            keep_prob: 1.0
        })
        for i in range(batch_size):
            img = np.reshape(aa[i], (28,28))
            matplotlib.image.imsave('nn-%03d.png' % imgn, img, vmin=0.0, vmax=1.0, cmap=matplotlib.image.cm.gray)
            imgn += 1
        ys = []
# mplayer mf://nn*.png -fs -zoom -loop 0 -vo gl -fps 60
import matplotlib

matplotlib.image.imsave('name.png', array)
saver.restore(sess, "digits-with-gan-checkpoints/ckpt-529000")
for v in detect.all_vars + generate.all_vars:
    a = sess.run(tf.reduce_max(tf.abs(v)))
    print "%s => %f" % (v.name,a)
W_conv1:0 => 0.149732
b_conv1:0 => 0.027031
W_conv2:0 => 0.063758
b_conv2:0 => 0.021519
W_fc1:0 => 0.130650
b_fc1:0 => 0.024133
W_fc2:0 => 0.136678
b_fc2:0 => 0.015503
G_W_fc1:0 => 0.279702
G_B_fc1:0 => 0.397946
G_W_conv:0 => 0.456417
G_B_conv:0 => 0.249894
G_W_conv2:0 => 0.295345
G_B_conv2:0 => 0.054249
plt.figure().set_size_inches(8, 8)
noise_data = np.random.randn(2)
for n in xrange(10):
    target = [0] * 10
    target[4] = n/4.
    aa = sess.run(generate.x, feed_dict= {
        y: [target]*batch_size,
        noise: noise_data,
        keep_prob: 1.0
    })
    for i in range(10):
        plt.subplot(10,10,i*10 + n + 1)
        img = np.reshape(aa[i], (28,28))
        a = plt.imshow(img)
        a.set_cmap('gray')
        plt.axis('off')