SRFlow: accomodate mismatches between global scale and flow_scale

This commit is contained in:
James Betker 2020-12-01 11:11:51 -07:00
parent 8f65f81ddb
commit 9a421a41f4

View File

@ -19,10 +19,11 @@ class SRFlowNet(nn.Module):
self.opt = opt self.opt = opt
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
None else opt_get(opt, ['datasets', 'train', 'quant']) None else opt_get(opt, ['datasets', 'train', 'quant'])
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) initial_stride = opt_get(opt, ['networks', 'generator', 'initial_stride'], 1)
self.RRDB = RRDBNet(in_nc, out_nc, nf=nf, nb=nb, gc=gc, scale=scale, opt=opt, initial_conv_stride=initial_stride)
if 'pretrain_rrdb' in opt['networks']['generator'].keys(): if 'pretrain_rrdb' in opt['networks']['generator'].keys():
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb']) rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
self.RRDB.load_state_dict(rrdb_state_dict, strict=False) self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels']) hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
hidden_channels = hidden_channels or 64 hidden_channels = hidden_channels or 64
@ -43,8 +44,8 @@ class SRFlowNet(nn.Module):
if seed: torch.manual_seed(seed) if seed: torch.manual_seed(seed)
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']): if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
C = self.flowUpsamplerNet.C C = self.flowUpsamplerNet.C
H = int(self.flow_scale * lr_shape[2] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale)) H = int(self.flow_scale * lr_shape[0] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
W = int(self.flow_scale * lr_shape[3] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale)) W = int(self.flow_scale * lr_shape[1] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
if heat == 0: if heat == 0:
@ -84,8 +85,9 @@ class SRFlowNet(nn.Module):
else: else:
assert lr.shape[1] == 3 assert lr.shape[1] == 3
if z is None: if z is None:
# Synthesize it. # Synthesize it. Accommodate mismatches in LR scale and flow_scale, which are normally handled by the RRDB subnet.
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device) lr_shape = [d * self.opt['scale'] / self.flow_scale for d in lr.shape[2:]]
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr_shape, device=lr.device)
if reverse_with_grad: if reverse_with_grad:
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
add_gt_noise=add_gt_noise) add_gt_noise=add_gt_noise)