Make tecogen functional
This commit is contained in:
parent
10da206db6
commit
ac3da0c5a6
|
@ -6,6 +6,7 @@ from models.archs.arch_util import ConvGnSilu, make_layer
|
||||||
|
|
||||||
class TecoResblock(nn.Module):
|
class TecoResblock(nn.Module):
|
||||||
def __init__(self, nf):
|
def __init__(self, nf):
|
||||||
|
super(TecoResblock, self).__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
|
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
|
||||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
|
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
|
||||||
|
@ -19,6 +20,7 @@ class TecoResblock(nn.Module):
|
||||||
|
|
||||||
class TecoUpconv(nn.Module):
|
class TecoUpconv(nn.Module):
|
||||||
def __init__(self, nf, scale):
|
def __init__(self, nf, scale):
|
||||||
|
super(TecoUpconv, self).__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||||
|
@ -32,7 +34,7 @@ class TecoUpconv(nn.Module):
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
return identity + self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
|
|
||||||
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
|
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
|
||||||
|
@ -43,12 +45,13 @@ class TecoUpconv(nn.Module):
|
||||||
# - Upsample block is slightly more complicated.
|
# - Upsample block is slightly more complicated.
|
||||||
class TecoGen(nn.Module):
|
class TecoGen(nn.Module):
|
||||||
def __init__(self, nf, scale):
|
def __init__(self, nf, scale):
|
||||||
|
super(TecoGen, self).__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
|
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
|
||||||
res_layers = [TecoResblock(nf) for i in range(15)]
|
res_layers = [TecoResblock(nf) for i in range(15)]
|
||||||
upsample = TecoUpconv(nf)
|
upsample = TecoUpconv(nf, scale)
|
||||||
everything = [fea_conv] + res_layers + upsample
|
everything = [fea_conv] + res_layers + [upsample]
|
||||||
self.core = nn.Sequential(*everything)
|
self.core = nn.Sequential(*everything)
|
||||||
|
|
||||||
def forward(self, x, ref=None):
|
def forward(self, x, ref=None):
|
||||||
|
@ -56,5 +59,5 @@ class TecoGen(nn.Module):
|
||||||
if ref is None:
|
if ref is None:
|
||||||
ref = torch.zeros_like(x)
|
ref = torch.zeros_like(x)
|
||||||
join = torch.cat([x, ref], dim=1)
|
join = torch.cat([x, ref], dim=1)
|
||||||
return sequential_checkpoint(self.core, 6, join)
|
return x + sequential_checkpoint(self.core, 6, join)
|
||||||
|
|
||||||
|
|
|
@ -202,7 +202,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
if self.min_loss != 0:
|
if self.min_loss != 0:
|
||||||
assert self.env['rank'] == 0 # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
||||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||||
self.rb_ptr = 0
|
self.rb_ptr = 0
|
||||||
self.losses_computed = 0
|
self.losses_computed = 0
|
||||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_tecogen.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user