Support 2x RRDB with 4x srflow

This commit is contained in:
James Betker 2020-11-21 14:46:15 -07:00
parent cad92bada8
commit 519ba6f10c
4 changed files with 45 additions and 21 deletions

View File

@ -48,7 +48,7 @@ class FlowUpsamplerNet(nn.Module):
4: 'fea_up0'
}
elif opt['scale'] == 4:
elif opt['scale'] == 4 or opt['scale'] == 2:
self.levelToName = {
0: 'fea_up4',
1: 'fea_up2',

View File

@ -186,6 +186,7 @@ class RRDBNet(nn.Module):
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
if self.scale >= 4:
results = {'last_lr_fea': last_lr_fea,
'fea_up1': last_lr_fea,
'fea_up2': fea_up2,
@ -201,6 +202,28 @@ class RRDBNet(nn.Module):
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
if fea_upn1_en:
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
elif self.scale == 2:
# "Pretend" this is is 4x by shuffling around the inputs a bit.
half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True)
results = {'last_lr_fea': half,
'fea_up1': half,
'fea_up2': last_lr_fea,
'fea_up4': fea_up2,
'fea_up8': fea_up4,
'fea_up16': fea_up8,
'fea_up32': fea_up16,
'out': out}
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
if fea_up0_en:
results['fea_up0'] = quarter
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
if fea_upn1_en:
results['fea_up-1'] = eighth
else:
raise NotImplementedError
if get_steps:
for k, v in block_results.items():

View File

@ -27,6 +27,7 @@ class SRFlowNet(nn.Module):
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
hidden_channels = hidden_channels or 64
self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False)
self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) # <!-- hack to enable RRDB to do 2x scaling while retaining the flow architecture of 4x.
self.flowUpsamplerNet = \
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
@ -41,8 +42,8 @@ class SRFlowNet(nn.Module):
if seed: torch.manual_seed(seed)
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
C = self.flowUpsamplerNet.C
H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
H = int(self.flow_scale * lr_shape[2] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale))
W = int(self.flow_scale * lr_shape[3] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale))
size = (batch_size, C, H, W)
if heat == 0:
@ -149,9 +150,9 @@ class SRFlowNet(nn.Module):
keys.append('fea_up0')
if 'fea_up-1' in rrdbResults.keys():
keys.append('fea_up-1')
if self.opt['scale'] >= 8:
if self.flow_scale >= 8:
keys.append('fea_up8')
if self.opt['scale'] == 16:
if self.flow_scale == 16:
keys.append('fea_up16')
for k in keys:
h = rrdbResults[k].shape[2]
@ -166,7 +167,7 @@ class SRFlowNet(nn.Module):
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
logdet = torch.zeros_like(lr[:, 0, 0, 0])
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
pixels = thops.pixels(lr) * self.flow_scale ** 2
if add_gt_noise:
logdet = logdet - float(-np.log(self.quant) * pixels)

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()