DCGANS for Fashion-Mnist Dataset

Aftab Gazali
9 min readJan 9



GANs (Generative Adversarial Networks) is a type of artificial intelligence algorithm that is used to generate new, synthetic data that is similar to a training dataset. They consist of two neural networks: a generator and a discriminator. The generator tries to create synthetic data that is similar to the training data, while the discriminator tries to distinguish the synthetic data from the real training data. The two networks are trained simultaneously, and the generator improves over time as it tries to create data that can fool the discriminator. GANs have been used to generate a wide range of synthetic data, including images, audio, and text.

There are many different variations of generative adversarial networks (GANs), each with its own unique characteristics and applications. Some of the most common types of GANs include:

  1. Vanilla GANs: These are the simplest and most basic types of GANs. They consist of a generator and a discriminator, as described above.
  2. Conditional GANs: These GANs are able to generate synthetic data that is conditioned on some additional input. For example, a conditional GAN could be trained to generate images of a specific type of object (e.g. cats) when given a label indicating the desired object type as input.
  3. Deep Convolutional GANs (DCGANs): These GANs use deep convolutional neural networks as the generator and discriminator, which makes them well-suited for generating images.
  4. InfoGANs: These GANs are designed to disentangle the latent factors of variation in the training data and allow control over the generated data by manipulating these factors.
  5. Wasserstein GANs (WGANs): These GANs use the Wasserstein distance as a measure of the difference between the real data distribution and the synthetic data distribution, rather than the traditional GAN objective of minimizing the cross-entropy loss.
  6. CycleGANs: These GANs are used for image-to-image translation tasks, such as translating photos of horses into photos of zebras.
  7. StyleGANs: These GANs are able to generate highly realistic images, and are particularly well-suited for tasks such as generating synthetic faces

Please refer to the below links for in depth explanation of GANS


Fashion-MNIST is a dataset of 60,000 28x28 grayscale images of 10 fashion categories, along with a test set of 10,000 images. It was created as a more challenging alternative to the MNIST dataset, which consists of 70,000 28x28 grayscale images of the 10 digits. Both datasets are widely used as benchmarks for machine learning algorithms and are often used as a starting point for developing and testing new models.

The fashion categories in the Fashion-MNIST dataset include:

  • T-shirt/top
  • Trouser
  • Pullover
  • Dress
  • Coat
  • Sandal
  • Shirt
  • Sneaker
  • Bag
  • Ankle boot
ds = tfds.load('fashion_mnist', split='train')
# Getting data out of the pipeline
fig, ax = plt.subplots(ncols=4, figsize=(20,20))
for idx in range(4):
# Grab an image and label
sample = dataiterator.next()
# Plot the image using a specific subplot

As we can see from the above given figure the dataset consists of labels and images, each fashion is grouped in the label numbering from 0 to 9.

Data Preprocessing

  1. Feature Scaling

Feature Scaling is an important step of data preprocessing, If feature scaling is not performed, a machine learning algorithm would consider larger values to be higher and smaller values to be lower, regardless of the unit of measurement. The scaler used here is the Traditional MinMax Scaler. Minmax scaler Transform features by scaling each feature to a given range[0,1]

2. Apply MCSBP

  1. Map: The “map” transformation applies a function to each element of an input dataset, creating a new dataset of transformed elements. For example, you might use the “map” transformation to apply a function that normalizes the values in each element of an input dataset.
  2. Cache: The “cache” transformation stores the elements of an input dataset in memory, allowing them to be reused in the future without the need to recreate them from the original data source. This can be useful for speeding up the input pipeline by avoiding expensive data preprocessing steps.
  3. Shuffle: The “shuffle” transformation randomly shuffles the elements of an input dataset. This is often used to improve the generalization performance of a machine learning model by preventing it from overfitting to the order of the input data.
  4. Batch: The “batch” transformation combines multiple elements of an input dataset into a single batch. This can be useful for improving the efficiency of the input pipeline by allowing the model to process multiple input examples at once.
  5. Prefetch: The “prefetch” transformation asynchronously prefetches elements from the input dataset, allowing them to be ready for use by the model when needed. This can be useful for improving the performance of the input pipeline by overlapping the input data preparation with the model’s forward pass.

def image_scale(data):
image = data['image']
return image / 255
ds = ds.map(image_scale)
ds = ds.cache()
ds = ds.shuffle(60000)
ds = ds.batch(128)
ds = ds.prefetch(64)

Our task would be to train the GAN model in such a way that our generator is able to produce the desireable images from the given random latent variable in this case the Generator must produce fashion images. Let’s us look into what is a generator and disriminator.

Working of GANS

Gans consists of two model that are competing simultaneously with each other: a generator and a discriminator. As discussed above the generator takes in a random latent sample and generates the image this image is then passed to the discriminator which is trained with the training dataset hence already knows which images are real and which images are fake. The task of the discriminator is to identify whether the generated image is either real or fake.The two networks are trained simultaneously, with the generator trying to produce synthetic data that can fool the discriminator, and the discriminator trying to correctly identify whether each piece of data is real or fake.

As training progresses, the generator improves at producing synthetic data that looks more and more like the real-world data, while the discriminator becomes better at distinguishing the synthetic data from the real data. During training, the generator and discriminator are both updated based on the performance of the other network. The generator is updated to produce more realistic synthetic data, while the discriminator is updated to better distinguish between real and fake data. This process continues until the generator is able to produce synthetic data that is indistinguishable from the real data, at which point the GAN is considered to be trained.

After training, the generator network can be used to generate new synthetic data by inputting random noise and generating the corresponding synthetic data. The synthetic data produced by the GAN is intended to be similar to the real-world data that the GAN was trained on.

Building the Neural Network

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Reshape, LeakyReLU, Dropout, UpSampling2D
def build_generator(): 
model = Sequential()

model.add(Dense(7*7*128, input_dim=128))

model.add(Conv2D(128, 5, padding='same'))

model.add(Conv2D(128, 5, padding='same'))

model.add(Conv2D(128, 4, padding='same'))

model.add(Conv2D(128, 4, padding='same'))

model.add(Conv2D(1, 4, padding='same', activation='sigmoid'))

return model
generator = build_generator()
Generator Model Summary

we start with an random latent variable give it 7*7*128 which is 6272 shapes now we will try to re-shape this variable to the shape which Discriminator requires i.e (28, 28, 1) which is an image. This image would be inputed to Discriminator to identify whether the given image is real or fake.

def build_discriminator(): 
model = Sequential()

model.add(Conv2D(32, 5, input_shape = (28,28,1)))

model.add(Conv2D(64, 5))

model.add(Conv2D(128, 5))

model.add(Conv2D(256, 5))

model.add(Dense(1, activation='sigmoid'))

return model
discriminator = build_discriminator()
Discriminator Model Summary

We are now going to use GAN Hacks by adding losses and noise to a GAN:

  1. Losses help to guide the training process and ensure that the GAN learns to perform the task it was designed for. By adding losses, the GAN can learn to generate samples that are more similar to the training data, or that have certain desired properties.
  2. Adding noise can help the GAN learn more robustly and generate more diverse and realistic samples. As mentioned in my previous response, adding noise to the training data and inputs of the generator can help the GAN learn to capture the variations and detail in the data, and generate a wider range of samples.
  3. Losses and noise can help stabilize the training process. By adding losses, the GAN can learn to balance the competing objectives of the generator and discriminator, which can help stabilize the training process. Similarly, adding noise can help prevent the generator and discriminator from getting stuck in poor local minima, which can also improve the stability of the training process.

Overall, adding losses and noise can help improve the performance and robustness of a GAN, and enable it to learn more effectively and generate higher-quality samples.

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
g_opt = Adam(learning_rate=0.0001)
d_opt = Adam(learning_rate=0.00001)
g_loss = BinaryCrossentropy()
d_loss = BinaryCrossentropy()
# Importing the base model class to subclass our training step
from tensorflow.keras.models import Model
class FashionMnistGAN(Model): 
def __init__(self, generator, discriminator, *args, **kwargs):
# Pass through args and kwargs to base class
super().__init__(*args, **kwargs)

# Create attributes for gen and disc
self.generator = generator
self.discriminator = discriminator

def compile(self, g_opt, d_opt, g_loss, d_loss, *args, **kwargs):
# Compile with base class
super().compile(*args, **kwargs)

# Create attributes for losses and optimizers
self.g_opt = g_opt
self.d_opt = d_opt
self.g_loss = g_loss
self.d_loss = d_loss

def train_step(self, batch):
# Get the data
real_images = batch
fake_images = self.generator(tf.random.normal((128, 128, 1)), training=False)

# Train the discriminator
with tf.GradientTape() as d_tape:
# Pass the real and fake images to the discriminator model
yhat_real = self.discriminator(real_images, training=True)
yhat_fake = self.discriminator(fake_images, training=True)
yhat_realfake = tf.concat([yhat_real, yhat_fake], axis=0)

# Create labels for real and fakes images
y_realfake = tf.concat([tf.zeros_like(yhat_real), tf.ones_like(yhat_fake)], axis=0)

# Add some noise to the TRUE outputs
noise_real = 0.15*tf.random.uniform(tf.shape(yhat_real))
noise_fake = -0.15*tf.random.uniform(tf.shape(yhat_fake))
y_realfake += tf.concat([noise_real, noise_fake], axis=0)

# Calculate loss - BINARYCROSS
total_d_loss = self.d_loss(y_realfake, yhat_realfake)

# Apply backpropagation - nn learn
dgrad = d_tape.gradient(total_d_loss, self.discriminator.trainable_variables)
self.d_opt.apply_gradients(zip(dgrad, self.discriminator.trainable_variables))

# Train the generator
with tf.GradientTape() as g_tape:
# Generate some new images
gen_images = self.generator(tf.random.normal((128,128,1)), training=True)

# Create the predicted labels
predicted_labels = self.discriminator(gen_images, training=False)

# Calculate loss - trick to training to fake out the discriminator
total_g_loss = self.g_loss(tf.zeros_like(predicted_labels), predicted_labels)

# Apply backprop
ggrad = g_tape.gradient(total_g_loss, self.generator.trainable_variables)
self.g_opt.apply_gradients(zip(ggrad, self.generator.trainable_variables))

return {"d_loss":total_d_loss, "g_loss":total_g_loss}

# Create instance of subclassed model
fashgan = FashionMnistGAN(generator, discriminator)
# Compile the model
fashgan.compile(g_opt, d_opt, g_loss, d_loss)
import os
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.callbacks import Callback
class ModelMonitor(Callback):
def __init__(self, num_img=3, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim

def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = tf.random.uniform((self.num_img, self.latent_dim,1))
generated_images = self.model.generator(random_latent_vectors)
generated_images *= 255
for i in range(self.num_img):
img = array_to_img(generated_images[i])
img.save(os.path.join('images', f'generated_img_{epoch}_{i}.png'))

Training a GAN can take a long time for several reasons:

  1. GANs are complex models with many parameters, which requires a lot of data and computational resources to train effectively. The more complex the model, the more data and computation it will need to learn accurately.
  2. GANs are trained using an adversarial process, in which the generator and discriminator are competing against each other. This process can be slow, as the generator and discriminator must be trained alternately and the training process can be sensitive to the balance between them.

We tried to train our model with 1300 epochs which can take nearly 22 hours but, it is recomended to at the least give model 2000 epochs to fully train. However due to the adversarial nature of GANS training is very time consuming and may sometimes take even days to train

hist = fashgan.fit(ds, epochs=1300, callbacks=[ModelMonitor()])

Testing the Generator


fig, ax = plt.subplots(nrows= 3, ncols=4, figsize=(10,10))
for r in range(3):
for c in range(4):

Loading the generator model and testing the generator by passing random latent variable

Generated Images

As we can see our generator is now capable of tricking the discriminator into thinking that all this fashion images are real but they are not! Therefore, our generator is now capable of producing desirable images from the given random latent variable.


  1. https://developers.google.com/machine-learning/gan#:~:text=Generative%20adversarial%20networks%20(GANs)%20are,belong%20to%20any%20real%20person.
  2. https://neptune.ai/blog/gan-loss-functions
  3. https://machinelearningmastery.com/how-to-code-generative-adversarial-network-hacks/