Extra logging for teco_resgen
This commit is contained in:
parent
2ab5054d4c
commit
f133243ac8
|
@ -367,6 +367,7 @@ class ConvGnLelu(nn.Module):
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvGnSilu(nn.Module):
|
class ConvGnSilu(nn.Module):
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
|
||||||
from utils.util import sequential_checkpoint
|
from utils.util import sequential_checkpoint
|
||||||
from models.archs.arch_util import ConvGnSilu, make_layer
|
from models.archs.arch_util import ConvGnSilu, make_layer
|
||||||
|
|
||||||
|
@ -29,7 +33,6 @@ class TecoUpconv(nn.Module):
|
||||||
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
identity = x
|
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
||||||
|
@ -59,5 +62,12 @@ class TecoGen(nn.Module):
|
||||||
if ref is None:
|
if ref is None:
|
||||||
ref = torch.zeros_like(x)
|
ref = torch.zeros_like(x)
|
||||||
join = torch.cat([x, ref], dim=1)
|
join = torch.cat([x, ref], dim=1)
|
||||||
return x + sequential_checkpoint(self.core, 6, join)
|
join = sequential_checkpoint(self.core, 6, join)
|
||||||
|
self.join = join.detach().clone() + .5
|
||||||
|
return x + join
|
||||||
|
|
||||||
|
def visual_dbg(self, step, path):
|
||||||
|
torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,)))
|
||||||
|
|
||||||
|
def get_debug_values(self, step, net_name):
|
||||||
|
return {'branch_std': self.join.std()}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import torch.nn as nn
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
import utils.util
|
import utils.util
|
||||||
from apex import amp
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel():
|
class BaseModel():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user