SRFlow: accomodate mismatches between global scale and flow_scale
This commit is contained in:
parent
8f65f81ddb
commit
9a421a41f4
|
@ -19,10 +19,11 @@ class SRFlowNet(nn.Module):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
||||||
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
||||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
initial_stride = opt_get(opt, ['networks', 'generator', 'initial_stride'], 1)
|
||||||
|
self.RRDB = RRDBNet(in_nc, out_nc, nf=nf, nb=nb, gc=gc, scale=scale, opt=opt, initial_conv_stride=initial_stride)
|
||||||
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
||||||
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
||||||
self.RRDB.load_state_dict(rrdb_state_dict, strict=False)
|
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
|
||||||
|
|
||||||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||||
hidden_channels = hidden_channels or 64
|
hidden_channels = hidden_channels or 64
|
||||||
|
@ -43,8 +44,8 @@ class SRFlowNet(nn.Module):
|
||||||
if seed: torch.manual_seed(seed)
|
if seed: torch.manual_seed(seed)
|
||||||
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
||||||
C = self.flowUpsamplerNet.C
|
C = self.flowUpsamplerNet.C
|
||||||
H = int(self.flow_scale * lr_shape[2] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
|
H = int(self.flow_scale * lr_shape[0] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
|
||||||
W = int(self.flow_scale * lr_shape[3] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
|
W = int(self.flow_scale * lr_shape[1] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
|
||||||
|
|
||||||
size = (batch_size, C, H, W)
|
size = (batch_size, C, H, W)
|
||||||
if heat == 0:
|
if heat == 0:
|
||||||
|
@ -84,8 +85,9 @@ class SRFlowNet(nn.Module):
|
||||||
else:
|
else:
|
||||||
assert lr.shape[1] == 3
|
assert lr.shape[1] == 3
|
||||||
if z is None:
|
if z is None:
|
||||||
# Synthesize it.
|
# Synthesize it. Accommodate mismatches in LR scale and flow_scale, which are normally handled by the RRDB subnet.
|
||||||
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device)
|
lr_shape = [d * self.opt['scale'] / self.flow_scale for d in lr.shape[2:]]
|
||||||
|
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr_shape, device=lr.device)
|
||||||
if reverse_with_grad:
|
if reverse_with_grad:
|
||||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
||||||
add_gt_noise=add_gt_noise)
|
add_gt_noise=add_gt_noise)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user