diff --git a/codes/data/__init__.py b/codes/data/__init__.py index e4aa847c..aac80e77 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -36,13 +36,8 @@ def create_dataset(dataset_opt): # datasets for image corruption elif mode == 'downsample': from data.Downsample_dataset import DownsampleDataset as D - # datasets for video restoration - elif mode == 'REDS': - from data.REDS_dataset import REDSDataset as D - elif mode == 'Vimeo90K': - from data.Vimeo90K_dataset import Vimeo90KDataset as D - elif mode == 'video_test': - from data.video_test_dataset import VideoTestDataset as D + elif mode == 'fullimage': + from data.full_image_dataset import FullImageDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index b4b2e03e..f0c99ebf 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -58,7 +58,7 @@ class FullImageDataset(data.Dataset): h, w, _ = image.shape if h == w: return image - offset = min(np.random.normal(scale=.3), 1.0) + offset = max(min(np.random.normal(scale=.3), 1.0), -1.0) if h > w: diff = h - w center = diff // 2 @@ -75,6 +75,14 @@ class FullImageDataset(data.Dataset): margin_center = margin_sz // 2 return min(max(int(min(np.random.normal(scale=dev), 1.0) * margin_sz + margin_center), 0), margin_sz) + def resize_point(self, point, orig_dim, new_dim): + oh, ow = orig_dim + nh, nw = new_dim + dh, dw = float(nh) / float(oh), float(nw) / float(ow) + point[0] = int(dh * float(point[0])) + point[1] = int(dw * float(point[1])) + return point + # - Randomly extracts a square from image and resizes it to opt['target_size']. # - Fills a mask with zeros, then places 1's where the square was extracted from. Resizes this mask and the source # image to the target_size and returns that too. @@ -83,11 +91,10 @@ class FullImageDataset(data.Dataset): # half-normal distribution, biasing towards the target_size. # - A biased normal distribution is also used to bias the tile selection towards the center of the source image. def pull_tile(self, image): - target_sz = self.opt['target_size'] + target_sz = self.opt['min_tile_size'] h, w, _ = image.shape possible_sizes_above_target = h - target_sz square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.1)), 1.0)) - print("Square size: %i" % (square_size,)) # Pick the left,top coords to draw the patch from left = self.pick_along_range(w, square_size, .3) top = self.pick_along_range(w, square_size, .3) @@ -95,12 +102,14 @@ class FullImageDataset(data.Dataset): mask = np.zeros((h, w, 1), dtype=np.float) mask[top:top+square_size, left:left+square_size] = 1 patch = image[top:top+square_size, left:left+square_size, :] + center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + center = self.resize_point(center, (h, w), image.shape[:2]) - return patch, image, mask + return patch, image, mask, center def augment_tile(self, img_GT, img_LQ, strength=1): scale = self.opt['scale'] @@ -145,16 +154,22 @@ class FullImageDataset(data.Dataset): return img_LQ def __getitem__(self, index): - GT_path, LQ_path = None, None scale = self.opt['scale'] - GT_size = self.opt['target_size'] # get full size image full_path = self.paths_GT[index % len(self.paths_GT)] + LQ_path = full_path img_full = util.read_img(None, full_path, None) - img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0] - img_full = self.get_square_image(img_full) - img_GT, gt_fullsize_ref, gt_mask = self.pull_tile(img_full) + img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0] + if self.opt['phase'] == 'train': + img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0] + img_full = self.get_square_image(img_full) + img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full) + else: + img_GT, gt_fullsize_ref = img_full, img_full + gt_mask = np.ones(img_full.shape[:2]) + gt_center = torch.tensor([img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long) + orig_gt_dim = gt_fullsize_ref.shape[:2] # get LQ image if self.paths_LQ: @@ -162,11 +177,16 @@ class FullImageDataset(data.Dataset): img_lq_full = util.read_img(None, LQ_path, None) img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] img_lq_full = self.get_square_image(img_lq_full) - img_LQ, lq_fullsize_ref, lq_mask = self.pull_tile(img_lq_full) + img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full) else: # down-sampling on-the-fly # randomly scale during training if self.opt['phase'] == 'train': + GT_size = self.opt['target_size'] random_scale = random.choice(self.random_scale_list) + if len(img_GT.shape) == 2: + print("ERRAR:") + print(img_GT.shape) + print(full_path) H_s, W_s, _ = img_GT.shape def _mod(n, random_scale, scale, thres): @@ -184,23 +204,34 @@ class FullImageDataset(data.Dataset): # using matlab imresize img_LQ = util.imresize_np(img_GT, 1 / scale, True) + lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True) if img_LQ.ndim == 2: img_LQ = np.expand_dims(img_LQ, axis=2) - lq_fullsize_ref, lq_mask = gt_fullsize_ref, gt_mask + lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2]) + orig_lq_dim = lq_fullsize_ref.shape[:2] - # Enforce force_resize constraints. + # Enforce force_resize constraints via clipping. h, w, _ = img_LQ.shape if h % self.force_multiple != 0 or w % self.force_multiple != 0: - h, w = (w - w % self.force_multiple), (h - h % self.force_multiple) - img_LQ = cv2.resize(img_LQ, (h, w)) + h, w = (h - h % self.force_multiple), (w - w % self.force_multiple) + img_LQ = img_LQ[:h, :w, :] + lq_fullsize_ref = lq_fullsize_ref[:h, :w, :] h *= scale w *= scale - img_GT = cv2.resize(img_GT, (h, w)) + img_GT = img_GT[:h, :w] + gt_fullsize_ref = gt_fullsize_ref[:h, :w, :] if self.opt['phase'] == 'train': img_GT, img_LQ = self.augment_tile(img_GT, img_LQ) gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2) - lq_mask = cv2.resize(lq_mask, img_LQ.shape[0:2], interpolation=cv2.INTER_LINEAR) + + # Scale masks. + lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) + gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) + + # Scale center coords + lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2]) + gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2]) # BGR to RGB, HWC to CHW, numpy to tensor if img_GT.shape[2] == 3: @@ -210,8 +241,9 @@ class FullImageDataset(data.Dataset): gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB) # LQ needs to go to a PIL image to perform the compression-artifact transformation. - img_LQ = self.pil_augment(img_LQ) - lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) + if self.opt['phase'] == 'train': + img_LQ = self.pil_augment(img_LQ) + lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float() @@ -226,19 +258,19 @@ class FullImageDataset(data.Dataset): lq_fullsize_ref += lq_noise # Apply the masks to the full images. - lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) + lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) - if LQ_path is None: - LQ_path = GT_path d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, - 'LQ_path': LQ_path, 'GT_path': GT_path} + 'lq_center': lq_center, 'gt_center': gt_center, + 'LQ_path': LQ_path, 'GT_path': full_path} return d def __len__(self): return len(self.paths_GT) if __name__ == '__main__': + ''' opt = { 'name': 'amalgam', 'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'], @@ -249,19 +281,32 @@ if __name__ == '__main__': 'use_rot': True, 'lq_noise': 5, 'target_size': 128, + 'min_tile_size': 256, 'scale': 2, 'phase': 'train' } + ''' + opt = { + 'name': 'amalgam', + 'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'], + 'dataroot_GT_weights': [1], + 'force_multiple': 32, + 'scale': 2, + 'phase': 'test' + } + ds = FullImageDataset(opt) import os os.makedirs("debug", exist_ok=True) - for i in range(1000): + for i in range(300, len(ds)): + print(i) o = ds[i] for k, v in o.items(): if 'path' not in k: - if 'full' in k: - masked = v[:3, :, :] * v[3] - torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) - v = v[:3, :, :] - import torchvision - torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file + #if 'full' in k: + #masked = v[:3, :, :] * v[3] + #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + #v = v[:3, :, :] + #import torchvision + #torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) + pass \ No newline at end of file diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index aca4bfd4..501d6f67 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -127,6 +127,11 @@ class ExtensibleTrainer(BaseModel): input_ref = data['ref'] if 'ref' in data else data['GT'] self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] + self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} + for k, v in data.items(): + if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): + self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.mega_batch_factor, dim=0)] + def optimize_parameters(self, step): self.env['step'] = step @@ -136,7 +141,7 @@ class ExtensibleTrainer(BaseModel): net.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) # Iterate through the steps, performing them one at a time. - state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} + state = self.dstate for step_num, s in enumerate(self.steps): # Only set requires_grad=True for the network being trained. nets_to_train = s.get_networks_trained() @@ -195,7 +200,7 @@ class ExtensibleTrainer(BaseModel): with torch.no_grad(): # Iterate through the steps, performing them one at a time. - state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} + state = self.dstate for step_num, s in enumerate(self.steps): ns = s.do_forward_backward(state, 0, step_num, backward=False) for k, v in ns.items(): diff --git a/codes/models/SPSR_model.py b/codes/models/SPSR_model.py deleted file mode 100644 index 7238349e..00000000 --- a/codes/models/SPSR_model.py +++ /dev/null @@ -1,458 +0,0 @@ -import os -import logging -from collections import OrderedDict - -import torch -import torch.nn as nn -from torch.optim import lr_scheduler -from apex import amp - -import models.networks as networks -from .base_model import BaseModel -from models.loss import GANLoss -import torchvision.utils as utils -from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding - -logger = logging.getLogger('base') - -class SPSRModel(BaseModel): - def __init__(self, opt): - super(SPSRModel, self).__init__(opt) - train_opt = opt['train'] - - # define networks and load pretrained models - self.netG = networks.define_G(opt).to(self.device) # G - if self.is_train: - self.netD = networks.define_D(opt).to(self.device) # D - self.netD_grad = networks.define_D(opt).to(self.device) # D_grad - self.netG.train() - self.netD.train() - self.netD_grad.train() - self.mega_batch_factor = 1 - self.load() # load G and D if needed - - # define losses, optimizer and scheduler - if self.is_train: - self.mega_batch_factor = train_opt['mega_batch_factor'] - - # G pixel loss - if train_opt['pixel_weight'] > 0: - l_pix_type = train_opt['pixel_criterion'] - if l_pix_type == 'l1': - self.cri_pix = nn.L1Loss().to(self.device) - elif l_pix_type == 'l2': - self.cri_pix = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) - self.l_pix_w = train_opt['pixel_weight'] - else: - logger.info('Remove pixel loss.') - self.cri_pix = None - - # G feature loss - if train_opt['feature_weight'] > 0: - l_fea_type = train_opt['feature_criterion'] - if l_fea_type == 'l1': - self.cri_fea = nn.L1Loss().to(self.device) - elif l_fea_type == 'l2': - self.cri_fea = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) - self.l_fea_w = train_opt['feature_weight'] - else: - logger.info('Remove feature loss.') - self.cri_fea = None - if self.cri_fea: # load VGG perceptual loss - self.netF = networks.define_F(use_bn=False).to(self.device) - - # GD gan loss - self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) - self.l_gan_w = train_opt['gan_weight'] - # D_update_ratio and D_init_iters are for WGAN - self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 - self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 - # Branch_init_iters - self.branch_pretrain = train_opt['branch_pretrain'] if train_opt['branch_pretrain'] else 0 - self.branch_init_iters = train_opt['branch_init_iters'] if train_opt['branch_init_iters'] else 1 - - # gradient_pixel_loss - if train_opt['gradient_pixel_weight'] > 0: - self.cri_pix_grad = nn.MSELoss().to(self.device) - self.l_pix_grad_w = train_opt['gradient_pixel_weight'] - else: - self.cri_pix_grad = None - - # gradient_gan_loss - if train_opt['gradient_gan_weight'] > 0: - self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) - self.l_gan_grad_w = train_opt['gradient_gan_weight'] - else: - self.cri_grad_gan = None - - # G_grad pixel loss - if train_opt['pixel_branch_weight'] > 0: - l_pix_type = train_opt['pixel_branch_criterion'] - if l_pix_type == 'l1': - self.cri_pix_branch = nn.L1Loss().to(self.device) - elif l_pix_type == 'l2': - self.cri_pix_branch = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) - self.l_pix_branch_w = train_opt['pixel_branch_weight'] - else: - logger.info('Remove G_grad pixel loss.') - self.cri_pix_branch = None - - # optimizers - # G - wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 - - optim_params = [] - for k, v in self.netG.named_parameters(): # optimize part of the model - - if v.requires_grad: - optim_params.append(v) - else: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ - weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) - self.optimizers.append(self.optimizer_G) - - # D - wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 - self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ - weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) - - self.optimizers.append(self.optimizer_D) - - # D_grad - wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 - self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \ - weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) - - self.optimizers.append(self.optimizer_D_grad) - - # AMP - [self.netG, self.netD, self.netD_grad], [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \ - amp.initialize([self.netG, self.netD, self.netD_grad], - [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad], - opt_level=self.amp_level, num_losses=3) - - # schedulers - if train_opt['lr_scheme'] == 'MultiStepLR': - for optimizer in self.optimizers: - self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ - train_opt['lr_steps'], train_opt['lr_gamma'])) - else: - raise NotImplementedError('MultiStepLR learning rate scheme is enough.') - - self.log_dict = OrderedDict() - self.get_grad = ImageGradient() - self.get_grad_nopadding = ImageGradientNoPadding() - - def feed_data(self, data, need_HR=True): - # LR - self.var_L = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)] - - if need_HR: # train or val - self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data else data['GT'] - self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref.to(self.device), chunks=self.mega_batch_factor, dim=0)] - - - - def optimize_parameters(self, step): - # Some generators have variants depending on the current step. - if hasattr(self.netG.module, "update_for_step"): - self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) - if hasattr(self.netD.module, "update_for_step"): - self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) - - # G - for p in self.netD.parameters(): - p.requires_grad = False - - for p in self.netD_grad.parameters(): - p.requires_grad = False - - if(self.branch_pretrain): - if(step < self.branch_init_iters): - for k,v in self.netG.named_parameters(): - if 'f_' not in k : - v.requires_grad=False - else: - for k,v in self.netG.named_parameters(): - if 'f_' not in k : - v.requires_grad=True - - self.optimizer_G.zero_grad() - - self.fake_H_branch = [] - self.fake_H = [] - self.grad_LR = [] - for var_L, var_H, var_ref in zip(self.var_L, self.var_H, self.var_ref): - fake_H_branch, fake_H, grad_LR = self.netG(var_L) - self.fake_H_branch.append(fake_H_branch.detach()) - self.fake_H.append(fake_H.detach()) - self.grad_LR.append(grad_LR.detach()) - - fake_H_grad = self.get_grad(fake_H) - var_H_grad = self.get_grad(var_H) - var_ref_grad = self.get_grad(var_ref) - var_H_grad_nopadding = self.get_grad_nopadding(var_H) - - l_g_total = 0 - if step % self.D_update_ratio == 0 and step > self.D_init_iters: - if self.cri_pix: # pixel loss - l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H) - l_g_total += l_g_pix - if self.cri_fea: # feature loss - real_fea = self.netF(var_H).detach() - fake_fea = self.netF(fake_H) - l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) - l_g_total += l_g_fea - - if self.cri_pix_grad: #gradient pixel loss - l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad) - l_g_total += l_g_pix_grad - - if self.cri_pix_branch: #branch pixel loss - l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch, var_H_grad_nopadding) - l_g_total += l_g_pix_grad_branch - - if self.l_gan_w > 0: - # G gan + cls loss - pred_g_fake = self.netD(fake_H) - pred_d_real = self.netD(var_ref).detach() - - l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + - self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 - l_g_total += l_g_gan - - if self.cri_grad_gan: - # grad G gan + cls loss - pred_g_fake_grad = self.netD_grad(fake_H_grad) - pred_d_real_grad = self.netD_grad(var_ref_grad).detach() - - l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) + - self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2 - l_g_total += l_g_gan_grad - - l_g_total /= self.mega_batch_factor - with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: - l_g_total_scaled.backward() - - if step % self.D_update_ratio == 0 and step > self.D_init_iters: - self.optimizer_G.step() - - - if self.l_gan_w > 0: - # D - for p in self.netD.parameters(): - p.requires_grad = True - - self.optimizer_D.zero_grad() - for var_ref, fake_H in zip(self.var_ref, self.fake_H): - pred_d_real = self.netD(var_ref) - pred_d_fake = self.netD(fake_H) # detach to avoid BP to G - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) - - l_d_total = (l_d_real + l_d_fake) / 2 - - l_d_total /= self.mega_batch_factor - with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: - l_d_total_scaled.backward() - - self.optimizer_D.step() - - if self.cri_grad_gan: - for p in self.netD_grad.parameters(): - p.requires_grad = True - - self.optimizer_D_grad.zero_grad() - for var_ref, fake_H in zip(self.var_ref, self.fake_H): - fake_H_grad = self.get_grad(fake_H) - var_ref_grad = self.get_grad(var_ref) - - pred_d_real_grad = self.netD_grad(var_ref_grad) - pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G - - l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) - l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) - - l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 - l_d_total_grad /= self.mega_batch_factor - - with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled: - l_d_total_grad_scaled.backward() - - self.optimizer_D_grad.step() - - # Log sample images from first microbatch. - if step % 50 == 0: - sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") - os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) - os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True) - # fed_LQ is not chunked. - utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,))) - utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,))) - utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,))) - utils.save_image(self.grad_LR[0].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i.png" % (step,))) - - - # set log - if step % self.D_update_ratio == 0 and step > self.D_init_iters: - # G - if self.cri_pix: - self.add_log_entry('l_g_pix', l_g_pix.item()) - if self.cri_fea: - self.add_log_entry('l_g_fea', l_g_fea.item()) - if self.l_gan_w > 0: - self.add_log_entry('l_g_gan', l_g_gan.item()) - - if self.cri_pix_branch: #branch pixel loss - self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item()) - - if self.l_gan_w > 0: - self.add_log_entry('l_d_real', l_d_real.item()) - self.add_log_entry('l_d_fake', l_d_fake.item()) - self.add_log_entry('l_d_real_grad', l_d_real_grad.item()) - self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item()) - self.add_log_entry('D_real', torch.mean(pred_d_real.detach())) - self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) - self.add_log_entry('D_real_grad', torch.mean(pred_d_real_grad.detach())) - self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach())) - - # Allows the log to serve as an easy-to-use rotating buffer. - def add_log_entry(self, key, value): - key_it = "%s_it" % (key,) - log_rotating_buffer_size = 50 - if key not in self.log_dict.keys(): - self.log_dict[key] = [] - self.log_dict[key_it] = 0 - if len(self.log_dict[key]) < log_rotating_buffer_size: - self.log_dict[key].append(value) - else: - self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value - self.log_dict[key_it] += 1 - - def test(self): - self.netG.eval() - with torch.no_grad(): - self.fake_H_branch = [] - self.fake_H = [] - self.grad_LR = [] - for var_L in self.var_L: - fake_H_branch, fake_H, grad_LR = self.netG(var_L) - self.fake_H_branch.append(fake_H_branch) - self.fake_H.append(fake_H) - self.grad_LR.append(grad_LR) - - self.netG.train() - - # Fetches a summary of the log. - def get_current_log(self, step): - return_log = {} - for k in self.log_dict.keys(): - if not isinstance(self.log_dict[k], list): - continue - return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) - - # Some generators can do their own metric logging. - if hasattr(self.netG.module, "get_debug_values"): - return_log.update(self.netG.module.get_debug_values(step)) - if hasattr(self.netD.module, "get_debug_values"): - return_log.update(self.netD.module.get_debug_values(step)) - - return return_log - - def get_current_visuals(self, need_HR=True): - out_dict = OrderedDict() - out_dict['LR'] = self.var_L[0].float().cpu() - - out_dict['rlt'] = self.fake_H[0].float().cpu() - out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu() - out_dict['LR_grad'] = self.grad_LR[0].float().cpu() - if need_HR: - out_dict['GT'] = self.var_H[0].float().cpu() - return out_dict - - def print_network(self): - # Generator - s, n = self.get_network_description(self.netG) - if isinstance(self.netG, nn.DataParallel): - net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, - self.netG.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netG.__class__.__name__) - - logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) - logger.info(s) - if self.is_train: - # Disriminator - s, n = self.get_network_description(self.netD) - if isinstance(self.netD, nn.DataParallel): - net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, - self.netD.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netD.__class__.__name__) - - logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) - logger.info(s) - - if self.cri_fea: # F, Perceptual Network - s, n = self.get_network_description(self.netF) - if isinstance(self.netF, nn.DataParallel): - net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, - self.netF.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netF.__class__.__name__) - - logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) - logger.info(s) - - def load(self): - load_path_G = self.opt['path']['pretrain_model_G'] - if load_path_G is not None: - logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G)) - self.load_network(load_path_G, self.netG) - load_path_D = self.opt['path']['pretrain_model_D'] - if self.opt['is_train'] and load_path_D is not None: - logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D)) - self.load_network(load_path_D, self.netD) - load_path_D_grad = self.opt['path']['pretrain_model_D_grad'] - if self.opt['is_train'] and load_path_D_grad is not None: - logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad)) - self.load_network(load_path_D_grad, self.netD_grad) - - def compute_fea_loss(self, real, fake): - if self.cri_fea is None: - return 0 - with torch.no_grad(): - real = real.unsqueeze(dim=0).to(self.device) - fake = fake.unsqueeze(dim=0).to(self.device) - real_fea = self.netF(real).detach() - fake_fea = self.netF(fake) - return self.cri_fea(fake_fea, real_fea).item() - - def force_restore_swapout(self): - pass - - def save(self, iter_step): - self.save_network(self.netG, 'G', iter_step) - self.save_network(self.netD, 'D', iter_step) - self.save_network(self.netD_grad, 'D_grad', iter_step) - - # override of load_network that allows loading partial params (like RRDB_PSNR_x4) - def load_network(self, load_path, network, strict=True): - if isinstance(network, nn.DataParallel): - network = network.module - pretrained_dict = torch.load(load_path) - model_dict = network.state_dict() - pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} - - model_dict.update(pretrained_dict) - network.load_state_dict(model_dict) \ No newline at end of file diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 223f8af1..ec0c616b 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -4,8 +4,8 @@ import torch.nn as nn import torch.nn.functional as F from models.archs import SPSR_util as B from .RRDBNet_arch import RRDB -from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock2 -from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer +from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock +from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch from switched_conv_util import save_attention_to_image_rgb from switched_conv import compute_attention_specificity import functools @@ -351,3 +351,123 @@ class SwitchedSpsr(nn.Module): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] return val + + +class SwitchedSpsrWithRef(nn.Module): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): + super(SwitchedSpsrWithRef, self).__init__() + n_upscale = int(math.log(upscale, 2)) + + # switch options + transformation_filters = nf + switch_filters = nf + self.transformation_counts = xforms + self.reference_processor = ReferenceImageBranch(transformation_filters) + multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters, self.transformation_counts) + pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=3, + weight_init_factor=.1) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + + # Grad branch + self.get_g_nopadding = ImageGradientNoPadding() + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + mplex_grad = functools.partial(ReferencingConvMultiplexer, nf * 2, nf * 2, self.transformation_counts // 2) + self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts // 2, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + # Upsampling + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + # Conv used to output grad branch shortcut. + self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) + + # Conjoin branch. + # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. + transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=4, + weight_init_factor=.1) + pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1) + self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] + self.attentions = None + self.init_temperature = init_temperature + self.final_temperature_step = 10000 + + def forward(self, x, ref, center_coord): + x_grad = self.get_g_nopadding(x) + ref = self.reference_processor(ref, center_coord) + x = self.model_fea_conv(x) + + x1, a1 = self.sw1(x, True, att_in=(x, ref)) + x2, a2 = self.sw2(x1, True, att_in=(x, ref)) + x_fea = self.feature_lr_conv(x2) + x_fea = self.feature_hr_conv2(x_fea) + + x_b_fea = self.b_fea_conv(x_grad) + x_grad, a3 = self.sw_grad(x_b_fea, att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True) + x_grad = self.grad_lr_conv(x_grad) + x_grad = self.grad_hr_conv(x_grad) + x_out_branch = self.upsample_grad(x_grad) + x_out_branch = self.grad_branch_output_conv(x_out_branch) + + x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) + x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True) + x_out = self.final_lr_conv(x__branch_pretrain_cat) + x_out = self.upsample(x_out) + x_out = self.final_hr_conv1(x_out) + x_out = self.final_hr_conv2(x_out) + + self.attentions = [a1, a2, a3, a4] + + return x_out_branch, x_out, x_grad + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 200 == 0: + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 1fe3ba3a..6fcdfab0 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti import torch.nn.functional as F import functools from collections import OrderedDict -from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2 +from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB from models.archs.spinenet_arch import SpineNet from switched_conv_util import save_attention_to_image_rgb @@ -92,6 +92,87 @@ class CachedBackboneWrapper: def get_forward_result(self): return self.cache +# torch.gather() which operates across 2d images. +def gather_2d(input, index): + b, c, h, w = input.shape + nodim = input.view(b, c, h * w) + ind_nd = index[:, 0]*w + index[:, 1] + ind_nd = ind_nd.unsqueeze(1) + ind_nd = ind_nd.repeat((1, c)) + ind_nd = ind_nd.unsqueeze(2) + result = torch.gather(nodim, dim=2, index=ind_nd) + return result.squeeze() + + +# Computes a linear latent by performing processing on the reference image and returning the filters of a single point, +# which should be centered on the image patch being processed. +# +# Output is base_filters * 8. +class ReferenceImageBranch(nn.Module): + def __init__(self, base_filters=64): + super(ReferenceImageBranch, self).__init__() + self.filter_conv = ConvGnSilu(4, base_filters, bias=True) + self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(3)]) + reduction_filters = base_filters * 2 ** 3 + self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(4)])) + + # center_point is a [b,2] long tensor describing the center point of where the patch was taken from the reference + # image. + def forward(self, x, center_point): + x = self.filter_conv(x) + reduction_identities = [] + for b in self.reduction_blocks: + reduction_identities.append(x) + x = b(x) + x = self.processing_blocks(x) + return gather_2d(x, center_point // 8) + + +# This is similar to ConvBasisMultiplexer, except that it takes a linear reference tensor as a second input to +# provide better results. It also has fixed parameterization in several places +class ReferencingConvMultiplexer(nn.Module): + def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True): + super(ReferencingConvMultiplexer, self).__init__() + self.filter_conv = ConvGnSilu(input_channels, multiplexer_channels, bias=True) + self.ref_proc = nn.Linear(512, 512) + self.ref_red = nn.Linear(512, base_filters * 2) + self.feature_norm = torch.nn.InstanceNorm2d(base_filters) + self.style_norm = torch.nn.InstanceNorm1d(base_filters) + + self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(3)]) + reduction_filters = base_filters * 2 ** 3 + self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(2)])) + self.expansion_blocks = nn.ModuleList([ExpansionBlock2(reduction_filters // (2 ** i)) for i in range(3)]) + + gap = base_filters - multiplexer_channels + cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm. + self.cbl1 = ConvGnSilu(base_filters, cbl1_out, norm=use_gn, bias=False, num_groups=4) + cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4 + self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, norm=use_gn, bias=False, num_groups=4) + self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False) + + def forward(self, x, ref): + # Start by fusing the reference vector and the input. Follows the ADAIn formula. + x = self.feature_norm(self.filter_conv(x)) + ref = self.ref_proc(ref) + ref = self.ref_red(ref) + b, c = ref.shape + ref = self.style_norm(ref.view(b, 2, c // 2)) + x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + + reduction_identities = [] + for b in self.reduction_blocks: + reduction_identities.append(x) + x = b(x) + x = self.processing_blocks(x) + for i, b in enumerate(self.expansion_blocks): + x = b(x, reduction_identities[-i - 1]) + + x = self.cbl1(x) + x = self.cbl2(x) + x = self.cbl3(x) + return x + class BackboneMultiplexer(nn.Module): def __init__(self, backbone: CachedBackboneWrapper, transform_count): @@ -151,7 +232,10 @@ class ConfigurableSwitchComputer(nn.Module): if self.pre_transform: x = self.pre_transform(x) xformed = [t.forward(x) for t in self.transforms] - m = self.multiplexer(att_in) + if isinstance(att_in, tuple): + m = self.multiplexer(*att_in) + else: + m = self.multiplexer(att_in) outputs, attention = self.switch(xformed, m, True) diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 125a2222..3b1df730 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -415,32 +415,16 @@ class ExpansionBlock2(nn.Module): return self.reduce(x) -# Similar to ExpansionBlock but does not upsample. +# Similar to ExpansionBlock2 but does not upsample. class ConjoinBlock(nn.Module): - def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True): + def __init__(self, filters_in, filters_out=None, filters_pt=None, block=ConvGnSilu, norm=True): super(ConjoinBlock, self).__init__() if filters_out is None: filters_out = filters_in - self.decimate = block(filters_in*2, filters_out, kernel_size=1, bias=False, activation=False, norm=norm) - self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=norm) - - # input is the feature signal with shape (b, f, w, h) - # passthrough is the structure signal with shape (b, f/2, w*2, h*2) - # output is conjoined upsample with shape (b, f/2, w*2, h*2) - def forward(self, input, passthrough): - x = torch.cat([input, passthrough], dim=1) - x = self.decimate(x) - return self.process(x) - - -# Similar to ExpansionBlock2 but does not upsample. -class ConjoinBlock2(nn.Module): - def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True): - super(ConjoinBlock2, self).__init__() - if filters_out is None: - filters_out = filters_in - self.process = block(filters_in*2, filters_in*2, kernel_size=3, bias=False, activation=True, norm=norm) - self.decimate = block(filters_in*2, filters_out, kernel_size=1, bias=False, activation=False, norm=norm) + if filters_pt is None: + filters_pt = filters_in + self.process = block(filters_in + filters_pt, filters_in + filters_pt, kernel_size=3, bias=False, activation=True, norm=norm) + self.decimate = block(filters_in + filters_pt, filters_out, kernel_size=1, bias=False, activation=False, norm=norm) def forward(self, input, passthrough): x = torch.cat([input, passthrough], dim=1) diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index e8c3913a..6ba39264 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -159,12 +159,12 @@ class CrossCompareBlock(nn.Module): class CrossCompareDiscriminator(nn.Module): - def __init__(self, in_nc, nf, scale=4): + def __init__(self, in_nc, ref_channels, nf, scale=4): super(CrossCompareDiscriminator, self).__init__() assert scale == 2 or scale == 4 self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True) - self.init_conv_lr = ConvGnLelu(in_nc, nf, stride=1, norm=False, bias=True, activation=True) + self.init_conv_lr = ConvGnLelu(ref_channels, nf, stride=1, norm=False, bias=True, activation=True) if scale == 4: strd_2 = 2 else: diff --git a/codes/models/networks.py b/codes/models/networks.py index 66656ad2..deb26144 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -119,6 +119,10 @@ def define_G(opt, net_key='network_G', scale=None): xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) + elif which_model == "spsr_switched_with_ref": + xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 + netG = spsr.SwitchedSpsrWithRef(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) # image corruption elif which_model == 'HighToLowResNet': @@ -159,7 +163,7 @@ def define_D_net(opt_net, img_sz=None): netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'], final_temperature_step=opt_net['final_temperature_step']) elif which_model == "cross_compare_vgg128": - netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], nf=opt_net['nf'], scale=opt_net['scale']) + netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index f86ff320..8effb6bf 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -41,7 +41,11 @@ class ImageGeneratorInjector(Injector): def forward(self, state): gen = self.env['generators'][self.opt['generator']] - results = gen(state[self.input]) + if isinstance(self.input, list): + params = [state[i] for i in self.input] + results = gen(*params) + else: + results = gen(state[self.input]) new_state = {} if isinstance(self.output, list): for i, k in enumerate(self.output): diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 9b040f5d..7b161a4d 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -78,7 +78,7 @@ class GeneratorGanLoss(ConfigurableLoss): netD = self.env['discriminators'][self.opt['discriminator']] if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: if self.opt['gan_type'] == 'crossgan': - pred_g_fake = netD(state[self.opt['fake']], state['lq']) + pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref']) else: pred_g_fake = netD(state[self.opt['fake']]) return self.criterion(pred_g_fake, True) @@ -101,9 +101,9 @@ class DiscriminatorGanLoss(ConfigurableLoss): self.metrics = [] if self.opt['gan_type'] == 'crossgan': - d_real = net(state[self.opt['real']], state['lq']) - d_fake = net(state[self.opt['fake']].detach(), state['lq']) - mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0) + d_real = net(state[self.opt['real']], state['lq_fullsize_ref']) + d_fake = net(state[self.opt['fake']].detach(), state['lq_fullsize_ref']) + mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0) d_mismatch_real = net(state[self.opt['real']], mismatched_lq) d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq) else: diff --git a/codes/train.py b/codes/train.py index 1b1444d2..5db47536 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/finetune_imgset_spsr_switched2_xlbatch_limfeat.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) diff --git a/codes/train2.py b/codes/train2.py index e3d107fc..8dccf326 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -161,7 +161,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = 0 if 'start_step' not in opt.keys() else opt['start_step'] + current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] start_epoch = 0 #### training