Adjustments to srflow to (maybe?) fix training

This commit is contained in:
James Betker 2020-11-20 14:44:24 -07:00
parent 6c8c35ac47
commit d51d12a41a
3 changed files with 5 additions and 4 deletions

View File

@ -195,7 +195,6 @@ class FlowUpsamplerNet(nn.Module):
assert gt is not None assert gt is not None
assert rrdbResults is not None assert rrdbResults is not None
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot) z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
return z, logdet return z, logdet
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None): def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):

View File

@ -3,6 +3,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
import numpy as np import numpy as np
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
@ -53,7 +54,7 @@ class SRFlowNet(nn.Module):
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
lr_enc=None, lr_enc=None,
add_gt_noise=False, step=None, y_label=None): add_gt_noise=True, step=None, y_label=None):
if not reverse: if not reverse:
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
y_onehot=y_label) y_onehot=y_label)
@ -91,7 +92,7 @@ class SRFlowNet(nn.Module):
logdet = logdet + float(-np.log(self.quant) * pixels) logdet = logdet + float(-np.log(self.quant) * pixels)
# Encode # Encode
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses, epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=[],
y_onehot=y_onehot) y_onehot=y_onehot)
objective = logdet.clone() objective = logdet.clone()

View File

@ -49,7 +49,8 @@ class Split2d(nn.Module):
if eps is None: if eps is None:
#print("WARNING: eps is None, generating eps untested functionality!") #print("WARNING: eps is None, generating eps untested functionality!")
eps = GaussianDiag.sample_eps(mean.shape, eps_std) eps = GaussianDiag.sample(mean, logs, eps_std)
#eps = GaussianDiag.sample_eps(mean.shape, eps_std)
eps = eps.to(mean.device) eps = eps.to(mean.device)
z2 = mean + self.exp_eps(logs) * eps z2 = mean + self.exp_eps(logs) * eps