Add noise to teco disc

This commit is contained in:
James Betker 2020-10-27 22:48:23 -06:00
parent 4dc16d5889
commit 2ab5054d4c

View File

@ -238,6 +238,7 @@ class TecoGanLoss(ConfigurableLoss):
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, _, state): def forward(self, _, state):
if self.ff: if self.ff:
@ -288,6 +289,9 @@ class TecoGanLoss(ConfigurableLoss):
def compute_loss(self, real_sext, fake_sext): def compute_loss(self, real_sext, fake_sext):
fp16 = self.env['opt']['fp16'] fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']] net = self.env['discriminators'][self.opt['discriminator']]
if self.noise != 0:
real_sext += torch.randn_like(real_sext) * self.noise
fake_sext += torch.randn_like(fake_sext) * self.noise
with autocast(enabled=fp16): with autocast(enabled=fp16):
d_fake = net(fake_sext) d_fake = net(fake_sext)
d_real = net(real_sext) d_real = net(real_sext)