Adjustments to srflow to (maybe?) fix training
This commit is contained in:
parent
6c8c35ac47
commit
d51d12a41a
|
@ -195,7 +195,6 @@ class FlowUpsamplerNet(nn.Module):
|
|||
assert gt is not None
|
||||
assert rrdbResults is not None
|
||||
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
|
||||
|
||||
return z, logdet
|
||||
|
||||
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
|
@ -53,7 +54,7 @@ class SRFlowNet(nn.Module):
|
|||
|
||||
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
||||
lr_enc=None,
|
||||
add_gt_noise=False, step=None, y_label=None):
|
||||
add_gt_noise=True, step=None, y_label=None):
|
||||
if not reverse:
|
||||
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
||||
y_onehot=y_label)
|
||||
|
@ -91,7 +92,7 @@ class SRFlowNet(nn.Module):
|
|||
logdet = logdet + float(-np.log(self.quant) * pixels)
|
||||
|
||||
# Encode
|
||||
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
|
||||
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=[],
|
||||
y_onehot=y_onehot)
|
||||
|
||||
objective = logdet.clone()
|
||||
|
|
|
@ -49,7 +49,8 @@ class Split2d(nn.Module):
|
|||
|
||||
if eps is None:
|
||||
#print("WARNING: eps is None, generating eps untested functionality!")
|
||||
eps = GaussianDiag.sample_eps(mean.shape, eps_std)
|
||||
eps = GaussianDiag.sample(mean, logs, eps_std)
|
||||
#eps = GaussianDiag.sample_eps(mean.shape, eps_std)
|
||||
|
||||
eps = eps.to(mean.device)
|
||||
z2 = mean + self.exp_eps(logs) * eps
|
||||
|
|
Loading…
Reference in New Issue
Block a user