Support 2x RRDB with 4x srflow
This commit is contained in:
parent
cad92bada8
commit
519ba6f10c
|
@ -48,7 +48,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
4: 'fea_up0'
|
4: 'fea_up0'
|
||||||
}
|
}
|
||||||
|
|
||||||
elif opt['scale'] == 4:
|
elif opt['scale'] == 4 or opt['scale'] == 2:
|
||||||
self.levelToName = {
|
self.levelToName = {
|
||||||
0: 'fea_up4',
|
0: 'fea_up4',
|
||||||
1: 'fea_up2',
|
1: 'fea_up2',
|
||||||
|
|
|
@ -186,21 +186,44 @@ class RRDBNet(nn.Module):
|
||||||
|
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
||||||
|
|
||||||
results = {'last_lr_fea': last_lr_fea,
|
if self.scale >= 4:
|
||||||
'fea_up1': last_lr_fea,
|
results = {'last_lr_fea': last_lr_fea,
|
||||||
'fea_up2': fea_up2,
|
'fea_up1': last_lr_fea,
|
||||||
'fea_up4': fea_up4,
|
'fea_up2': fea_up2,
|
||||||
'fea_up8': fea_up8,
|
'fea_up4': fea_up4,
|
||||||
'fea_up16': fea_up16,
|
'fea_up8': fea_up8,
|
||||||
'fea_up32': fea_up32,
|
'fea_up16': fea_up16,
|
||||||
'out': out}
|
'fea_up32': fea_up32,
|
||||||
|
'out': out}
|
||||||
|
|
||||||
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
||||||
if fea_up0_en:
|
if fea_up0_en:
|
||||||
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||||
if fea_upn1_en:
|
if fea_upn1_en:
|
||||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
|
elif self.scale == 2:
|
||||||
|
# "Pretend" this is is 4x by shuffling around the inputs a bit.
|
||||||
|
half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
|
quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
|
eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
|
results = {'last_lr_fea': half,
|
||||||
|
'fea_up1': half,
|
||||||
|
'fea_up2': last_lr_fea,
|
||||||
|
'fea_up4': fea_up2,
|
||||||
|
'fea_up8': fea_up4,
|
||||||
|
'fea_up16': fea_up8,
|
||||||
|
'fea_up32': fea_up16,
|
||||||
|
'out': out}
|
||||||
|
|
||||||
|
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
||||||
|
if fea_up0_en:
|
||||||
|
results['fea_up0'] = quarter
|
||||||
|
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||||
|
if fea_upn1_en:
|
||||||
|
results['fea_up-1'] = eighth
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
if get_steps:
|
if get_steps:
|
||||||
for k, v in block_results.items():
|
for k, v in block_results.items():
|
||||||
|
|
|
@ -27,6 +27,7 @@ class SRFlowNet(nn.Module):
|
||||||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||||
hidden_channels = hidden_channels or 64
|
hidden_channels = hidden_channels or 64
|
||||||
self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False)
|
self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False)
|
||||||
|
self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) # <!-- hack to enable RRDB to do 2x scaling while retaining the flow architecture of 4x.
|
||||||
|
|
||||||
self.flowUpsamplerNet = \
|
self.flowUpsamplerNet = \
|
||||||
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
||||||
|
@ -41,8 +42,8 @@ class SRFlowNet(nn.Module):
|
||||||
if seed: torch.manual_seed(seed)
|
if seed: torch.manual_seed(seed)
|
||||||
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
||||||
C = self.flowUpsamplerNet.C
|
C = self.flowUpsamplerNet.C
|
||||||
H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
|
H = int(self.flow_scale * lr_shape[2] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
|
||||||
W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
|
W = int(self.flow_scale * lr_shape[3] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
|
||||||
|
|
||||||
size = (batch_size, C, H, W)
|
size = (batch_size, C, H, W)
|
||||||
if heat == 0:
|
if heat == 0:
|
||||||
|
@ -149,9 +150,9 @@ class SRFlowNet(nn.Module):
|
||||||
keys.append('fea_up0')
|
keys.append('fea_up0')
|
||||||
if 'fea_up-1' in rrdbResults.keys():
|
if 'fea_up-1' in rrdbResults.keys():
|
||||||
keys.append('fea_up-1')
|
keys.append('fea_up-1')
|
||||||
if self.opt['scale'] >= 8:
|
if self.flow_scale >= 8:
|
||||||
keys.append('fea_up8')
|
keys.append('fea_up8')
|
||||||
if self.opt['scale'] == 16:
|
if self.flow_scale == 16:
|
||||||
keys.append('fea_up16')
|
keys.append('fea_up16')
|
||||||
for k in keys:
|
for k in keys:
|
||||||
h = rrdbResults[k].shape[2]
|
h = rrdbResults[k].shape[2]
|
||||||
|
@ -166,7 +167,7 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
||||||
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
||||||
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
|
pixels = thops.pixels(lr) * self.flow_scale ** 2
|
||||||
|
|
||||||
if add_gt_noise:
|
if add_gt_noise:
|
||||||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
logdet = logdet - float(-np.log(self.quant) * pixels)
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user