From f133243ac88f29494ea6a4b2b0ad71c660df6872 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 28 Oct 2020 15:21:22 -0600 Subject: [PATCH] Extra logging for teco_resgen --- codes/models/archs/arch_util.py | 1 + codes/models/archs/teco_resgen.py | 14 ++++++++++++-- codes/models/base_model.py | 1 - 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index ba2e2abd..7be2961f 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -367,6 +367,7 @@ class ConvGnLelu(nn.Module): else: return x + ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' class ConvGnSilu(nn.Module): diff --git a/codes/models/archs/teco_resgen.py b/codes/models/archs/teco_resgen.py index b89220a9..06d93fd4 100644 --- a/codes/models/archs/teco_resgen.py +++ b/codes/models/archs/teco_resgen.py @@ -1,5 +1,9 @@ +import os + import torch import torch.nn as nn +import torchvision + from utils.util import sequential_checkpoint from models.archs.arch_util import ConvGnSilu, make_layer @@ -29,7 +33,6 @@ class TecoUpconv(nn.Module): self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False) def forward(self, x): - identity = x x = self.conv1(x) x = self.conv2(x) x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest") @@ -59,5 +62,12 @@ class TecoGen(nn.Module): if ref is None: ref = torch.zeros_like(x) join = torch.cat([x, ref], dim=1) - return x + sequential_checkpoint(self.core, 6, join) + join = sequential_checkpoint(self.core, 6, join) + self.join = join.detach().clone() + .5 + return x + join + def visual_dbg(self, step, path): + torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,))) + + def get_debug_values(self, step, net_name): + return {'branch_std': self.join.std()} diff --git a/codes/models/base_model.py b/codes/models/base_model.py index be942956..c60afdc1 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -5,7 +5,6 @@ import torch.nn as nn from torch.nn.parallel.distributed import DistributedDataParallel import utils.util -from apex import amp class BaseModel():