diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index ba2e2abd..7be2961f 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -367,6 +367,7 @@ class ConvGnLelu(nn.Module): else: return x + ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' class ConvGnSilu(nn.Module): diff --git a/codes/models/archs/teco_resgen.py b/codes/models/archs/teco_resgen.py index b89220a9..06d93fd4 100644 --- a/codes/models/archs/teco_resgen.py +++ b/codes/models/archs/teco_resgen.py @@ -1,5 +1,9 @@ +import os + import torch import torch.nn as nn +import torchvision + from utils.util import sequential_checkpoint from models.archs.arch_util import ConvGnSilu, make_layer @@ -29,7 +33,6 @@ class TecoUpconv(nn.Module): self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False) def forward(self, x): - identity = x x = self.conv1(x) x = self.conv2(x) x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest") @@ -59,5 +62,12 @@ class TecoGen(nn.Module): if ref is None: ref = torch.zeros_like(x) join = torch.cat([x, ref], dim=1) - return x + sequential_checkpoint(self.core, 6, join) + join = sequential_checkpoint(self.core, 6, join) + self.join = join.detach().clone() + .5 + return x + join + def visual_dbg(self, step, path): + torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,))) + + def get_debug_values(self, step, net_name): + return {'branch_std': self.join.std()} diff --git a/codes/models/base_model.py b/codes/models/base_model.py index be942956..c60afdc1 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -5,7 +5,6 @@ import torch.nn as nn from torch.nn.parallel.distributed import DistributedDataParallel import utils.util -from apex import amp class BaseModel():