SRFlow: accomodate mismatches between global scale and flow_scale
This commit is contained in:
parent
8f65f81ddb
commit
9a421a41f4
|
@ -19,10 +19,11 @@ class SRFlowNet(nn.Module):
|
|||
self.opt = opt
|
||||
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
||||
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
||||
initial_stride = opt_get(opt, ['networks', 'generator', 'initial_stride'], 1)
|
||||
self.RRDB = RRDBNet(in_nc, out_nc, nf=nf, nb=nb, gc=gc, scale=scale, opt=opt, initial_conv_stride=initial_stride)
|
||||
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
||||
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
||||
self.RRDB.load_state_dict(rrdb_state_dict, strict=False)
|
||||
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
|
||||
|
||||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||
hidden_channels = hidden_channels or 64
|
||||
|
@ -43,8 +44,8 @@ class SRFlowNet(nn.Module):
|
|||
if seed: torch.manual_seed(seed)
|
||||
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
||||
C = self.flowUpsamplerNet.C
|
||||
H = int(self.flow_scale * lr_shape[2] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
|
||||
W = int(self.flow_scale * lr_shape[3] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
|
||||
H = int(self.flow_scale * lr_shape[0] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
|
||||
W = int(self.flow_scale * lr_shape[1] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
|
||||
|
||||
size = (batch_size, C, H, W)
|
||||
if heat == 0:
|
||||
|
@ -84,8 +85,9 @@ class SRFlowNet(nn.Module):
|
|||
else:
|
||||
assert lr.shape[1] == 3
|
||||
if z is None:
|
||||
# Synthesize it.
|
||||
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device)
|
||||
# Synthesize it. Accommodate mismatches in LR scale and flow_scale, which are normally handled by the RRDB subnet.
|
||||
lr_shape = [d * self.opt['scale'] / self.flow_scale for d in lr.shape[2:]]
|
||||
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr_shape, device=lr.device)
|
||||
if reverse_with_grad:
|
||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
||||
add_gt_noise=add_gt_noise)
|
||||
|
|
Loading…
Reference in New Issue
Block a user