Make tecogen functional
This commit is contained in:
parent
10da206db6
commit
ac3da0c5a6
|
@ -6,6 +6,7 @@ from models.archs.arch_util import ConvGnSilu, make_layer
|
|||
|
||||
class TecoResblock(nn.Module):
|
||||
def __init__(self, nf):
|
||||
super(TecoResblock, self).__init__()
|
||||
self.nf = nf
|
||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
|
||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
|
||||
|
@ -19,6 +20,7 @@ class TecoResblock(nn.Module):
|
|||
|
||||
class TecoUpconv(nn.Module):
|
||||
def __init__(self, nf, scale):
|
||||
super(TecoUpconv, self).__init__()
|
||||
self.nf = nf
|
||||
self.scale = scale
|
||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
|
@ -32,7 +34,7 @@ class TecoUpconv(nn.Module):
|
|||
x = self.conv2(x)
|
||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
||||
x = self.conv3(x)
|
||||
return identity + self.final_conv(x)
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
|
||||
|
@ -43,12 +45,13 @@ class TecoUpconv(nn.Module):
|
|||
# - Upsample block is slightly more complicated.
|
||||
class TecoGen(nn.Module):
|
||||
def __init__(self, nf, scale):
|
||||
super(TecoGen, self).__init__()
|
||||
self.nf = nf
|
||||
self.scale = scale
|
||||
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
|
||||
res_layers = [TecoResblock(nf) for i in range(15)]
|
||||
upsample = TecoUpconv(nf)
|
||||
everything = [fea_conv] + res_layers + upsample
|
||||
upsample = TecoUpconv(nf, scale)
|
||||
everything = [fea_conv] + res_layers + [upsample]
|
||||
self.core = nn.Sequential(*everything)
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
|
@ -56,5 +59,5 @@ class TecoGen(nn.Module):
|
|||
if ref is None:
|
||||
ref = torch.zeros_like(x)
|
||||
join = torch.cat([x, ref], dim=1)
|
||||
return sequential_checkpoint(self.core, 6, join)
|
||||
return x + sequential_checkpoint(self.core, 6, join)
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||
if self.min_loss != 0:
|
||||
assert self.env['rank'] == 0 # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
||||
assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
self.rb_ptr = 0
|
||||
self.losses_computed = 0
|
||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_tecogen.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user