Extra logging for teco_resgen

This commit is contained in:
James Betker 2020-10-28 15:21:22 -06:00
parent 2ab5054d4c
commit f133243ac8
3 changed files with 13 additions and 3 deletions

View File

@ -367,6 +367,7 @@ class ConvGnLelu(nn.Module):
else: else:
return x return x
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. ''' kernel sizes. '''
class ConvGnSilu(nn.Module): class ConvGnSilu(nn.Module):

View File

@ -1,5 +1,9 @@
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision
from utils.util import sequential_checkpoint from utils.util import sequential_checkpoint
from models.archs.arch_util import ConvGnSilu, make_layer from models.archs.arch_util import ConvGnSilu, make_layer
@ -29,7 +33,6 @@ class TecoUpconv(nn.Module):
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False) self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
def forward(self, x): def forward(self, x):
identity = x
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest") x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
@ -59,5 +62,12 @@ class TecoGen(nn.Module):
if ref is None: if ref is None:
ref = torch.zeros_like(x) ref = torch.zeros_like(x)
join = torch.cat([x, ref], dim=1) join = torch.cat([x, ref], dim=1)
return x + sequential_checkpoint(self.core, 6, join) join = sequential_checkpoint(self.core, 6, join)
self.join = join.detach().clone() + .5
return x + join
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,)))
def get_debug_values(self, step, net_name):
return {'branch_std': self.join.std()}

View File

@ -5,7 +5,6 @@ import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
import utils.util import utils.util
from apex import amp
class BaseModel(): class BaseModel():