Pytorch入门实战:神经风格迁移

风格迁移,即获取两个图片(一张内容图片content-image、一张风格图片style-image),从而生成一张新的拥有style-image图像风格的内容图像。

在接下来的几篇博文中,我将用Gluon和Pytorch分别实现神经风格迁移。

本文需要掌握的几个用法

  1. torch.squeeze(dim):
    如果可以去掉dim维度,就去掉,如果不能去掉就不去。

  2. torch.unsqueeze(dim):
    在dim维度上新增一个维度。

  3. clone
    将tensor复制一份出来,但注意的是这个操作会记录到计算图中,梯度往cloned tensor反传时将会传到original tensor。

  4. torch.mm
    torch.mm(A, B)进行矩阵乘法。

  5. detach
    将tensor从计算图中分离出来。

  6. vgg19 = torchvision.models.vgg19(pretrained=True).features.to(device).eval()

    • 显现vgg由两部分组成,一个是特征提取的Squential,指的是前面的卷积层;另一个是分类的Squential,指的是后面的全连接层。这两个属性分别为features和classifier;
    • eval()表示模型在预测模式下。有的网络训练和预测的行为是不一样的。
  7. torch.clamp(src, min, max)
    将tensor的数值范围钳位在min和max之间,小于min的赋为min,大于max的赋为manx。

网络结构

网络使用预训练的vgg19,整个训练过程中,vgg19的参数不变。

  1. 训练的target是什么?
    对于content的target,我们可以简单的看成在某个卷积层的输出,具体就是content_img输入到vgg19上,在conv4层的输出;
    对于style的target,我们可以简单将它看成是像素点在每个通道的统计分布。例如要匹配两张图像的样式,我们可以匹配这两张图像在 RGB 这三个通道上的直方图。更一般的,假设卷积层的输出格式是 c×h×w,既(通道,高,宽)。那么我们可以把它变形成 c×hw 的二维数组,并将它看成是一个维度为 c 的随机变量采样到的 hw 个点。所谓的样式匹配就是使得两个 c 维随机变量统计分布一致。为了计算简单起见,我们只匹配二阶信息,即协方差。然后就是conv1、conv2、conv3、conv4、conv5各层输出做一个协方差。

  2. 输入是什么?
    我们将输入初始化为content_img,也可以随机初始化为白噪声,但白噪声一般收敛很慢。

  3. 损失是什么?
    损失就是我们输入通过vgg19的网络后,在content和sytle对应的各层产生的输出与target的差距。

  4. 既然vgg19的参数不变,那我们在训练什么?什么在一直改变?
    我们改变的是输入图像。

1
2
3
4
5
6
7
8
9
10
% matplotlib inline

import torch
import torchvision
import torchvision.transforms as transforms
import PIL
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import copy
1
2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

导入图像

  1. 定义一个转换器,把图像resize成一定大小,如果是gpu就是512,如果是cpu就是128,然后转换成tensor形式。

  2. 定义一个image_loader(path)的函数,使用PIL中的函数将图像导入,并应用1定义的转换器转换,最后弄成第一个维度为batch,批量大小为1。

  3. 得到的style_img和content_img的大小一定要相同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
imsize = 512 if torch.cuda.is_available() else 128
loader = transforms.Compose([
transforms.Resize((300, 500)),
transforms.ToTensor(),
])

def image_loader(path):
image = PIL.Image.open(path)
image = loader(image).unsqueeze(0).to(device)
return image

content_img = image_loader('./images/rainier.jpg')
style_img = image_loader('./images/autumn_oak.jpg')

assert content_img.shape == style_img.shape

定义一个imshow的函数来显示图像,图像从tensor格式变为正常的格式,注意:直接把image传入,这时可能我们在里面的操作会盖面image,所以我们要把image复制一份再操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
unloader = transforms.ToPILImage()

def imshow(image, title=None):
image = image.clone().to("cpu")
image = unloader(image.squeeze(0))
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001)

plt.figure()
imshow(content_img, 'content image')

plt.figure()
imshow(style_img, 'style image')

png

png

定义损失函数

我们有两种损失,content loss 和 style loss。我们将这两种损失都在加载vgg19的网络中,相当于在两个层之间加一个loss层,但这个loss层不改变输入输出,我们只是在里面计算一下loss保存起来而已。

两种loss都采用均方误差。对于content来说,直接是conv的输出的均方误差;对于style来说,要计算gamma矩阵。

定义两个 loss 的 class,定义一个计算 gamma 矩阵的函数

1
2
3
4
5
6
7
8
9
class ContentLoss(nn.Module):
def __init__(self, target):
super().__init__()
self.target = target.detach()

def forward(self, input):
self.loss = F.mse_loss(input, self.target)
# print('cl type', type(self.loss))
return input
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def gram_matrix(input):
a, b, c, d = input.size()
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())
return G.div(a * b * c * d)

class StyleLoss(nn.Module):
def __init__(self, target):
super().__init__()
self.target = gram_matrix(target).detach()

def forward(self, input):
self.loss = F.mse_loss(gram_matrix(input), self.target)
# print('sl type', type(self.loss))
return input

导入模型

  1. 获取vgg19

  2. 定义一个层,这个层将输入转为均值为[0.485, 0.456, 0.406],标准差为[0.229, 0.224, 0.225]

  3. 定义一个函数get_model_and_losses,这个函数用来获取model和loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
vgg19 = torchvision.models.vgg19(pretrained=True).features.to(device).eval()

normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
normalization_std = torch.tensor([0.229, 0.224, 0.224]).to(device)

class Normalization(nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)

def forward(self, img):
return (img - self.mean) / self.std


style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers_default = ['conv_4']

def get_model_and_losses(cnn, content_img, style_img, content_layers=content_layers_default,
style_layers=style_layers_default):

model = nn.Sequential()
model.add_module('norm', Normalization(normalization_mean, normalization_std))

# 这里deepcopy的作用
cnn = copy.deepcopy(cnn)
style_losses = []
content_losses = []
i = 0
# print(cnn)
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
name = 'relu_{}'.format(i)
# 试下把这个去掉的效果
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = 'maxpool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

model.add_module(name, layer)
if name in style_layers:
target = model(style_img).detach()
style_loss = StyleLoss(target)
model.add_module('style_loss_{}'.format(i), style_loss)
style_losses.append(style_loss)

if name in content_layers:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module('content_loss_{}'.format(i), content_loss)
content_losses.append(content_loss)

for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break

model = model[: (i + 1)]
return model, style_losses, content_losses
1
2
3
4
5
6
# 使用content image作为输入
input_img = content_img.clone()
# 使用白噪声作为输入
# input_img = torch.randn(content_img.shape, device=device)
plt.figure()
imshow(input_img, title="Input Image")

png

梯度下降

定义一个梯度下降的优化器,注意这个优化器下降的参数是imput_img。

1
2
3
def get_input_optimizer(input_img):
optimizer = torch.optim.LBFGS([input_img.requires_grad_()])
return optimizer

定义一个run_style_transfer的函数,我们的训练函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def run_style_transfer(cnn, normalize_mean, normalize_std, 
content_img, style_img, input_img,
num_steps=300, style_weight=10000000,
content_weight=1):
model, style_losses, content_losses = get_model_and_losses(cnn, content_img, style_img)
optimizer = get_input_optimizer(input_img)
run = [0]
while run[0] <= num_steps:
def closure():
# 随着输入的盖面,input可能已经不是0到1,这时候把它强制钳位
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = 0
content_score = 0

for sl in style_losses:
style_score += sl.loss
for cl in content_losses:
content_score += cl.loss

style_score *= style_weight
content_score *= content_weight

loss = style_score + content_score
loss.backward()

run[0] += 1
if run[0] % 100 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
style_score.item(), content_score.item()))
print()

return style_score + content_score

optimizer.step(closure)

input_img.data.clamp_(0, 1)
return input_img

训练

调训练函数进行训练

1
2
3
4
5
6
7
8
9
output = run_style_transfer(vgg19, normalization_mean, normalization_std,
content_img, style_img, input_img, num_steps=1500)

plt.figure()
imshow(output, title='Output Image')

# sphinx_gallery_thumbnail_number = 4
plt.ioff()
plt.show()
run [100]:
Style Loss : 3233.703857 Content Loss: 49.690994

run [200]:
Style Loss : 744.537781 Content Loss: 54.276138

run [300]:
Style Loss : 243.069855 Content Loss: 56.419792

run [400]:
Style Loss : 130.609573 Content Loss: 57.134781

run [500]:
Style Loss : 47.136871 Content Loss: 56.989006

run [600]:
Style Loss : 28.142462 Content Loss: 56.046928

run [700]:
Style Loss : 13.646354 Content Loss: 55.086681

run [800]:
Style Loss : 8.587596 Content Loss: 53.759434

run [900]:
Style Loss : 6.053142 Content Loss: 52.145359

run [1000]:
Style Loss : 4.736339 Content Loss: 50.745125

run [1100]:
Style Loss : 4.154952 Content Loss: 49.411102

run [1200]:
Style Loss : 808.660522 Content Loss: 47.144882

run [1300]:
Style Loss : 3.039961 Content Loss: 47.177006

run [1400]:
Style Loss : 2.938888 Content Loss: 46.098812

run [1500]:
Style Loss : 2.799488 Content Loss: 45.275436

png

本文参考NEURAL TRANSFER USING PYTORCH

持续技术分享,您的支持将鼓励我继续创作!