Support 2x RRDB with 4x srflow
This commit is contained in:
parent
cad92bada8
commit
519ba6f10c
|
@ -48,7 +48,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
4: 'fea_up0'
|
||||
}
|
||||
|
||||
elif opt['scale'] == 4:
|
||||
elif opt['scale'] == 4 or opt['scale'] == 2:
|
||||
self.levelToName = {
|
||||
0: 'fea_up4',
|
||||
1: 'fea_up2',
|
||||
|
|
|
@ -186,6 +186,7 @@ class RRDBNet(nn.Module):
|
|||
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
||||
|
||||
if self.scale >= 4:
|
||||
results = {'last_lr_fea': last_lr_fea,
|
||||
'fea_up1': last_lr_fea,
|
||||
'fea_up2': fea_up2,
|
||||
|
@ -201,6 +202,28 @@ class RRDBNet(nn.Module):
|
|||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||
if fea_upn1_en:
|
||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
elif self.scale == 2:
|
||||
# "Pretend" this is is 4x by shuffling around the inputs a bit.
|
||||
half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
results = {'last_lr_fea': half,
|
||||
'fea_up1': half,
|
||||
'fea_up2': last_lr_fea,
|
||||
'fea_up4': fea_up2,
|
||||
'fea_up8': fea_up4,
|
||||
'fea_up16': fea_up8,
|
||||
'fea_up32': fea_up16,
|
||||
'out': out}
|
||||
|
||||
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
||||
if fea_up0_en:
|
||||
results['fea_up0'] = quarter
|
||||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||
if fea_upn1_en:
|
||||
results['fea_up-1'] = eighth
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if get_steps:
|
||||
for k, v in block_results.items():
|
||||
|
|
|
@ -27,6 +27,7 @@ class SRFlowNet(nn.Module):
|
|||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||
hidden_channels = hidden_channels or 64
|
||||
self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False)
|
||||
self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) # <!-- hack to enable RRDB to do 2x scaling while retaining the flow architecture of 4x.
|
||||
|
||||
self.flowUpsamplerNet = \
|
||||
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
||||
|
@ -41,8 +42,8 @@ class SRFlowNet(nn.Module):
|
|||
if seed: torch.manual_seed(seed)
|
||||
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
||||
C = self.flowUpsamplerNet.C
|
||||
H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
|
||||
W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
|
||||
H = int(self.flow_scale * lr_shape[2] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
|
||||
W = int(self.flow_scale * lr_shape[3] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
|
||||
|
||||
size = (batch_size, C, H, W)
|
||||
if heat == 0:
|
||||
|
@ -149,9 +150,9 @@ class SRFlowNet(nn.Module):
|
|||
keys.append('fea_up0')
|
||||
if 'fea_up-1' in rrdbResults.keys():
|
||||
keys.append('fea_up-1')
|
||||
if self.opt['scale'] >= 8:
|
||||
if self.flow_scale >= 8:
|
||||
keys.append('fea_up8')
|
||||
if self.opt['scale'] == 16:
|
||||
if self.flow_scale == 16:
|
||||
keys.append('fea_up16')
|
||||
for k in keys:
|
||||
h = rrdbResults[k].shape[2]
|
||||
|
@ -166,7 +167,7 @@ class SRFlowNet(nn.Module):
|
|||
|
||||
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
||||
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
||||
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
|
||||
pixels = thops.pixels(lr) * self.flow_scale ** 2
|
||||
|
||||
if add_gt_noise:
|
||||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user