Extra logging for teco_resgen
This commit is contained in:
parent
2ab5054d4c
commit
f133243ac8
|
@ -367,6 +367,7 @@ class ConvGnLelu(nn.Module):
|
|||
else:
|
||||
return x
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvGnSilu(nn.Module):
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from utils.util import sequential_checkpoint
|
||||
from models.archs.arch_util import ConvGnSilu, make_layer
|
||||
|
||||
|
@ -29,7 +33,6 @@ class TecoUpconv(nn.Module):
|
|||
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
||||
|
@ -59,5 +62,12 @@ class TecoGen(nn.Module):
|
|||
if ref is None:
|
||||
ref = torch.zeros_like(x)
|
||||
join = torch.cat([x, ref], dim=1)
|
||||
return x + sequential_checkpoint(self.core, 6, join)
|
||||
join = sequential_checkpoint(self.core, 6, join)
|
||||
self.join = join.detach().clone() + .5
|
||||
return x + join
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
return {'branch_std': self.join.std()}
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch.nn as nn
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import utils.util
|
||||
from apex import amp
|
||||
|
||||
|
||||
class BaseModel():
|
||||
|
|
Loading…
Reference in New Issue
Block a user