Adjustments to srflow to (maybe?) fix training
This commit is contained in:
parent
6c8c35ac47
commit
d51d12a41a
|
@ -195,7 +195,6 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
assert gt is not None
|
assert gt is not None
|
||||||
assert rrdbResults is not None
|
assert rrdbResults is not None
|
||||||
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
|
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
|
||||||
|
|
||||||
return z, logdet
|
return z, logdet
|
||||||
|
|
||||||
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
|
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
|
||||||
|
|
|
@ -3,6 +3,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||||
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
||||||
|
@ -53,7 +54,7 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
||||||
lr_enc=None,
|
lr_enc=None,
|
||||||
add_gt_noise=False, step=None, y_label=None):
|
add_gt_noise=True, step=None, y_label=None):
|
||||||
if not reverse:
|
if not reverse:
|
||||||
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
||||||
y_onehot=y_label)
|
y_onehot=y_label)
|
||||||
|
@ -91,7 +92,7 @@ class SRFlowNet(nn.Module):
|
||||||
logdet = logdet + float(-np.log(self.quant) * pixels)
|
logdet = logdet + float(-np.log(self.quant) * pixels)
|
||||||
|
|
||||||
# Encode
|
# Encode
|
||||||
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
|
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=[],
|
||||||
y_onehot=y_onehot)
|
y_onehot=y_onehot)
|
||||||
|
|
||||||
objective = logdet.clone()
|
objective = logdet.clone()
|
||||||
|
|
|
@ -49,7 +49,8 @@ class Split2d(nn.Module):
|
||||||
|
|
||||||
if eps is None:
|
if eps is None:
|
||||||
#print("WARNING: eps is None, generating eps untested functionality!")
|
#print("WARNING: eps is None, generating eps untested functionality!")
|
||||||
eps = GaussianDiag.sample_eps(mean.shape, eps_std)
|
eps = GaussianDiag.sample(mean, logs, eps_std)
|
||||||
|
#eps = GaussianDiag.sample_eps(mean.shape, eps_std)
|
||||||
|
|
||||||
eps = eps.to(mean.device)
|
eps = eps.to(mean.device)
|
||||||
z2 = mean + self.exp_eps(logs) * eps
|
z2 = mean + self.exp_eps(logs) * eps
|
||||||
|
|
Loading…
Reference in New Issue
Block a user