From Deep Learning with PyTorch by Eli Stevens and Luca Antiga
In this article, we explore some of PyTorch’s capabilities by playing generative adversarial networks.
Take 37% off Deep Learning with PyTorch. Just enter code fccstevens into the promotional discount code box at checkout at manning.com.
In part two we saw how to use a pre-trained model for image classification. Now, let’s take a look at GANs.
A pre-trained network that makes stuff up
Models falling under the name of Generative Adversarial Networks (GANs) are one of the most original outcomes of recent deep learning research. We’ll look into these later in our journey. For now, we’ll say that, whereas in standard neural network architectures we’ve one big network optimizing its weights in order to minimize a loss function related to, say, a classification task, in GANs we’ve a couple of networks, named the generator and the discriminator.
Figure 1. Concept of a GAN game.
The generator has the task of producing realistic-looking images starting from an input, while the discriminator has to tell whether a given image was fabricated by the generator or it belonged in a set of real images. The end-goal for the generator’s to fool the discriminator into mixing up real and fake images. The end-goal for the discriminator is to find out when it’s being tricked. It’s called the GAN game.
Note that “Discriminator wins” or “Generator wins” shouldn’t be taken literally, as there’s no explicit match between the two. Both networks are associated with cost functions which depend on the outcome of the other network, and which are minimized in turn during training.
This technique lead to generators that produce realistic images out of noise and a conditioning signal, like an attribute, or another image. A well-trained generator learns a plausible model for generating real-world images.
An interesting evolution of this concept’s CycleGAN. A CycleGAN can turn images of one domain into images of another domain (and back), without the need for explicitly providing matching pairs in the training set.
Figure 2. A CycleGAN trained to the point where it can fool both discriminator networks.
In CycleGAN, the generator learns to produce an image conforming to a target distribution—Monet paintings, for instance—starting from an image belonging to a different distribution—landscape photos, for instance—to ensure that the discriminator can’t tell if the image produced from a landscape photo is a genuine Monet painting. At the same time, and here’s where the Cycle prefix in the acronym comes in, the resulting painting is sent through a different generator going the other way, Monet to photo in our case (!), to be judged by another discriminator on the other side. Creating such cycle stabilizes the training process considerably, which is one of the original issues with GANs.
The fun part is that, at this point, we don’t need pairs of Monet/photos as ground truths: it’s enough to start from a collection of unrelated Monet works and landscape photos for the generators to learn their task, going beyond a purely supervised setting. The implications of this model go even further than this: the generator learns how to selectively change the appearance of objects in the scene without supervision on what’s what. No signals indicate that water is water and a tree is a tree, but they get translated to something which is the way water and trees are represented in the Monet domain and vice versa.
A NETWORK THAT TURNS HORSES INTO ZEBRAS
An even clearer example’s the Horse2Zebra CycleGAN, which is what we’ll play with right now. In this case, the CycleGAN network was trained on a dataset of (unrelated) horse images and zebra images extracted from the ImageNet dataset. The network learns to take an image of one or more horses and turn them all into zebras, leaving the rest as unmodified as possible. While humankind hasn’t held its breath over the last few million years for a tool that turn horses into zebras, this task showcases the ability of these architectures to model complex real-world processes with distant supervision. While they have their limits, there are hints that in the future we won’t be able to tell real from fake from a live video feed, which opens a can of worms that we’ll duly close right now.
It’s time to play with a pre-trained CycleGAN. This gives us the opportunity to take a step closer and look at how a network, a generator in this case, is implemented. Let’s do it right away: this is what a possible generator architecture for the horse to zebra task looks like. In our case it’s our old friend ResNet. We’ll how show the full source code for the ResnetGenerator class, with the aim of demonstrating that it’s condensed for doing what it does. It takes an image, recognizes one or more horses in it by looking at the pixels and individually modifies the values of those pixels resulting in something that looks like a credible zebra. We won’t recognize anything like that in the source code, because it’s not in there; the network is a scaffold, the juice is in the weights.
Throughout the article we’ll walk ourselves through code piece by piece, trying to provide all the explanations for why things are a certain way. We’ll start off by breaking this rule. We don’t have the tools yet to understand the code in detail, but we can get a feel for what it’s like and what we’ll be able to create at the end of this journey.
# In: import torch import torch.nn as nn class ResNetBlock(nn.Module): def __init__(self, dim): super(ResNetBlock, self).__init__() self.conv_block = self.build_conv_block(dim) def build_conv_block(self, dim): conv_block =  conv_block += [nn.ReflectionPad2d(1)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim), nn.ReLU(True)] conv_block += [nn.ReflectionPad2d(1)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out class ResNetGenerator(nn.Module): def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): assert(n_blocks >= 0) super(ResNetGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True), nn.InstanceNorm2d(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True), nn.InstanceNorm2d(ngf * mult * 2), nn.ReLU(True)] mult = 2**n_downsampling for i in range(n_blocks): model += [ResNetBlock(ngf * mult)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), nn.InstanceNorm2d(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input)
Here we’ve declared two classes, a
ResNetGenerator and a
ResNetBlock. The latter is used from the former, and both of them derive from
nn.Module, which is the PyTorch way to specify a portion of a neural network, or in a more elegant way, a piece of differentiable computation. Every instance of
nn.Module can be called as a function, with the same arguments as specified in the
forward function –
input in this case.
Without delving into the details, we can recognize the building blocks, also named modules in PyTorch and commonly named layers in other frameworks, that make up the computation. We can spot linear functions, such as
Conv2d, whereby an input image is convolved with learned filters to produce an output, and non-linear functions, such as
ReLU. All these are instantiated, accumulated in a list,
model, and fed to a
nn.Sequential container. When called with an input, the latter invokes each contained module with the output of the preceding module as input. This is one of the ways in which models can be defined in PyTorch.
At this point we can instantiate the
ResNetGenerator class with the default parameters:
# In: netG = ResNetGenerator()
At this point, the model is created, but it contains garbage as weights. We mentioned earlier that we’d run a generator model which was pre-trained on the horse2zebra dataset. The weights of the model are saved in a
pth file, which is nothing but a
pickle file of the tensor parameters of the model. We can load those into our ResNetGenerator using the
load_state_dict method of
# In: model_path = 'horse2zebra_0.4.0.pth' model_data = torch.load(model_path) netG.load_state_dict(model_data)
At this point
netG acquired all the knowledge it achieved during training. Note that this is fully equivalent to what happened when we loaded the ResNet101 from
torchvision, only that the
torchvision.resnet101 function hid it from us.
Let’s put the network in
eval mode, as we did for ResNet101:
# In: netG.eval() # Out: ResNetGenerator( (model): Sequential( ... ) )
We’re ready to load some random image of a horse and see what our generator produces. First of all, we need to import
# In: from PIL import Image from torchvision import transforms
Then we define a few input transformations to make sure data enters the network with the right shape and size:
# In: preprocess = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])</programlisting> <simpara>Let’s open a horse file
Let’s open a horse file:
# In: img = Image.open("horse.jpg") img
Figure 3. A man riding a horse. A horse not having it.
Oh, there’s a dude on the horse. Not for long, judging by the picture. Anyhow, let’s pass it through preprocessing and turn it into a properly shaped variable:
# In: img_t = preprocess(img) batch_t = torch.unsqueeze(img_t, 0)
We shouldn’t worry about the details right now. The important thing is that we follow from a distance. At this point,
img_v can be sent to our model
# In: batch_out = netG(batch_t)
batch_out is now the output of the generator, which we can convert back to an image
# In: out_t = (batch_out.data.squeeze() + 1.0) / 2.0 out_img = transforms.ToPILImage()(out_t) # out_img.save('zebra.jpg') out_img # Out: <PIL.Image.Image image mode=RGB size=316x256 at 0x1C0C8E4C550>
Figure 4. A man riding a zebra. A zebra not having it.
Oh, man. Who rides a zebra that way? The resulting image isn’t perfect but considering how unusual for the network to find someone riding on top (sort of). It bears repeating that the learning process hasn’t passed through direct supervision, where humans have delineated tens of thousands of horses. The generator learned to produce an image that’d fool the discriminator into thinking that it’s a zebra and there’s nothing fishy with the image (clearly the discriminator has never been to a rodeo).
It’s hard to overstate the implications of this kind of work. Chances are we’ll see a lot of this technology in our future, probably in disparate aspects of our lives.
Numerous other fun generators were developed using adversarial training or with other approaches. Some of them are capable of creating credible human faces of non-existing individuals, while others can translate sketches into real looking pictures of imaginary landscapes. Generative models are also being explored for producing real sounding audio, credible text or enjoyable music. It’s likely that these models will be at the basis of future tools that support the creative process.
This far we’ve had a chance to play with a model that sees into images and a model that generates new images. We’ll end our tour with a model that involves one more, fundamental ingredient: natural language.
A pre-trained network that describes scenes
In order to get first-hand experience with a model involving natural language, we’ll use a pre-trained image captioning model, generously provided by Ruotian Luo and implemented after the work on NeuralTalk2 by Andrej Karpathy. We maintain a clone of the code at [REF]. This kind of models generates a caption in current English describing a scene when presented with a natural image. Again, the interesting part is that the model is trained on a large dataset of images with their sentence description, e.g. “A Tabby cat’s leaning on a wooden table, with one paw on a laser mouse and the other on a black laptop” [REF paper].
Figure 5. Concept of a captioning model.
This captioning model has two connected halves. The first half of the network learns to generate “descriptive” numerical representations of the scene (Tabby cat, laser mouse, paw), which are then taken as input to the second half. That half is a recurrent neural network which generates a coherent sentence by putting those descriptions together. The whole architecture is trained end-to-end on image-caption pairs.
A few other proposed captioning models, specifically img2seq, of the seq2seq family, which are versatile kind of models specialized on encoding an input sequence (in this case a sequence of pixels) into a vector, which is then decoded into another sequence (a sequence of characters or words).
Back to the NeuralTalk2 model, we can find it at
github.com/deep-learning-with-pytorch/ImageCaptioning.pytorch. We can just place a set of images in the
data directory and run the following script
python eval.py --model ./data/FC/fc-model.pth --infos_path ./data/FC/fc-infos.pkl --image_folder ./data
Let’s try with our
horse.jpg image. It says “A person riding a horse on a beach”. Quite appropriate.
Now, for fun, let’s see if our CycleGAN can also fool this NeuralTalk2 model. Let’s add the
zebra.jpg image in the data folder and rerun the model: “A group of zebras are standing in a field.” Well, it got the animal right, but it saw more than one of them in the image. For sure this isn’t a pose that the network has ever seen in a zebra, nor did it ever see a rider on top of a zebra (with some spurious zebra patterns). In addition, it’s likely that zebras are depicted in groups in the training dataset, and there might be some bias that one could investigate. The captioning network hasn’t seen the rider either. Again, it’s probably for the same reason: the network hasn’t seen a rider on a zebra ever in the training dataset.
In any case, this is an impressive feat: we generated a fake image with an impossible situation and the captioning network was flexible enough to get the subject right.
We’d like to stress that something like this, which was extremely hard to achieve before the advent of deep learning, can be obtained with under a thousand lines of code, with a general-purpose architecture that knows nothing about horses or zebras, and a corpus of images and their descriptions (the MS COCO dataset, in this case). No hard-coded criterion or grammar—everything, including the sentence, is emerging from patterns in the data.
The network architecture in this last case was more complex than the ones we’ve seen earlier—it has a convolutional part and a recurrent part.
That’s where we will stop for now. And remember, this is just a taste of what PyTorch can do.
For more information about the book, check it out on liveBook for free here.