This commit is contained in:
James Betker 2020-11-11 21:49:06 -07:00
parent 88f349bdf1
commit fd97573085
3 changed files with 10 additions and 4 deletions

View File

@ -206,5 +206,6 @@ class RRDBNet(nn.Module):
def visual_dbg(self, step, path):
for i, bm in enumerate(self.body):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
if hasattr(bm, 'bypass_map'):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))

View File

@ -675,9 +675,9 @@ class PyramidDiscriminator(nn.Module):
ResidualBlockGN(nf),
ResidualBlockGN(nf),
ResidualBlockGN(nf),
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=False, bias=True),
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=False, bias=True),
ConvGnLelu(nf // 4, 1, kernel_size=1, activation=False, norm=False, bias=True)])
def forward(self, x):
fea = self.initial_conv(x)

View File

@ -146,6 +146,11 @@ class AdaRRDBNet(nn.Module):
self.conv_up2, self.conv_hr, self.conv_last
]:
default_init_weights(m, 0.1)
self.latent_mean = 0
self.latent_std = 0
self.latent_var = 0
self.block_residual_means = []
self.block_residual_stds = []
def forward(self, x, latent=None, ref=None):
latent_was_none = latent