Gender Swap and CycleGAN in TensorFlow 2
Learn to use a Generative Adversarial Network to build the most popular Snapchat lens.
We are a team of experts to help your business adopt AI solutions or build new AI products. Contact us at info@imaginaire.ai or visit our website https://www.imaginaire.ai
To see the full source code of this CycleGAN, please visit here.
Background
Recently, the gender swap lens from Snapchat becomes very popular on the internet. There’re many buzzwords about Generative Adversarial Networks since 2016 but this is the first time that ordinary people get to experience the power of GANs. What’s more extraordinary about this lens is its great real-time performance which make it just like looking into a magic mirror. Although we can’t know the exact algorithm behind this virus lens, it’s most likely a CycleGAN which is introduced in 2017 by Jun-Yan, Taesung, Phillip and Alexei in this paper. And in this article, I’m going to show you how to implement a gender swap effect with TensorFlow 2.0 just like Snapchat does.
Generative Adversarial Network
First of all, let’s quickly go over the basics of Generative Adversarial Network (GAN) to help readers who are not familiar with it. In some scenario, we want to generate an image that belongs to a particular domain. For example, we’d like to draw a random interior design photo, and we ask the computer to generate such an image. In order to do so, we need a mathematical representation of the interior design domain space. Assume there’s a function F
, and a random input number x. We want y = F(x)
to always be very close to our target domain Y
. However, this target domain is in a very high dimensional space so that no human-being can figure out explicit rules to define it. A GAN is such kind of a network, by playing a minimax game between two AI agent, it can eventually find out an approximate representation F of our target domain Y
.
So how does GAN accomplish it? The trick here is to break down the problem into two parts: 1. We need a generator to keep making new images out of some random seed 2. We need a discriminator to give feedback for the generator about how good the generated image is. The generator here is just like a young artist who has no idea how to paint but want to fake some masterpiece, and the discriminator is a judge who can tell what’s wrong in the new paint. The judge doesn’t need to know how to paint by himself. However, as long as he’s good at telling the difference between a good one and a bad one, our young painter can benefit from his feedback for sure. So we use Deep Learning to build a good judge and use it to train a good painter in the meantime.
To train a good judge, we need to feed both the authentic image and the generated image to our discriminator together. Since we know which is authentic and which is fake beforehand, the discriminator can update its weights by comparing its decision and the truth. For the generator, it takes a look at the decision from the discriminator. If the discriminator seems more agreeable with the fake image, it indicates that the generator is heading in the right direction, and vice versa. The tricky part is that we can’t just train a judge without a painter, or a painter without a judge. They learn from each other and try to beat each other by playing this minimax game. If we train the discriminator too much without training the generator, the discriminator will become too dominant, and the generator won’t ever have a chance to improve because every move is a losing move.
Eventually, both generator and discriminator will be really good at their job. And now we can take the generator out to perform the painting task independently. However, what on earth does this GAN has anything to do with gender swap? In fact, generating a face with a different gender is just like generate an image from a random seed number. The only difference is that we will now take the original face image as input instead of that random seed. Let’s say we want to convert a male face to a female face. We are looking for a function F
, by taking a male face x
, can output a value y
that's very close to the real female version y_0
of that face.
The GAN approach sounds clear. But we haven’t discussed one curial caveat in the above approach. To train the discriminator, we need both true image and false image. We get the false image y
from the generator, but where do we get the true image for that specific male face? We can't just use a random female face here because we want to preserve some common trait when we swap the gender and a face from a different person would ruin that. But it's also really hard to get the paired training data as well. You could go out to find those real twin brothers and sisters and take pictures of them. Or you can ask a professional dresser to 'turn' a man into a woman. Both are very expensive. So is there a way for the model to learn the most important facial difference between man and woman from an unpaired dataset?
CycleGAN
Fortunately, researchers discovered ways to utilize unpair training data. One of the most famous models is called CycleGAN. The main idea behind CycleGAN is that, instead of using paired data to train the discriminator, we can form a cycle, where two generators work together to convert the image back and forth. More specifically, generator A2B first generates an image from domain A to domain B, and then generator B2A uses that as input to generate another image from domain B to domain A. We then set a goal to make sure the second image (reconstructed image) looks as close as the first input. For example, if a generator A2B first converts a horse to a zebra, then generator B2A should convert that zebra back to a horse, and the reconstructed horse image should look identical to the original one. In this way, the generator will learn to not generate some trivial changes, but only those critical differences between the two domains. Otherwise, it probably won’t be able to convert it back. With this goal setup, we can now use unpaired images as training data.
In reality, we need two cycles here. Since we are training generator A2B and generator B2A together, we have to make sure both generators is improving over time; otherwise, it will still have a problem to reconstruct a good image. Moreover, as we discussed above, improving a generator means we need to improve the discriminator in the meantime. In the cycle A2B2A (A -> B -> A), we use discriminator A to decide if the reconstructed image is in domain A. Thus, discriminator A will be trained. Likewise, we also need a cycle B2A2B so that discriminator B can be trained as well. If both discriminator A and discriminator B are well trained, it means our generator A2B and B2A can improve too!
There’s another great article here for CycleGAN for further reading. Now that you get the main idea of this network, let’s dive deep into some details.
Optimizer
It’s recommended here that Adam is the best optimizer for GAN training. It may comes from the delicate nature of GAN, and Adam can adapt to a complex gradient territory well. The learning rate remains at 0.0002 for the first 100 epochs, then linearly decay to 0 in the next 100 epochs. Here total_batches
is the number of mini-batches for each epoch because our learning rate scheduler only considers each mini batch as a step.
Generator
Network Structure
CycleGAN uses a regular generator structure. It first encodes the input image into a feature matrix by applying 2D convolutions. This is used to extract valuable feature information from local or global.
Then, six or nine layers of ResNet blocks are used to transform the features from the encoder into the features in the target domain. As we know, the skip connection in ResNet block helps the network to memorize the gradients from previous layers, which makes sure the deeper layers can still learn something. If you are not familiar with ResNet, please refer to this paper.
Finally, a few layers of deconvolution is used as a decoder. The decoder converts the features from the target domain into an actual image from the target domain by upsampling.
Unlike the idea from VGG and Inception network, it’s recommended for a GAN to use a larger convolution kernel size like 7X7 so that it can pick up broader information instead of just focusing on details. It makes sense because when we reconstruct an image, it’s not only the details matters but also the overall pattern. Also, reflection padding is used here to improve the quality around the image border.
Loss Functions
There’re three types of loss we care about here:
- To calculate the GAN loss, we measure the L2 distance (MSE) between the generated image and the truth image.
- To calculate the cyclic loss, we measure the L1 distance (MAE) between the reconstructed image from the cycle and the truth image
- To calculate the identity loss, we measure the L1 distance (MAE) between the identity image and the truth image
GAN loss is the typical loss we use the GANs, and I won’t discuss much here. The interesting parts are cyclic loss and identity loss. The cyclic loss measures how good the reconstructed image is, which helps both generators to catch the essential style difference between the two domains. The identity loss is optional, but it helps to avoid the generator to make unnecessary changes. The way it works is that, by applying generator A2B to a real B image, it shouldn’t make any changes as it’s already the desired outcome. According to the author, this mitigates some weird issues like background color change.
To combine all losses, we also need to assign some weights for each loss so indicate the importance. In the paper, the author proposed two Lambda parameters, which 10x the cycle loss and 5x the identity loss. Note the usage of the GradientTape here, we record the gradient and apply gradient descent for both generators together. Here, real_a
is the truth image from domain A, real_b
is the truth image from domain B. fake_a2b
is the generated image from domain A to domain B. and fake_b2a
is the generative image from domain B to domain A.
Discriminator
Network Structure
Similar to other GANs, the discriminator consists of some 2d convolution layers to extract features from the generated image. However, to help the generator to generate a high-resolution image, CycleGAN uses a technique called PatchGAN to created more fine-grained decision matrix instead of one decision value. Each value in this 32×32 decision matrix maps to a patch of the generated image, and indicate how real this patch is.
In fact, we don’t crop a patch of the input image during implementation. We just need to use a final convolution layer to do the job for us. Essentially, the convolution layer performs like cropping a patch.
Loss Function
The loss functions for discriminators are much more straightforward. Just like typical GANs, we tell the discriminator to treat truth image as real, and the generated image as fake. So we have two losses for each discriminator loss_real
and loss_fake
, both have an equal effect on the final loss. In calc_gan_loss
, we are comparing two matrices. Usually, the output of a discriminator is just one value between 0 and 1. However, as we mentioned above, we use a technique called PatchGAN, so the discriminator will produce one decision for each patch, which forms a 32×32 decision matrix.
Training
Now that we have defined both models and loss functions, we can put them together and start training. By default, the eager mode is enabled in TensorFlow 2.0, so we don’t have to make the graph. However, if you are a careful person, you might found that both discriminator and generator training functions are decorated with a tf.function
decorator. This is the new way introduced by TensorFlow 2.0 to replace the old tf.Session()
. With this decorator, all operations within will be converted into a graph. Hence, the performance could be much better compared with the default eager mode. To learn more about tf.function
, please refer to this article.
One thing to mention is that, instead of feeding the generated image to the discriminator directly, we are actually using an image pool here. Each time, the image pool will randomly decide to give the discriminator a newly generated image, or a generated image from past steps. The benefit of doing this is that the discriminator can learn from other cases and sort of having a memory about the hacks the generator uses. Unfortunately, we can’t use this random image pool in graph mode at the moment, so we need to put them back to CPU when selecting a random image from the pool. This indeed introduces some cost.
The model illustrated in this article is trained on my own GTX 2080 home computer, so it’s a bit slow. On a V100 16G GPU and 64G RAM instance, though, you should be able to set the mini batch size to 4, and the trainer can process one epoch of 260 mini batches in 3 minutes for the horse2zebra dataset. So it takes about 10 hours to train a horse2zebra model fully. If you reduce the image resolution and some network parameters correspondingly, the training could be faster. The final generator is about 44mb each.
To see the full training script, please go visit my repo here.
Let’s see some inference results on a few datasets. Among those, horse2zebra and monet2photo is the original dataset from the paper. And the CelebA dataset is from here.
Results
Let’s see some generated image results from our model. On the left side it’s the original image, on the right side it’s the generated one.
horse2zebra
Horse -> Zebra
Zebra -> Horse
monet2photo
Monet -> Photo
Photo -> Monet
CelebA
Male -> Female
Female -> Male
Gender Swap
We successfully mapped a male face to a female face, but to use it in a production environment, we need to pipeline to orchestrate lots of other steps together. One naive way to do so is:
- Run face detection to find a bounding box and keypoints for the most dominant face in the picture.
- Extend the bounding box a little bit bigger to match the training dataset distribution.
- Crop the picture with this extended bounding box and run CycleGAN over it.
- Patch the generated image back to the original picture
- Overlay some hair, eyeliner, and beards on top of the new face picture based on the keypoints we had from the last step
Here’s a great article that explains how this pipeline works in details.
Questions
Lastly, I want to throw out some questions I have. I don’t know the answers yet, but I hope those who have experiences of building similar products can share their opinions in the comments below.
- The CycleGAN model turns out to be 44mb, with quantization it could become 12mb but still too large. What are the effective methods to make them usable on those mobile and embedded devices?
- The output image resolution isn’t great and lost much of sharpness. How to generate a bigger image such as 1024×1024 without blowing up the model size? Will a super-resolution model help in this case?
- How do we know if a model is thoroughly trained and converged? The loss isn’t a good metric here, but we also don’t know what’s the “best” output. How to measure the similarity between the two styles?
References
- Understanding and Implementing CycleGAN in TensorFlow
- Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros, Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
- Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, Alexei A. Efros, Image-to-Image Translation with Conditional Adversarial Networks
- Gender and Race Change on Your Selfie with Neural Nets