1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| INPUT = 'noise' pad = 'reflection' OPT_OVER = 'net'
reg_noise_std = 1./30. LR = 0.01
OPTIMIZER='adam' show_every = 100 exp_weight=0.99
if fname == 'data/denoising/snail.jpg': num_iter = 2400 input_depth = 3 figsize = 5 net = skip( input_depth, 3, num_channels_down = [8, 16, 32, 64, 128], num_channels_up = [8, 16, 32, 64, 128], num_channels_skip = [0, 0, 0, 4, 4], upsample_mode='bilinear', need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')
net = net.type(dtype)
elif fname == 'data/denoising/F16_GT.png': num_iter = 3000 input_depth = 32 figsize = 4 net = get_net(input_depth, 'skip', pad, skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, upsample_mode='bilinear').type(dtype)
else: assert False net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()
s = sum([np.prod(list(p.size())) for p in net.parameters()]); print ('Number of params: %d' % s)
mse = torch.nn.MSELoss().type(dtype)
img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)
|