RGB v/s $L*a*b$¶
When we load an image, we get a rank-3 tensor with the last axis containing the color data for the image. These data represent color in RGB color space and there are 3 numbers for each pixel indicating how much Red, Green, and Blue the pixel is.
In $L*a*b$ color space, we have again three dimensions for each pixel. The first dimension (channel) $L$, encodes the lightness of each pixel and when we visualize this channel it appears as black and white image. The $*a$ and $*b$ channels encode how much green-red and yellow-blue each pixel is, respectively.
To train a model for colorization, we should it give it a grayscale image and hope that it will make it colorful. When using $L*a*b$, we can give the $L$ channel to the model and want it to predict the other channels ($*a$, $*b$) and after its prediction, we concatenate all the channels and we get our colorful image. But if we use RGB image, we have to first convert it to grayscale, feed the grayscale to the model and hope it will predict 3 numbers for which is way more difficult and unstable task due to many more possible combinations of 3 numbers compared to two numbers. If we assume we have 256 choices for each number, predicting the three numbers for each of the pixels is choosing between $256^3$ combinations which is more than 16M choices, but when predicting two numbers we have about 65K choices.
How to solve the problem?¶
Colorful Image Colorization paper approached the problem as a classification task and they have also considered the uncertainty of this problem (e.g- a car in the image can take on many different and valid colors and we cannot be sure about any color about it); however, another paper approached the problem as a regression task.
Current Implementation¶
Image-to-Image Translation with Conditional Adversarial Networks paper, also known as pix2pix, proposed a general solution to many image-to-image tasks in deep learning which one of those was colorization. In this approach two losses are used: L1 loss, which makes it a regression task, and an adversarial (GAN) loss, which helps to solve the problem in an unsupervised manner.
GAN¶
In a GAN we have a generator and a discriminator model which learn to solve a problem together. In our setting, the generator model takes a grayscale image ($L$ channel) and produces a 2-channel image ($*a$, $*b$). The discriminator, takes these two produced channels and concatenates them with the input grayscale image and decides whether this new 3-channel image is fake or real.
The grayscale image which both the generator and discriminator see is the condition that we provide to both the models in our GAN and expect that they take this condition into consideration.
Consider $x$ as the grayscale image, $z$ as the input noise for the generator, and $y$ as the 2-channel output we want from the generator. Also, $G$ is the generator and $D$ is the discriminator. Then the loss for our conditional GAN will be:
$$\mathcal{L}_{cGAN}(G, D) = \mathbb{E}_{x, y}[log D(x, y)] + \mathbb{E}_{x, z}[log(1-D(x, G(x, z))]$$
Loss function we optimize¶
The earlier loss function helps to produce good-looking colorful images that seem real, but to further help the models and introduce some supervision in our task, we combine this loss function with $L1$ loss of the predicted colors compared with the actual colors:
$$\mathcal{L}_{L1}(G) = \mathbb{E}_{x, y, z}[||y-G(x, z)||_1]$$
If we use L1 loss alone, the model still learns to colorize the images but it will be conservative and most of the time uses colors like "gray" or "brown" because when it doubts which color is the best, it takes the average and uses these colors to reduce the L1 loss as much as possible (similar to blurring effect of L1 or L2 loss in super resolution task). Also, the L1 loss is preferred over L2 loss (or mean squared error) because it reduces that effect of producing gray-ish images. So, our combined loss function will be:
$$G^* = arg \space \underset{G}{\mathrm{min}} \space \underset{D}{\mathrm{max}} \mathcal{L}_{cGAN}(G, D) + \lambda \mathcal{L}_{L1}(G)$$
where $\lambda$ is a coefficient to balance the contribution of the two losses to the final loss (of course the discriminator loss does not involve the L1 loss).
Implementing the paper-Baseline¶
%%capture
!pip install -y scikit-image fastai
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
from fastai.data.external import untar_data, URLs
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = os.path.join(str(coco_path), 'train_sample')
paths = glob.glob(coco_path + '/*.jpg')
np.random.seed(42)
paths_subset = np.random.choice(paths, 10000, replace=False)
rand_idxs = np.random.permutation(10000)
train_idxs = rand_idxs[:8000]
val_idxs = rand_idxs[8000:]
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(len(train_paths), len(val_paths))
_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths):
ax.imshow(Image.open(img_path))
ax.axis('off')
SIZE = 256
class ColorizationDataset(Dataset):
def __init__(self, paths, split='train'):
if split == 'train':
self.transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE), Image.BICUBIC),
transforms.RandomHorizontalFlip()
])
elif split == 'val':
self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
self.split = split
self.size = SIZE
self.paths = paths
def __getitem__(self, index):
img = Image.open(self.paths[index]).convert('RGB')
img = self.transforms(img)
img = np.array(img)
img_lab = rgb2lab(img).astype('float32')
img_lab = transforms.ToTensor()(img_lab)
L = img_lab[[0], ...] / 50. - 1.
ab = img_lab[[1, 2], ...] / 110.
return {'L': L, 'ab': ab}
def __len__(self):
return len(self.paths)
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
dataset = ColorizationDataset(**kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, pin_memory=pin_memory)
return dataloader
train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')
data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))
class UNetBlock(nn.Module):
def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False, innermost=False, outermost=False):
super().__init__()
self.outermost = outermost
if input_c is None:
input_c = nf
downconv = nn.Conv2d(input_c, ni, kernel_size=4, stride=2, padding=1, bias=False)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = nn.BatchNorm2d(ni)
uprelu = nn.ReLU(True)
upnorm = nn.BatchNorm2d(nf)
if outermost:
upconv = nn.ConvTranspose2d(ni*2, nf, kernel_size=4, stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4, stride=2, padding=1, bias=False)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(ni*2, nf, kernel_size=4, stride=2, padding=1, bias=False)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if dropout:
up += [nn.Dropout(0.5)]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
class UNet(nn.Module):
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
super().__init__()
unet_block = UNetBlock(num_filters*8, num_filters*8, innermost=True)
for _ in range(n_down-5):
unet_block = UNetBlock(num_filters*8, num_filters*8, submodule=unet_block, dropout=True)
out_filters = num_filters*8
for _ in range(3):
unet_block = UNetBlock(out_filters//2, out_filters, submodule=unet_block)
out_filters //= 2
self.model = UNetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
def forward(self, x):
return self.model(x)
In a vanilla discriminator, the model outputs a scalar quantity which represents how much the model thinks the input is real or fake. In a patch discriminator, the model outputs one number for every patch of the input image and for each of them decides whether it is fake or not seperately. Using such a model for the task of colorization seems reasonable to me because the local changes that the model needs to make are really important and maybe deciding on the whole image as in vanilla discriminator cannot take care of the subtleties of this task.
class PatchDiscriminator(nn.Module):
def __init__(self, input_c, num_filters=64, n_down=3):
super().__init__()
model = [self.get_layers(input_c, num_filters, norm=False)]
model += [self.get_layers(num_filters*2**i, num_filters*2**(i+1), s=1 if i == (n_down-1) else 2) for i in range(n_down)]
model += [self.get_layers(num_filters*2**n_down, 1, s=1, norm=False, act=False)]
self.model = nn.Sequential(*model)
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]
if norm:
layers += [nn.BatchNorm2d(nf)]
if act:
layers += [nn.LeakyReLU(0.2, True)]
return nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
PatchDiscriminator(3)
discriminator = PatchDiscriminator(3)
dummy_input = torch.randn(16, 3, 256, 256)
out = discriminator(dummy_input)
out.shape
class GANLoss(nn.Module):
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
super().__init__()
self.register_buffer('real_label', torch.tensor(real_label))
self.register_buffer('fake_label', torch.tensor(fake_label))
if gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'lsgan':
self.loss = nn.MSELoss()
def get_labels(self, preds, target_is_real):
if target_is_real:
labels = self.real_label
else:
labels = self.fake_label
return labels.expand_as(preds)
def __call__(self, preds, target_is_real):
labels = self.get_labels(preds, target_is_real)
loss = self.loss(preds, labels)
return loss
def init_weights(net, init='norm', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and 'Conv' in classname:
if init == 'norm':
nn.init.normal_(m.weight.data, mean=0.0, std=gain)
elif init == 'xavier':
nn.init.xavier_normal(m.weight.data, gain=gain)
elif init == 'kaiming':
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif 'BatchNorm2d' in classname:
nn.init.normal_(m.weight.data, 1., gain)
nn.init.constant_(m.bias.data, 0.)
net.apply(init_func)
print(f'Model initialized with {init} initialization')
return net
def init_model(model, device):
model = model.to(device)
model = init_weights(model)
return model
class MainModel(nn.Module):
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, beta1=0.5, beta2=0.999, lambda_L1=100.):
super().__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.lambda_L1 = lambda_L1
if net_G is None:
self.net_G = init_model(UNet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
else:
self.net_G = net_G.to(self.device)
self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
self.L1criterion = nn.L1Loss()
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
def set_requires_grad(self, model, requires_grad=True):
for p in model.parameters():
p.requires_grad = requires_grad
def setup_input(self, data):
self.L = data['L'].to(self.device)
self.ab = data['ab'].to(self.device)
def forward(self):
self.fake_color = self.net_G(self.L)
def backward_D(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image.detach())
self.loss_D_fake = self.GANcriterion(fake_preds, False)
real_image = torch.cat([self.L, self.ab], dim=1)
real_preds = self.net_D(real_image)
self.loss_D_real = self.GANcriterion(real_preds, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
def backward_G(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image)
self.loss_G_GAN = self.GANcriterion(fake_preds, True)
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
self.loss_G = self.loss_G_GAN + self.loss_G_L1
return self.loss_G.backward()
def optimize(self):
self.forward()
self.net_D.train()
self.set_requires_grad(self.net_D, True)
self.opt_D.zero_grad()
self.backward_D()
self.opt_D.step()
self.net_G.train()
self.set_requires_grad(self.net_D, False)
self.opt_G.zero_grad()
self.backward_G()
self.opt_G.step()
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
self.count, self.avg, self.sum = [0.] * 3
def update(self, val, count=1):
self.count += count
self.sum += count * val
self.avg = self.sum / self.count
def create_loss_meters():
loss_D_fake = AverageMeter()
loss_D_real = AverageMeter()
loss_D = AverageMeter()
loss_G_GAN = AverageMeter()
loss_G_L1 = AverageMeter()
loss_G = AverageMeter()
return {
'loss_D_fake': loss_D_fake,
'loss_D_real': loss_D_real,
'loss_D': loss_D,
'loss_G_GAN': loss_G_GAN,
'loss_G_L1': loss_G_L1,
'loss_G': loss_G
}
def update_losses(model, loss_meter_dict, count):
for loss_name, loss_meter in loss_meter_dict.items():
loss = getattr(model, loss_name)
loss_meter.update(loss.item(), count=count)
def lab_to_rgb(L, ab):
L = (L + 1.) * 50.
ab *= 110.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
def visualize(model, data, save=True):
model.net_G.eval()
with torch.no_grad():
model.setup_input(data)
model.forward()
model.net_G.train()
fake_color = model.fake_color.detach()
real_color = model.ab
L = model.L
fake_imgs = lab_to_rgb(L, fake_color)
real_imgs = lab_to_rgb(L, real_color)
fig = plt.figure(figsize=(15, 8))
for i in range(5):
ax = plt.subplot(3, 5, i+1)
ax.imshow(L[i][0].cpu(), cmap='gray')
ax.axis('off')
ax = plt.subplot(3, 5, i+1+5)
ax.imshow(fake_imgs[i])
ax.axis('off')
ax = plt.subplot(3, 5, i+1+10)
ax.imshow(real_imgs[i])
ax.axis('off')
plt.show()
if save:
fig.savefig(f'colorization_{time.time()}.png')
def log_results(loss_meter_dict):
for loss_name, loss_meter in loss_meter_dict.items():
print(f'{loss_name}: {loss_meter.avg:.5f}')
def train_model(model, train_dl, epochs, display_every=200):
data = next(iter(val_dl))
for e in range(epochs):
loss_meter_dict = create_loss_meters()
i = 0
for data in tqdm(train_dl):
model.setup_input(data)
model.optimize()
update_losses(model, loss_meter_dict, count=data['L'].size(0))
i += 1
if i % display_every == 0:
print(f'\nEpoch {e+1}/{epochs}')
print(f'Iteration {i}/{len(train_dl)}')
log_results(loss_meter_dict)
visualize(model, data, save=True)
model = MainModel()
train_model(model, train_dl, 20)
New Strategy:¶
Inspired by an idea in Super Resolution paper, we will pretrain the generator separately in a supervised and deterministic manner to avoid the problem of "the blind leading the blind" in the GAN where neither the generator nor discriminator knows anything about the task at the beginning of training.
Pretraining can be done in 2 stages:
- The backbone of the generator (the down sampling path) is a pretrained model for classification (on ImageNet)
- The whole generator will be pretrained on the task of colorization with L1 loss.
We will use a pretrained ResNet18 as the backbone of the U-Net and to accomplish the second stage of pretraining, we will train the U-Net on our training set with only L1 loss. Then we will move to the combined adversarial and L1 loss.
from fastai.vision.learner import create_body
from torchvision.models import resnet18
from fastai.vision.models.unet import DynamicUnet
def build_res_unet(n_input=1, n_output=2, size=256):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
return net_G
def pretrained_generator(net_G, train_dl, opt, criterion, epochs):
for e in range(epochs):
loss_meter = AverageMeter()
for data in tqdm(train_dl):
L, ab = data['L'].to(device), data['ab'].to(device)
preds = net_G(L)
loss = criterion(preds, ab)
opt.zero_grad()
loss.backward()
opt.step()
loss_meter.update(loss.item(), L.size(0))
print(f'Epoch {e+1}/{epochs}')
print(f'L1 loss: {loss_meter.avg:.5f}')
net_G = build_res_unet(n_input=1, n_output=2, size=256)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()
pretrained_generator(net_G, train_dl, opt, criterion, 20)
torch.save(net_G.state_dict(), 'res18-unet.pt')
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load('res18-unet.pt', map_location=device))
model = MainModel(net_G=net_G)
train_model(model, train_dl, 20)
Comparing the results of the pretrained U-Net and without adversarial training¶
U-Net we built with the ResNet18 backbone is performing well in colorizing images after pretraining with L1 loss only (a step before the final adversarial training). But, the model is still conservative and encourages using gray-ish colors when it is not sure about what the object is or what color it should be. However, it performs really well for common scenes in the images like sky, tree, grass etc.