导入库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
from models.resnet import ResNet
from models.unet import UNet
from models.skip import skip
import torch
import torch.optim

from utils.inpainting_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

PLOT = True
imsize = -1
dim_div_by = 64

设定对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
## Fig 6
# img_path = 'data/inpainting/vase.png'
# mask_path = 'data/inpainting/vase_mask.png'

## Fig 8
# img_path = 'data/inpainting/library.png'
# mask_path = 'data/inpainting/library_mask.png'

## Fig 7 (top)
img_path = 'data/inpainting/kate.png'
mask_path = 'data/inpainting/kate_mask.png'

# Another text inpainting example
# img_path = 'data/inpainting/peppers.png'
# mask_path = 'data/inpainting/peppers_mask.png'

NET_TYPE = 'skip_depth6' # one of skip_depth4|skip_depth2|UNET|ResNet

加载水印

1
2
img_pil, img_np = get_image(img_path, imsize)
img_mask_pil, img_mask_np = get_image(mask_path, imsize)
1
2
3
4
5
img_mask_pil = crop_image(img_mask_pil, dim_div_by)
img_pil = crop_image(img_pil, dim_div_by)

img_np = pil_to_np(img_pil)
img_mask_np = pil_to_np(img_mask_pil)
1
2
3
img_mask_var = np_to_torch(img_mask_np).type(dtype)

plot_image_grid([img_np, img_mask_np, img_mask_np*img_np], 3,11);

image-20220816161249350

设置参数

1
2
3
pad = 'reflection' # 'zero'
OPT_OVER = 'net'
OPTIMIZER = 'adam'
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
if 'vase.png' in img_path:
INPUT = 'meshgrid'
input_depth = 2
LR = 0.01
num_iter = 5001
param_noise = False
show_every = 50
figsize = 5
reg_noise_std = 0.03

net = skip(input_depth, img_np.shape[0],
num_channels_down = [128] * 5,
num_channels_up = [128] * 5,
num_channels_skip = [0] * 5,
upsample_mode='nearest', filter_skip_size=1, filter_size_up=3, filter_size_down=3,
need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)

elif ('kate.png' in img_path) or ('peppers.png' in img_path):
# Same params and net as in super-resolution and denoising
INPUT = 'noise'
input_depth = 32
LR = 0.01
num_iter = 6001
param_noise = False
show_every = 50
figsize = 5
reg_noise_std = 0.03

net = skip(input_depth, img_np.shape[0],
num_channels_down = [128] * 5,
num_channels_up = [128] * 5,
num_channels_skip = [128] * 5,
filter_size_up = 3, filter_size_down = 3,
upsample_mode='nearest', filter_skip_size=1,
need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)

elif 'library.png' in img_path:

INPUT = 'noise'
input_depth = 1

num_iter = 3001
show_every = 50
figsize = 8
reg_noise_std = 0.00
param_noise = True

if 'skip' in NET_TYPE:

depth = int(NET_TYPE[-1])
net = skip(input_depth, img_np.shape[0],
num_channels_down = [16, 32, 64, 128, 128, 128][:depth],
num_channels_up = [16, 32, 64, 128, 128, 128][:depth],
num_channels_skip = [0, 0, 0, 0, 0, 0][:depth],
filter_size_up = 3,filter_size_down = 5, filter_skip_size=1,
upsample_mode='nearest', # downsample_mode='avg',
need1x1_up=False,
need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)

LR = 0.01

elif NET_TYPE == 'UNET':

net = UNet(num_input_channels=input_depth, num_output_channels=3,
feature_scale=8, more_layers=1,
concat_x=False, upsample_mode='deconv',
pad='zero', norm_layer=torch.nn.InstanceNorm2d, need_sigmoid=True, need_bias=True)

LR = 0.001
param_noise = False

elif NET_TYPE == 'ResNet':

net = ResNet(input_depth, img_np.shape[0], 8, 32, need_sigmoid=True, act_fun='LeakyReLU')

LR = 0.001
param_noise = False

else:
assert False
else:
assert False

net = net.type(dtype)
net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype)
1
2
3
4
5
6
7
8
9
# Compute number of parameters
s = sum(np.prod(list(p.size())) for p in net.parameters())
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_var = np_to_torch(img_np).type(dtype)
mask_var = np_to_torch(img_mask_np).type(dtype)

迭代

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
i = 0
def closure():

global i

if param_noise:
for n in [x for x in net.parameters() if len(x.size()) == 4]:
n = n + n.detach().clone().normal_() * n.std() / 50

net_input = net_input_saved
if reg_noise_std > 0:
net_input = net_input_saved + (noise.normal_() * reg_noise_std)


out = net(net_input)

total_loss = mse(out * mask_var, img_var * mask_var)
total_loss.backward()

print ('Iteration %05d Loss %f' % (i, total_loss.item()), '\r', end='')
if PLOT and i % show_every == 0:
out_np = torch_to_np(out)
plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)

i += 1

return total_loss

net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)
1
2
Starting optimization with ADAM
Iteration 00000 Loss 0.120503

image-20220816160857245

1
Iteration 00050    Loss 0.016066 

image-20220816161210277

1
Iteration 06000    Loss 0.000105

image-20220816161332492