forked from mrq/DL-Art-School
Allow "corruptor" network to be specified
This network is just a fixed (pre-trained) generator that performs a corruption transformation that the generator-in-training is expected to undo alongside SR.
This commit is contained in:
parent
f389025b53
commit
e36f22e14a
|
@ -30,6 +30,12 @@ class SRGANModel(BaseModel):
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.netD = networks.define_D(opt).to(self.device)
|
self.netD = networks.define_D(opt).to(self.device)
|
||||||
|
|
||||||
|
if 'network_C' in opt.keys():
|
||||||
|
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
||||||
|
self.netC.eval()
|
||||||
|
else:
|
||||||
|
self.netC = None
|
||||||
|
|
||||||
# define losses, optimizer and scheduler
|
# define losses, optimizer and scheduler
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||||
|
@ -145,7 +151,15 @@ class SRGANModel(BaseModel):
|
||||||
self.load() # load G and D if needed
|
self.load() # load G and D if needed
|
||||||
|
|
||||||
def feed_data(self, data, need_GT=True):
|
def feed_data(self, data, need_GT=True):
|
||||||
self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ
|
# Corrupt the data with the given corruptor, if specified.
|
||||||
|
self.fed_LQ = data['LQ'].to(self.device)
|
||||||
|
if self.netC:
|
||||||
|
with torch.no_grad():
|
||||||
|
corrupted_L = self.netC(self.fed_LQ)[0].detach()
|
||||||
|
else:
|
||||||
|
corrupted_L = self.fed_LQ
|
||||||
|
|
||||||
|
self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
|
||||||
if need_GT:
|
if need_GT:
|
||||||
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||||
|
@ -272,6 +286,7 @@ class SRGANModel(BaseModel):
|
||||||
if step % 50 == 0:
|
if step % 50 == 0:
|
||||||
os.makedirs("temp/hr", exist_ok=True)
|
os.makedirs("temp/hr", exist_ok=True)
|
||||||
os.makedirs("temp/lr", exist_ok=True)
|
os.makedirs("temp/lr", exist_ok=True)
|
||||||
|
os.makedirs("temp/lr_precorrupt", exist_ok=True)
|
||||||
os.makedirs("temp/gen", exist_ok=True)
|
os.makedirs("temp/gen", exist_ok=True)
|
||||||
os.makedirs("temp/pix", exist_ok=True)
|
os.makedirs("temp/pix", exist_ok=True)
|
||||||
multi_gen = False
|
multi_gen = False
|
||||||
|
@ -280,6 +295,9 @@ class SRGANModel(BaseModel):
|
||||||
os.makedirs("temp/genmr", exist_ok=True)
|
os.makedirs("temp/genmr", exist_ok=True)
|
||||||
os.makedirs("temp/ref", exist_ok=True)
|
os.makedirs("temp/ref", exist_ok=True)
|
||||||
multi_gen = True
|
multi_gen = True
|
||||||
|
|
||||||
|
# fed_LQ is not chunked.
|
||||||
|
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join("temp/lr_precorrupt", "%05i.png" % (step,)))
|
||||||
for i in range(self.mega_batch_factor):
|
for i in range(self.mega_batch_factor):
|
||||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
||||||
|
|
|
@ -166,8 +166,8 @@ class FixupResNetV2(FixupResNet):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.inject_noise:
|
if self.inject_noise:
|
||||||
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
|
rand_feature = torch.randn_like(x)
|
||||||
x = torch.cat([x, rand_feature], dim=1)
|
x = x + rand_feature * .1
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.lrelu(x + self.bias1)
|
x = self.lrelu(x + self.bias1)
|
||||||
x = self.layer1(x)
|
x = self.layer1(x)
|
||||||
|
|
|
@ -14,8 +14,8 @@ import models.archs.ResGen_arch as ResGen_arch
|
||||||
import math
|
import math
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
def define_G(opt):
|
def define_G(opt, net_key='network_G'):
|
||||||
opt_net = opt['network_G']
|
opt_net = opt[net_key]
|
||||||
which_model = opt_net['which_model_G']
|
which_model = opt_net['which_model_G']
|
||||||
scale = opt['scale']
|
scale = opt['scale']
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user