diff --git a/codes/models/archs/srflow/FlowActNorms.py b/codes/models/archs/srflow/FlowActNorms.py deleted file mode 100644 index 9a8d4d90..00000000 --- a/codes/models/archs/srflow/FlowActNorms.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch -from torch import nn as nn - -from models.archs.srflow import thops - - -class _ActNorm(nn.Module): - """ - Activation Normalization - Initialize the bias and scale with a given minibatch, - so that the output per-channel have zero mean and unit variance for that. - - After initialization, `bias` and `logs` will be trained as parameters. - """ - - def __init__(self, num_features, scale=1.): - super().__init__() - # register mean and scale - size = [1, num_features, 1, 1] - self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) - self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) - self.num_features = num_features - self.scale = float(scale) - self.inited = False - - def _check_input_dim(self, input): - return NotImplemented - - def initialize_parameters(self, input): - self._check_input_dim(input) - if not self.training: - return - if (self.bias != 0).any(): - self.inited = True - return - assert input.device == self.bias.device, (input.device, self.bias.device) - with torch.no_grad(): - bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 - vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) - logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) - self.bias.data.copy_(bias.data) - self.logs.data.copy_(logs.data) - self.inited = True - - def _center(self, input, reverse=False, offset=None): - bias = self.bias - - if offset is not None: - bias = bias + offset - - if not reverse: - return input + bias - else: - return input - bias - - def _scale(self, input, logdet=None, reverse=False, offset=None): - logs = self.logs - - if offset is not None: - logs = logs + offset - - if not reverse: - input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 - # input = input * torch.exp(logs+logs_offset) - else: - input = input * torch.exp(-logs) - if logdet is not None: - """ - logs is log_std of `mean of channels` - so we need to multiply pixels - """ - dlogdet = thops.sum(logs) * thops.pixels(input) - if reverse: - dlogdet *= -1 - logdet = logdet + dlogdet - return input, logdet - - def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): - if not self.inited: - self.initialize_parameters(input) - self._check_input_dim(input) - - if offset_mask is not None: - logs_offset *= offset_mask - bias_offset *= offset_mask - # no need to permute dims as old version - if not reverse: - # center and scale - - # self.input = input - input = self._center(input, reverse, bias_offset) - input, logdet = self._scale(input, logdet, reverse, logs_offset) - else: - # scale and center - input, logdet = self._scale(input, logdet, reverse, logs_offset) - input = self._center(input, reverse, bias_offset) - return input, logdet - - -class ActNorm2d(_ActNorm): - def __init__(self, num_features, scale=1.): - super().__init__(num_features, scale) - - def _check_input_dim(self, input): - assert len(input.size()) == 4 - assert input.size(1) == self.num_features, ( - "[ActNorm]: input should be in shape as `BCHW`," - " channels should be {} rather than {}".format( - self.num_features, input.size())) - - -class MaskedActNorm2d(ActNorm2d): - def __init__(self, num_features, scale=1.): - super().__init__(num_features, scale) - - def forward(self, input, mask, logdet=None, reverse=False): - - assert mask.dtype == torch.bool - output, logdet_out = super().forward(input, logdet, reverse) - - input[mask] = output[mask] - logdet[mask] = logdet_out[mask] - - return input, logdet - diff --git a/codes/models/archs/srflow/FlowAffineCouplingsAblation.py b/codes/models/archs/srflow/FlowAffineCouplingsAblation.py deleted file mode 100644 index c50d0d5d..00000000 --- a/codes/models/archs/srflow/FlowAffineCouplingsAblation.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch -from torch import nn as nn - -from models.archs.srflow import thops -from models.archs.srflow.flow import Conv2d, Conv2dZeros - - -class CondAffineSeparatedAndCond(nn.Module): - def __init__(self, in_channels, hidden_channels=64, affine_eps=.00001): - super().__init__() - self.need_features = True - self.in_channels = in_channels - self.in_channels_rrdb = 320 - self.kernel_hidden = 1 - self.affine_eps = 0.0001 - self.n_hidden_layers = 1 - self.hidden_channels = hidden_channels - self.affine_eps = affine_eps - - self.channels_for_nn = self.in_channels // 2 - self.channels_for_co = self.in_channels - self.channels_for_nn - - if self.channels_for_nn is None: - self.channels_for_nn = self.in_channels // 2 - - self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb, - out_channels=self.channels_for_co * 2, - hidden_channels=self.hidden_channels, - kernel_hidden=self.kernel_hidden, - n_hidden_layers=self.n_hidden_layers) - - self.fFeatures = self.F(in_channels=self.in_channels_rrdb, - out_channels=self.in_channels * 2, - hidden_channels=self.hidden_channels, - kernel_hidden=self.kernel_hidden, - n_hidden_layers=self.n_hidden_layers) - - def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None): - if not reverse: - z = input - assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels) - - # Feature Conditional - scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) - z = z + shiftFt - z = z * scaleFt - logdet = logdet + self.get_logdet(scaleFt) - - # Self Conditional - z1, z2 = self.split(z) - scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) - self.asserts(scale, shift, z1, z2) - z2 = z2 + shift - z2 = z2 * scale - - logdet = logdet + self.get_logdet(scale) - z = thops.cat_feature(z1, z2) - output = z - else: - z = input - - # Self Conditional - z1, z2 = self.split(z) - scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) - self.asserts(scale, shift, z1, z2) - z2 = z2 / scale - z2 = z2 - shift - z = thops.cat_feature(z1, z2) - logdet = logdet - self.get_logdet(scale) - - # Feature Conditional - scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) - z = z / scaleFt - z = z - shiftFt - logdet = logdet - self.get_logdet(scaleFt) - - output = z - return output, logdet - - def asserts(self, scale, shift, z1, z2): - assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn) - assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co) - assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1]) - assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1]) - - def get_logdet(self, scale): - return thops.sum(torch.log(scale), dim=[1, 2, 3]) - - def feature_extract(self, z, f): - h = f(z) - shift, scale = thops.split_feature(h, "cross") - scale = (torch.sigmoid(scale + 2.) + self.affine_eps) - return scale, shift - - def feature_extract_aff(self, z1, ft, f): - z = torch.cat([z1, ft], dim=1) - h = f(z) - shift, scale = thops.split_feature(h, "cross") - scale = (torch.sigmoid(scale + 2.) + self.affine_eps) - return scale, shift - - def split(self, z): - z1 = z[:, :self.channels_for_nn] - z2 = z[:, self.channels_for_nn:] - assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1]) - return z1, z2 - - def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1): - layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)] - - for _ in range(n_hidden_layers): - layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden])) - layers.append(nn.ReLU(inplace=False)) - layers.append(Conv2dZeros(hidden_channels, out_channels)) - - return nn.Sequential(*layers) diff --git a/codes/models/archs/srflow/FlowStep.py b/codes/models/archs/srflow/FlowStep.py deleted file mode 100644 index 87b8d6aa..00000000 --- a/codes/models/archs/srflow/FlowStep.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -from torch import nn as nn - -from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation, FlowActNorms, Permutations - - -def getConditional(rrdbResults, position): - img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position] - return img_ft - - -class FlowStep(nn.Module): - FlowPermutation = { - "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet), - "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet), - "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - "squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - "resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - "resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - "InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - "InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - "InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - "InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), - } - - def __init__(self, in_channels, hidden_channels, - actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", - LU_decomposed=False, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None, - position=None): - # check configures - assert flow_permutation in FlowStep.FlowPermutation, \ - "float_permutation should be in `{}`".format( - FlowStep.FlowPermutation.keys()) - super().__init__() - self.flow_permutation = flow_permutation - self.flow_coupling = flow_coupling - self.image_injector = image_injector - - self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d' - self.position = normOpt['position'] if normOpt else None - - self.in_shape = in_shape - self.position = position - self.acOpt = acOpt - - # 1. actnorm - self.actnorm = FlowActNorms.ActNorm2d(in_channels, actnorm_scale) - - # 2. permute - if flow_permutation == "invconv": - self.invconv = Permutations.InvertibleConv1x1( - in_channels, LU_decomposed=LU_decomposed) - - # 3. coupling - if flow_coupling == "CondAffineSeparatedAndCond": - self.affine = FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels) - elif flow_coupling == "noCoupling": - pass - else: - raise RuntimeError("coupling not Found:", flow_coupling) - - def forward(self, input, logdet=None, reverse=False, rrdbResults=None): - if not reverse: - return self.normal_flow(input, logdet, rrdbResults) - else: - return self.reverse_flow(input, logdet, rrdbResults) - - def normal_flow(self, z, logdet, rrdbResults=None): - if self.flow_coupling == "bentIdentityPreAct": - z, logdet = self.bentIdentPar(z, logdet, reverse=False) - - # 1. actnorm - if self.norm_type == "ConditionalActNormImageInjector": - img_ft = getConditional(rrdbResults, self.position) - z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False) - elif self.norm_type == "noNorm": - pass - else: - z, logdet = self.actnorm(z, logdet=logdet, reverse=False) - - # 2. permute - z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( - self, z, logdet, False) - - need_features = self.affine_need_features() - - # 3. coupling - if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: - img_ft = getConditional(rrdbResults, self.position) - z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft) - return z, logdet - - def reverse_flow(self, z, logdet, rrdbResults=None): - - need_features = self.affine_need_features() - - # 1.coupling - if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: - img_ft = getConditional(rrdbResults, self.position) - z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft) - - # 2. permute - z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( - self, z, logdet, True) - - # 3. actnorm - z, logdet = self.actnorm(z, logdet=logdet, reverse=True) - - return z, logdet - - def affine_need_features(self): - need_features = False - try: - need_features = self.affine.need_features - except: - pass - return need_features diff --git a/codes/models/archs/srflow/FlowUpsamplerNet.py b/codes/models/archs/srflow/FlowUpsamplerNet.py deleted file mode 100644 index c492d7e4..00000000 --- a/codes/models/archs/srflow/FlowUpsamplerNet.py +++ /dev/null @@ -1,267 +0,0 @@ -import numpy as np -import torch -from torch import nn as nn - -import models.archs.srflow.Split -from models.archs.srflow import flow, thops, Split -from models.archs.srflow.Split import Split2d -from models.archs.srflow.glow_arch import f_conv2d_bias -from models.archs.srflow.FlowStep import FlowStep -from utils.util import opt_get -import torchvision - - -class FlowUpsamplerNet(nn.Module): - def __init__(self, image_shape, hidden_channels, scale, - rrdb_blocks, - actnorm_scale=1.0, - flow_permutation='invconv', - flow_coupling="affine", - LU_decomposed=False, K=16, L=3, - norm_opt=None, - n_bypass_channels=None): - - super().__init__() - - self.layers = nn.ModuleList() - self.output_shapes = [] - self.L = L - self.K = K - self.scale=scale - if isinstance(self.K, int): - self.K = [K for K in [K, ] * (self.L + 1)] - - H, W, self.C = image_shape - self.image_shape = image_shape - self.check_image_shape() - - if scale == 16: - self.levelToName = { - 0: 'fea_up16', - 1: 'fea_up8', - 2: 'fea_up4', - 3: 'fea_up2', - 4: 'fea_up1', - } - - if scale == 8: - self.levelToName = { - 0: 'fea_up8', - 1: 'fea_up4', - 2: 'fea_up2', - 3: 'fea_up1', - 4: 'fea_up0' - } - - elif scale == 4: - self.levelToName = { - 0: 'fea_up4', - 1: 'fea_up2', - 2: 'fea_up1', - 3: 'fea_up0', - 4: 'fea_up-1' - } - - affineInCh = self.get_affineInCh(rrdb_blocks) - - conditional_channels = {} - n_rrdb = self.get_n_rrdb_channels(rrdb_blocks) - conditional_channels[0] = n_rrdb - for level in range(1, self.L + 1): - # Level 1 gets conditionals from 2, 3, 4 => L - level - # Level 2 gets conditionals from 3, 4 - # Level 3 gets conditionals from 4 - # Level 4 gets conditionals from None - n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels - conditional_channels[level] = n_rrdb + n_bypass - - # Upsampler - for level in range(1, self.L + 1): - # 1. Squeeze - H, W = self.arch_squeeze(H, W) - - # 2. K FlowStep - self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels) - self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, - flow_permutation, - hidden_channels, norm_opt, - n_conditional_channels=conditional_channels[level]) - # Split - self.arch_split(H, W, level, self.L) - - self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) - self.H = H - self.W = W - - def get_n_rrdb_channels(self, blocks): - n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 - return n_rrdb - - def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation, - hidden_channels, normOpt, n_conditional_channels=None, condAff=None): - if condAff is not None: - condAff['in_channels_rrdb'] = n_conditional_channels - - for k in range(K): - position_name = self.get_position_name(H, self.scale) - if normOpt: normOpt['position'] = position_name - - self.layers.append( - FlowStep(in_channels=self.C, - hidden_channels=hidden_channels, - actnorm_scale=actnorm_scale, - flow_permutation=flow_permutation, - flow_coupling=flow_coupling, - acOpt=condAff, - position=position_name, - LU_decomposed=LU_decomposed, idx=k, normOpt=normOpt)) - self.output_shapes.append( - [-1, self.C, H, W]) - - def arch_split(self, H, W, L, levels, split_flow=True, correct_splits=False, logs_eps=0, consume_ratio=.5, split_conditional=False, cond_channels=None, split_type='Split2d'): - correction = 0 if correct_splits else 1 - if split_flow and L < levels - correction: - logs_eps = logs_eps - consume_ratio = consume_ratio - position_name = self.get_position_name(H, self.scale) - position = position_name if split_conditional else None - cond_channels = 0 if cond_channels is None else cond_channels - - if split_type == 'Split2d': - split = Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, - cond_channels=cond_channels, consume_ratio=consume_ratio) - self.layers.append(split) - self.output_shapes.append([-1, split.num_channels_pass, H, W]) - self.C = split.num_channels_pass - - def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, additionalFlowNoAffine=2): - for _ in range(additionalFlowNoAffine): - self.layers.append( - FlowStep(in_channels=self.C, - hidden_channels=hidden_channels, - actnorm_scale=actnorm_scale, - flow_permutation='invconv', - flow_coupling='noCoupling', - LU_decomposed=LU_decomposed)) - self.output_shapes.append( - [-1, self.C, H, W]) - - def arch_squeeze(self, H, W): - self.C, H, W = self.C * 4, H // 2, W // 2 - self.layers.append(flow.SqueezeLayer(factor=2)) - self.output_shapes.append([-1, self.C, H, W]) - return H, W - - def get_affineInCh(self, rrdb_blocks): - affineInCh = (len(rrdb_blocks) + 1) * 64 - return affineInCh - - def check_image_shape(self): - assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)" - "self.C == 1 or self.C == 3") - - def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None, - y_onehot=None): - - if reverse: - epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses - - sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot) - return sr, logdet - else: - assert gt is not None - assert rrdbResults is not None - z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot) - - return z, logdet - - def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None): - fl_fea = gt - reverse = False - level_conditionals = {} - bypasses = {} - - L = self.L - - for level in range(1, L + 1): - bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) - - for layer, shape in zip(self.layers, self.output_shapes): - size = shape[2] - level = int(np.log(self.image_shape[0] / size) / np.log(2)) - - if level > 0 and level not in level_conditionals.keys(): - level_conditionals[level] = rrdbResults[self.levelToName[level]] - - level_conditionals[level] = rrdbResults[self.levelToName[level]] - - if isinstance(layer, FlowStep): - fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level]) - elif isinstance(layer, Split2d): - fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level], - y_onehot=y_onehot) - else: - fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse) - - z = fl_fea - - if not isinstance(epses, list): - return z, logdet - - epses.append(z) - return epses, logdet - - def forward_preFlow(self, fl_fea, logdet, reverse): - if hasattr(self, 'preFlow'): - for l in self.preFlow: - fl_fea, logdet = l(fl_fea, logdet, reverse=reverse) - return fl_fea, logdet - - def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None): - ft = None if layer.position is None else rrdbResults[layer.position] - fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot) - - if isinstance(epses, list): - epses.append(eps) - return fl_fea, logdet - - def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None): - z = epses.pop() if isinstance(epses, list) else z - - fl_fea = z - # debug.imwrite("fl_fea", fl_fea) - bypasses = {} - level_conditionals = {} - for level in range(self.L + 1): - level_conditionals[level] = rrdbResults[self.levelToName[level]] - - for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): - size = shape[2] - level = int(np.log(self.H / size) / np.log(2)) - - if isinstance(layer, Split2d): - fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer, - rrdbResults[self.levelToName[level]], logdet=logdet, - y_onehot=y_onehot) - elif isinstance(layer, FlowStep): - fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level]) - else: - fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True) - - sr = fl_fea - - assert sr.shape[1] == 3 - return sr, logdet - - def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None): - ft = None if layer.position is None else rrdbResults[layer.position] - fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, - eps=epses.pop() if isinstance(epses, list) else None, - eps_std=eps_std, ft=ft, y_onehot=y_onehot) - return fl_fea, logdet - - - def get_position_name(self, H, scale): - downscale_factor = self.image_shape[0] // H - position_name = 'fea_up{}'.format(scale / downscale_factor) - return position_name diff --git a/codes/models/archs/srflow/Permutations.py b/codes/models/archs/srflow/Permutations.py deleted file mode 100644 index 20548699..00000000 --- a/codes/models/archs/srflow/Permutations.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np -import torch -from torch import nn as nn -from torch.nn import functional as F - -from models.archs.srflow import thops - - -class InvertibleConv1x1(nn.Module): - def __init__(self, num_channels, LU_decomposed=False): - super().__init__() - w_shape = [num_channels, num_channels] - w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) - self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) - self.w_shape = w_shape - self.LU = LU_decomposed - - def get_weight(self, input, reverse): - w_shape = self.w_shape - pixels = thops.pixels(input) - dlogdet = torch.slogdet(self.weight)[1] * pixels - if not reverse: - weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) - else: - weight = torch.inverse(self.weight.double()).float() \ - .view(w_shape[0], w_shape[1], 1, 1) - return weight, dlogdet - def forward(self, input, logdet=None, reverse=False): - """ - log-det = log|abs(|W|)| * pixels - """ - weight, dlogdet = self.get_weight(input, reverse) - if not reverse: - z = F.conv2d(input, weight) - if logdet is not None: - logdet = logdet + dlogdet - return z, logdet - else: - z = F.conv2d(input, weight) - if logdet is not None: - logdet = logdet - dlogdet - return z, logdet diff --git a/codes/models/archs/srflow/RRDBNet_arch.py b/codes/models/archs/srflow/RRDBNet_arch.py deleted file mode 100644 index 033650b3..00000000 --- a/codes/models/archs/srflow/RRDBNet_arch.py +++ /dev/null @@ -1,133 +0,0 @@ -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F -import models.archs.srflow.module_util as mutil -from utils.util import opt_get, checkpoint - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, block_outputs=[], fea_up0=True, - fea_up1=False): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - self.scale = scale - self.block_outputs = block_outputs - self.fea_up0 = fea_up0 - self.fea_up1 = fea_up1 - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - if self.scale >= 8: - self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - if self.scale >= 16: - self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - if self.scale >= 32: - self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x, get_steps=False): - fea = self.conv_first(x) - - block_idxs = self.block_outputs or [] - block_results = {} - - for idx, m in enumerate(self.RRDB_trunk.children()): - fea = checkpoint(m, fea) - for b in block_idxs: - if b == idx: - block_results["block_{}".format(idx)] = fea - - trunk = self.trunk_conv(fea) - - last_lr_fea = fea + trunk - - fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) - fea = self.lrelu(fea_up2) - - fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')) - fea = self.lrelu(fea_up4) - - fea_up8 = None - fea_up16 = None - fea_up32 = None - - if self.scale >= 8: - fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')) - fea = self.lrelu(fea_up8) - if self.scale >= 16: - fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest')) - fea = self.lrelu(fea_up16) - if self.scale >= 32: - fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest')) - fea = self.lrelu(fea_up32) - - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - results = {'last_lr_fea': last_lr_fea, - 'fea_up1': last_lr_fea, - 'fea_up2': fea_up2, - 'fea_up4': fea_up4, - 'fea_up8': fea_up8, - 'fea_up16': fea_up16, - 'fea_up32': fea_up32, - 'out': out} - - if self.fea_up0: - results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) - if self.fea_up1: - results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) - - if get_steps: - for k, v in block_results.items(): - results[k] = v - return results - else: - return out diff --git a/codes/models/archs/srflow/Split.py b/codes/models/archs/srflow/Split.py deleted file mode 100644 index c24eaf41..00000000 --- a/codes/models/archs/srflow/Split.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -from torch import nn as nn - -from models.archs.srflow import thops -from models.archs.srflow.FlowStep import FlowStep -from models.archs.srflow.flow import Conv2dZeros, GaussianDiag - - -class Split2d(nn.Module): - def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5): - super().__init__() - - self.num_channels_consume = int(round(num_channels * consume_ratio)) - self.num_channels_pass = num_channels - self.num_channels_consume - - self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels, - out_channels=self.num_channels_consume * 2) - self.logs_eps = logs_eps - self.position = position - - def split2d_prior(self, z, ft): - if ft is not None: - z = torch.cat([z, ft], dim=1) - h = self.conv(z) - return thops.split_feature(h, "cross") - - def exp_eps(self, logs): - return torch.exp(logs) + self.logs_eps - - def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None): - if not reverse: - # self.input = input - z1, z2 = self.split_ratio(input) - mean, logs = self.split2d_prior(z1, ft) - - eps = (z2 - mean) / self.exp_eps(logs) - - logdet = logdet + self.get_logdet(logs, mean, z2) - - # print(logs.shape, mean.shape, z2.shape) - # self.eps = eps - # print('split, enc eps:', eps) - return z1, logdet, eps - else: - z1 = input - mean, logs = self.split2d_prior(z1, ft) - - if eps is None: - #print("WARNING: eps is None, generating eps untested functionality!") - eps = GaussianDiag.sample_eps(mean.shape, eps_std) - - eps = eps.to(mean.device) - z2 = mean + self.exp_eps(logs) * eps - - z = thops.cat_feature(z1, z2) - logdet = logdet - self.get_logdet(logs, mean, z2) - - return z, logdet - # return z, logdet, eps - - def get_logdet(self, logs, mean, z2): - logdet_diff = GaussianDiag.logp(mean, logs, z2) - # print("Split2D: logdet diff", logdet_diff.item()) - return logdet_diff - - def split_ratio(self, input): - z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] - return z1, z2 \ No newline at end of file diff --git a/codes/models/archs/srflow/__init__.py b/codes/models/archs/srflow/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/archs/srflow/flow.py b/codes/models/archs/srflow/flow.py deleted file mode 100644 index 39b55284..00000000 --- a/codes/models/archs/srflow/flow.py +++ /dev/null @@ -1,150 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -from models.archs.srflow.FlowActNorms import ActNorm2d -from . import thops - - -class Conv2d(nn.Conv2d): - pad_dict = { - "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)], - "valid": lambda kernel, stride: [0 for _ in kernel] - } - - @staticmethod - def get_padding(padding, kernel_size, stride): - # make paddding - if isinstance(padding, str): - if isinstance(kernel_size, int): - kernel_size = [kernel_size, kernel_size] - if isinstance(stride, int): - stride = [stride, stride] - padding = padding.lower() - try: - padding = Conv2d.pad_dict[padding](kernel_size, stride) - except KeyError: - raise ValueError("{} is not supported".format(padding)) - return padding - - def __init__(self, in_channels, out_channels, - kernel_size=[3, 3], stride=[1, 1], - padding="same", do_actnorm=True, weight_std=0.05): - padding = Conv2d.get_padding(padding, kernel_size, stride) - super().__init__(in_channels, out_channels, kernel_size, stride, - padding, bias=(not do_actnorm)) - # init weight with std - self.weight.data.normal_(mean=0.0, std=weight_std) - if not do_actnorm: - self.bias.data.zero_() - else: - self.actnorm = ActNorm2d(out_channels) - self.do_actnorm = do_actnorm - - def forward(self, input): - x = super().forward(input) - if self.do_actnorm: - x, _ = self.actnorm(x) - return x - - -class Conv2dZeros(nn.Conv2d): - def __init__(self, in_channels, out_channels, - kernel_size=[3, 3], stride=[1, 1], - padding="same", logscale_factor=3): - padding = Conv2d.get_padding(padding, kernel_size, stride) - super().__init__(in_channels, out_channels, kernel_size, stride, padding) - # logscale_factor - self.logscale_factor = logscale_factor - self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) - # init - self.weight.data.zero_() - self.bias.data.zero_() - - def forward(self, input): - output = super().forward(input) - return output * torch.exp(self.logs * self.logscale_factor) - - -class GaussianDiag: - Log2PI = float(np.log(2 * np.pi)) - - @staticmethod - def likelihood(mean, logs, x): - """ - lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } - k = 1 (Independent) - Var = logs ** 2 - """ - if mean is None and logs is None: - return -0.5 * (x ** 2 + GaussianDiag.Log2PI) - else: - return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI) - - @staticmethod - def logp(mean, logs, x): - likelihood = GaussianDiag.likelihood(mean, logs, x) - return thops.sum(likelihood, dim=[1, 2, 3]) - - @staticmethod - def sample(mean, logs, eps_std=None): - eps_std = eps_std or 1 - eps = torch.normal(mean=torch.zeros_like(mean), - std=torch.ones_like(logs) * eps_std) - return mean + torch.exp(logs) * eps - - @staticmethod - def sample_eps(shape, eps_std, seed=None): - if seed is not None: - torch.manual_seed(seed) - eps = torch.normal(mean=torch.zeros(shape), - std=torch.ones(shape) * eps_std) - return eps - - -def squeeze2d(input, factor=2): - assert factor >= 1 and isinstance(factor, int) - if factor == 1: - return input - size = input.size() - B = size[0] - C = size[1] - H = size[2] - W = size[3] - assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor)) - x = input.view(B, C, H // factor, factor, W // factor, factor) - x = x.permute(0, 1, 3, 5, 2, 4).contiguous() - x = x.view(B, C * factor * factor, H // factor, W // factor) - return x - - -def unsqueeze2d(input, factor=2): - assert factor >= 1 and isinstance(factor, int) - factor2 = factor ** 2 - if factor == 1: - return input - size = input.size() - B = size[0] - C = size[1] - H = size[2] - W = size[3] - assert C % (factor2) == 0, "{}".format(C) - x = input.view(B, C // factor2, factor, factor, H, W) - x = x.permute(0, 1, 4, 2, 5, 3).contiguous() - x = x.view(B, C // (factor2), H * factor, W * factor) - return x - - -class SqueezeLayer(nn.Module): - def __init__(self, factor): - super().__init__() - self.factor = factor - - def forward(self, input, logdet=None, reverse=False): - if not reverse: - output = squeeze2d(input, self.factor) # Squeeze in forward - return output, logdet - else: - output = unsqueeze2d(input, self.factor) - return output, logdet diff --git a/codes/models/archs/srflow/glow_arch.py b/codes/models/archs/srflow/glow_arch.py deleted file mode 100644 index 00da3cbb..00000000 --- a/codes/models/archs/srflow/glow_arch.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch.nn as nn - - -def f_conv2d_bias(in_channels, out_channels): - def padding_same(kernel, stride): - return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)] - - padding = padding_same([3, 3], [1, 1]) - assert padding == [1, 1], padding - return nn.Sequential( - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1, - bias=True)) diff --git a/codes/models/archs/srflow/module_util.py b/codes/models/archs/srflow/module_util.py deleted file mode 100644 index ca5d7fa9..00000000 --- a/codes/models/archs/srflow/module_util.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.init as init -import torch.nn.functional as F - - -def initialize_weights(net_l, scale=1): - if not isinstance(net_l, list): - net_l = [net_l] - for net in net_l: - for m in net.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale # for residual block - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - init.constant_(m.weight, 1) - init.constant_(m.bias.data, 0.0) - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualBlock_noBN(nn.Module): - '''Residual block w/o BN - ---Conv-ReLU-Conv-+- - |________________| - ''' - - def __init__(self, nf=64): - super(ResidualBlock_noBN, self).__init__() - self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - - # initialization - initialize_weights([self.conv1, self.conv2], 0.1) - - def forward(self, x): - identity = x - out = F.relu(self.conv1(x), inplace=True) - out = self.conv2(out) - return identity + out - - -def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): - """Warp an image or feature map with optical flow - Args: - x (Tensor): size (N, C, H, W) - flow (Tensor): size (N, H, W, 2), normal value - interp_mode (str): 'nearest' or 'bilinear' - padding_mode (str): 'zeros' or 'border' or 'reflection' - - Returns: - Tensor: warped image or feature map - """ - assert x.size()[-2:] == flow.size()[1:3] - B, C, H, W = x.size() - # mesh grid - grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) - grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 - grid.requires_grad = False - grid = grid.type_as(x) - vgrid = grid + flow - # scale grid to [-1,1] - vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 - vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 - vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) - output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) - return output diff --git a/codes/models/archs/srflow/srflow_arch.py b/codes/models/archs/srflow/srflow_arch.py deleted file mode 100644 index 2c127c3b..00000000 --- a/codes/models/archs/srflow/srflow_arch.py +++ /dev/null @@ -1,135 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet -import models.archs.srflow.thops as thops -import models.archs.srflow.flow as flow -from models.archs.srflow.RRDBNet_arch import RRDBNet - - -class SRFlowNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, quant, flow_block_maps, noise_quant, - hidden_channels=64, gc=32, scale=4, K=16, L=3, train_rrdb_at_step=0, - hr_img_shape=(128,128,3), coupling='CondAffineSeparatedAndCond'): - super(SRFlowNet, self).__init__() - - self.scale = scale - self.noise_quant = noise_quant - self.quant = quant - self.flow_block_maps = flow_block_maps - self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, flow_block_maps) - self.train_rrdb_step = train_rrdb_at_step - self.RRDB_training = True - - self.flowUpsamplerNet = FlowUpsamplerNet(image_shape=hr_img_shape, - hidden_channels=hidden_channels, - scale=scale, rrdb_blocks=flow_block_maps, - K=K, L=L, flow_coupling=coupling) - self.i = 0 - - def forward(self, gt=None, lr=None, reverse=False, z=None, eps_std=None, epses=None, reverse_with_grad=False, - lr_enc=None, - add_gt_noise=False, step=None, y_label=None): - if not reverse: - return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, - y_onehot=y_label) - else: - assert lr.shape[1] == 3 - if reverse_with_grad: - return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, - add_gt_noise=add_gt_noise) - else: - with torch.no_grad(): - return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, - add_gt_noise=add_gt_noise) - - def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None): - if lr_enc is None: - lr_enc = self.rrdbPreprocessing(lr) - - logdet = torch.zeros_like(gt[:, 0, 0, 0]) - pixels = thops.pixels(gt) - - z = gt - - if add_gt_noise: - # Setup - if self.noise_quant: - z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) - logdet = logdet + float(-np.log(self.quant) * pixels) - - # Encode - epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses, - y_onehot=y_onehot) - - objective = logdet.clone() - - if isinstance(epses, (list, tuple)): - z = epses[-1] - else: - z = epses - - objective = objective + flow.GaussianDiag.logp(None, None, z) - - nll = (-objective) / float(np.log(2.) * pixels) - - if isinstance(epses, list): - return epses, nll, logdet - return z, nll, logdet - - def rrdbPreprocessing(self, lr): - rrdbResults = self.RRDB(lr, get_steps=True) - block_idxs = self.flow_block_maps - if len(block_idxs) > 0: - concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) - - keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] - if 'fea_up0' in rrdbResults.keys(): - keys.append('fea_up0') - if 'fea_up-1' in rrdbResults.keys(): - keys.append('fea_up-1') - if self.scale >= 8: - keys.append('fea_up8') - if self.scale == 16: - keys.append('fea_up16') - for k in keys: - h = rrdbResults[k].shape[2] - w = rrdbResults[k].shape[3] - rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) - return rrdbResults - - def get_score(self, disc_loss_sigma, z): - score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \ - z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma) - return -score_real - - def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): - logdet = torch.zeros_like(lr[:, 0, 0, 0]) - pixels = thops.pixels(lr) * self.scale ** 2 - - if add_gt_noise: - logdet = logdet - float(-np.log(self.quant) * pixels) - - if lr_enc is None: - lr_enc = self.rrdbPreprocessing(lr) - - x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, - logdet=logdet) - - return x, logdet - - def set_rrdb_training(self, trainable): - if self.RRDB_training != trainable: - for p in self.RRDB.parameters(): - if not trainable: - p.DO_NOT_TRAIN = True - elif hasattr(p, "DO_NOT_TRAIN"): - del p.DO_NOT_TRAIN - self.RRDB_training = trainable - - def update_for_step(self, step, experiments_path='.'): - self.set_rrdb_training(step > self.train_rrdb_step) \ No newline at end of file diff --git a/codes/models/archs/srflow/thops.py b/codes/models/archs/srflow/thops.py deleted file mode 100644 index 6cbc28b6..00000000 --- a/codes/models/archs/srflow/thops.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch - - -def sum(tensor, dim=None, keepdim=False): - if dim is None: - # sum up all dim - return torch.sum(tensor) - else: - if isinstance(dim, int): - dim = [dim] - dim = sorted(dim) - for d in dim: - tensor = tensor.sum(dim=d, keepdim=True) - if not keepdim: - for i, d in enumerate(dim): - tensor.squeeze_(d-i) - return tensor - - -def mean(tensor, dim=None, keepdim=False): - if dim is None: - # mean all dim - return torch.mean(tensor) - else: - if isinstance(dim, int): - dim = [dim] - dim = sorted(dim) - for d in dim: - tensor = tensor.mean(dim=d, keepdim=True) - if not keepdim: - for i, d in enumerate(dim): - tensor.squeeze_(d-i) - return tensor - - -def split_feature(tensor, type="split"): - """ - type = ["split", "cross"] - """ - C = tensor.size(1) - if type == "split": - return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] - elif type == "cross": - return tensor[:, 0::2, ...], tensor[:, 1::2, ...] - - -def cat_feature(tensor_a, tensor_b): - return torch.cat((tensor_a, tensor_b), dim=1) - - -def pixels(tensor): - return int(tensor.size(2) * tensor.size(3)) \ No newline at end of file