diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index d7f1bd49..7d61a19d 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -58,7 +58,7 @@ class ExtensibleTrainer(BaseModel): new_net = None if net['type'] == 'generator': if new_net is None: - new_net = networks.define_G(net, None, opt['scale']).to(self.device) + new_net = networks.define_G(opt, net, opt['scale']).to(self.device) self.netsG[name] = new_net elif net['type'] == 'discriminator': if new_net is None: diff --git a/codes/models/archs/srflow_orig/FlowActNorms.py b/codes/models/archs/srflow_orig/FlowActNorms.py index 3292aafa..e92dc642 100644 --- a/codes/models/archs/srflow_orig/FlowActNorms.py +++ b/codes/models/archs/srflow_orig/FlowActNorms.py @@ -1,7 +1,7 @@ import torch from torch import nn as nn -from models.modules import thops +from models.archs.srflow_orig import thops class _ActNorm(nn.Module): diff --git a/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py b/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py index 5a94abe3..a3c9fb0d 100644 --- a/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py +++ b/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py @@ -1,8 +1,8 @@ import torch from torch import nn as nn -from models.modules import thops -from models.modules.flow import Conv2d, Conv2dZeros +from models.archs.srflow_orig import thops +from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros from utils.util import opt_get @@ -15,10 +15,10 @@ class CondAffineSeparatedAndCond(nn.Module): self.kernel_hidden = 1 self.affine_eps = 0.0001 self.n_hidden_layers = 1 - hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) + hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) self.hidden_channels = 64 if hidden_channels is None else hidden_channels - self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) + self.affine_eps = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) self.channels_for_nn = self.in_channels // 2 self.channels_for_co = self.in_channels - self.channels_for_nn diff --git a/codes/models/archs/srflow_orig/FlowStep.py b/codes/models/archs/srflow_orig/FlowStep.py index 41a867be..1c128c92 100644 --- a/codes/models/archs/srflow_orig/FlowStep.py +++ b/codes/models/archs/srflow_orig/FlowStep.py @@ -1,9 +1,8 @@ import torch from torch import nn as nn -import models.modules -import models.modules.Permutations -from models.modules import flow, thops, FlowAffineCouplingsAblation +import models.archs.srflow_orig.Permutations +from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation from utils.util import opt_get @@ -47,16 +46,16 @@ class FlowStep(nn.Module): self.acOpt = acOpt # 1. actnorm - self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) + self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) # 2. permute if flow_permutation == "invconv": - self.invconv = models.modules.Permutations.InvertibleConv1x1( + self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1( in_channels, LU_decomposed=LU_decomposed) # 3. coupling if flow_coupling == "CondAffineSeparatedAndCond": - self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, + self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, opt=opt) elif flow_coupling == "noCoupling": pass diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py index 6bc4f5d8..0b0d4e23 100644 --- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py +++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py @@ -2,11 +2,11 @@ import numpy as np import torch from torch import nn as nn -import models.modules.Split -from models.modules import flow, thops -from models.modules.Split import Split2d -from models.modules.glow_arch import f_conv2d_bias -from models.modules.FlowStep import FlowStep +import models.archs.srflow_orig.Split +from models.archs.srflow_orig import flow, thops +from models.archs.srflow_orig.Split import Split2d +from models.archs.srflow_orig.glow_arch import f_conv2d_bias +from models.archs.srflow_orig.FlowStep import FlowStep from utils.util import opt_get @@ -21,8 +21,8 @@ class FlowUpsamplerNet(nn.Module): self.layers = nn.ModuleList() self.output_shapes = [] - self.L = opt_get(opt, ['network_G', 'flow', 'L']) - self.K = opt_get(opt, ['network_G', 'flow', 'K']) + self.L = opt_get(opt, ['networks', 'generator','flow', 'L']) + self.K = opt_get(opt, ['networks', 'generator','flow', 'K']) if isinstance(self.K, int): self.K = [K for K in [K, ] * (self.L + 1)] @@ -60,11 +60,11 @@ class FlowUpsamplerNet(nn.Module): affineInCh = self.get_affineInCh(opt_get) flow_permutation = self.get_flow_permutation(flow_permutation, opt) - normOpt = opt_get(opt, ['network_G', 'flow', 'norm']) + normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm']) conditional_channels = {} n_rrdb = self.get_n_rrdb_channels(opt, opt_get) - n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels']) + n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels']) conditional_channels[0] = n_rrdb for level in range(1, self.L + 1): # Level 1 gets conditionals from 2, 3, 4 => L - level @@ -88,7 +88,7 @@ class FlowUpsamplerNet(nn.Module): # Split self.arch_split(H, W, level, self.L, opt, opt_get) - if opt_get(opt, ['network_G', 'flow', 'split', 'enable']): + if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']): self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) else: self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64) @@ -99,7 +99,7 @@ class FlowUpsamplerNet(nn.Module): self.scaleW = 160 / W def get_n_rrdb_channels(self, opt, opt_get): - blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) + blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 return n_rrdb @@ -126,33 +126,33 @@ class FlowUpsamplerNet(nn.Module): [-1, self.C, H, W]) def get_condAffSetting(self, opt, opt_get): - condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None - condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff + condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None + condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff return condAff def arch_split(self, H, W, L, levels, opt, opt_get): - correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False) + correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False) correction = 0 if correct_splits else 1 - if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction: - logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0 - consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5 + if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction: + logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0 + consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5 position_name = get_position_name(H, self.opt['scale']) - position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None - cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels']) + position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None + cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels']) cond_channels = 0 if cond_channels is None else cond_channels - t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d') + t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d') if t == 'Split2d': - split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, + split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) self.layers.append(split) self.output_shapes.append([-1, split.num_channels_pass, H, W]) self.C = split.num_channels_pass def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt): - if 'additionalFlowNoAffine' in opt['network_G']['flow']: - n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine']) + if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']: + n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine']) for _ in range(n_additionalFlowNoAffine): self.layers.append( FlowStep(in_channels=self.C, @@ -171,11 +171,11 @@ class FlowUpsamplerNet(nn.Module): return H, W def get_flow_permutation(self, flow_permutation, opt): - flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv') + flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv') return flow_permutation def get_affineInCh(self, opt_get): - affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or [] affineInCh = (len(affineInCh) + 1) * 64 return affineInCh @@ -204,7 +204,7 @@ class FlowUpsamplerNet(nn.Module): level_conditionals = {} bypasses = {} - L = opt_get(self.opt, ['network_G', 'flow', 'L']) + L = opt_get(self.opt, ['networks', 'generator','flow', 'L']) for level in range(1, L + 1): bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) @@ -255,7 +255,7 @@ class FlowUpsamplerNet(nn.Module): # debug.imwrite("fl_fea", fl_fea) bypasses = {} level_conditionals = {} - if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True: + if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True: for level in range(self.L + 1): level_conditionals[level] = rrdbResults[self.levelToName[level]] diff --git a/codes/models/archs/srflow_orig/Permutations.py b/codes/models/archs/srflow_orig/Permutations.py index 86584e58..9115dcc7 100644 --- a/codes/models/archs/srflow_orig/Permutations.py +++ b/codes/models/archs/srflow_orig/Permutations.py @@ -3,7 +3,7 @@ import torch from torch import nn as nn from torch.nn import functional as F -from models.modules import thops +from models.archs.srflow_orig import thops class InvertibleConv1x1(nn.Module): diff --git a/codes/models/archs/srflow_orig/RRDBNet_arch.py b/codes/models/archs/srflow_orig/RRDBNet_arch.py index f5cdb4d5..a6d75117 100644 --- a/codes/models/archs/srflow_orig/RRDBNet_arch.py +++ b/codes/models/archs/srflow_orig/RRDBNet_arch.py @@ -2,70 +2,148 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F -import models.modules.module_util as mutil +import models.archs.srflow_orig.module_util as mutil +from models.archs.arch_util import default_init_weights, ConvGnSilu from utils.util import opt_get -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + mid_channels (int): Channel number of intermediate features. + growth_channels (int): Channels for each growth. + """ + + def __init__(self, mid_channels=64, growth_channels=32): + super(ResidualDenseBlock, self).__init__() + for i in range(5): + out_channels = mid_channels if i == 4 else growth_channels + self.add_module( + f'conv{i+1}', + nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3, + 1, 1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + for i in range(5): + default_init_weights(getattr(self, f'conv{i+1}'), 0.1) - # initialization - mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance return x5 * 0.2 + x class RRDB(nn.Module): - '''Residual in Residual Dense Block''' + """Residual in Residual Dense Block. - def __init__(self, nf, gc=32): + Used in RRDB-Net in ESRGAN. + + Args: + mid_channels (int): Channel number of intermediate features. + growth_channels (int): Channels for each growth. + """ + + def __init__(self, mid_channels, growth_channels=32): super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) + self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) + self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) + self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance return out * 0.2 + x +class RRDBWithBypass(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + mid_channels (int): Channel number of intermediate features. + growth_channels (int): Channels for each growth. + """ + + def __init__(self, mid_channels, growth_channels=32): + super(RRDBWithBypass, self).__init__() + self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) + self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) + self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) + self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True), + ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), + ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), + nn.Sigmoid()) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + bypass = self.bypass(torch.cat([x, out], dim=1)) + self.bypass_map = bypass.detach().clone() + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 * bypass + x + + class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): self.opt = opt super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + bypass = opt_get(self.opt, ['networks', 'generator', 'rrdb_bypass']) + if bypass: + RRDB_block_f = functools.partial(RRDBWithBypass, mid_channels=nf, growth_channels=gc) + else: + RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc) self.scale = scale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.body = mutil.make_layer(RRDB_block_f, nb) + self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if self.scale >= 8: - self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if self.scale >= 16: - self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_up4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if self.scale >= 32: - self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_up5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) @@ -73,23 +151,23 @@ class RRDBNet(nn.Module): def forward(self, x, get_steps=False): fea = self.conv_first(x) - block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or [] block_results = {} - for idx, m in enumerate(self.RRDB_trunk.children()): + for idx, m in enumerate(self.body.children()): fea = m(fea) for b in block_idxs: if b == idx: block_results["block_{}".format(idx)] = fea - trunk = self.trunk_conv(fea) + trunk = self.conv_body(fea) last_lr_fea = fea + trunk - fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) + fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up2) - fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up4) fea_up8 = None @@ -97,16 +175,16 @@ class RRDBNet(nn.Module): fea_up32 = None if self.scale >= 8: - fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up8) if self.scale >= 16: - fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up16 = self.conv_up4(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up16) if self.scale >= 32: - fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up32 = self.conv_up5(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up32) - out = self.conv_last(self.lrelu(self.HRconv(fea))) + out = self.conv_last(self.lrelu(self.conv_hr(fea))) results = {'last_lr_fea': last_lr_fea, 'fea_up1': last_lr_fea, @@ -117,10 +195,10 @@ class RRDBNet(nn.Module): 'fea_up32': fea_up32, 'out': out} - fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False + fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False if fea_up0_en: results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) - fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False + fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False if fea_upn1_en: results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index b95374ef..2bdab09d 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -4,10 +4,10 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -from models.modules.RRDBNet_arch import RRDBNet -from models.modules.FlowUpsamplerNet import FlowUpsamplerNet -import models.modules.thops as thops -import models.modules.flow as flow +from models.archs.srflow_orig.RRDBNet_arch import RRDBNet +from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet +import models.archs.srflow_orig.thops as thops +import models.archs.srflow_orig.flow as flow from utils.util import opt_get @@ -19,27 +19,40 @@ class SRFlowNet(nn.Module): self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ None else opt_get(opt, ['datasets', 'train', 'quant']) self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) - hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels']) + if 'pretrain_rrdb' in opt['networks']['generator'].keys(): + rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb']) + self.RRDB.load_state_dict(rrdb_state_dict, strict=True) + + hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels']) hidden_channels = hidden_channels or 64 self.RRDB_training = True # Default is true - train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) - set_RRDB_to_train = False - if set_RRDB_to_train: - self.set_rrdb_training(True) + train_RRDB_delay = opt_get(self.opt, ['networks', 'generator','train_RRDB_delay']) + self.RRDB_training = False self.flowUpsamplerNet = \ FlowUpsamplerNet((160, 160, 3), hidden_channels, K, - flow_coupling=opt['network_G']['flow']['coupling'], opt=opt) + flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt) self.i = 0 - def set_rrdb_training(self, trainable): - if self.RRDB_training != trainable: - for p in self.RRDB.parameters(): - p.requires_grad = trainable - self.RRDB_training = trainable - return True - return False + def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'): + if seed: torch.manual_seed(seed) + if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']): + C = self.flowUpsamplerNet.C + H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH) + W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW) + + size = (batch_size, C, H, W) + if heat == 0: + z = torch.zeros(size) + else: + z = torch.normal(mean=0, std=heat, size=size) + else: + L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3 + fac = 2 ** (L - 3) + z_size = int(self.lr_size // (2 ** (L - 3))) + z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) + return z.to(device) def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, lr_enc=None, @@ -48,14 +61,10 @@ class SRFlowNet(nn.Module): return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, y_onehot=y_label) else: - # assert lr.shape[0] == 1 assert lr.shape[1] == 3 - # assert lr.shape[2] == 20 - # assert lr.shape[3] == 20 - # assert z.shape[0] == 1 - # assert z.shape[1] == 3 * 8 * 8 - # assert z.shape[2] == 20 - # assert z.shape[3] == 20 + if z is None: + # Synthesize it. + z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device) if reverse_with_grad: return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise) @@ -66,7 +75,11 @@ class SRFlowNet(nn.Module): def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None): if lr_enc is None: - lr_enc = self.rrdbPreprocessing(lr) + if self.RRDB_training: + lr_enc = self.rrdbPreprocessing(lr) + else: + with torch.no_grad(): + lr_enc = self.rrdbPreprocessing(lr) logdet = torch.zeros_like(gt[:, 0, 0, 0]) pixels = thops.pixels(gt) @@ -75,7 +88,7 @@ class SRFlowNet(nn.Module): if add_gt_noise: # Setup - noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) + noiseQuant = opt_get(self.opt, ['networks', 'generator','flow', 'augmentation', 'noiseQuant'], True) if noiseQuant: z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) logdet = logdet + float(-np.log(self.quant) * pixels) @@ -101,11 +114,11 @@ class SRFlowNet(nn.Module): def rrdbPreprocessing(self, lr): rrdbResults = self.RRDB(lr, get_steps=True) - block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or [] if len(block_idxs) > 0: concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) - if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False: + if opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'concat']) or False: keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] if 'fea_up0' in rrdbResults.keys(): keys.append('fea_up0') @@ -134,7 +147,11 @@ class SRFlowNet(nn.Module): logdet = logdet - float(-np.log(self.quant) * pixels) if lr_enc is None: - lr_enc = self.rrdbPreprocessing(lr) + if self.RRDB_training: + lr_enc = self.rrdbPreprocessing(lr) + else: + with torch.no_grad(): + lr_enc = self.rrdbPreprocessing(lr) x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, logdet=logdet) diff --git a/codes/models/archs/srflow_orig/Split.py b/codes/models/archs/srflow_orig/Split.py index 60897eb0..b7b1df98 100644 --- a/codes/models/archs/srflow_orig/Split.py +++ b/codes/models/archs/srflow_orig/Split.py @@ -1,9 +1,9 @@ import torch from torch import nn as nn -from models.modules import thops -from models.modules.FlowStep import FlowStep -from models.modules.flow import Conv2dZeros, GaussianDiag +from models.archs.srflow_orig import thops +from models.archs.srflow_orig.FlowStep import FlowStep +from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag from utils.util import opt_get diff --git a/codes/models/archs/srflow_orig/flow.py b/codes/models/archs/srflow_orig/flow.py index 5c0ae968..71e03646 100644 --- a/codes/models/archs/srflow_orig/flow.py +++ b/codes/models/archs/srflow_orig/flow.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from models.modules.FlowActNorms import ActNorm2d +from models.archs.srflow_orig.FlowActNorms import ActNorm2d from . import thops diff --git a/codes/models/base_model.py b/codes/models/base_model.py index c60afdc1..55371407 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -105,12 +105,25 @@ class BaseModel(): if 'state_dict' in load_net: load_net = load_net['state_dict'] + is_srflow = False load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith('module.'): load_net_clean[k[7:]] = v if k.startswith('generator'): # Hack to fix ESRGAN pretrained model. load_net_clean[k[10:]] = v + if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names. + is_srflow = True + fixed_key = k.replace('RRDB_trunk', 'body') + if '.RDB' in fixed_key: + fixed_key = fixed_key.replace('.RDB', '.rdb') + elif '.upconv' in fixed_key: + fixed_key = fixed_key.replace('.upconv', '.conv_up') + elif '.trunk_conv' in fixed_key: + fixed_key = fixed_key.replace('.trunk_conv', '.conv_body') + elif '.HRconv' in fixed_key: + fixed_key = fixed_key.replace('.HRconv', '.conv_hr') + load_net_clean[fixed_key] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) diff --git a/codes/models/networks.py b/codes/models/networks.py index fb1a0149..75130989 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -28,11 +28,7 @@ from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') # Generator -def define_G(opt, net_key='network_G', scale=None): - if net_key is not None: - opt_net = opt[net_key] - else: - opt_net = opt +def define_G(opt, opt_net, scale=None): if scale is None: scale = opt['scale'] which_model = opt_net['which_model_G'] @@ -141,6 +137,17 @@ def define_G(opt, net_key='network_G', scale=None): netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], style_depth=opt_net['style_depth'], structure_input=is_structured, attn_layers=attn) + elif which_model == 'srflow': + from models.archs.srflow import SRFlow_arch + netG = SRFlow_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], + quant=opt_net['quant'], flow_block_maps=opt_net['rrdb_block_maps'], + noise_quant=opt_net['noise_quant'], hidden_channels=opt_net['nf'], + K=opt_net['K'], L=opt_net['L'], train_rrdb_at_step=opt_net['rrdb_train_step'], + hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale']) + elif which_model == 'srflow_orig': + from models.archs.srflow_orig import SRFlowNet_arch + netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], + K=opt_net['K'], opt=opt) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 333601c5..3cb9366a 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -157,6 +157,8 @@ class AddNoiseInjector(Injector): scale = state[self.opt['scale']] else: scale = self.opt['scale'] + if scale is None: + scale = 1 ref = state[self.opt['in']] if self.mode == 'normal': diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 483d9592..448d5c09 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -139,7 +139,8 @@ class ConfigurableStep(Module): continue # Don't do injections tagged with 'after' or 'before' when we are out of spec. if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ - 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: + 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \ + 'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0: continue injected = inj(local_state) local_state.update(injected) diff --git a/codes/train.py b/codes/train.py index ecef64c3..bb0be9cb 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index 4eff3b79..ecef64c3 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()