import torch from torch import nn as nn import models.archs.srflow_orig.Permutations from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation from utils.util import opt_get def getConditional(rrdbResults, position): img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position] return img_ft class FlowStep(nn.Module): FlowPermutation = { "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet), "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet), "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), "squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), "resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), "resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), "InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), "InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), "InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), "InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), } def __init__(self, in_channels, hidden_channels, actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None, position=None): # check configures assert flow_permutation in FlowStep.FlowPermutation, \ "float_permutation should be in `{}`".format( FlowStep.FlowPermutation.keys()) super().__init__() self.flow_permutation = flow_permutation self.flow_coupling = flow_coupling self.image_injector = image_injector self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d' self.position = normOpt['position'] if normOpt else None self.in_shape = in_shape self.position = position self.acOpt = acOpt # 1. actnorm self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) # 2. permute if flow_permutation == "invconv": self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1( in_channels, LU_decomposed=LU_decomposed) # 3. coupling if flow_coupling == "CondAffineSeparatedAndCond": self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, opt=opt) elif flow_coupling == "noCoupling": pass else: raise RuntimeError("coupling not Found:", flow_coupling) def forward(self, input, logdet=None, rrdbResults=None, reverse=False): if not reverse: return self.normal_flow(input, logdet, rrdbResults) else: return self.reverse_flow(input, logdet, rrdbResults) def normal_flow(self, z, logdet, rrdbResults=None): if self.flow_coupling == "bentIdentityPreAct": z, logdet = self.bentIdentPar(z, logdet, reverse=False) # 1. actnorm if self.norm_type == "ConditionalActNormImageInjector": img_ft = getConditional(rrdbResults, self.position) z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False) elif self.norm_type == "noNorm": pass else: z, logdet = self.actnorm(z, logdet=logdet, reverse=False) # 2. permute z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( self, z, logdet, False) need_features = self.affine_need_features() # 3. coupling if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: img_ft = getConditional(rrdbResults, self.position) z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft) return z, logdet def reverse_flow(self, z, logdet, rrdbResults=None): need_features = self.affine_need_features() # 1.coupling if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: img_ft = getConditional(rrdbResults, self.position) z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft) # 2. permute z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( self, z, logdet, True) # 3. actnorm z, logdet = self.actnorm(z, logdet=logdet, reverse=True) return z, logdet def affine_need_features(self): need_features = False try: need_features = self.affine.need_features except: pass return need_features