From e36f22e14a9dce7f6b029a36677ad4c3a382fa4b Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 13 May 2020 15:26:55 -0600 Subject: [PATCH] Allow "corruptor" network to be specified This network is just a fixed (pre-trained) generator that performs a corruption transformation that the generator-in-training is expected to undo alongside SR. --- codes/models/SRGAN_model.py | 20 +++++++++++++++++++- codes/models/archs/ResGen_arch.py | 4 ++-- codes/models/networks.py | 4 ++-- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index a52cf685..1b715b05 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -30,6 +30,12 @@ class SRGANModel(BaseModel): if self.is_train: self.netD = networks.define_D(opt).to(self.device) + if 'network_C' in opt.keys(): + self.netC = networks.define_G(opt, net_key='network_C').to(self.device) + self.netC.eval() + else: + self.netC = None + # define losses, optimizer and scheduler if self.is_train: self.mega_batch_factor = train_opt['mega_batch_factor'] @@ -145,7 +151,15 @@ class SRGANModel(BaseModel): self.load() # load G and D if needed def feed_data(self, data, need_GT=True): - self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ + # Corrupt the data with the given corruptor, if specified. + self.fed_LQ = data['LQ'].to(self.device) + if self.netC: + with torch.no_grad(): + corrupted_L = self.netC(self.fed_LQ)[0].detach() + else: + corrupted_L = self.fed_LQ + + self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0) if need_GT: self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data else data['GT'] @@ -272,6 +286,7 @@ class SRGANModel(BaseModel): if step % 50 == 0: os.makedirs("temp/hr", exist_ok=True) os.makedirs("temp/lr", exist_ok=True) + os.makedirs("temp/lr_precorrupt", exist_ok=True) os.makedirs("temp/gen", exist_ok=True) os.makedirs("temp/pix", exist_ok=True) multi_gen = False @@ -280,6 +295,9 @@ class SRGANModel(BaseModel): os.makedirs("temp/genmr", exist_ok=True) os.makedirs("temp/ref", exist_ok=True) multi_gen = True + + # fed_LQ is not chunked. + utils.save_image(self.fed_LQ.cpu().detach(), os.path.join("temp/lr_precorrupt", "%05i.png" % (step,))) for i in range(self.mega_batch_factor): utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i))) diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py index dc857c58..d4352e00 100644 --- a/codes/models/archs/ResGen_arch.py +++ b/codes/models/archs/ResGen_arch.py @@ -166,8 +166,8 @@ class FixupResNetV2(FixupResNet): def forward(self, x): if self.inject_noise: - rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype) - x = torch.cat([x, rand_feature], dim=1) + rand_feature = torch.randn_like(x) + x = x + rand_feature * .1 x = self.conv1(x) x = self.lrelu(x + self.bias1) x = self.layer1(x) diff --git a/codes/models/networks.py b/codes/models/networks.py index 9c8b4847..1bf95aa0 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -14,8 +14,8 @@ import models.archs.ResGen_arch as ResGen_arch import math # Generator -def define_G(opt): - opt_net = opt['network_G'] +def define_G(opt, net_key='network_G'): + opt_net = opt[net_key] which_model = opt_net['which_model_G'] scale = opt['scale']