Tuesday, 18 June 2019

Generative Adversarial Networks - Part V

This is the Part 5 of a short series of posts introducing and building generative adversarial networks, known as GANs.

In this post we'll learn about a different architecture called a conditional GAN which enables us to direct the GAN to produce images of a class that we want, rather than images of a random class.


Previously:
  • Part 1 introduced the idea of adversarial learning and we started to build the machinery of a GAN implementation.
  • Part 2 we extended our code to learn a simple 1-dimensional pattern 1010.
  • Part 3 we developed our code to learn to generate 2-dimensional grey-scale images that look like handwritten digits
  • Part 4 extended our code to learn full colour faces, and also developed convolutional networks to encourage learning localised image features


Controlling What A Gan Creates

In parts 3 and 4 of this series, we trained a GAN on data that contained unique and diverse images. Each handwritten digit in the MNIST dataset is different, and each face in the CelebA data set is unique.

When we use the trained generator to create an image, we have no control over what kind of image it creates. All we know is that the image will be plausible enough to get past the discriminator.


We can't ask the generator from part 3 to create a 7 or a 9 for us. All we can do is feed the generator a random vector of numbers as input and see what image pops out.

If we experiment with that random input, in an attempt to control what comes out of the generator, we find that it doesn't sufficient, if any, control over the output.

Is is possible to train the generator so that we can influence the output?

The answer is yes, and that is what a conditional GAN architecture aims to do.


Conditional GAN Architecture

The following picture shows the conditional GAN architecture.


You can see that both the generator and discriminator are provided with additional information about the image. For us, this additional information can be a label, such as the digit an MNIST image represents.

It is not immediately clear how this helps. Let's break it down:
  • The discriminator can use the label to improve how it identifies whether an image is real or fake. How it does this is up to itself, additional information can only help. Without the label, the discriminator has a set amount of information on which to make the decision. With the label, it has additional information. 
  • The generator can learn to associate the label with the image it generates. It doesn't have to - it could choose to ignore the additional information. But the generator learns by getting feedback from the discriminator, which has learned to associate the label with an image, so the generator is encouraged to make this association too by generating images that match the image-label pair the discriminator sees from the training set.


Training A Conditional GAN

The training loop is unchanged from a vanilla GAN. The only difference is the additional information appended to the inputs to the generator and the discriminator:

  • The discriminator is shown a real image from the training dataset, as well as that image's label. It is trained to output a 1 for real.
  • The discriminator is shown a fake image from the generator together with its label, and is trained to output a 0 for fake.
  • The generator is trained to cause the discriminator to output a 1 for real.

The labels associated with the real images are part of the training data.

The labels associated with the generator are randomly chosen one-hot vectors of the same length as the labels in the training dataset. We just need to make sure that this randomly chosen label remains the same when fed into the generator as part of the seed, and when associated with the generated image for the discriminator to test. We can't have a different label for these two parts of the training.

When feeding the generator, the one-hot label vectors are combined with the random seed by concatenating the tensors like this:


def forward(self, noise_tensor, label_tensor):
    # combine image and label
    inputs = torch.cat((noise_tensor, label_tensor),1)
    
    # simply run model
    return self.model(inputs)


Similarly, when feeding the discriminator the one-hot label vectors are combined with the image data like this:


def forward(self, image_tensor, label_tensor):
        # combine image and label
        inputs = torch.cat((image_tensor.view(1, 784), label_tensor),1)
        
        # simply run model
        return self.model(inputs)


The following shows code for the training loop:


# train Discriminator and Generator

epochs = 12

for i in range(epochs):
    print('training epoch', i+1, "of", epochs)
    
    for label, image_data_tensor, target_tensor in mnist_dataset:
      
        # train discriminator on real data
        D.train(image_data_tensor.view(1, 1, 28, 28), target_tensor, torch.cuda.FloatTensor([1.0]).view(1,1))
        
        # random 1-hot label for generator
        random_target_tensor = generate_random_target(10)

        # train discriminator on false
        # use detach() so only D is updated, not G
        # label softening doesn't apply to 0 labels
        D.train(G.forward(generate_random(100).view(1,100), random_target_tensor).detach(), random_target_tensor, torch.cuda.FloatTensor([0.0]).view(1,1))
        
        # random 1-hot label for generator
        random_target_tensor = generate_random_target(10)
        
        # train generator
        G.train(D, generate_random(100).view(1,100), random_target_tensor, torch.cuda.FloatTensor([1.0]).view(1,1))
        
        pass
    
    pass


The full code is online:



Results

The results of training should be a generator that can create images of a desired class by providing it with the label as well as the normal random seed. So feeding the generator a label of 1 should result in images that look like a hand-drawn 1.

The following shows the results of 12 epochs of training.


The zeros at the top left are produced by feeding the trained generator a random seed augmented with a one-hot vector corresponding to the label 0, which would be 1000000000.

We can see that for each input label, the generator does indeed produce images of that label.

The following shows the results for 24 epochs of training:


The quality of the digits has improved.

As an experiment to see how important the labels are to training, the following set of results are from the same code but with the one-hot vector to the discriminator set to 0000000000.


We can see two things:

  • the generator no longer creates images of the desired class
  • the image quality overall is lower than without the label

This shows that it is important for the discriminator to learn the association between an image and its class, for it to then feed the generator useful gradients to learn from. The lower quality is likely a result of the fact that we have, in effect, an enlarged image to learn which means longer training time or perhaps a more efficient neural network design, referred to as the hyper-parameters.


Experimenting With Input Labels

Let's see what happens when we use input labels to the trained generator that are not 1-hot but have several elements activated?

We can use the plot_images() generator method to activate more than one location by supplying a tuple of labels. The following sets the input vector to be [0, 0, 0, 0, 0, 0, 1, 0, 0, 1].


G.plot_images((6,9))


The resulting images are shapes which are intermediate between 6 and 9.


This is interesting as it shows that we can manipulate the input vector in ways that have a visual meaning.

The following shows the results for G.plot_images((3, 5)).


That also broadly works. Let's try a more challenging combination, G.plot_images((1, 7)).


The results are understandable as it is hard to find a shape that is both 1 and 7 in nature.


Conclusion

If we think about these results, they're quite impressive.

We've managed to not only to train a GAN to generate plausible images, where the generator has not directly seen the training data, we've also managed to control the class of image being generated by associating the learned representation with a label we provide.

Previously the learned representation was entangled, and it was difficult to induce the generator to produce an image of the class we wanted just by manipulating the random seed.

We also saw how we can manipulate the input vector to create images which have shapes representing combinations of more than one class. 


More Reading



No comments:

Post a Comment