diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py index 0b0d4e23..52fd918f 100644 --- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py +++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py @@ -195,7 +195,6 @@ class FlowUpsamplerNet(nn.Module): assert gt is not None assert rrdbResults is not None z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot) - return z, logdet def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None): diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index b0e909f0..5cb9d4b4 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -3,6 +3,7 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +import torchvision import numpy as np from models.archs.srflow_orig.RRDBNet_arch import RRDBNet from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet @@ -53,7 +54,7 @@ class SRFlowNet(nn.Module): def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, lr_enc=None, - add_gt_noise=False, step=None, y_label=None): + add_gt_noise=True, step=None, y_label=None): if not reverse: return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, y_onehot=y_label) @@ -91,7 +92,7 @@ class SRFlowNet(nn.Module): logdet = logdet + float(-np.log(self.quant) * pixels) # Encode - epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses, + epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=[], y_onehot=y_onehot) objective = logdet.clone() diff --git a/codes/models/archs/srflow_orig/Split.py b/codes/models/archs/srflow_orig/Split.py index b7b1df98..1f5de9bf 100644 --- a/codes/models/archs/srflow_orig/Split.py +++ b/codes/models/archs/srflow_orig/Split.py @@ -49,7 +49,8 @@ class Split2d(nn.Module): if eps is None: #print("WARNING: eps is None, generating eps untested functionality!") - eps = GaussianDiag.sample_eps(mean.shape, eps_std) + eps = GaussianDiag.sample(mean, logs, eps_std) + #eps = GaussianDiag.sample_eps(mean.shape, eps_std) eps = eps.to(mean.device) z2 = mean + self.exp_eps(logs) * eps