forked from mrq/DL-Art-School
Reference network
This commit is contained in:
parent
f224907603
commit
a65b07607c
|
@ -36,13 +36,8 @@ def create_dataset(dataset_opt):
|
|||
# datasets for image corruption
|
||||
elif mode == 'downsample':
|
||||
from data.Downsample_dataset import DownsampleDataset as D
|
||||
# datasets for video restoration
|
||||
elif mode == 'REDS':
|
||||
from data.REDS_dataset import REDSDataset as D
|
||||
elif mode == 'Vimeo90K':
|
||||
from data.Vimeo90K_dataset import Vimeo90KDataset as D
|
||||
elif mode == 'video_test':
|
||||
from data.video_test_dataset import VideoTestDataset as D
|
||||
elif mode == 'fullimage':
|
||||
from data.full_image_dataset import FullImageDataset as D
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
|
@ -58,7 +58,7 @@ class FullImageDataset(data.Dataset):
|
|||
h, w, _ = image.shape
|
||||
if h == w:
|
||||
return image
|
||||
offset = min(np.random.normal(scale=.3), 1.0)
|
||||
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
||||
if h > w:
|
||||
diff = h - w
|
||||
center = diff // 2
|
||||
|
@ -75,6 +75,14 @@ class FullImageDataset(data.Dataset):
|
|||
margin_center = margin_sz // 2
|
||||
return min(max(int(min(np.random.normal(scale=dev), 1.0) * margin_sz + margin_center), 0), margin_sz)
|
||||
|
||||
def resize_point(self, point, orig_dim, new_dim):
|
||||
oh, ow = orig_dim
|
||||
nh, nw = new_dim
|
||||
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
|
||||
point[0] = int(dh * float(point[0]))
|
||||
point[1] = int(dw * float(point[1]))
|
||||
return point
|
||||
|
||||
# - Randomly extracts a square from image and resizes it to opt['target_size'].
|
||||
# - Fills a mask with zeros, then places 1's where the square was extracted from. Resizes this mask and the source
|
||||
# image to the target_size and returns that too.
|
||||
|
@ -83,11 +91,10 @@ class FullImageDataset(data.Dataset):
|
|||
# half-normal distribution, biasing towards the target_size.
|
||||
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
|
||||
def pull_tile(self, image):
|
||||
target_sz = self.opt['target_size']
|
||||
target_sz = self.opt['min_tile_size']
|
||||
h, w, _ = image.shape
|
||||
possible_sizes_above_target = h - target_sz
|
||||
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.1)), 1.0))
|
||||
print("Square size: %i" % (square_size,))
|
||||
# Pick the left,top coords to draw the patch from
|
||||
left = self.pick_along_range(w, square_size, .3)
|
||||
top = self.pick_along_range(w, square_size, .3)
|
||||
|
@ -95,12 +102,14 @@ class FullImageDataset(data.Dataset):
|
|||
mask = np.zeros((h, w, 1), dtype=np.float)
|
||||
mask[top:top+square_size, left:left+square_size] = 1
|
||||
patch = image[top:top+square_size, left:left+square_size, :]
|
||||
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
||||
|
||||
patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
center = self.resize_point(center, (h, w), image.shape[:2])
|
||||
|
||||
return patch, image, mask
|
||||
return patch, image, mask, center
|
||||
|
||||
def augment_tile(self, img_GT, img_LQ, strength=1):
|
||||
scale = self.opt['scale']
|
||||
|
@ -145,16 +154,22 @@ class FullImageDataset(data.Dataset):
|
|||
return img_LQ
|
||||
|
||||
def __getitem__(self, index):
|
||||
GT_path, LQ_path = None, None
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size']
|
||||
|
||||
# get full size image
|
||||
full_path = self.paths_GT[index % len(self.paths_GT)]
|
||||
LQ_path = full_path
|
||||
img_full = util.read_img(None, full_path, None)
|
||||
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_full = self.get_square_image(img_full)
|
||||
img_GT, gt_fullsize_ref, gt_mask = self.pull_tile(img_full)
|
||||
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
||||
if self.opt['phase'] == 'train':
|
||||
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_full = self.get_square_image(img_full)
|
||||
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
|
||||
else:
|
||||
img_GT, gt_fullsize_ref = img_full, img_full
|
||||
gt_mask = np.ones(img_full.shape[:2])
|
||||
gt_center = torch.tensor([img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long)
|
||||
orig_gt_dim = gt_fullsize_ref.shape[:2]
|
||||
|
||||
# get LQ image
|
||||
if self.paths_LQ:
|
||||
|
@ -162,11 +177,16 @@ class FullImageDataset(data.Dataset):
|
|||
img_lq_full = util.read_img(None, LQ_path, None)
|
||||
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_lq_full = self.get_square_image(img_lq_full)
|
||||
img_LQ, lq_fullsize_ref, lq_mask = self.pull_tile(img_lq_full)
|
||||
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full)
|
||||
else: # down-sampling on-the-fly
|
||||
# randomly scale during training
|
||||
if self.opt['phase'] == 'train':
|
||||
GT_size = self.opt['target_size']
|
||||
random_scale = random.choice(self.random_scale_list)
|
||||
if len(img_GT.shape) == 2:
|
||||
print("ERRAR:")
|
||||
print(img_GT.shape)
|
||||
print(full_path)
|
||||
H_s, W_s, _ = img_GT.shape
|
||||
|
||||
def _mod(n, random_scale, scale, thres):
|
||||
|
@ -184,23 +204,34 @@ class FullImageDataset(data.Dataset):
|
|||
|
||||
# using matlab imresize
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
lq_fullsize_ref, lq_mask = gt_fullsize_ref, gt_mask
|
||||
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
|
||||
orig_lq_dim = lq_fullsize_ref.shape[:2]
|
||||
|
||||
# Enforce force_resize constraints.
|
||||
# Enforce force_resize constraints via clipping.
|
||||
h, w, _ = img_LQ.shape
|
||||
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
||||
h, w = (w - w % self.force_multiple), (h - h % self.force_multiple)
|
||||
img_LQ = cv2.resize(img_LQ, (h, w))
|
||||
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
|
||||
img_LQ = img_LQ[:h, :w, :]
|
||||
lq_fullsize_ref = lq_fullsize_ref[:h, :w, :]
|
||||
h *= scale
|
||||
w *= scale
|
||||
img_GT = cv2.resize(img_GT, (h, w))
|
||||
img_GT = img_GT[:h, :w]
|
||||
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
|
||||
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
||||
lq_mask = cv2.resize(lq_mask, img_LQ.shape[0:2], interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Scale masks.
|
||||
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Scale center coords
|
||||
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
|
||||
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
|
@ -210,8 +241,9 @@ class FullImageDataset(data.Dataset):
|
|||
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
img_LQ = self.pil_augment(img_LQ)
|
||||
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||||
if self.opt['phase'] == 'train':
|
||||
img_LQ = self.pil_augment(img_LQ)
|
||||
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||||
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
||||
|
@ -226,19 +258,19 @@ class FullImageDataset(data.Dataset):
|
|||
lq_fullsize_ref += lq_noise
|
||||
|
||||
# Apply the masks to the full images.
|
||||
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
|
||||
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
|
||||
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
|
||||
|
||||
if LQ_path is None:
|
||||
LQ_path = GT_path
|
||||
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
|
||||
'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
'lq_center': lq_center, 'gt_center': gt_center,
|
||||
'LQ_path': LQ_path, 'GT_path': full_path}
|
||||
return d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_GT)
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'],
|
||||
|
@ -249,19 +281,32 @@ if __name__ == '__main__':
|
|||
'use_rot': True,
|
||||
'lq_noise': 5,
|
||||
'target_size': 128,
|
||||
'min_tile_size': 256,
|
||||
'scale': 2,
|
||||
'phase': 'train'
|
||||
}
|
||||
'''
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'],
|
||||
'dataroot_GT_weights': [1],
|
||||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'phase': 'test'
|
||||
}
|
||||
|
||||
ds = FullImageDataset(opt)
|
||||
import os
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
for i in range(1000):
|
||||
for i in range(300, len(ds)):
|
||||
print(i)
|
||||
o = ds[i]
|
||||
for k, v in o.items():
|
||||
if 'path' not in k:
|
||||
if 'full' in k:
|
||||
masked = v[:3, :, :] * v[3]
|
||||
torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
v = v[:3, :, :]
|
||||
import torchvision
|
||||
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
#if 'full' in k:
|
||||
#masked = v[:3, :, :] * v[3]
|
||||
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
#v = v[:3, :, :]
|
||||
#import torchvision
|
||||
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
pass
|
|
@ -127,6 +127,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
for k, v in data.items():
|
||||
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor):
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
self.env['step'] = step
|
||||
|
||||
|
@ -136,7 +141,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
# Only set requires_grad=True for the network being trained.
|
||||
nets_to_train = s.get_networks_trained()
|
||||
|
@ -195,7 +200,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
with torch.no_grad():
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
ns = s.do_forward_backward(state, 0, step_num, backward=False)
|
||||
for k, v in ns.items():
|
||||
|
|
|
@ -1,458 +0,0 @@
|
|||
import os
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import lr_scheduler
|
||||
from apex import amp
|
||||
|
||||
import models.networks as networks
|
||||
from .base_model import BaseModel
|
||||
from models.loss import GANLoss
|
||||
import torchvision.utils as utils
|
||||
from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
class SPSRModel(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(SPSRModel, self).__init__(opt)
|
||||
train_opt = opt['train']
|
||||
|
||||
# define networks and load pretrained models
|
||||
self.netG = networks.define_G(opt).to(self.device) # G
|
||||
if self.is_train:
|
||||
self.netD = networks.define_D(opt).to(self.device) # D
|
||||
self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
|
||||
self.netG.train()
|
||||
self.netD.train()
|
||||
self.netD_grad.train()
|
||||
self.mega_batch_factor = 1
|
||||
self.load() # load G and D if needed
|
||||
|
||||
# define losses, optimizer and scheduler
|
||||
if self.is_train:
|
||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||
|
||||
# G pixel loss
|
||||
if train_opt['pixel_weight'] > 0:
|
||||
l_pix_type = train_opt['pixel_criterion']
|
||||
if l_pix_type == 'l1':
|
||||
self.cri_pix = nn.L1Loss().to(self.device)
|
||||
elif l_pix_type == 'l2':
|
||||
self.cri_pix = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||
self.l_pix_w = train_opt['pixel_weight']
|
||||
else:
|
||||
logger.info('Remove pixel loss.')
|
||||
self.cri_pix = None
|
||||
|
||||
# G feature loss
|
||||
if train_opt['feature_weight'] > 0:
|
||||
l_fea_type = train_opt['feature_criterion']
|
||||
if l_fea_type == 'l1':
|
||||
self.cri_fea = nn.L1Loss().to(self.device)
|
||||
elif l_fea_type == 'l2':
|
||||
self.cri_fea = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||
self.l_fea_w = train_opt['feature_weight']
|
||||
else:
|
||||
logger.info('Remove feature loss.')
|
||||
self.cri_fea = None
|
||||
if self.cri_fea: # load VGG perceptual loss
|
||||
self.netF = networks.define_F(use_bn=False).to(self.device)
|
||||
|
||||
# GD gan loss
|
||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_w = train_opt['gan_weight']
|
||||
# D_update_ratio and D_init_iters are for WGAN
|
||||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||
# Branch_init_iters
|
||||
self.branch_pretrain = train_opt['branch_pretrain'] if train_opt['branch_pretrain'] else 0
|
||||
self.branch_init_iters = train_opt['branch_init_iters'] if train_opt['branch_init_iters'] else 1
|
||||
|
||||
# gradient_pixel_loss
|
||||
if train_opt['gradient_pixel_weight'] > 0:
|
||||
self.cri_pix_grad = nn.MSELoss().to(self.device)
|
||||
self.l_pix_grad_w = train_opt['gradient_pixel_weight']
|
||||
else:
|
||||
self.cri_pix_grad = None
|
||||
|
||||
# gradient_gan_loss
|
||||
if train_opt['gradient_gan_weight'] > 0:
|
||||
self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_grad_w = train_opt['gradient_gan_weight']
|
||||
else:
|
||||
self.cri_grad_gan = None
|
||||
|
||||
# G_grad pixel loss
|
||||
if train_opt['pixel_branch_weight'] > 0:
|
||||
l_pix_type = train_opt['pixel_branch_criterion']
|
||||
if l_pix_type == 'l1':
|
||||
self.cri_pix_branch = nn.L1Loss().to(self.device)
|
||||
elif l_pix_type == 'l2':
|
||||
self.cri_pix_branch = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||
self.l_pix_branch_w = train_opt['pixel_branch_weight']
|
||||
else:
|
||||
logger.info('Remove G_grad pixel loss.')
|
||||
self.cri_pix_branch = None
|
||||
|
||||
# optimizers
|
||||
# G
|
||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||
|
||||
optim_params = []
|
||||
for k, v in self.netG.named_parameters(): # optimize part of the model
|
||||
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
|
||||
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
# D
|
||||
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
|
||||
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
||||
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
# D_grad
|
||||
wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||
self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \
|
||||
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
||||
|
||||
self.optimizers.append(self.optimizer_D_grad)
|
||||
|
||||
# AMP
|
||||
[self.netG, self.netD, self.netD_grad], [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
|
||||
amp.initialize([self.netG, self.netD, self.netD_grad],
|
||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
|
||||
opt_level=self.amp_level, num_losses=3)
|
||||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
|
||||
train_opt['lr_steps'], train_opt['lr_gamma']))
|
||||
else:
|
||||
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||
|
||||
self.log_dict = OrderedDict()
|
||||
self.get_grad = ImageGradient()
|
||||
self.get_grad_nopadding = ImageGradientNoPadding()
|
||||
|
||||
def feed_data(self, data, need_HR=True):
|
||||
# LR
|
||||
self.var_L = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
if need_HR: # train or val
|
||||
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref.to(self.device), chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
# Some generators have variants depending on the current step.
|
||||
if hasattr(self.netG.module, "update_for_step"):
|
||||
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
if hasattr(self.netD.module, "update_for_step"):
|
||||
self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
||||
# G
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if(self.branch_pretrain):
|
||||
if(step < self.branch_init_iters):
|
||||
for k,v in self.netG.named_parameters():
|
||||
if 'f_' not in k :
|
||||
v.requires_grad=False
|
||||
else:
|
||||
for k,v in self.netG.named_parameters():
|
||||
if 'f_' not in k :
|
||||
v.requires_grad=True
|
||||
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
self.fake_H_branch = []
|
||||
self.fake_H = []
|
||||
self.grad_LR = []
|
||||
for var_L, var_H, var_ref in zip(self.var_L, self.var_H, self.var_ref):
|
||||
fake_H_branch, fake_H, grad_LR = self.netG(var_L)
|
||||
self.fake_H_branch.append(fake_H_branch.detach())
|
||||
self.fake_H.append(fake_H.detach())
|
||||
self.grad_LR.append(grad_LR.detach())
|
||||
|
||||
fake_H_grad = self.get_grad(fake_H)
|
||||
var_H_grad = self.get_grad(var_H)
|
||||
var_ref_grad = self.get_grad(var_ref)
|
||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(var_H).detach()
|
||||
fake_fea = self.netF(fake_H)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
if self.cri_pix_grad: #gradient pixel loss
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad)
|
||||
l_g_total += l_g_pix_grad
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch, var_H_grad_nopadding)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
# G gan + cls loss
|
||||
pred_g_fake = self.netD(fake_H)
|
||||
pred_d_real = self.netD(var_ref).detach()
|
||||
|
||||
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
if self.cri_grad_gan:
|
||||
# grad G gan + cls loss
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
|
||||
|
||||
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
||||
l_g_total += l_g_gan_grad
|
||||
|
||||
l_g_total /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
l_g_total_scaled.backward()
|
||||
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
self.optimizer_G.step()
|
||||
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
# D
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
|
||||
pred_d_real = self.netD(var_ref)
|
||||
pred_d_fake = self.netD(fake_H) # detach to avoid BP to G
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
|
||||
l_d_total /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
l_d_total_scaled.backward()
|
||||
|
||||
self.optimizer_D.step()
|
||||
|
||||
if self.cri_grad_gan:
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D_grad.zero_grad()
|
||||
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
|
||||
fake_H_grad = self.get_grad(fake_H)
|
||||
var_ref_grad = self.get_grad(var_ref)
|
||||
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
||||
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
|
||||
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
l_d_total_grad /= self.mega_batch_factor
|
||||
|
||||
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
|
||||
l_d_total_grad_scaled.backward()
|
||||
|
||||
self.optimizer_D_grad.step()
|
||||
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True)
|
||||
# fed_LQ is not chunked.
|
||||
utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
|
||||
utils.save_image(self.grad_LR[0].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i.png" % (step,)))
|
||||
|
||||
|
||||
# set log
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
# G
|
||||
if self.cri_pix:
|
||||
self.add_log_entry('l_g_pix', l_g_pix.item())
|
||||
if self.cri_fea:
|
||||
self.add_log_entry('l_g_fea', l_g_fea.item())
|
||||
if self.l_gan_w > 0:
|
||||
self.add_log_entry('l_g_gan', l_g_gan.item())
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
self.add_log_entry('l_d_real', l_d_real.item())
|
||||
self.add_log_entry('l_d_fake', l_d_fake.item())
|
||||
self.add_log_entry('l_d_real_grad', l_d_real_grad.item())
|
||||
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item())
|
||||
self.add_log_entry('D_real', torch.mean(pred_d_real.detach()))
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
self.add_log_entry('D_real_grad', torch.mean(pred_d_real_grad.detach()))
|
||||
self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach()))
|
||||
|
||||
# Allows the log to serve as an easy-to-use rotating buffer.
|
||||
def add_log_entry(self, key, value):
|
||||
key_it = "%s_it" % (key,)
|
||||
log_rotating_buffer_size = 50
|
||||
if key not in self.log_dict.keys():
|
||||
self.log_dict[key] = []
|
||||
self.log_dict[key_it] = 0
|
||||
if len(self.log_dict[key]) < log_rotating_buffer_size:
|
||||
self.log_dict[key].append(value)
|
||||
else:
|
||||
self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
|
||||
self.log_dict[key_it] += 1
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H_branch = []
|
||||
self.fake_H = []
|
||||
self.grad_LR = []
|
||||
for var_L in self.var_L:
|
||||
fake_H_branch, fake_H, grad_LR = self.netG(var_L)
|
||||
self.fake_H_branch.append(fake_H_branch)
|
||||
self.fake_H.append(fake_H)
|
||||
self.grad_LR.append(grad_LR)
|
||||
|
||||
self.netG.train()
|
||||
|
||||
# Fetches a summary of the log.
|
||||
def get_current_log(self, step):
|
||||
return_log = {}
|
||||
for k in self.log_dict.keys():
|
||||
if not isinstance(self.log_dict[k], list):
|
||||
continue
|
||||
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
||||
|
||||
# Some generators can do their own metric logging.
|
||||
if hasattr(self.netG.module, "get_debug_values"):
|
||||
return_log.update(self.netG.module.get_debug_values(step))
|
||||
if hasattr(self.netD.module, "get_debug_values"):
|
||||
return_log.update(self.netD.module.get_debug_values(step))
|
||||
|
||||
return return_log
|
||||
|
||||
def get_current_visuals(self, need_HR=True):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['LR'] = self.var_L[0].float().cpu()
|
||||
|
||||
out_dict['rlt'] = self.fake_H[0].float().cpu()
|
||||
out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu()
|
||||
out_dict['LR_grad'] = self.grad_LR[0].float().cpu()
|
||||
if need_HR:
|
||||
out_dict['GT'] = self.var_H[0].float().cpu()
|
||||
return out_dict
|
||||
|
||||
def print_network(self):
|
||||
# Generator
|
||||
s, n = self.get_network_description(self.netG)
|
||||
if isinstance(self.netG, nn.DataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||
self.netG.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||
|
||||
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
if self.is_train:
|
||||
# Disriminator
|
||||
s, n = self.get_network_description(self.netD)
|
||||
if isinstance(self.netD, nn.DataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
||||
self.netD.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
||||
|
||||
logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
|
||||
if self.cri_fea: # F, Perceptual Network
|
||||
s, n = self.get_network_description(self.netF)
|
||||
if isinstance(self.netF, nn.DataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
||||
self.netF.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
||||
|
||||
logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
|
||||
def load(self):
|
||||
load_path_G = self.opt['path']['pretrain_model_G']
|
||||
if load_path_G is not None:
|
||||
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
|
||||
self.load_network(load_path_G, self.netG)
|
||||
load_path_D = self.opt['path']['pretrain_model_D']
|
||||
if self.opt['is_train'] and load_path_D is not None:
|
||||
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
||||
self.load_network(load_path_D, self.netD)
|
||||
load_path_D_grad = self.opt['path']['pretrain_model_D_grad']
|
||||
if self.opt['is_train'] and load_path_D_grad is not None:
|
||||
logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad))
|
||||
self.load_network(load_path_D_grad, self.netD_grad)
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
if self.cri_fea is None:
|
||||
return 0
|
||||
with torch.no_grad():
|
||||
real = real.unsqueeze(dim=0).to(self.device)
|
||||
fake = fake.unsqueeze(dim=0).to(self.device)
|
||||
real_fea = self.netF(real).detach()
|
||||
fake_fea = self.netF(fake)
|
||||
return self.cri_fea(fake_fea, real_fea).item()
|
||||
|
||||
def force_restore_swapout(self):
|
||||
pass
|
||||
|
||||
def save(self, iter_step):
|
||||
self.save_network(self.netG, 'G', iter_step)
|
||||
self.save_network(self.netD, 'D', iter_step)
|
||||
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
||||
|
||||
# override of load_network that allows loading partial params (like RRDB_PSNR_x4)
|
||||
def load_network(self, load_path, network, strict=True):
|
||||
if isinstance(network, nn.DataParallel):
|
||||
network = network.module
|
||||
pretrained_dict = torch.load(load_path)
|
||||
model_dict = network.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
|
||||
model_dict.update(pretrained_dict)
|
||||
network.load_state_dict(model_dict)
|
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from models.archs import SPSR_util as B
|
||||
from .RRDBNet_arch import RRDB
|
||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock2
|
||||
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer
|
||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock
|
||||
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
from switched_conv import compute_attention_specificity
|
||||
import functools
|
||||
|
@ -351,3 +351,123 @@ class SwitchedSpsr(nn.Module):
|
|||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
||||
class SwitchedSpsrWithRef(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
||||
super(SwitchedSpsrWithRef, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
# switch options
|
||||
transformation_filters = nf
|
||||
switch_filters = nf
|
||||
self.transformation_counts = xforms
|
||||
self.reference_processor = ReferenceImageBranch(transformation_filters)
|
||||
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters, self.transformation_counts)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=3,
|
||||
weight_init_factor=.1)
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
# Grad branch
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
mplex_grad = functools.partial(ReferencingConvMultiplexer, nf * 2, nf * 2, self.transformation_counts // 2)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
# Upsampling
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
# Conv used to output grad branch shortcut.
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
# Conjoin branch.
|
||||
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
|
||||
transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=4,
|
||||
weight_init_factor=.1)
|
||||
pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1)
|
||||
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw]
|
||||
self.attentions = None
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x, ref, center_coord):
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref = self.reference_processor(ref, center_coord)
|
||||
x = self.model_fea_conv(x)
|
||||
|
||||
x1, a1 = self.sw1(x, True, att_in=(x, ref))
|
||||
x2, a2 = self.sw2(x1, True, att_in=(x, ref))
|
||||
x_fea = self.feature_lr_conv(x2)
|
||||
x_fea = self.feature_hr_conv2(x_fea)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_grad, a3 = self.sw_grad(x_b_fea, att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_hr_conv(x_grad)
|
||||
x_out_branch = self.upsample_grad(x_grad)
|
||||
x_out_branch = self.grad_branch_output_conv(x_out_branch)
|
||||
|
||||
x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1)
|
||||
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True)
|
||||
x_out = self.final_lr_conv(x__branch_pretrain_cat)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
x_out = self.final_hr_conv2(x_out)
|
||||
|
||||
self.attentions = [a1, a2, a3, a4]
|
||||
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, 1 + self.init_temperature *
|
||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
if step % 200 == 0:
|
||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"switch_temperature": temp}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
|
|
@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti
|
|||
import torch.nn.functional as F
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock
|
||||
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
|
@ -92,6 +92,87 @@ class CachedBackboneWrapper:
|
|||
def get_forward_result(self):
|
||||
return self.cache
|
||||
|
||||
# torch.gather() which operates across 2d images.
|
||||
def gather_2d(input, index):
|
||||
b, c, h, w = input.shape
|
||||
nodim = input.view(b, c, h * w)
|
||||
ind_nd = index[:, 0]*w + index[:, 1]
|
||||
ind_nd = ind_nd.unsqueeze(1)
|
||||
ind_nd = ind_nd.repeat((1, c))
|
||||
ind_nd = ind_nd.unsqueeze(2)
|
||||
result = torch.gather(nodim, dim=2, index=ind_nd)
|
||||
return result.squeeze()
|
||||
|
||||
|
||||
# Computes a linear latent by performing processing on the reference image and returning the filters of a single point,
|
||||
# which should be centered on the image patch being processed.
|
||||
#
|
||||
# Output is base_filters * 8.
|
||||
class ReferenceImageBranch(nn.Module):
|
||||
def __init__(self, base_filters=64):
|
||||
super(ReferenceImageBranch, self).__init__()
|
||||
self.filter_conv = ConvGnSilu(4, base_filters, bias=True)
|
||||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(3)])
|
||||
reduction_filters = base_filters * 2 ** 3
|
||||
self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(4)]))
|
||||
|
||||
# center_point is a [b,2] long tensor describing the center point of where the patch was taken from the reference
|
||||
# image.
|
||||
def forward(self, x, center_point):
|
||||
x = self.filter_conv(x)
|
||||
reduction_identities = []
|
||||
for b in self.reduction_blocks:
|
||||
reduction_identities.append(x)
|
||||
x = b(x)
|
||||
x = self.processing_blocks(x)
|
||||
return gather_2d(x, center_point // 8)
|
||||
|
||||
|
||||
# This is similar to ConvBasisMultiplexer, except that it takes a linear reference tensor as a second input to
|
||||
# provide better results. It also has fixed parameterization in several places
|
||||
class ReferencingConvMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True):
|
||||
super(ReferencingConvMultiplexer, self).__init__()
|
||||
self.filter_conv = ConvGnSilu(input_channels, multiplexer_channels, bias=True)
|
||||
self.ref_proc = nn.Linear(512, 512)
|
||||
self.ref_red = nn.Linear(512, base_filters * 2)
|
||||
self.feature_norm = torch.nn.InstanceNorm2d(base_filters)
|
||||
self.style_norm = torch.nn.InstanceNorm1d(base_filters)
|
||||
|
||||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(3)])
|
||||
reduction_filters = base_filters * 2 ** 3
|
||||
self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(2)]))
|
||||
self.expansion_blocks = nn.ModuleList([ExpansionBlock2(reduction_filters // (2 ** i)) for i in range(3)])
|
||||
|
||||
gap = base_filters - multiplexer_channels
|
||||
cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm.
|
||||
self.cbl1 = ConvGnSilu(base_filters, cbl1_out, norm=use_gn, bias=False, num_groups=4)
|
||||
cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4
|
||||
self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, norm=use_gn, bias=False, num_groups=4)
|
||||
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False)
|
||||
|
||||
def forward(self, x, ref):
|
||||
# Start by fusing the reference vector and the input. Follows the ADAIn formula.
|
||||
x = self.feature_norm(self.filter_conv(x))
|
||||
ref = self.ref_proc(ref)
|
||||
ref = self.ref_red(ref)
|
||||
b, c = ref.shape
|
||||
ref = self.style_norm(ref.view(b, 2, c // 2))
|
||||
x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape)
|
||||
|
||||
reduction_identities = []
|
||||
for b in self.reduction_blocks:
|
||||
reduction_identities.append(x)
|
||||
x = b(x)
|
||||
x = self.processing_blocks(x)
|
||||
for i, b in enumerate(self.expansion_blocks):
|
||||
x = b(x, reduction_identities[-i - 1])
|
||||
|
||||
x = self.cbl1(x)
|
||||
x = self.cbl2(x)
|
||||
x = self.cbl3(x)
|
||||
return x
|
||||
|
||||
|
||||
class BackboneMultiplexer(nn.Module):
|
||||
def __init__(self, backbone: CachedBackboneWrapper, transform_count):
|
||||
|
@ -151,7 +232,10 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
if self.pre_transform:
|
||||
x = self.pre_transform(x)
|
||||
xformed = [t.forward(x) for t in self.transforms]
|
||||
m = self.multiplexer(att_in)
|
||||
if isinstance(att_in, tuple):
|
||||
m = self.multiplexer(*att_in)
|
||||
else:
|
||||
m = self.multiplexer(att_in)
|
||||
|
||||
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
|
|
|
@ -415,32 +415,16 @@ class ExpansionBlock2(nn.Module):
|
|||
return self.reduce(x)
|
||||
|
||||
|
||||
# Similar to ExpansionBlock but does not upsample.
|
||||
# Similar to ExpansionBlock2 but does not upsample.
|
||||
class ConjoinBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True):
|
||||
def __init__(self, filters_in, filters_out=None, filters_pt=None, block=ConvGnSilu, norm=True):
|
||||
super(ConjoinBlock, self).__init__()
|
||||
if filters_out is None:
|
||||
filters_out = filters_in
|
||||
self.decimate = block(filters_in*2, filters_out, kernel_size=1, bias=False, activation=False, norm=norm)
|
||||
self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=norm)
|
||||
|
||||
# input is the feature signal with shape (b, f, w, h)
|
||||
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
||||
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
|
||||
def forward(self, input, passthrough):
|
||||
x = torch.cat([input, passthrough], dim=1)
|
||||
x = self.decimate(x)
|
||||
return self.process(x)
|
||||
|
||||
|
||||
# Similar to ExpansionBlock2 but does not upsample.
|
||||
class ConjoinBlock2(nn.Module):
|
||||
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True):
|
||||
super(ConjoinBlock2, self).__init__()
|
||||
if filters_out is None:
|
||||
filters_out = filters_in
|
||||
self.process = block(filters_in*2, filters_in*2, kernel_size=3, bias=False, activation=True, norm=norm)
|
||||
self.decimate = block(filters_in*2, filters_out, kernel_size=1, bias=False, activation=False, norm=norm)
|
||||
if filters_pt is None:
|
||||
filters_pt = filters_in
|
||||
self.process = block(filters_in + filters_pt, filters_in + filters_pt, kernel_size=3, bias=False, activation=True, norm=norm)
|
||||
self.decimate = block(filters_in + filters_pt, filters_out, kernel_size=1, bias=False, activation=False, norm=norm)
|
||||
|
||||
def forward(self, input, passthrough):
|
||||
x = torch.cat([input, passthrough], dim=1)
|
||||
|
|
|
@ -159,12 +159,12 @@ class CrossCompareBlock(nn.Module):
|
|||
|
||||
|
||||
class CrossCompareDiscriminator(nn.Module):
|
||||
def __init__(self, in_nc, nf, scale=4):
|
||||
def __init__(self, in_nc, ref_channels, nf, scale=4):
|
||||
super(CrossCompareDiscriminator, self).__init__()
|
||||
assert scale == 2 or scale == 4
|
||||
|
||||
self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True)
|
||||
self.init_conv_lr = ConvGnLelu(in_nc, nf, stride=1, norm=False, bias=True, activation=True)
|
||||
self.init_conv_lr = ConvGnLelu(ref_channels, nf, stride=1, norm=False, bias=True, activation=True)
|
||||
if scale == 4:
|
||||
strd_2 = 2
|
||||
else:
|
||||
|
|
|
@ -119,6 +119,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
elif which_model == "spsr_switched_with_ref":
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.SwitchedSpsrWithRef(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
|
@ -159,7 +163,7 @@ def define_D_net(opt_net, img_sz=None):
|
|||
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
||||
final_temperature_step=opt_net['final_temperature_step'])
|
||||
elif which_model == "cross_compare_vgg128":
|
||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], nf=opt_net['nf'], scale=opt_net['scale'])
|
||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale'])
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
|
@ -41,7 +41,11 @@ class ImageGeneratorInjector(Injector):
|
|||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
results = gen(state[self.input])
|
||||
if isinstance(self.input, list):
|
||||
params = [state[i] for i in self.input]
|
||||
results = gen(*params)
|
||||
else:
|
||||
results = gen(state[self.input])
|
||||
new_state = {}
|
||||
if isinstance(self.output, list):
|
||||
for i, k in enumerate(self.output):
|
||||
|
|
|
@ -78,7 +78,7 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||
if self.opt['gan_type'] == 'crossgan':
|
||||
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
|
||||
pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref'])
|
||||
else:
|
||||
pred_g_fake = netD(state[self.opt['fake']])
|
||||
return self.criterion(pred_g_fake, True)
|
||||
|
@ -101,9 +101,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
self.metrics = []
|
||||
|
||||
if self.opt['gan_type'] == 'crossgan':
|
||||
d_real = net(state[self.opt['real']], state['lq'])
|
||||
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
||||
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
|
||||
d_real = net(state[self.opt['real']], state['lq_fullsize_ref'])
|
||||
d_fake = net(state[self.opt['fake']].detach(), state['lq_fullsize_ref'])
|
||||
mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0)
|
||||
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
||||
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
||||
else:
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/finetune_imgset_spsr_switched2_xlbatch_limfeat.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
|
@ -161,7 +161,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = 0 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
|
Loading…
Reference in New Issue
Block a user