Enable testing in ExtensibleTrainer, fix it in SRGAN_model

Also compute fea loss for this.
This commit is contained in:
James Betker 2020-08-31 09:41:48 -06:00
parent b2091cb698
commit 4b4d08bdec
7 changed files with 86 additions and 68 deletions

View File

@ -30,6 +30,11 @@ class ExtensibleTrainer(BaseModel):
'opt': opt,
'step': 0}
self.mega_batch_factor = 1
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
self.env['mega_batch_factor'] = self.mega_batch_factor
self.netsG = {}
self.netsD = {}
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
@ -49,71 +54,68 @@ class ExtensibleTrainer(BaseModel):
step = ConfigurableStep(step, self.env)
self.steps.append(step)
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
# yet.
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
# Define the optimizers from the steps
for s in self.steps:
s.define_optimizers()
self.optimizers.extend(s.get_optimizers())
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
if self.mega_batch_factor is None:
self.mega_batch_factor = 1
self.env['mega_batch_factor'] = self.mega_batch_factor
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
# yet.
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
# Define the optimizers from the steps
for s in self.steps:
s.define_optimizers()
self.optimizers.extend(s.get_optimizers())
# Find the optimizers that are using the default scheduler, then build them.
def_opt = []
for s in self.steps:
def_opt.extend(s.get_optimizers_with_default_scheduler())
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
else:
self.schedulers = []
# Initialize amp.
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
# Initialize amp.
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
# Unwrap steps & netF
self.netF = amp_nets[len(total_nets)]
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
self.steps = amp_nets[len(total_nets)+1:]
amp_nets = amp_nets[:len(total_nets)]
# Unwrap steps & netF
self.netF = amp_nets[len(total_nets)]
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
self.steps = amp_nets[len(total_nets)+1:]
amp_nets = amp_nets[:len(total_nets)]
# DataParallel
dnets = []
for anet in amp_nets:
if opt['dist']:
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
else:
dnet = DataParallel(anet)
if self.is_train:
dnet.train()
else:
dnet.eval()
dnets.append(dnet)
if not opt['dist']:
self.netF = DataParallel(self.netF)
# DataParallel
dnets = []
for anet in amp_nets:
if opt['dist']:
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
else:
dnet = DataParallel(anet)
if self.is_train:
dnet.train()
else:
dnet.eval()
dnets.append(dnet)
if not opt['dist']:
self.netF = DataParallel(self.netF)
# Backpush the wrapped networks into the network dicts..
self.networks = {}
found = 0
for dnet in dnets:
for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items():
if v == dnet.module:
net_dict[k] = dnet
self.networks[k] = dnet
found += 1
assert found == len(self.netsG) + len(self.netsD)
# Backpush the wrapped networks into the network dicts..
self.networks = {}
found = 0
for dnet in dnets:
for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items():
if v == dnet.module:
net_dict[k] = dnet
self.networks[k] = dnet
found += 1
assert found == len(self.netsG) + len(self.netsD)
# Replace the env networks with the wrapped networks
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
# Replace the env networks with the wrapped networks
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
self.print_network() # print network
self.load() # load G and D if needed
@ -121,7 +123,12 @@ class ExtensibleTrainer(BaseModel):
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True
def feed_data(self, data):
def feed_data(self, data, need_GT=False):
self.eval_state = {}
for o in self.optimizers:
o.zero_grad()
torch.cuda.empty_cache()
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT']
@ -206,7 +213,9 @@ class ExtensibleTrainer(BaseModel):
for k, v in ns.items():
state[k] = [v]
self.eval_state = state
self.eval_state = {}
for k, v in state.items():
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
for net in self.netsG.values():
net.train()

View File

@ -352,6 +352,9 @@ class SRGANModel(BaseModel):
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50
else:
self.netF = networks.define_F(use_bn=False).to(self.device)
self.cri_fea = nn.L1Loss().to(self.device)
#self.print_network() # print network
self.load() # load G and D if needed

View File

@ -181,7 +181,7 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
final_temperature_step=opt_net['final_temperature_step'])
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale'])
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD

View File

@ -30,9 +30,10 @@ class ConfigurableStep(Module):
losses = []
self.weights = {}
for loss_name, loss in self.step_opt['losses'].items():
losses.append((loss_name, create_generator_loss(loss, env)))
self.weights[loss_name] = loss['weight']
if 'losses' in self.step_opt.keys():
for loss_name, loss in self.step_opt['losses'].items():
losses.append((loss_name, create_generator_loss(loss, env)))
self.weights[loss_name] = loss['weight']
self.losses = OrderedDict(losses)
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.

View File

@ -61,10 +61,8 @@ def forward_pass(model, output_dir, alteration_suffix=''):
model.feed_data(data, need_GT=need_GT)
model.test()
if isinstance(model.fake_GenOut[0], tuple):
visuals = model.fake_GenOut[0][0].detach().float().cpu()
else:
visuals = model.fake_GenOut[0].detach().float().cpu()
visuals = model.get_current_visuals()['rlt'].cpu()
fea_loss = 0
for i in range(visuals.shape[0]):
img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
img_name = osp.splitext(osp.basename(img_path))[0]
@ -78,7 +76,10 @@ def forward_pass(model, output_dir, alteration_suffix=''):
else:
save_img_path = osp.join(output_dir, img_name + '.png')
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
util.save_img(sr_img, save_img_path)
return fea_loss
if __name__ == "__main__":
@ -87,7 +88,7 @@ if __name__ == "__main__":
want_just_images = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/analyze_srg.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
@ -108,6 +109,7 @@ if __name__ == "__main__":
test_loaders.append(test_loader)
model = create_model(opt)
fea_loss = 0
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
logger.info('\nTesting [{:s}]...'.format(test_set_name))
@ -143,4 +145,7 @@ if __name__ == "__main__":
model_copy.load_state_dict(orig_model.state_dict())
model.netG = model_copy
else:
forward_pass(model, dataset_dir, opt['name'])
fea_loss += forward_pass(model, dataset_dir, opt['name'])
# log
logger.info('# Validation # Fea: {:.4e}'.format(fea_loss / len(test_loader)))

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_xlbatch_ragan.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/srgan_compute_feature.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2_fullimgref.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)