forked from mrq/DL-Art-School
Allow "corruptor" network to be specified
This network is just a fixed (pre-trained) generator that performs a corruption transformation that the generator-in-training is expected to undo alongside SR.
This commit is contained in:
parent
f389025b53
commit
e36f22e14a
|
@ -30,6 +30,12 @@ class SRGANModel(BaseModel):
|
|||
if self.is_train:
|
||||
self.netD = networks.define_D(opt).to(self.device)
|
||||
|
||||
if 'network_C' in opt.keys():
|
||||
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
||||
self.netC.eval()
|
||||
else:
|
||||
self.netC = None
|
||||
|
||||
# define losses, optimizer and scheduler
|
||||
if self.is_train:
|
||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||
|
@ -145,7 +151,15 @@ class SRGANModel(BaseModel):
|
|||
self.load() # load G and D if needed
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ
|
||||
# Corrupt the data with the given corruptor, if specified.
|
||||
self.fed_LQ = data['LQ'].to(self.device)
|
||||
if self.netC:
|
||||
with torch.no_grad():
|
||||
corrupted_L = self.netC(self.fed_LQ)[0].detach()
|
||||
else:
|
||||
corrupted_L = self.fed_LQ
|
||||
|
||||
self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
|
||||
if need_GT:
|
||||
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
|
@ -272,6 +286,7 @@ class SRGANModel(BaseModel):
|
|||
if step % 50 == 0:
|
||||
os.makedirs("temp/hr", exist_ok=True)
|
||||
os.makedirs("temp/lr", exist_ok=True)
|
||||
os.makedirs("temp/lr_precorrupt", exist_ok=True)
|
||||
os.makedirs("temp/gen", exist_ok=True)
|
||||
os.makedirs("temp/pix", exist_ok=True)
|
||||
multi_gen = False
|
||||
|
@ -280,6 +295,9 @@ class SRGANModel(BaseModel):
|
|||
os.makedirs("temp/genmr", exist_ok=True)
|
||||
os.makedirs("temp/ref", exist_ok=True)
|
||||
multi_gen = True
|
||||
|
||||
# fed_LQ is not chunked.
|
||||
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join("temp/lr_precorrupt", "%05i.png" % (step,)))
|
||||
for i in range(self.mega_batch_factor):
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
||||
|
|
|
@ -166,8 +166,8 @@ class FixupResNetV2(FixupResNet):
|
|||
|
||||
def forward(self, x):
|
||||
if self.inject_noise:
|
||||
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
|
||||
x = torch.cat([x, rand_feature], dim=1)
|
||||
rand_feature = torch.randn_like(x)
|
||||
x = x + rand_feature * .1
|
||||
x = self.conv1(x)
|
||||
x = self.lrelu(x + self.bias1)
|
||||
x = self.layer1(x)
|
||||
|
|
|
@ -14,8 +14,8 @@ import models.archs.ResGen_arch as ResGen_arch
|
|||
import math
|
||||
|
||||
# Generator
|
||||
def define_G(opt):
|
||||
opt_net = opt['network_G']
|
||||
def define_G(opt, net_key='network_G'):
|
||||
opt_net = opt[net_key]
|
||||
which_model = opt_net['which_model_G']
|
||||
scale = opt['scale']
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user