Introduction

  • Generative Adversarial Networks (GANs)
  • How GANs Work
  • GANs Process
  • Examples
  • Generating Hand-Written digits

Generative Adversarial Networks (GANs)

Generative Adversarial Networks are used to generate images that never existed before. They learn about the world (objects, animals and so forth) and create new versions of those images that never existed.

They have two components:

  • A Generator - this creates the images.
  • A Discriminator - this assesses the images and tells the generator if they are similar to what it has been trained on. These are based off real world examples.

When training the network, both the generator and discriminator start from scratch and learn together.

How GANs Work

G for Generative - this is a model that takes an input as a random noise singal and then outputs an image.

A for Adversarial - this is the discriminator, the opponent of the generator. This is capable of learning about objects, animals or other features specified. For example: if you supply it with pictures of dogs and non-dogs, it would be able to identify the difference between the two.

Using this example, once the discriminator has been trained, showing the discriminator a picture that isn't a dog it will return a 0. Whereas, if you show it a dog it will return a 1.

N for Network - meaning the generator and discriminator are both neural networks.

GANs Process

Step 1 - we input a random noise signal into the generator. The generator creates some images which is used for training the discriminator. We provide the discriminator with some features/images we want it to learn and the discriminator outputs probabilities. These probabilities can be rather high as the discriminator has only just started being trained. The values are then assessed and identified. The error is calculated and these are backpropagated through the discriminator, where the weights are updated.

Next we train the generator. We take the batch of images that it created and put them through the discriminator again. We do not include the feature images. The generator learns by tricking the discriminator into it outputting false positives.

The discriminator will provide an output of probabilities. The values are then assessed and compared to what they should have been. The error is calculated and backpropagated through the generator and the weights are updated.

Step 2 - This is the same as step 1 but the generator and discriminator are trained a little more. Through backpropagation the generator understands its mistakes and starts to make them more like the feature.

This is created through a Deconvolutional Neural Network.

Examples

GANs can be used for the following:

  • Generating Images
  • Image Modification
  • Super Resolution
  • Assisting Artists
  • Photo-Realistic Images
  • Speech Generation
  • Face Ageing
It’s Training Cats and Dogs: NVIDIA Research Uses AI to Turn Cats Into Dogs, Lions and Tigers, Too

Cats into Dogs

Hourses into Zebras

Generating Hand-Written digits

Let's strat by importing some useful packages

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Fixing Seed for Reproducibility

import warnings
warnings.filterwarnings('ignore')

Let's create a visualizer function to see the input and output data.

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

MNIST Dataset

The training images your discriminator will be using is from a dataset called MNIST. It contains 60,000 images of handwritten digits, from 0 to 9, like these:

MNIST

Discriminator vs. Generator

The first thing that we will get to explore is the difference between a discriminator and a generator. Remember that these are the two main components of a GAN!

What is a discriminator?

One of the most widely used type of ML models is the classifier, which is used to sift through items in a dataset and classify them into different categories. For example, you might be familiar with image classifiers that discriminate between images of a cat and images of a dog. For a well-trained classifier: when you give it an image of a cat, it will say "cat"! When you give it an image of a dog, it will say "dog"! You can also use other classes, like coconuts vs. starfruit.

Discriminative Model

Image Credit:Google In terms of probabilities, the classifier wants to find $p(y|x)$: the probability that given an image input $x$, the image class $y$ is cat, $p(y=\text{cat}|x)$, or dog, $p(y=\text{dog}|x)$.

The discriminator is simply a classifier with two classes: real and fake. Given an input x, the discriminator will calculate the probabilities $p(y=\text{real}|x)$ and $p(y=\text{fake}|x)$ and classify $x$. The input $x$ can be anything that you have the generator create and is not limited to images. Your GAN can be trained on videos, text, audio, etc.

What is a generator?

Generators are designed to have a different goal from discriminators (classifiers). Imagine you're working at a tropical fruit stand and asked to sort the fruit into two categories: coconuts and starfruit. That's the job of a classifier. But what if a customer comes up to the stand, and asks: what is a starfruit? You can't just say that it's not a coconut. You would need to explain what makes something a starfruit and what doesn't, not just its differences from a coconut. That's the job of a generator: to represent different classes in general, not just distinguish them.

In terms of probabilities, the generator wants to figure out $p(x|y)$: the probability that, given that you generated a starfruit $(y=\text{starfruit})$, the resulting image $(x)$ is the one generated. The output space of possible starfruit images is huge, so that makes this challenging for the generator.

This can be a much harder task than discrimination. Typically, you will need the generator to take multiple steps to improve itself for every step the discriminator takes. It's easy to tell the difference between a coconut and a starfruit when you look at a mix of them, but to know exactly all the features of all possible coconuts in the world? That's a lot, but it's really cool if you can get even close to it, because you can start generating all sorts of coconuts and starfruit when you do. I don't have a generator for these tropical fruits figured out for you to play with, but you can generate all kinds of cool things with these models, like realistic faces!

In the below image, the generator is trying to find the features that represent all cats using the feedback from the discriminator.

Generative Model

Image Credit:TensorFlow

Generator

The first step is to build the generator component.

We will start by creating a function to make a single layer/block for the generator's neural network. Each block should include a linear transformation to map to another shape, a batch normalization for stabilization, and finally a non-linear activation function (ReLU) so the output can be transformed in complex ways.

def generator_block(input_dim, output_dim):
    '''
    Function for returning a block of the generator's neural network
    given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a generator neural network layer, with a linear transformation 
          followed by a batch normalization and then a relu activation
    '''
    return nn.Sequential(
        # Linear Transformation Layer
        nn.Linear(input_dim, output_dim),
        # BatchNorm1d for stablizing the training process
        nn.BatchNorm1d(output_dim),
        # Add some non-linearity effect
        nn.ReLU(inplace=True)
    )

Now, we can build the generator class. It will take 3 values:

  • The noise vector dimension
  • The image dimension
  • The initial hidden dimension

Using these values, the generator will build a neural network with 5 layers/blocks. Beginning with the noise vector, the generator will apply non-linear transformations via the block function until the tensor is mapped to the size of the image to be outputted (the same size as the real images from MNIST). The final layer does not need a normalization or activation function, but does need to be scaled with a sigmoid function.

Finally, you are given a forward pass function that takes in a noise vector and generates an image of the output dimension using your neural network.

class Generator(nn.Module):
    '''
    Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scalar
        im_dim: the dimension of the images, fitted for the dataset used, a scalar
          (MNIST images are 28 x 28 = 784 so that is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        # Building a Generator Neural Network
        # 6 - Generator Blocks (Previously Built)
        self.gen = nn.Sequential(
            generator_block(z_dim, hidden_dim), 
            generator_block(hidden_dim, hidden_dim * 2), 
            generator_block(hidden_dim * 2, hidden_dim * 4), 
            generator_block(hidden_dim * 4, hidden_dim * 8), 
            generator_block(hidden_dim * 8, hidden_dim * 16), 
            generator_block(hidden_dim * 16, hidden_dim * 8), 
            # Ouptput Layer = Linear Transformation + Sigmoid activation
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
            )
        
    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        return self.gen(noise)

Noise

In order to use your generator, you will need to be able to create noise vectors. The noise vector z has the important role of making sure the images generated from the same class don't all look the same -- It's like a random seed. We will generate it randomly using PyTorch by sampling random numbers from the normal distribution. Since multiple images will be processed per pass, we will generate all the noise vectors at once.

def get_noise(n_samples, z_dim, device='cuda'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim),
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    '''
    return torch.randn((n_samples, z_dim), device=device)

Discriminator

The second component that you need to construct is the discriminator. As with the generator component, We will start by creating a function that builds a neural network block for the discriminator.

Note: We use leaky ReLUs to prevent the "dying ReLU" problem.

def discriminator_block(input_dim, output_dim):
    '''
    Discriminator Block
    Function for returning a neural network of the discriminator given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a discriminator neural network layer, with a linear transformation 
          followed by an nn.LeakyReLU activation with negative slope of 0.2 
          (https://pytorch.org/docs/master/generated/torch.nn.LeakyReLU.html)
    '''
    
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(negative_slope=0.2, inplace=True)
    )

Now we can use these blocks to make a discriminator! The discriminator class holds 2 values:

  • The image dimension
  • The hidden dimension

The discriminator will build a neural network with 4 layers. It will start with the image tensor and transform it until it returns a single number (1-dimension tensor) output. This output classifies whether an image is fake or real. Note that you do not need a sigmoid after the output layer since it is included in the loss function. Finally, to use your discrimator's neural network you are given a forward pass function that takes in an image tensor to be classified.

class Discriminator(nn.Module):    
    '''
    Discriminator Class
    Values:
        im_dim: the dimension of the images, fitted for the dataset used, a scalar
            (MNIST images are 28x28 = 784 so that is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # 3 - Discriminator Blocks
            discriminator_block(im_dim, hidden_dim * 4),
            discriminator_block(hidden_dim * 4, hidden_dim * 8),
            discriminator_block(hidden_dim * 8, hidden_dim * 4),
            discriminator_block(hidden_dim * 4, hidden_dim * 2),
            discriminator_block(hidden_dim * 2, hidden_dim),

            # Adding a linear output
            nn.Linear(hidden_dim , 1)

        )
        

    def forward(self, image):  
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_dim)
        ''' 
        return self.disc(image)
    

Training

First, we will set your parameters:

  • criterion: the loss function
  • n_epochs: the number of times you iterate through the entire dataset when training
  • z_dim: the dimension of the noise vector
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • device: the device type, here using a GPU (which runs CUDA), not CPU

Next, we will load the MNIST dataset as tensors using a dataloader.

criterion = nn.BCEWithLogitsLoss()
n_epochs = 500
z_dim = 64
display_step = 10000
batch_size = 64
lr = 0.00001

# Load MNIST dataset as tensors
dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

### Set the device for training
device = 'cuda'
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Now, we can initialize the generator, discriminator, and optimizers. Note that each optimizer only takes the parameters of one particular model, since we want each optimizer to optimize only one of the models.

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

Before we train our GAN network, we will need to create functions to calculate the discriminator's loss and the generator's loss. This is how the discriminator and generator will know how they are doing and improve themselves.

def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    '''
    Return the loss of the discriminator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        real: a batch of real images
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        disc_loss: a torch scalar loss value for the current batch
    '''
    # Create noise vectors and generate a batch (num_images) of fake images.
    z_noise = get_noise(num_images, z_dim, device)
    fake_imgs = gen(z_noise)
    
    # Get the discriminator's prediction of the fake image and calculate the loss
    disc_fake_pred = disc(fake_imgs.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    
    # Get the discriminator's prediction of the real image and calculate the loss.
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    
    # Calculate the discriminator's loss by averaging the real and fake loss
    disc_loss = (disc_real_loss + disc_fake_loss)/2
    
    return disc_loss
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        gen_loss: a torch scalar loss value for the current batch
    '''
    # Create noise vectors and generate a batch of fake images.
    z_noise = get_noise(num_images, z_dim, device)
    fake_imgs = gen(z_noise)
    
    # Get the discriminator's prediction of the fake image.
    disc_fake_pred = disc(fake_imgs)
    
    # Calculate the generator's loss (Note: It compares the generator output of fake images with the real ones)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False
for epoch in range(n_epochs):
  
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        # For testing purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()

        # For testing purposes, to check that your code changes the generator weights
        if test_generator:
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print("Runtime tests have failed")

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1
Epoch 10, step 10000: Generator loss: 5.4047161012649445, discriminator loss: 0.05443974889267115
Epoch 21, step 20000: Generator loss: 6.291417294359194, discriminator loss: 0.05477201426360978
Epoch 31, step 30000: Generator loss: 3.850531491291531, discriminator loss: 0.15501205386193492
Epoch 42, step 40000: Generator loss: 2.3283984214425137, discriminator loss: 0.23069154401868655
Epoch 53, step 50000: Generator loss: 2.0571084298729923, discriminator loss: 0.2636471902921786
Epoch 63, step 60000: Generator loss: 1.6916157222747796, discriminator loss: 0.3298123072624213
Epoch 74, step 70000: Generator loss: 1.479275153839581, discriminator loss: 0.3825816682457922
Epoch 85, step 80000: Generator loss: 1.3958063300907644, discriminator loss: 0.4082342538863425
Epoch 95, step 90000: Generator loss: 1.4873849587976908, discriminator loss: 0.38783631078004704
Epoch 106, step 100000: Generator loss: 1.5395997340500345, discriminator loss: 0.3742863032758239
Epoch 117, step 110000: Generator loss: 1.500580434679985, discriminator loss: 0.3821151517182578
Epoch 127, step 120000: Generator loss: 1.3551663604080733, discriminator loss: 0.4202767654120942
Epoch 138, step 130000: Generator loss: 1.3603083747446494, discriminator loss: 0.41840620750039575
Epoch 149, step 140000: Generator loss: 1.3575507786154786, discriminator loss: 0.42234055254906416
Epoch 159, step 150000: Generator loss: 1.4248633385539051, discriminator loss: 0.40728020892888367
Epoch 170, step 160000: Generator loss: 1.47160139136314, discriminator loss: 0.394411649782956
Epoch 181, step 170000: Generator loss: 1.4830472575783757, discriminator loss: 0.3922100943788883
Epoch 191, step 180000: Generator loss: 1.4683524356603588, discriminator loss: 0.40119695761352636
Epoch 202, step 190000: Generator loss: 1.4103956043541401, discriminator loss: 0.4197945091888331
Epoch 213, step 200000: Generator loss: 1.3819597438156572, discriminator loss: 0.423482037006321
Epoch 223, step 210000: Generator loss: 1.4246015510618661, discriminator loss: 0.41522652971297497
Epoch 234, step 220000: Generator loss: 1.4182038937389847, discriminator loss: 0.41454712654650205
Epoch 245, step 230000: Generator loss: 1.4623772296845878, discriminator loss: 0.40369796157926335
Epoch 255, step 240000: Generator loss: 1.4457483284533021, discriminator loss: 0.4046686042502515
Epoch 266, step 250000: Generator loss: 1.38512691426873, discriminator loss: 0.4234224717453112
Epoch 277, step 260000: Generator loss: 1.352188807457682, discriminator loss: 0.4336424077853562
Epoch 287, step 270000: Generator loss: 1.3844688882827767, discriminator loss: 0.41897474683076086
Epoch 298, step 280000: Generator loss: 1.3627459356606006, discriminator loss: 0.4291179351970558
Epoch 309, step 290000: Generator loss: 1.3106346433222356, discriminator loss: 0.44792365804612694
Epoch 319, step 300000: Generator loss: 1.3448638754308249, discriminator loss: 0.4394607469886559
Epoch 330, step 310000: Generator loss: 1.3012820241808845, discriminator loss: 0.4524194141492234
Epoch 341, step 320000: Generator loss: 1.344684604495768, discriminator loss: 0.4374074421599516
Epoch 351, step 330000: Generator loss: 1.2217702299892876, discriminator loss: 0.47974520444720825
Epoch 362, step 340000: Generator loss: 1.231750753808027, discriminator loss: 0.4726514756649734
Epoch 373, step 350000: Generator loss: 1.2394801887631413, discriminator loss: 0.4716021126791837
Epoch 383, step 360000: Generator loss: 1.2483579675018717, discriminator loss: 0.4689559111356739
Epoch 394, step 370000: Generator loss: 1.2837351714789904, discriminator loss: 0.4569161177024241
Epoch 405, step 380000: Generator loss: 1.294854753339291, discriminator loss: 0.4585107137933375
Epoch 415, step 390000: Generator loss: 1.2179591452300491, discriminator loss: 0.47934117802828724
Epoch 426, step 400000: Generator loss: 1.2247088456273103, discriminator loss: 0.4760446860566737
Epoch 437, step 410000: Generator loss: 1.1542475645661356, discriminator loss: 0.5038219578146936
Epoch 447, step 420000: Generator loss: 1.0992902897894385, discriminator loss: 0.5215150975674391
Epoch 458, step 430000: Generator loss: 1.0931267674624914, discriminator loss: 0.5264235710889109
Epoch 469, step 440000: Generator loss: 1.1034114745616896, discriminator loss: 0.5200620216712355
Epoch 479, step 450000: Generator loss: 1.0829909142255831, discriminator loss: 0.5259332243412744
Epoch 490, step 460000: Generator loss: 1.0734401089608672, discriminator loss: 0.5309209023088215