导入库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
from models import *

import torch
import torch.optim

from utils.denoising_utils import *
from utils.sr_utils import load_LR_HR_imgs_sr
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

imsize =-1
PLOT = True

加载图片

1
2
3
4
5
6
7
8
imgs = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_00_flash.jpg', -1, 1, enforse_div32='CROP')
img_flash = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_00_flash.jpg', -1, 1, enforse_div32='CROP')['HR_pil']
img_flash_np = pil_to_np(img_flash)

img_noflash = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_01_noflash.jpg', -1, 1, enforse_div32='CROP')['HR_pil']
img_noflash_np = pil_to_np(img_noflash)

g = plot_image_grid([img_flash_np, img_noflash_np],3,12)

image-20220808164442611

设置参数

1
2
3
4
5
6
7
8
9
10
11
12
13
pad = 'reflection'
OPT_OVER = 'net'

num_iter = 601
LR = 0.1
OPTIMIZER = 'adam'
reg_noise_std = 0.0
show_every = 50
figsize = 6

# We will use flash image as input
input_depth = 3
net_input =np_to_torch(img_flash_np).type(dtype)
1
2
3
4
5
6
7
8
9
10
net = skip(input_depth, 3, num_channels_down = [128, 128, 128, 128, 128], 
num_channels_up = [128, 128, 128, 128, 128],
num_channels_skip = [4, 4, 4, 4, 4],
upsample_mode=['nearest', 'nearest', 'bilinear', 'bilinear', 'bilinear'],
need_sigmoid=True, need_bias=True, pad=pad).type(dtype)

mse = torch.nn.MSELoss().type(dtype)

img_flash_var = np_to_torch(img_flash_np).type(dtype)
img_noflash_var = np_to_torch(img_noflash_np).type(dtype)

迭代

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()


i = 0
def closure():

global i, net_input

if reg_noise_std > 0:
net_input = net_input_saved + (noise.normal_() * reg_noise_std)

out = net(net_input)

total_loss = mse(out, img_noflash_var)
total_loss.backward()

print ('Iteration %05d Loss %f' % (i, total_loss.item()), '\r', end='')
if PLOT and i % show_every == 0:
out_np = torch_to_np(out)
plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)

i += 1

return total_loss

p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

image-20220808164725401

image-20220808164745556

image-20220808164805716

1
2
out_np = torch_to_np(net(net_input))
q = plot_image_grid([np.clip(out_np, 0, 1), img_noflash_np], factor=13);

image-20220808165045267