From cbd5e7a9866ff54fc84152c4d36569b3f8b9f1fe Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 26 Aug 2020 17:52:35 -0600 Subject: [PATCH] Support old school crossgan in extensibletrainer --- codes/models/steps/losses.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 7b161a4d..b99d5f73 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -76,9 +76,11 @@ class GeneratorGanLoss(ConfigurableLoss): def forward(self, net, state): netD = self.env['discriminators'][self.opt['discriminator']] - if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: + if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan', 'crossgan_lrref']: if self.opt['gan_type'] == 'crossgan': pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref']) + elif self.opt['gan_type'] == 'crossgan_lrref': + pred_g_fake = netD(state[self.opt['fake']], state['lq']) else: pred_g_fake = netD(state[self.opt['fake']]) return self.criterion(pred_g_fake, True) @@ -106,16 +108,22 @@ class DiscriminatorGanLoss(ConfigurableLoss): mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0) d_mismatch_real = net(state[self.opt['real']], mismatched_lq) d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq) + elif self.opt['gan_type'] == 'crossgan_lrref': + d_real = net(state[self.opt['real']], state['lq']) + d_fake = net(state[self.opt['fake']].detach(), state['lq']) + mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0) + d_mismatch_real = net(state[self.opt['real']], mismatched_lq) + d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq) else: d_real = net(state[self.opt['real']]) d_fake = net(state[self.opt['fake']].detach()) self.metrics.append(("d_fake", torch.mean(d_fake))) - if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']: + if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']: l_real = self.criterion(d_real, True) l_fake = self.criterion(d_fake, False) l_total = l_real + l_fake - if self.opt['gan_type'] == 'crossgan': + if 'crossgan' in self.opt['gan_type']: l_mreal = self.criterion(d_mismatch_real, False) l_mfake = self.criterion(d_mismatch_fake, False) l_total += l_mreal + l_mfake