diff --git a/codes/models/archs/srflow/FlowAffineCouplingsAblation.py b/codes/models/archs/srflow/FlowAffineCouplingsAblation.py index 297b7969..c50d0d5d 100644 --- a/codes/models/archs/srflow/FlowAffineCouplingsAblation.py +++ b/codes/models/archs/srflow/FlowAffineCouplingsAblation.py @@ -3,11 +3,10 @@ from torch import nn as nn from models.archs.srflow import thops from models.archs.srflow.flow import Conv2d, Conv2dZeros -from utils.util import opt_get class CondAffineSeparatedAndCond(nn.Module): - def __init__(self, in_channels, opt): + def __init__(self, in_channels, hidden_channels=64, affine_eps=.00001): super().__init__() self.need_features = True self.in_channels = in_channels @@ -15,10 +14,8 @@ class CondAffineSeparatedAndCond(nn.Module): self.kernel_hidden = 1 self.affine_eps = 0.0001 self.n_hidden_layers = 1 - hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) - self.hidden_channels = 64 if hidden_channels is None else hidden_channels - - self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) + self.hidden_channels = hidden_channels + self.affine_eps = affine_eps self.channels_for_nn = self.in_channels // 2 self.channels_for_co = self.in_channels - self.channels_for_nn diff --git a/codes/models/archs/srflow/FlowStep.py b/codes/models/archs/srflow/FlowStep.py index 37a9ebb1..87b8d6aa 100644 --- a/codes/models/archs/srflow/FlowStep.py +++ b/codes/models/archs/srflow/FlowStep.py @@ -1,10 +1,7 @@ import torch from torch import nn as nn -import models.archs.srflow -import models.archs.srflow.Permutations -from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation -from utils.util import opt_get +from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation, FlowActNorms, Permutations def getConditional(rrdbResults, position): @@ -28,7 +25,7 @@ class FlowStep(nn.Module): def __init__(self, in_channels, hidden_channels, actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", - LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None, + LU_decomposed=False, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None, position=None): # check configures assert flow_permutation in FlowStep.FlowPermutation, \ @@ -47,17 +44,16 @@ class FlowStep(nn.Module): self.acOpt = acOpt # 1. actnorm - self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) + self.actnorm = FlowActNorms.ActNorm2d(in_channels, actnorm_scale) # 2. permute if flow_permutation == "invconv": - self.invconv = models.modules.Permutations.InvertibleConv1x1( + self.invconv = Permutations.InvertibleConv1x1( in_channels, LU_decomposed=LU_decomposed) # 3. coupling if flow_coupling == "CondAffineSeparatedAndCond": - self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, - opt=opt) + self.affine = FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels) elif flow_coupling == "noCoupling": pass else: diff --git a/codes/models/archs/srflow/FlowUpsamplerNet.py b/codes/models/archs/srflow/FlowUpsamplerNet.py index 5beb1115..c492d7e4 100644 --- a/codes/models/archs/srflow/FlowUpsamplerNet.py +++ b/codes/models/archs/srflow/FlowUpsamplerNet.py @@ -3,34 +3,39 @@ import torch from torch import nn as nn import models.archs.srflow.Split -from models.archs.srflow import flow, thops +from models.archs.srflow import flow, thops, Split from models.archs.srflow.Split import Split2d from models.archs.srflow.glow_arch import f_conv2d_bias from models.archs.srflow.FlowStep import FlowStep from utils.util import opt_get +import torchvision class FlowUpsamplerNet(nn.Module): - def __init__(self, image_shape, hidden_channels, K, L=None, + def __init__(self, image_shape, hidden_channels, scale, + rrdb_blocks, actnorm_scale=1.0, - flow_permutation=None, + flow_permutation='invconv', flow_coupling="affine", - LU_decomposed=False, opt=None): + LU_decomposed=False, K=16, L=3, + norm_opt=None, + n_bypass_channels=None): super().__init__() self.layers = nn.ModuleList() self.output_shapes = [] - self.L = opt_get(opt, ['network_G', 'flow', 'L']) - self.K = opt_get(opt, ['network_G', 'flow', 'K']) + self.L = L + self.K = K + self.scale=scale if isinstance(self.K, int): self.K = [K for K in [K, ] * (self.L + 1)] - self.opt = opt H, W, self.C = image_shape + self.image_shape = image_shape self.check_image_shape() - if opt['scale'] == 16: + if scale == 16: self.levelToName = { 0: 'fea_up16', 1: 'fea_up8', @@ -39,7 +44,7 @@ class FlowUpsamplerNet(nn.Module): 4: 'fea_up1', } - if opt['scale'] == 8: + if scale == 8: self.levelToName = { 0: 'fea_up8', 1: 'fea_up4', @@ -48,7 +53,7 @@ class FlowUpsamplerNet(nn.Module): 4: 'fea_up0' } - elif opt['scale'] == 4: + elif scale == 4: self.levelToName = { 0: 'fea_up4', 1: 'fea_up2', @@ -57,14 +62,10 @@ class FlowUpsamplerNet(nn.Module): 4: 'fea_up-1' } - affineInCh = self.get_affineInCh(opt_get) - flow_permutation = self.get_flow_permutation(flow_permutation, opt) - - normOpt = opt_get(opt, ['network_G', 'flow', 'norm']) + affineInCh = self.get_affineInCh(rrdb_blocks) conditional_channels = {} - n_rrdb = self.get_n_rrdb_channels(opt, opt_get) - n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels']) + n_rrdb = self.get_n_rrdb_channels(rrdb_blocks) conditional_channels[0] = n_rrdb for level in range(1, self.L + 1): # Level 1 gets conditionals from 2, 3, 4 => L - level @@ -80,37 +81,29 @@ class FlowUpsamplerNet(nn.Module): H, W = self.arch_squeeze(H, W) # 2. K FlowStep - self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt) + self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels) self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation, - hidden_channels, normOpt, opt, opt_get, - n_conditinal_channels=conditional_channels[level]) + hidden_channels, norm_opt, + n_conditional_channels=conditional_channels[level]) # Split - self.arch_split(H, W, level, self.L, opt, opt_get) - - if opt_get(opt, ['network_G', 'flow', 'split', 'enable']): - self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) - else: - self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64) + self.arch_split(H, W, level, self.L) + self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) self.H = H self.W = W - self.scaleH = 160 / H - self.scaleW = 160 / W - def get_n_rrdb_channels(self, opt, opt_get): - blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) + def get_n_rrdb_channels(self, blocks): n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 return n_rrdb def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation, - hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None): - condAff = self.get_condAffSetting(opt, opt_get) + hidden_channels, normOpt, n_conditional_channels=None, condAff=None): if condAff is not None: - condAff['in_channels_rrdb'] = n_conditinal_channels + condAff['in_channels_rrdb'] = n_conditional_channels for k in range(K): - position_name = get_position_name(H, self.opt['scale']) + position_name = self.get_position_name(H, self.scale) if normOpt: normOpt['position'] = position_name self.layers.append( @@ -121,48 +114,37 @@ class FlowUpsamplerNet(nn.Module): flow_coupling=flow_coupling, acOpt=condAff, position=position_name, - LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt)) + LU_decomposed=LU_decomposed, idx=k, normOpt=normOpt)) self.output_shapes.append( [-1, self.C, H, W]) - def get_condAffSetting(self, opt, opt_get): - condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None - condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff - return condAff - - def arch_split(self, H, W, L, levels, opt, opt_get): - correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False) + def arch_split(self, H, W, L, levels, split_flow=True, correct_splits=False, logs_eps=0, consume_ratio=.5, split_conditional=False, cond_channels=None, split_type='Split2d'): correction = 0 if correct_splits else 1 - if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction: - logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0 - consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5 - position_name = get_position_name(H, self.opt['scale']) - position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None - cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels']) + if split_flow and L < levels - correction: + logs_eps = logs_eps + consume_ratio = consume_ratio + position_name = self.get_position_name(H, self.scale) + position = position_name if split_conditional else None cond_channels = 0 if cond_channels is None else cond_channels - t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d') - - if t == 'Split2d': - split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, - cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) + if split_type == 'Split2d': + split = Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, + cond_channels=cond_channels, consume_ratio=consume_ratio) self.layers.append(split) self.output_shapes.append([-1, split.num_channels_pass, H, W]) self.C = split.num_channels_pass - def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt): - if 'additionalFlowNoAffine' in opt['network_G']['flow']: - n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine']) - for _ in range(n_additionalFlowNoAffine): - self.layers.append( - FlowStep(in_channels=self.C, - hidden_channels=hidden_channels, - actnorm_scale=actnorm_scale, - flow_permutation='invconv', - flow_coupling='noCoupling', - LU_decomposed=LU_decomposed, opt=opt)) - self.output_shapes.append( - [-1, self.C, H, W]) + def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, additionalFlowNoAffine=2): + for _ in range(additionalFlowNoAffine): + self.layers.append( + FlowStep(in_channels=self.C, + hidden_channels=hidden_channels, + actnorm_scale=actnorm_scale, + flow_permutation='invconv', + flow_coupling='noCoupling', + LU_decomposed=LU_decomposed)) + self.output_shapes.append( + [-1, self.C, H, W]) def arch_squeeze(self, H, W): self.C, H, W = self.C * 4, H // 2, W // 2 @@ -170,13 +152,8 @@ class FlowUpsamplerNet(nn.Module): self.output_shapes.append([-1, self.C, H, W]) return H, W - def get_flow_permutation(self, flow_permutation, opt): - flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv') - return flow_permutation - - def get_affineInCh(self, opt_get): - affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] - affineInCh = (len(affineInCh) + 1) * 64 + def get_affineInCh(self, rrdb_blocks): + affineInCh = (len(rrdb_blocks) + 1) * 64 return affineInCh def check_image_shape(self): @@ -204,14 +181,14 @@ class FlowUpsamplerNet(nn.Module): level_conditionals = {} bypasses = {} - L = opt_get(self.opt, ['network_G', 'flow', 'L']) + L = self.L for level in range(1, L + 1): bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) for layer, shape in zip(self.layers, self.output_shapes): size = shape[2] - level = int(np.log(160 / size) / np.log(2)) + level = int(np.log(self.image_shape[0] / size) / np.log(2)) if level > 0 and level not in level_conditionals.keys(): level_conditionals[level] = rrdbResults[self.levelToName[level]] @@ -255,15 +232,12 @@ class FlowUpsamplerNet(nn.Module): # debug.imwrite("fl_fea", fl_fea) bypasses = {} level_conditionals = {} - if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True: - for level in range(self.L + 1): - level_conditionals[level] = rrdbResults[self.levelToName[level]] + for level in range(self.L + 1): + level_conditionals[level] = rrdbResults[self.levelToName[level]] for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): size = shape[2] - level = int(np.log(160 / size) / np.log(2)) - # size = fl_fea.shape[2] - # level = int(np.log(160 / size) / np.log(2)) + level = int(np.log(self.H / size) / np.log(2)) if isinstance(layer, Split2d): fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer, @@ -287,7 +261,7 @@ class FlowUpsamplerNet(nn.Module): return fl_fea, logdet -def get_position_name(H, scale): - downscale_factor = 160 // H - position_name = 'fea_up{}'.format(scale / downscale_factor) - return position_name + def get_position_name(self, H, scale): + downscale_factor = self.image_shape[0] // H + position_name = 'fea_up{}'.format(scale / downscale_factor) + return position_name diff --git a/codes/models/archs/srflow/Permutations.py b/codes/models/archs/srflow/Permutations.py index 86584e58..20548699 100644 --- a/codes/models/archs/srflow/Permutations.py +++ b/codes/models/archs/srflow/Permutations.py @@ -3,7 +3,7 @@ import torch from torch import nn as nn from torch.nn import functional as F -from models.modules import thops +from models.archs.srflow import thops class InvertibleConv1x1(nn.Module): diff --git a/codes/models/archs/srflow/RRDBNet_arch.py b/codes/models/archs/srflow/RRDBNet_arch.py index e747ca0e..033650b3 100644 --- a/codes/models/archs/srflow/RRDBNet_arch.py +++ b/codes/models/archs/srflow/RRDBNet_arch.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import models.archs.srflow.module_util as mutil -from utils.util import opt_get +from utils.util import opt_get, checkpoint class ResidualDenseBlock_5C(nn.Module): @@ -46,11 +46,14 @@ class RRDB(nn.Module): class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): - self.opt = opt + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, block_outputs=[], fea_up0=True, + fea_up1=False): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.scale = scale + self.block_outputs = block_outputs + self.fea_up0 = fea_up0 + self.fea_up1 = fea_up1 self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) @@ -73,11 +76,11 @@ class RRDBNet(nn.Module): def forward(self, x, get_steps=False): fea = self.conv_first(x) - block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + block_idxs = self.block_outputs or [] block_results = {} for idx, m in enumerate(self.RRDB_trunk.children()): - fea = m(fea) + fea = checkpoint(m, fea) for b in block_idxs: if b == idx: block_results["block_{}".format(idx)] = fea @@ -117,11 +120,9 @@ class RRDBNet(nn.Module): 'fea_up32': fea_up32, 'out': out} - fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False - if fea_up0_en: + if self.fea_up0: results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) - fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False - if fea_upn1_en: + if self.fea_up1: results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) if get_steps: diff --git a/codes/models/archs/srflow/Split.py b/codes/models/archs/srflow/Split.py index fb600344..c24eaf41 100644 --- a/codes/models/archs/srflow/Split.py +++ b/codes/models/archs/srflow/Split.py @@ -7,7 +7,7 @@ from models.archs.srflow.flow import Conv2dZeros, GaussianDiag class Split2d(nn.Module): - def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): + def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5): super().__init__() self.num_channels_consume = int(round(num_channels * consume_ratio)) @@ -17,7 +17,6 @@ class Split2d(nn.Module): out_channels=self.num_channels_consume * 2) self.logs_eps = logs_eps self.position = position - self.opt = opt def split2d_prior(self, z, ft): if ft is not None: diff --git a/codes/models/archs/srflow/srflow_arch.py b/codes/models/archs/srflow/srflow_arch.py index aaaddc95..2c127c3b 100644 --- a/codes/models/archs/srflow/srflow_arch.py +++ b/codes/models/archs/srflow/srflow_arch.py @@ -4,57 +4,41 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np + from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet import models.archs.srflow.thops as thops import models.archs.srflow.flow as flow -from utils.util import opt_get +from models.archs.srflow.RRDBNet_arch import RRDBNet class SRFlowNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None): + def __init__(self, in_nc, out_nc, nf, nb, quant, flow_block_maps, noise_quant, + hidden_channels=64, gc=32, scale=4, K=16, L=3, train_rrdb_at_step=0, + hr_img_shape=(128,128,3), coupling='CondAffineSeparatedAndCond'): super(SRFlowNet, self).__init__() - self.opt = opt - self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ - None else opt_get(opt, ['datasets', 'train', 'quant']) - self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) - hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels']) - hidden_channels = hidden_channels or 64 - self.RRDB_training = True # Default is true + self.scale = scale + self.noise_quant = noise_quant + self.quant = quant + self.flow_block_maps = flow_block_maps + self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, flow_block_maps) + self.train_rrdb_step = train_rrdb_at_step + self.RRDB_training = True - train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) - set_RRDB_to_train = False - if set_RRDB_to_train: - self.set_rrdb_training(True) - - self.flowUpsamplerNet = \ - FlowUpsamplerNet((160, 160, 3), hidden_channels, K, - flow_coupling=opt['network_G']['flow']['coupling'], opt=opt) + self.flowUpsamplerNet = FlowUpsamplerNet(image_shape=hr_img_shape, + hidden_channels=hidden_channels, + scale=scale, rrdb_blocks=flow_block_maps, + K=K, L=L, flow_coupling=coupling) self.i = 0 - def set_rrdb_training(self, trainable): - if self.RRDB_training != trainable: - for p in self.RRDB.parameters(): - p.requires_grad = trainable - self.RRDB_training = trainable - return True - return False - - def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, + def forward(self, gt=None, lr=None, reverse=False, z=None, eps_std=None, epses=None, reverse_with_grad=False, lr_enc=None, add_gt_noise=False, step=None, y_label=None): if not reverse: return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, y_onehot=y_label) else: - # assert lr.shape[0] == 1 assert lr.shape[1] == 3 - # assert lr.shape[2] == 20 - # assert lr.shape[3] == 20 - # assert z.shape[0] == 1 - # assert z.shape[1] == 3 * 8 * 8 - # assert z.shape[2] == 20 - # assert z.shape[3] == 20 if reverse_with_grad: return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise) @@ -74,8 +58,7 @@ class SRFlowNet(nn.Module): if add_gt_noise: # Setup - noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) - if noiseQuant: + if self.noise_quant: z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) logdet = logdet + float(-np.log(self.quant) * pixels) @@ -100,24 +83,23 @@ class SRFlowNet(nn.Module): def rrdbPreprocessing(self, lr): rrdbResults = self.RRDB(lr, get_steps=True) - block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + block_idxs = self.flow_block_maps if len(block_idxs) > 0: concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) - if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False: - keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] - if 'fea_up0' in rrdbResults.keys(): - keys.append('fea_up0') - if 'fea_up-1' in rrdbResults.keys(): - keys.append('fea_up-1') - if self.opt['scale'] >= 8: - keys.append('fea_up8') - if self.opt['scale'] == 16: - keys.append('fea_up16') - for k in keys: - h = rrdbResults[k].shape[2] - w = rrdbResults[k].shape[3] - rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) + keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] + if 'fea_up0' in rrdbResults.keys(): + keys.append('fea_up0') + if 'fea_up-1' in rrdbResults.keys(): + keys.append('fea_up-1') + if self.scale >= 8: + keys.append('fea_up8') + if self.scale == 16: + keys.append('fea_up16') + for k in keys: + h = rrdbResults[k].shape[2] + w = rrdbResults[k].shape[3] + rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) return rrdbResults def get_score(self, disc_loss_sigma, z): @@ -127,7 +109,7 @@ class SRFlowNet(nn.Module): def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): logdet = torch.zeros_like(lr[:, 0, 0, 0]) - pixels = thops.pixels(lr) * self.opt['scale'] ** 2 + pixels = thops.pixels(lr) * self.scale ** 2 if add_gt_noise: logdet = logdet - float(-np.log(self.quant) * pixels) @@ -138,4 +120,16 @@ class SRFlowNet(nn.Module): x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, logdet=logdet) - return x, logdet \ No newline at end of file + return x, logdet + + def set_rrdb_training(self, trainable): + if self.RRDB_training != trainable: + for p in self.RRDB.parameters(): + if not trainable: + p.DO_NOT_TRAIN = True + elif hasattr(p, "DO_NOT_TRAIN"): + del p.DO_NOT_TRAIN + self.RRDB_training = trainable + + def update_for_step(self, step, experiments_path='.'): + self.set_rrdb_training(step > self.train_rrdb_step) \ No newline at end of file