Misc fixes & adjustments
This commit is contained in:
parent
0a9b85f239
commit
886d59d5df
|
@ -54,8 +54,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
step = ConfigurableStep(step, self.env)
|
step = ConfigurableStep(step, self.env)
|
||||||
self.steps.append(step)
|
self.steps.append(step)
|
||||||
|
|
||||||
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
|
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
|
||||||
# yet.
|
# they aren't wrapped yet.
|
||||||
self.env['generators'] = self.netsG
|
self.env['generators'] = self.netsG
|
||||||
self.env['discriminators'] = self.netsD
|
self.env['discriminators'] = self.netsD
|
||||||
|
|
||||||
|
|
|
@ -118,6 +118,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
d_real = net(state[self.opt['real']])
|
d_real = net(state[self.opt['real']])
|
||||||
d_fake = net(state[self.opt['fake']].detach())
|
d_fake = net(state[self.opt['fake']].detach())
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
|
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
|
||||||
l_real = self.criterion(d_real, True)
|
l_real = self.criterion(d_real, True)
|
||||||
|
@ -129,10 +130,11 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
l_total += l_mreal + l_mfake
|
l_total += l_mreal + l_mfake
|
||||||
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
||||||
self.metrics.append(("l_fake", l_fake))
|
self.metrics.append(("l_fake", l_fake))
|
||||||
|
self.metrics.append(("l_real", l_real))
|
||||||
return l_total
|
return l_total
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
return (self.cri_gan(d_real - torch.mean(d_fake), True) +
|
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||||
self.cri_gan(d_fake - torch.mean(d_real), False))
|
self.criterion(d_fake - torch.mean(d_real), False))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user