Enable testing in ExtensibleTrainer, fix it in SRGAN_model
Also compute fea loss for this.
This commit is contained in:
parent
b2091cb698
commit
4b4d08bdec
|
@ -30,6 +30,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
'opt': opt,
|
'opt': opt,
|
||||||
'step': 0}
|
'step': 0}
|
||||||
|
|
||||||
|
self.mega_batch_factor = 1
|
||||||
|
if self.is_train:
|
||||||
|
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||||
|
self.env['mega_batch_factor'] = self.mega_batch_factor
|
||||||
|
|
||||||
self.netsG = {}
|
self.netsG = {}
|
||||||
self.netsD = {}
|
self.netsD = {}
|
||||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||||
|
@ -49,71 +54,68 @@ class ExtensibleTrainer(BaseModel):
|
||||||
step = ConfigurableStep(step, self.env)
|
step = ConfigurableStep(step, self.env)
|
||||||
self.steps.append(step)
|
self.steps.append(step)
|
||||||
|
|
||||||
|
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
|
||||||
|
# yet.
|
||||||
|
self.env['generators'] = self.netsG
|
||||||
|
self.env['discriminators'] = self.netsD
|
||||||
|
|
||||||
|
# Define the optimizers from the steps
|
||||||
|
for s in self.steps:
|
||||||
|
s.define_optimizers()
|
||||||
|
self.optimizers.extend(s.get_optimizers())
|
||||||
|
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
|
||||||
if self.mega_batch_factor is None:
|
|
||||||
self.mega_batch_factor = 1
|
|
||||||
self.env['mega_batch_factor'] = self.mega_batch_factor
|
|
||||||
|
|
||||||
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
|
|
||||||
# yet.
|
|
||||||
self.env['generators'] = self.netsG
|
|
||||||
self.env['discriminators'] = self.netsD
|
|
||||||
|
|
||||||
# Define the optimizers from the steps
|
|
||||||
for s in self.steps:
|
|
||||||
s.define_optimizers()
|
|
||||||
self.optimizers.extend(s.get_optimizers())
|
|
||||||
|
|
||||||
# Find the optimizers that are using the default scheduler, then build them.
|
# Find the optimizers that are using the default scheduler, then build them.
|
||||||
def_opt = []
|
def_opt = []
|
||||||
for s in self.steps:
|
for s in self.steps:
|
||||||
def_opt.extend(s.get_optimizers_with_default_scheduler())
|
def_opt.extend(s.get_optimizers_with_default_scheduler())
|
||||||
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
||||||
|
else:
|
||||||
|
self.schedulers = []
|
||||||
|
|
||||||
# Initialize amp.
|
# Initialize amp.
|
||||||
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
||||||
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||||
|
|
||||||
# Unwrap steps & netF
|
# Unwrap steps & netF
|
||||||
self.netF = amp_nets[len(total_nets)]
|
self.netF = amp_nets[len(total_nets)]
|
||||||
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
||||||
self.steps = amp_nets[len(total_nets)+1:]
|
self.steps = amp_nets[len(total_nets)+1:]
|
||||||
amp_nets = amp_nets[:len(total_nets)]
|
amp_nets = amp_nets[:len(total_nets)]
|
||||||
|
|
||||||
# DataParallel
|
# DataParallel
|
||||||
dnets = []
|
dnets = []
|
||||||
for anet in amp_nets:
|
for anet in amp_nets:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
dnet = DistributedDataParallel(anet,
|
dnet = DistributedDataParallel(anet,
|
||||||
device_ids=[torch.cuda.current_device()],
|
device_ids=[torch.cuda.current_device()],
|
||||||
find_unused_parameters=True)
|
find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet)
|
dnet = DataParallel(anet)
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
dnet.train()
|
dnet.train()
|
||||||
else:
|
else:
|
||||||
dnet.eval()
|
dnet.eval()
|
||||||
dnets.append(dnet)
|
dnets.append(dnet)
|
||||||
if not opt['dist']:
|
if not opt['dist']:
|
||||||
self.netF = DataParallel(self.netF)
|
self.netF = DataParallel(self.netF)
|
||||||
|
|
||||||
# Backpush the wrapped networks into the network dicts..
|
# Backpush the wrapped networks into the network dicts..
|
||||||
self.networks = {}
|
self.networks = {}
|
||||||
found = 0
|
found = 0
|
||||||
for dnet in dnets:
|
for dnet in dnets:
|
||||||
for net_dict in [self.netsD, self.netsG]:
|
for net_dict in [self.netsD, self.netsG]:
|
||||||
for k, v in net_dict.items():
|
for k, v in net_dict.items():
|
||||||
if v == dnet.module:
|
if v == dnet.module:
|
||||||
net_dict[k] = dnet
|
net_dict[k] = dnet
|
||||||
self.networks[k] = dnet
|
self.networks[k] = dnet
|
||||||
found += 1
|
found += 1
|
||||||
assert found == len(self.netsG) + len(self.netsD)
|
assert found == len(self.netsG) + len(self.netsD)
|
||||||
|
|
||||||
# Replace the env networks with the wrapped networks
|
# Replace the env networks with the wrapped networks
|
||||||
self.env['generators'] = self.netsG
|
self.env['generators'] = self.netsG
|
||||||
self.env['discriminators'] = self.netsD
|
self.env['discriminators'] = self.netsD
|
||||||
|
|
||||||
self.print_network() # print network
|
self.print_network() # print network
|
||||||
self.load() # load G and D if needed
|
self.load() # load G and D if needed
|
||||||
|
@ -121,7 +123,12 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||||
self.updated = True
|
self.updated = True
|
||||||
|
|
||||||
def feed_data(self, data):
|
def feed_data(self, data, need_GT=False):
|
||||||
|
self.eval_state = {}
|
||||||
|
for o in self.optimizers:
|
||||||
|
o.zero_grad()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||||
|
@ -206,7 +213,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
state[k] = [v]
|
state[k] = [v]
|
||||||
|
|
||||||
self.eval_state = state
|
self.eval_state = {}
|
||||||
|
for k, v in state.items():
|
||||||
|
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
|
||||||
|
|
||||||
for net in self.netsG.values():
|
for net in self.netsG.values():
|
||||||
net.train()
|
net.train()
|
||||||
|
|
|
@ -352,6 +352,9 @@ class SRGANModel(BaseModel):
|
||||||
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
|
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
|
||||||
|
|
||||||
self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50
|
self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50
|
||||||
|
else:
|
||||||
|
self.netF = networks.define_F(use_bn=False).to(self.device)
|
||||||
|
self.cri_fea = nn.L1Loss().to(self.device)
|
||||||
|
|
||||||
#self.print_network() # print network
|
#self.print_network() # print network
|
||||||
self.load() # load G and D if needed
|
self.load() # load G and D if needed
|
||||||
|
|
|
@ -181,7 +181,7 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
||||||
final_temperature_step=opt_net['final_temperature_step'])
|
final_temperature_step=opt_net['final_temperature_step'])
|
||||||
elif which_model == "cross_compare_vgg128":
|
elif which_model == "cross_compare_vgg128":
|
||||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale'])
|
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
|
@ -30,9 +30,10 @@ class ConfigurableStep(Module):
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
self.weights = {}
|
self.weights = {}
|
||||||
for loss_name, loss in self.step_opt['losses'].items():
|
if 'losses' in self.step_opt.keys():
|
||||||
losses.append((loss_name, create_generator_loss(loss, env)))
|
for loss_name, loss in self.step_opt['losses'].items():
|
||||||
self.weights[loss_name] = loss['weight']
|
losses.append((loss_name, create_generator_loss(loss, env)))
|
||||||
|
self.weights[loss_name] = loss['weight']
|
||||||
self.losses = OrderedDict(losses)
|
self.losses = OrderedDict(losses)
|
||||||
|
|
||||||
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
||||||
|
|
|
@ -61,10 +61,8 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
model.feed_data(data, need_GT=need_GT)
|
model.feed_data(data, need_GT=need_GT)
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
if isinstance(model.fake_GenOut[0], tuple):
|
visuals = model.get_current_visuals()['rlt'].cpu()
|
||||||
visuals = model.fake_GenOut[0][0].detach().float().cpu()
|
fea_loss = 0
|
||||||
else:
|
|
||||||
visuals = model.fake_GenOut[0].detach().float().cpu()
|
|
||||||
for i in range(visuals.shape[0]):
|
for i in range(visuals.shape[0]):
|
||||||
img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
|
img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
|
||||||
img_name = osp.splitext(osp.basename(img_path))[0]
|
img_name = osp.splitext(osp.basename(img_path))[0]
|
||||||
|
@ -78,7 +76,10 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
else:
|
else:
|
||||||
save_img_path = osp.join(output_dir, img_name + '.png')
|
save_img_path = osp.join(output_dir, img_name + '.png')
|
||||||
|
|
||||||
|
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
|
||||||
|
|
||||||
util.save_img(sr_img, save_img_path)
|
util.save_img(sr_img, save_img_path)
|
||||||
|
return fea_loss
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -87,7 +88,7 @@ if __name__ == "__main__":
|
||||||
want_just_images = True
|
want_just_images = True
|
||||||
srg_analyze = False
|
srg_analyze = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/analyze_srg.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
|
@ -108,6 +109,7 @@ if __name__ == "__main__":
|
||||||
test_loaders.append(test_loader)
|
test_loaders.append(test_loader)
|
||||||
|
|
||||||
model = create_model(opt)
|
model = create_model(opt)
|
||||||
|
fea_loss = 0
|
||||||
for test_loader in test_loaders:
|
for test_loader in test_loaders:
|
||||||
test_set_name = test_loader.dataset.opt['name']
|
test_set_name = test_loader.dataset.opt['name']
|
||||||
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
||||||
|
@ -143,4 +145,7 @@ if __name__ == "__main__":
|
||||||
model_copy.load_state_dict(orig_model.state_dict())
|
model_copy.load_state_dict(orig_model.state_dict())
|
||||||
model.netG = model_copy
|
model.netG = model_copy
|
||||||
else:
|
else:
|
||||||
forward_pass(model, dataset_dir, opt['name'])
|
fea_loss += forward_pass(model, dataset_dir, opt['name'])
|
||||||
|
|
||||||
|
# log
|
||||||
|
logger.info('# Validation # Fea: {:.4e}'.format(fea_loss / len(test_loader)))
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_xlbatch_ragan.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/srgan_compute_feature.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2_fullimgref.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user