From 9a421a41f46521bbb33ce3424813abb66e1b727c Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 1 Dec 2020 11:11:51 -0700 Subject: [PATCH] SRFlow: accomodate mismatches between global scale and flow_scale --- codes/models/archs/srflow_orig/SRFlowNet_arch.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index 933fe57e..3dd99896 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -19,10 +19,11 @@ class SRFlowNet(nn.Module): self.opt = opt self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ None else opt_get(opt, ['datasets', 'train', 'quant']) - self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) + initial_stride = opt_get(opt, ['networks', 'generator', 'initial_stride'], 1) + self.RRDB = RRDBNet(in_nc, out_nc, nf=nf, nb=nb, gc=gc, scale=scale, opt=opt, initial_conv_stride=initial_stride) if 'pretrain_rrdb' in opt['networks']['generator'].keys(): rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb']) - self.RRDB.load_state_dict(rrdb_state_dict, strict=False) + self.RRDB.load_state_dict(rrdb_state_dict, strict=True) hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels']) hidden_channels = hidden_channels or 64 @@ -43,8 +44,8 @@ class SRFlowNet(nn.Module): if seed: torch.manual_seed(seed) if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']): C = self.flowUpsamplerNet.C - H = int(self.flow_scale * lr_shape[2] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale)) - W = int(self.flow_scale * lr_shape[3] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale)) + H = int(self.flow_scale * lr_shape[0] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale)) + W = int(self.flow_scale * lr_shape[1] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale)) size = (batch_size, C, H, W) if heat == 0: @@ -84,8 +85,9 @@ class SRFlowNet(nn.Module): else: assert lr.shape[1] == 3 if z is None: - # Synthesize it. - z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device) + # Synthesize it. Accommodate mismatches in LR scale and flow_scale, which are normally handled by the RRDB subnet. + lr_shape = [d * self.opt['scale'] / self.flow_scale for d in lr.shape[2:]] + z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr_shape, device=lr.device) if reverse_with_grad: return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise)