Make tecogen functional

This commit is contained in:
James Betker 2020-10-27 21:08:59 -06:00
parent 10da206db6
commit ac3da0c5a6
3 changed files with 9 additions and 6 deletions

View File

@ -6,6 +6,7 @@ from models.archs.arch_util import ConvGnSilu, make_layer
class TecoResblock(nn.Module): class TecoResblock(nn.Module):
def __init__(self, nf): def __init__(self, nf):
super(TecoResblock, self).__init__()
self.nf = nf self.nf = nf
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1) self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1) self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
@ -19,6 +20,7 @@ class TecoResblock(nn.Module):
class TecoUpconv(nn.Module): class TecoUpconv(nn.Module):
def __init__(self, nf, scale): def __init__(self, nf, scale):
super(TecoUpconv, self).__init__()
self.nf = nf self.nf = nf
self.scale = scale self.scale = scale
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
@ -32,7 +34,7 @@ class TecoUpconv(nn.Module):
x = self.conv2(x) x = self.conv2(x)
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest") x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
x = self.conv3(x) x = self.conv3(x)
return identity + self.final_conv(x) return self.final_conv(x)
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper. # Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
@ -43,12 +45,13 @@ class TecoUpconv(nn.Module):
# - Upsample block is slightly more complicated. # - Upsample block is slightly more complicated.
class TecoGen(nn.Module): class TecoGen(nn.Module):
def __init__(self, nf, scale): def __init__(self, nf, scale):
super(TecoGen, self).__init__()
self.nf = nf self.nf = nf
self.scale = scale self.scale = scale
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True) fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
res_layers = [TecoResblock(nf) for i in range(15)] res_layers = [TecoResblock(nf) for i in range(15)]
upsample = TecoUpconv(nf) upsample = TecoUpconv(nf, scale)
everything = [fea_conv] + res_layers + upsample everything = [fea_conv] + res_layers + [upsample]
self.core = nn.Sequential(*everything) self.core = nn.Sequential(*everything)
def forward(self, x, ref=None): def forward(self, x, ref=None):
@ -56,5 +59,5 @@ class TecoGen(nn.Module):
if ref is None: if ref is None:
ref = torch.zeros_like(x) ref = torch.zeros_like(x)
join = torch.cat([x, ref], dim=1) join = torch.cat([x, ref], dim=1)
return sequential_checkpoint(self.core, 6, join) return x + sequential_checkpoint(self.core, 6, join)

View File

@ -202,7 +202,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
# generators and discriminators by essentially having them skip steps while their counterparts "catch up". # generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
if self.min_loss != 0: if self.min_loss != 0:
assert self.env['rank'] == 0 # distributed training does not support 'min_loss' - it can result in backward() desync by design. assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
self.rb_ptr = 0 self.rb_ptr = 0
self.losses_computed = 0 self.losses_computed = 0

View File

@ -278,7 +278,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_tecogen.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)