import math import torch import torch.nn as nn import torch.nn.functional as F import torchvision import numpy as np from models.archs.srflow_orig.RRDBNet_arch import RRDBNet from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet import models.archs.srflow_orig.thops as thops import models.archs.srflow_orig.flow as flow from utils.util import opt_get class SRFlowNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None): super(SRFlowNet, self).__init__() self.opt = opt self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ None else opt_get(opt, ['datasets', 'train', 'quant']) initial_stride = opt_get(opt, ['networks', 'generator', 'initial_stride'], 1) self.RRDB = RRDBNet(in_nc, out_nc, nf=nf, nb=nb, gc=gc, scale=scale, opt=opt, initial_conv_stride=initial_stride) if 'pretrain_rrdb' in opt['networks']['generator'].keys(): rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb']) self.RRDB.load_state_dict(rrdb_state_dict, strict=True) hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels']) hidden_channels = hidden_channels or 64 self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False) self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) #