Allow "corruptor" network to be specified

This network is just a fixed (pre-trained) generator
that performs a corruption transformation that the
generator-in-training is expected to undo alongside
SR.
This commit is contained in:
James Betker 2020-05-13 15:26:55 -06:00
parent f389025b53
commit e36f22e14a
3 changed files with 23 additions and 5 deletions

View File

@ -30,6 +30,12 @@ class SRGANModel(BaseModel):
if self.is_train:
self.netD = networks.define_D(opt).to(self.device)
if 'network_C' in opt.keys():
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
self.netC.eval()
else:
self.netC = None
# define losses, optimizer and scheduler
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
@ -145,7 +151,15 @@ class SRGANModel(BaseModel):
self.load() # load G and D if needed
def feed_data(self, data, need_GT=True):
self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ
# Corrupt the data with the given corruptor, if specified.
self.fed_LQ = data['LQ'].to(self.device)
if self.netC:
with torch.no_grad():
corrupted_L = self.netC(self.fed_LQ)[0].detach()
else:
corrupted_L = self.fed_LQ
self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
if need_GT:
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT']
@ -272,6 +286,7 @@ class SRGANModel(BaseModel):
if step % 50 == 0:
os.makedirs("temp/hr", exist_ok=True)
os.makedirs("temp/lr", exist_ok=True)
os.makedirs("temp/lr_precorrupt", exist_ok=True)
os.makedirs("temp/gen", exist_ok=True)
os.makedirs("temp/pix", exist_ok=True)
multi_gen = False
@ -280,6 +295,9 @@ class SRGANModel(BaseModel):
os.makedirs("temp/genmr", exist_ok=True)
os.makedirs("temp/ref", exist_ok=True)
multi_gen = True
# fed_LQ is not chunked.
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join("temp/lr_precorrupt", "%05i.png" % (step,)))
for i in range(self.mega_batch_factor):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))

View File

@ -166,8 +166,8 @@ class FixupResNetV2(FixupResNet):
def forward(self, x):
if self.inject_noise:
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
x = torch.cat([x, rand_feature], dim=1)
rand_feature = torch.randn_like(x)
x = x + rand_feature * .1
x = self.conv1(x)
x = self.lrelu(x + self.bias1)
x = self.layer1(x)

View File

@ -14,8 +14,8 @@ import models.archs.ResGen_arch as ResGen_arch
import math
# Generator
def define_G(opt):
opt_net = opt['network_G']
def define_G(opt, net_key='network_G'):
opt_net = opt[net_key]
which_model = opt_net['which_model_G']
scale = opt['scale']