Misc fixes & adjustments

This commit is contained in:
James Betker 2020-09-01 07:58:11 -06:00
parent 0a9b85f239
commit 886d59d5df
3 changed files with 7 additions and 5 deletions

View File

@ -54,8 +54,8 @@ class ExtensibleTrainer(BaseModel):
step = ConfigurableStep(step, self.env)
self.steps.append(step)
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
# yet.
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
# they aren't wrapped yet.
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD

View File

@ -118,6 +118,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
d_real = net(state[self.opt['real']])
d_fake = net(state[self.opt['fake']].detach())
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
l_real = self.criterion(d_real, True)
@ -129,10 +130,11 @@ class DiscriminatorGanLoss(ConfigurableLoss):
l_total += l_mreal + l_mfake
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
self.metrics.append(("l_fake", l_fake))
self.metrics.append(("l_real", l_real))
return l_total
elif self.opt['gan_type'] == 'ragan':
return (self.cri_gan(d_real - torch.mean(d_fake), True) +
self.cri_gan(d_fake - torch.mean(d_real), False))
return (self.criterion(d_real - torch.mean(d_fake), True) +
self.criterion(d_fake - torch.mean(d_real), False))
else:
raise NotImplementedError

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)