From ac3da0c5a6c789f97f0dc45ff101aa8069f07381 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 27 Oct 2020 21:08:59 -0600 Subject: [PATCH] Make tecogen functional --- codes/models/archs/teco_resgen.py | 11 +++++++---- codes/models/steps/losses.py | 2 +- codes/train2.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/codes/models/archs/teco_resgen.py b/codes/models/archs/teco_resgen.py index 85d46502..b89220a9 100644 --- a/codes/models/archs/teco_resgen.py +++ b/codes/models/archs/teco_resgen.py @@ -6,6 +6,7 @@ from models.archs.arch_util import ConvGnSilu, make_layer class TecoResblock(nn.Module): def __init__(self, nf): + super(TecoResblock, self).__init__() self.nf = nf self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1) self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1) @@ -19,6 +20,7 @@ class TecoResblock(nn.Module): class TecoUpconv(nn.Module): def __init__(self, nf, scale): + super(TecoUpconv, self).__init__() self.nf = nf self.scale = scale self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) @@ -32,7 +34,7 @@ class TecoUpconv(nn.Module): x = self.conv2(x) x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest") x = self.conv3(x) - return identity + self.final_conv(x) + return self.final_conv(x) # Extremely simple resnet based generator that is very similar to the one used in the tecogan paper. @@ -43,12 +45,13 @@ class TecoUpconv(nn.Module): # - Upsample block is slightly more complicated. class TecoGen(nn.Module): def __init__(self, nf, scale): + super(TecoGen, self).__init__() self.nf = nf self.scale = scale fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True) res_layers = [TecoResblock(nf) for i in range(15)] - upsample = TecoUpconv(nf) - everything = [fea_conv] + res_layers + upsample + upsample = TecoUpconv(nf, scale) + everything = [fea_conv] + res_layers + [upsample] self.core = nn.Sequential(*everything) def forward(self, x, ref=None): @@ -56,5 +59,5 @@ class TecoGen(nn.Module): if ref is None: ref = torch.zeros_like(x) join = torch.cat([x, ref], dim=1) - return sequential_checkpoint(self.core, 6, join) + return x + sequential_checkpoint(self.core, 6, join) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 830ff6b2..5ea6c113 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -202,7 +202,7 @@ class DiscriminatorGanLoss(ConfigurableLoss): # generators and discriminators by essentially having them skip steps while their counterparts "catch up". self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 if self.min_loss != 0: - assert self.env['rank'] == 0 # distributed training does not support 'min_loss' - it can result in backward() desync by design. + assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design. self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) self.rb_ptr = 0 self.losses_computed = 0 diff --git a/codes/train2.py b/codes/train2.py index f9c94a82..456b570c 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -278,7 +278,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_tecogen.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)