Misc fixes & adjustments
This commit is contained in:
parent
0a9b85f239
commit
886d59d5df
|
@ -54,8 +54,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
step = ConfigurableStep(step, self.env)
|
||||
self.steps.append(step)
|
||||
|
||||
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
|
||||
# yet.
|
||||
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
|
||||
# they aren't wrapped yet.
|
||||
self.env['generators'] = self.netsG
|
||||
self.env['discriminators'] = self.netsD
|
||||
|
||||
|
|
|
@ -118,6 +118,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
d_real = net(state[self.opt['real']])
|
||||
d_fake = net(state[self.opt['fake']].detach())
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
|
||||
l_real = self.criterion(d_real, True)
|
||||
|
@ -129,10 +130,11 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
l_total += l_mreal + l_mfake
|
||||
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
||||
self.metrics.append(("l_fake", l_fake))
|
||||
self.metrics.append(("l_real", l_real))
|
||||
return l_total
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
return (self.cri_gan(d_real - torch.mean(d_fake), True) +
|
||||
self.cri_gan(d_fake - torch.mean(d_real), False))
|
||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||
self.criterion(d_fake - torch.mean(d_real), False))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user