Conditional GANS(cGANs) for MNIST Dataset

Aftab Gazali
12 min readJan 25, 2023

Gans stands for Generative Adversarial Networks. It is a type of deep learning model composed of two neural networks: a generator and a discriminator. The generator creates new data samples, while the discriminator attempts to distinguish the generated samples from actual examples. The two networks are trained simultaneously, with the generator trying to create samples that can fool the discriminator, and the discriminator trying to correctly identify the generated samples. GANs have been used for a variety of tasks, including image synthesis, text generation, and anomaly detection.

There are several different types of GANs, each with its own unique characteristics and use cases. Some of the most popular types include:

  1. Vanilla GAN: This is the original GAN architecture, consisting of a generator and a discriminator network.
  2. DCGAN (Deep Convolutional GAN): This type of GAN uses convolutional layers in both the generator and discriminator networks, allowing it to generate high-resolution images.
  3. WGAN (Wasserstein GAN): This type of GAN replaces the traditional GAN loss function with the Wasserstein distance, which is more stable during training and produces less mode collapse.
  4. LSGAN (Least Squares GAN): This type of GAN uses a least squares loss function instead of the traditional binary cross-entropy loss, resulting in a more stable training process.
  5. CycleGAN: This type of GAN is able to translate an image from one domain to another, for example, converting a horse picture to a zebra picture.
  6. BEGAN (Boundary Equilibrium Generative Adversarial Networks): This type of GAN uses an equilibrium objective that helps to stabilize the training process, while also providing a measure of the quality of the generated images.
  7. BigGAN: This type of GAN is a large Generative Adversarial Network that generates high-resolution images with fine details and high quality.

In this project, we are going to use another variation of GANs called as Conditional GANs

A Conditional GAN, or cGAN for short, is a variant of the Generative Adversarial Network (GAN) architecture in which both the generator and discriminator networks are conditioned on additional input information. This input information can be used to control the specific characteristics of the generated images, such as class labels in image classification tasks or specific text prompts in natural language generation tasks. The goal of the generator network is to produce synthetic samples that are similar to real samples from the target distribution, while the goal of the discriminator network is to distinguish between real and synthetic samples.

Dataset Introduction

Ahandwritten digits dataset is a collection of images of handwritten digits, along with their corresponding labels (i.e., the correct digit). These datasets are commonly used for training and evaluating machine learning algorithms, particularly in the field of image recognition and computer vision. One of the most popular datasets of this type is the MNIST dataset, which consists of 60,000 training images and 10,000 test images of handwritten digits, along with their corresponding labels. We are going to use this dataset for training our GAN Model. The images in this dataset are 28x28 pixels and grayscale images, and it is widely used as a benchmark dataset for testing and comparing the performance of different machine learning algorithms. Other examples of handwritten digits datasets include the USPS dataset, which contains images of USPS postal codes, and the KMNIST dataset, which is a more diverse set of handwritten digits than the MNIST dataset.

fig, ax = plt.subplots(ncols = 4, nrows = 4, figsize = (20,20))
for i in range(4):
for j in range(4):
ax[i][j].imshow(trainX[(i+1)*(j+1)-1], cmap ='gray_r')
Handwritten Digit Dataset

Preprocessing

  1. Feature Scaling: Feature Scaling is an important step of data preprocessing, If feature scaling is not performed, a machine learning algorithm would consider larger values to be higher and smaller values to be lower, regardless of the unit of measurement. The scaler used here is the Traditional MinMax Scaler. Minmax scaler Transform features by scaling each feature to a given range[0,1]
  2. Expand Dims: In the python NumPy library, numpy.expand_dims is a function that inserts a new axis with size 1. It takes in the input array and the axis position where the new axis is to be inserted. It returns a new array with the new dimension inserted.
def load_real_samples():
# load mnist dataset
(trainX, _), (_, _) = load_data()
# expand to 3d, e.g. add channels dimension
X = np.expand_dims(trainX, axis=-1)
# convert from unsigned ints to floats
X = X.astype('float32')
# scale from [0,255] to [0,1]
X = X / 255.0
return X

Working of Conditional GANs

The working of a conditional GAN (cGAN) is similar to that of a standard GAN, but with an additional conditioning variable. The cGAN consists of two main components: the generator network and the discriminator network.

  1. The generator network: The generator network takes as input both a random noise vector and a conditioning variable, and produces a synthetic sample. The generator network is trained to produce samples that are similar to real samples from the target distribution and are consistent with the conditioning variable.
  2. The discriminator network: The discriminator network takes as input both real samples and synthetic samples, along with the conditioning variable, and attempts to distinguish between them. The discriminator network is trained to correctly classify whether a given sample is real or synthetic, conditioned on the input variable.

Both the generator and discriminator networks are trained concurrently in an adversarial manner. The generator network is trained to produce samples that are difficult for the discriminator network to distinguish from real samples, while the discriminator network is trained to correctly identify whether a given sample is real or synthetic.

In the training phase, the generator network is presented with a random noise vector and a conditioning variable and produces a synthetic sample. The synthetic sample, along with the conditioning variable, is then presented to the discriminator network, which attempts to classify it as real or synthetic. The generator network is then updated based on the performance of the discriminator network.

In the inference phase, the generator network is presented with a random noise vector and a conditioning variable and produces a synthetic sample. The conditioning variable can be used to control specific characteristics of the generated sample, such as class labels in image classification tasks or specific text prompts in natural language generation tasks.

cGANs have proven to be an effective approach for controlling the specific characteristics of the generated samples and have been used in a variety of tasks such as image-to-image translation, text-to-image synthesis, and video prediction.

Working of cGANs

Building Discriminator Model

In a conditional GAN (cGAN), the discriminator network is responsible for distinguishing between real samples and synthetic samples, conditioned on the input variable. The input to the discriminator network includes both real samples, synthetic samples, and the conditioning variable. The output of the discriminator network is a scalar value, usually between 0 and 1, that represents the probability that a given sample is real.

The structure of the discriminator network in a cGAN can be similar to that of a standard GAN. It can be a fully-connected neural network or a convolutional neural network, depending on the type of input data. The discriminator network takes as input both the real and synthetic samples, along with the conditioning variable, and produces a scalar value that represents the probability that the input sample is real.

The conditioning variable is concatenated with the input data before it is processed by the network. For example, if the input data is an image, the conditioning variable can be concatenated with the image as an additional channel. If the input data is text, the conditioning variable can be concatenated with the text as an additional feature.

The discriminator network is trained using the binary cross-entropy loss function, which measures the difference between the predicted probability and the true label (real or synthetic). The goal of the discriminator network is to correctly classify whether a given sample is real or synthetic, conditioned on the input variable.

In summary, the discriminator network in a cGAN is responsible for distinguishing between real and synthetic samples conditioned on the input variable by taking as input both real and synthetic samples, along with the conditioning variable, processing the input and outputting a scalar value that represents the probability that the input sample is real.

def define_discriminator_model(in_shape = (28,28,1), n_classes = 10):
in_label = Input(shape=(1,))
# embedding for categorical input
li = Embedding(n_classes, 50)(in_label)
# scale up to image dimensions with linear activation
n_nodes = in_shape[0] * in_shape[1]
li = Dense(n_nodes)(li)
# reshape to additional channel
li = Reshape((in_shape[0], in_shape[1], 1))(li)
# image input
in_image = Input(shape=in_shape)
# concat label as a channel
merge = Concatenate()([in_image, li])
# downsample
fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)
fe = LeakyReLU(alpha=0.2)(fe)
# downsample
fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
fe = LeakyReLU(alpha=0.2)(fe)
# flatten feature maps
fe = Flatten()(fe)
# dropout
fe = Dropout(0.4)(fe)
# output
out_layer = Dense(1, activation='sigmoid')(fe)
# define model
model = Model([in_image, in_label], out_layer)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
discriminator = define_discriminator_model()
discriminator.summary()
plot_model(discriminator, to_file='discriminator_plot.png', show_shapes=True, show_layer_names=True)
Discriminator Plot Model

from the above plot model, we can see that the discriminator model takes two inputs first one is the label and the other is the Image of size (28,28). The class label is then passed through an Embedding layer with a size of 50. This means that each of the 10 classes for the MNIST dataset (0 through 9) will map to a different 50-element vector representation that will be learned by the discriminator model. The output of the embedding is then passed to a fully connected layer with a linear activation.

Building Generator Model

In a conditional GAN (cGAN), the generator network is responsible for producing synthetic samples that are similar to real samples from the target distribution and are consistent with the conditioning variable. The generator network takes as input both a random noise vector and a conditioning variable and produces a synthetic sample.

The structure of the generator network in a cGAN can be similar to that of a standard GAN. It can be a fully-connected neural network or a convolutional neural network, depending on the type of input data. The generator network takes as input both the random noise vector and the conditioning variable and produces a synthetic sample that is conditioned on the input variable.

The conditioning variable is concatenated with the random noise vector before it is processed by the network. For example, if the input data is an image, the conditioning variable can be concatenated with the random noise vector as an additional channel. If the input data is text, the conditioning variable can be concatenated with the random noise vector as an additional feature.

The generator network is trained using the binary cross-entropy loss function, which measures the difference between the predicted probability and the true label (real or synthetic). However, in this case, the generator is trying to minimize the loss function by producing samples that the discriminator network is unable to distinguish from real samples.

In summary, the generator network in a cGAN is responsible for producing synthetic samples that are similar to real samples from the target distribution and are consistent with the conditioning variable by taking as input both a random noise vector and a conditioning variable, processing the input and outputting a synthetic sample that is conditioned on the input variable.

def define_generator(latent_dim, n_classes=10):
in_label = Input(shape=(1,))
# embedding for categorical input
li = Embedding(n_classes, 50)(in_label)
# linear multiplication
n_nodes = 7 * 7
li = Dense(n_nodes)(li)
# reshape to additional channel
li = Reshape((7, 7, 1))(li)
# image generator input
in_lat = Input(shape=(latent_dim,))
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
gen = Dense(n_nodes)(in_lat)
gen = LeakyReLU(alpha=0.2)(gen)
gen = Reshape((7, 7, 128))(gen)
# merge image gen and label input
merge = Concatenate()([gen, li])
# upsample to 14x14
gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)
gen = LeakyReLU(alpha=0.2)(gen)
# upsample to 28x28
gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)
gen = LeakyReLU(alpha=0.2)(gen)
# output
out_layer = Conv2D(1, (7,7), activation='tanh', padding='same')(gen)
# define model
model = Model([in_lat, in_label], out_layer)
return model
generator = define_generator(100)
plot_model(generator, to_file='generator_plot.png', show_shapes=True, show_layer_names=True)
Generator Plot Model

The generator model must be updated to take the class label. This has the effect of making the point in the latent space conditional on the provided class label. As in the discriminator, the class label is passed through an embedding layer to map it to a unique 50-element vector and is then passed through a fully connected layer with a linear activation before being resized. In this case, the activations of the fully connected layer are resized into a single 7 × 7 feature map. This is to match the 7 × 7 feature map activations of the unconditional generator model. The new 7 × 7 feature map is added as one more channel to the existing 128, resulting in 129 feature maps that are then upsampled as in the prior model.

Before building the combined GAN model we must look into the preprocessing required first,

Preprocessing

def load_real_samples():
# load dataset
(trainX, trainy), (_, _) = load_data()
# expand to 3d, e.g. add channels
X = expand_dims(trainX, axis=-1)
# convert from ints to floats
X = X.astype('float32')
# scale from [0,255] to [0,1]
X = (X - 127.5) / 127.5
return [X, trainy]
# select real samples
def generate_real_samples(dataset, n_samples):
# split into images and labels
images, labels = dataset
# choose random instances
ix = randint(0, images.shape[0], n_samples)
# select images and labels
X, labels = images[ix], labels[ix]
# generate class labels
y = ones((n_samples, 1))
return [X, labels], y

def generate_latent_points(latent_dim, n_samples, n_classes=10):
# generate points in the latent space
x_input = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
z_input = x_input.reshape(n_samples, latent_dim)
# generate labels
labels = randint(0, n_classes, n_samples)
return [z_input, labels]

def generate_fake_samples(generator, latent_dim, n_samples):
# generate points in latent space
z_input, labels_input = generate_latent_points(latent_dim, n_samples)
# predict outputs
images = generator.predict([z_input, labels_input])
# create class labels
y = zeros((n_samples, 1))
return [images, labels_input], y

from the above snippet, we can see that we have scaled the data from [0,255] to [0,1]. Real samples are selected randomly and labelled accordingly while the fake samples are selected by generating random latent points and labelled as zero.

Building the cGAN

def define_gan(g_model, d_model):
# make weights in the discriminator not trainable
d_model.trainable = False
# get noise and label inputs from generator model
gen_noise, gen_label = g_model.input
# get image output from the generator model
gen_output = g_model.output
# connect image output and label input from generator as inputs to discriminator
gan_output = d_model([gen_output, gen_label])
# define gan model as taking noise and label and outputting a classification
model = Model([gen_noise, gen_label], gan_output)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model

As we can see the only difference in conditional gans as compared to DCGAN is that we have to add noise and labels to the generator model previously there was no need to do so and then connect the image generated to the discriminator model.

Training the cGAN

def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=50, n_batch=128):
bat_per_epo = int(dataset[0].shape[0] / n_batch)
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_epochs):
# enumerate batches over the training set
for j in range(bat_per_epo):
# get randomly selected ✬real✬ samples
[X_real, labels_real], y_real = generate_real_samples(dataset, half_batch)
# update discriminator model weights
d_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)
# generate ✬fake✬ examples
[X_fake, labels], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update discriminator model weights
d_loss2, _ = d_model.train_on_batch([X_fake, labels], y_fake)
# prepare points in latent space as input for the generator
[z_input, labels_input] = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = ones((n_batch, 1))
# update the generator via the discriminator✬s error
g_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)
# summarize loss on this batch
print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %
(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
g_model.save('cgan_generator.h5')
latent_dim = 100
# create the discriminator
d_model = define_discriminator_model()
# create the generator
g_model = define_generator(latent_dim)
gan_model = define_gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim)

The GAN model is trained for 60 epochs with 128 batch size first, the discriminator model is trained for half real and half fake batches and then the Generator model is trained using the effective cGAN model.

Following are the results produced by the Generator Model,

Results of cGAN

In conclusion, the use of conditional cGANs allows for more control over the generated data, making it possible to generate specific types of data, such as images of a certain object or text written in a specific style. It has been used in many areas such as image synthesis, video generation, text generation, etc.

Resources

https://machinelearningmastery.com/how-to-develop-a-conditional-generative-adversarial-network-from-scratch

--

--