1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
   | import os.path import logging import torch
  from utils import utils_logger from utils import utils_image as util from models.network_rrdbnet import RRDBNet as net
 
  def main():
      utils_logger.logger_info('blind_sr_log', log_path='blind_sr_log.log')     logger = logging.getLogger('blind_sr_log')
      testsets = 'testsets'            testset_Ls = ['RealSRSet']  
      model_names = ['RRDB','ESRGAN','FSSR_DPED','FSSR_JPEG','RealSR_DPED','RealSR_JPEG']     model_names = ['BSRGAN']    
      save_results = True     sf = 4     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      for model_name in model_names:         if model_name in ['BSRGANx2']:             sf = 2         model_path = os.path.join('model_zoo', model_name+'.pth')          logger.info('{:>16s} : {:s}'.format('Model Name', model_name))
                   logger.info('{:>16s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))         torch.cuda.empty_cache()
                                     model = net(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=sf)  
          model.load_state_dict(torch.load(model_path), strict=True)         model.eval()         for k, v in model.named_parameters():             v.requires_grad = False         model = model.to(device)         torch.cuda.empty_cache()
          for testset_L in testset_Ls:
              L_path = os.path.join(testsets, testset_L)                          E_path = os.path.join(testsets, testset_L+'_results_x'+str(sf))             util.mkdir(E_path)
              logger.info('{:>16s} : {:s}'.format('Input Path', L_path))             logger.info('{:>16s} : {:s}'.format('Output Path', E_path))             idx = 0
              for img in util.get_image_paths(L_path):
                                                                     idx += 1                 img_name, ext = os.path.splitext(os.path.basename(img))                 logger.info('{:->4d} --> {:<s} --> x{:<d}--> {:<s}'.format(idx, model_name, sf, img_name+ext))
                  img_L = util.imread_uint(img, n_channels=3)                 img_L = util.uint2tensor4(img_L)                 img_L = img_L.to(device)
                                                                     img_E = model(img_L)
                                                                     img_E = util.tensor2uint(img_E)                 if save_results:                     util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'.png'))
 
  if __name__ == '__main__':
      main()
 
   |