导入库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
from models import *

import torch
import torch.optim

from utils.feature_inversion_utils import *
from utils.perceptual_loss.perceptual_loss import get_pretrained_net
from utils.common_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

PLOT = True
fname = './data/feature_inversion/building.jpg'

pretrained_net = 'alexnet_caffe' # 'alexnet_caffe' # 'vgg19_caffe'
layers_to_use = 'fc6' # comma-separated string of layer names e.g. 'fc6,fc7'

载入预训练网络

1
2
3
4
5
6
7
8
9
10
11
12
# cnn = get_pretrained_net(pretrained_net).type(dtype)
cnn = torch.load('./data/feature_inversion/alexnet-torch_py3.pth').type(dtype)

opt_content = {'layers': layers_to_use, 'what':'features'}

# Remove the layers we don't need
keys = [x for x in cnn._modules.keys()]
max_idx = max(keys.index(x) for x in opt_content['layers'].split(','))
for k in keys[max_idx+1:]:
cnn._modules.pop(k)

print(cnn)

其中,源码第一行

1
cnn = get_pretrained_net(pretrained_net).type(dtype)

中的get_pretrained_net函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_pretrained_net(name):
"""Loads pretrained network"""
if name == 'alexnet_caffe':
if not os.path.exists('alexnet-torch_py3.pth'):
print('Downloading AlexNet')
os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
# os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
return torch.load('alexnet-torch_py3.pth')
elif name == 'vgg19_caffe':
if not os.path.exists('vgg19-caffe-py3.pth'):
print('Downloading VGG-19')
os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download')

vgg = get_vgg19_caffe()

return vgg
elif name == 'vgg16_caffe':
if not os.path.exists('vgg16-caffe-py3.pth'):
print('Downloading VGG-16')
os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download')

vgg = get_vgg16_caffe()

return vgg
elif name == 'vgg19_pytorch_modified':
# os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1')

model = VGGModified(vgg19(pretrained=False), 0.2)
model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict'])

return model
else:
assert False

其中,例如第6行及其类似语句

1
os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')

无法正确执行,所以我弃用这个函数,改为本地导入,如下:

1
cnn = torch.load('./data/feature_inversion/alexnet-torch_py3.pth').type(dtype)

最后能正确执行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.normalization.CrossMapLRN2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Sequential(
(conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
(relu1): ReLU()
(norm1): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
(pool1): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
(conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2)
(relu2): ReLU()
(norm2): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
(pool2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
(conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3): ReLU()
(conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
(relu4): ReLU()
(conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
(relu5): ReLU()
(pool5): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
(torch_view): View()
(fc6): Linear(in_features=9216, out_features=4096, bias=True)
)

载入图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Target imsize 
imsize = 227 if pretrained_net == 'alexnet' else 224

# Something divisible by a power of two
imsize_net = 256

# VGG and Alexnet need input to be correctly normalized
preprocess, deprocess = get_preprocessor(imsize), get_deprocessor()


img_content_pil, img_content_np = get_image(fname, imsize)
img_content_prerocessed = preprocess(img_content_pil)[None,:].type(dtype)

img_content_pil

image-20220802134514886

设置匹配器和网络

1
2
3
4
matcher_content = get_matcher(cnn, opt_content)

matcher_content.mode = 'store'
cnn(img_content_prerocessed);
1
2
3
4
5
6
7
8
9
10
INPUT = 'noise'
pad = 'zero' # 'refection'
OPT_OVER = 'net' #'net,input'
OPTIMIZER = 'adam' # 'LBFGS'
LR = 0.001

num_iter = 3100

input_depth = 32
net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()
1
2
3
4
5
6
7
8
9
10
net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],
num_channels_up = [16, 32, 64, 128, 128, 128],
num_channels_skip = [4, 4, 4, 4, 4, 4],
filter_size_down = [7, 7, 5, 5, 3, 3], filter_size_up = [7, 7, 5, 5, 3, 3],
upsample_mode='nearest', downsample_mode='avg',
need_sigmoid=True, pad=pad, act_fun='LeakyReLU').type(dtype)

# Compute number of parameters
s = sum(np.prod(list(p.size())) for p in net.parameters())
print ('Number of params: %d' % s)

迭代

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def closure():

global i

out = net(net_input)[:, :, :imsize, :imsize]

cnn(vgg_preprocess_var(out))
total_loss = sum(matcher_content.losses.values())
total_loss.backward()

print ('Iteration %05d Loss %.3f' % (i, total_loss.item()), '\r', end='')
if PLOT and i % 200 == 0:
out_np = np.clip(torch_to_np(out), 0, 1)
plot_image_grid([out_np], 3, 3);

i += 1

return total_loss
1
2
3
4
i=0
matcher_content.mode = 'match'
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

由于计算机性能问题,只能跑一段有限的时间,得到下面结果,loss = 0.112

image-20220802141826939