GANs on Hand-Written Digits Dataset

Aftab Gazali
12 min readJan 15, 2023

Gans stands for Generative Adversarial Networks. It is a type of deep learning model that is composed of two neural networks: a generator and a discriminator. The generator creates new data samples, while the discriminator attempts to distinguish the generated samples from real samples. The two networks are trained simultaneously, with the generator trying to create samples that can fool the discriminator, and the discriminator tries to correctly identify the generated samples. GANs have been used for a variety of tasks, including image synthesis, text generation, and anomaly detection.

There are several different types of GANs, each with its own unique characteristics and use cases. Some of the most popular types include:

  1. Vanilla GAN: This is the original GAN architecture, consisting of a generator and a discriminator network.
  2. DCGAN (Deep Convolutional GAN): This type of GAN uses convolutional layers in both the generator and discriminator networks, allowing it to generate high-resolution images.
  3. WGAN (Wasserstein GAN): This type of GAN replaces the traditional GAN loss function with the Wasserstein distance, which is more stable during training and produces less mode collapse.
  4. LSGAN (Least Squares GAN): This type of GAN uses a least squares loss function instead of the traditional binary cross-entropy loss, resulting in a more stable training process.
  5. CycleGAN: This type of GAN is able to translate an image from one domain to another, for example, converting a horse picture to a zebra picture.
  6. BEGAN (Boundary Equilibrium Generative Adversarial Networks): This type of GAN uses an equilibrium objective that helps to stabilize the training process, while also providing a measure of the quality of the generated images.
  7. BigGAN: This type of GAN is a large Generative Adversarial Network that generates high-resolution images with fine details and high quality.

These are some of the most popular types of GANs, but new variations and architectures are being proposed regularly.

Kindly refer to the links given below for a more in-depth explanation,

Dataset Introduction

A handwritten digits dataset is a collection of images of handwritten digits, along with their corresponding labels (i.e., the correct digit). These datasets are commonly used for training and evaluating machine learning algorithms, particularly in the field of image recognition and computer vision. One of the most popular datasets of this type is the MNIST dataset, which consists of 60,000 training images and 10,000 test images of handwritten digits, along with their corresponding labels. We are going to use this dataset for training our GAN Model. The images in this dataset are 28x28 pixels and grayscale images, and it is widely used as a benchmark dataset for testing and comparing the performance of different machine learning algorithms. Other examples of handwritten digits datasets include the USPS dataset, which contains images of USPS postal codes, and the KMNIST dataset, which is a more diverse set of handwritten digits than the MNIST dataset.

from matplotlib import pyplot 
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# plot images from the training dataset
fig, ax = pyplot.subplots(ncols = 5,nrows=5,figsize=(20,20))
for i in range(5):
for j in range(5):
ax[i][j].imshow(trainX[(i+1)*(j+1)-1], cmap='gray_r')
HandWritten Dataset

As we can see from the above-given figure the dataset consists of Hand-Written digits images numbering from 0 to 9.

Preprocessing

  1. Feature Scaling: Feature Scaling is an important step of data preprocessing, If feature scaling is not performed, a machine learning algorithm would consider larger values to be higher and smaller values to be lower, regardless of the unit of measurement. The scaler used here is the Traditional MinMax Scaler. Minmax scaler Transform features by scaling each feature to a given range[0,1]
  2. Expand Dims: In the python NumPy library, numpy.expand_dims is a function that inserts a new axis with size 1. It takes in the input array and the axis position where the new axis is to be inserted. It returns a new array with the new dimension inserted.
def load_real_samples():
# load mnist dataset
(trainX, _), (_, _) = load_data()
# expand to 3d, e.g. add channels dimension
X = np.expand_dims(trainX, axis=-1)
# convert from unsigned ints to floats
X = X.astype('float32')
# scale from [0,255] to [0,1]
X = X / 255.0
return X

Before Building the GAN Model we must discuss what we are looking at, at the end of this project we are going to build a Generator Model in such a way that after passing any random latent variable the generator must be capable of producing digits images. Let’s look into the Working of GAN Model

Working of GANs

Working of GANs [5]

Gans consists of two models that is competing simultaneously with each other: a generator and a discriminator. As discussed above the generator takes in a random latent sample and generates the image this image is then passed to the discriminator which is trained with the training dataset and hence already knows which images are real and which images are fake. The task of the discriminator is to identify whether the generated image is either real or fake. The two networks are trained simultaneously, with the generator trying to produce synthetic data that can fool the discriminator, and the discriminator tries to correctly identify whether each piece of data is real or fake.

As training progresses, the generator improves at producing synthetic data that looks more and more like real-world data, while the discriminator becomes better at distinguishing the synthetic data from the real data. During training, the generator and discriminator are both updated based on the performance of the other network. The generator is updated to produce more realistic synthetic data, while the discriminator is updated to better distinguish between real and fake data. This process continues until the generator is able to produce synthetic data that is indistinguishable from the real data, at which point the GAN is considered to be trained.

After training, the generator network can be used to generate new synthetic data by inputting random noise and generating the corresponding synthetic data. The synthetic data produced by the GAN is intended to be similar to the real-world data that the GAN was trained on.

Before Training the Discriminator Model we must first look at the most important step, labeling real & fake samples.

Labeling Real & Fake Samples

def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# retrieve selected images
X = dataset[ix]
# generate ✬real✬ class labels (1)
y = ones((n_samples, 1))
return X, y

From the above snippet, we can see that for labeling the real sample we take samples from the actual dataset and since the data/image is real we label it as one(1).

# generate n fake samples with class labels
def generate_fake_samples(n_samples):
# generate uniform random numbers in [0,1]
X = rand(28 * 28 * n_samples)
# reshape into a batch of grayscale images
X = X.reshape((n_samples, 28, 28, 1))
# generate ✬fake✬ class labels (0)
y = zeros((n_samples, 1))
return X, y

From the above snippet, we can see that for labeling the fake samples we first take any random latent variable, reshape it and then label that image as zero(0).

Let’s build our Discriminator Model first because it’s very easy to build them as compared to the Generator Model.

Building Discriminator Model

Training a discriminator model in a generative model setting is often considered to be easier than training the generator model for a few reasons:

  1. The discriminator model’s job is to simply classify input samples as real or fake, whereas the generator’s job is to produce realistic-looking samples that can fool the discriminator. This binary classification task is generally considered to be simpler than the task of generating new samples.
  2. Discriminator models can often be trained using standard supervised learning techniques, such as backpropagation and stochastic gradient descent, whereas generator models require more advanced techniques such as unsupervised learning or reinforcement learning.
  3. Discriminator models have fewer parameters to learn, and thus require fewer data to train.
  4. Discriminator models are trained to identify real data and fake data. Real data is already available which makes the process of discrimination between real and fake data easier.
  5. In GAN’s architecture generator and discriminator, models are trained simultaneously. As the generator becomes better at generating fake data, the discriminator becomes better at detecting fake data.

These are some of the reasons that make the training of the discriminator model relatively easy.

def define_discriminator_model(in_shape=(28,28,1)):
model = Sequential()
model.add(Conv2D(64, (3,3), strides=(2,2), padding='same', input_shape = in_shape))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.2))
model.add(Conv2D(64, (3,3), strides=(2,2), padding ='same'))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))

opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# 2x2 strides causes the downsampling from 14,14 to 7,7 
model = define_discriminator_model()
model.summary()
Discriminator Model Summary

As we can see from the above figure our Discriminator model works, in the same way, an Image Classifier would work, it takes a random image sample and classifies it, in this case, whether it is real or fake.

Training the Discriminator

def train_discriminator(model, dataset, n_iter=100, n_batch=256):
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_iter):
# get randomly selected ✬real✬ samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator on real samples
_, real_acc = model.train_on_batch(X_real, y_real)
# generate ✬fake✬ examples
X_fake, y_fake = generate_fake_samples(half_batch)
# update discriminator on fake samples
_, fake_acc = model.train_on_batch(X_fake, y_fake)
# summarize performance
print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))
# define the discriminator model
model = define_discriminator_model()
# load image data
dataset = load_real_samples()
# fit the model
train_discriminator(model, dataset)

As we can see from the above snippet we have tried to train the discriminator model for both actual real & fake samples so that it would be easy for the discriminator to identify any new sample.

Since the discriminator is done, let’s look into building the generator model.

Building Generator Model

Training a generator model in a generative model setting is often considered to be more challenging than training the discriminator model for a few reasons:

  1. The generator model’s job is to produce realistic-looking samples that can fool the discriminator, which is a more complex task than simply classifying input samples as real or fake.
  2. Generator models require more advanced techniques such as unsupervised learning or reinforcement learning, whereas discriminator models can often be trained using standard supervised learning techniques.
  3. Generator models have more parameters to learn, and thus require more data to train.
  4. The generator model is trained to generate fake data that can fool the discriminator. As the discriminator becomes better at detecting fake data, it becomes harder for the generator to generate such data that can fool the discriminator.
  5. The generator model is trained to output data that is similar to the real data. However, it is hard to evaluate the similarity between the generated data and the real data.
  6. In GAN’s architecture generator and discriminator, models are trained simultaneously. As the generator becomes better at generating fake data, the discriminator becomes better at detecting fake data. This creates constant competition between the two models making the generator’s job harder.
def define_generator(latent_dim):
model = Sequential()
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
model.add(Dense(n_nodes, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 128)))
# upsample to 14x14
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 28x28
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
return model
# define the size of the latent space
latent_dim = 100
# define the generator model
model = define_generator(latent_dim)
# summarize the model
model.summary()
# plot the model
plot_model(model, to_file='generator_plot.png', show_shapes=True, show_layer_names=True)

Before Testing the Generator Model, we must look into a few Important steps, The snippet below discusses everything but lets us look into a few points in detail.

As we know by now, the Generator model requires a random latent variable, this variable must also have the same shape as the image trained by the Discriminator. Therefore, we have to define two functions the first one is to generate latent points and the second is to generate fake samples.

Labeling the fake samples generated by the Generator is important because this would be later used by our Combined GAN Model(Discussed below)

def generate_latent_points(latent_dim, n_samples):
# generate points in the latent space
x_input = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
x_input = x_input.reshape(n_samples, latent_dim)
return x_input
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):
# generate points in latent space
x_input = generate_latent_points(latent_dim, n_samples)
# predict outputs
X = g_model.predict(x_input)
# create ✬fake✬ class labels (0)
y = zeros((n_samples, 1))
return X, y
# size of the latent space
latent_dim = 100
# define the discriminator model
model = define_generator(latent_dim)
# generate samples
n_samples = 25
X, _ = generate_fake_samples(model, latent_dim, n_samples)
# plot the generated samples
fig, ax = pyplot.subplots(ncols = 4, figsize=(20,20))
for i in range(4):
ax[i].imshow(X[i,:,:,0], cmap='gray_r')
Sample Images Generated by Generator.

The above figure is the digits generated by the Generator, I know those are not digits but trust me after applying a few GAN Hacks we can see how well the generator is able to deceive the discriminator and produce awesome images.

Improving the Generator Model

The weights in the generator model are updated based on the performance of the discriminator model. When the discriminator is good at detecting fake samples, the generator is updated more, and when the discriminator model is relatively poor or confused when detecting fake samples, the generator model is updated less. This defines the zero-sum or adversarial relationship between these two models. There may be many ways to implement this using the Keras API, but perhaps the simplest approach is to create a new model that combines the generator and discriminator models

def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=50, n_batch=256):
bat_per_epo = int(dataset.shape[0] / n_batch)
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_epochs):
# enumerate batches over the training set
for j in range(bat_per_epo):
# get randomly selected ✬real✬ samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# generate ✬fake✬ examples
X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# create training set for the discriminator
X, y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
# update discriminator model weights
d_loss, _ = d_model.train_on_batch(X, y)
# prepare points in latent space as input for the generator
X_gan = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = ones((n_batch, 1))
# update the generator via the discriminator✬s error
g_loss = gan_model.train_on_batch(X_gan, y_gan)
# summarize loss on this batch
print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, bat_per_epo, d_loss, g_loss))
if (i+1) % 10 == 0:
summarize_performance(i, g_model, d_model, dataset, latent_dim)

latent_dim = 100
# create the discriminator
d_model = define_discriminator_model()
g_model = define_generator(latent_dim)
# create the gan
gan_model = define_gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim)

the above snippet combines Discriminator & Generator, we have trained the model for 50 epochs but 100 epochs are recommended.

Generating Images

fig, ax = pyplot.subplots(ncols = 5,nrows=5,figsize=(20,20))
def save_plot(examples, n):
# plot images
for i in range(5):
for j in range(5):
ax[i][j].imshow(examples[(i+1)*(j+1)-1,:,:,0])

# load model
model = load_model('generator_model_050.h5')
# generate images
latent_points = generate_latent_points(100, 25)
# generate images
X = model.predict(latent_points)
# plot the result
save_plot(X, 5)
Generated Coloured Images

By setting cmap = “gray_r” we can get below figure,

Generated Gray Scaled Images

By looking at both of the above-generated images, we can surely say that our generator is able to generate digits from the given random sample. A few common digits identifiable are 8, 0, 9, and 6. Training the model for 50 epochs gave such results, we can further improve it by training the Defined Model for 100 epochs for getting better results. Finally, we can see that GANs are a very powerful type of machine learning model that can be used to produce any sample that required not only images but also Audio, Videos, Text Generation, and many more!

Resources

  1. https://developers.google.com/machine-learning/gan#:~:text=Generative%20adversarial%20networks%20(GANs)%20are,belong%20to%20any%20real%20person
  2. https://neptune.ai/blog/gan-loss-functions
  3. https://machinelearningmastery.com/how-to-code-generative-adversarial-network-hacks/
  4. https://medium.com/r?url=http%3A%2F%2Fwiki.pathmind.com%2Fgenerative-adversarial-network-gan
  5. https://medium.com/analytics-vidhya/introduction-to-generative-adversarial-networks-gans-852c8a29bd70

--

--