Support old school crossgan in extensibletrainer

This commit is contained in:
James Betker 2020-08-26 17:52:35 -06:00
parent b593d8e7c3
commit cbd5e7a986

View File

@ -76,9 +76,11 @@ class GeneratorGanLoss(ConfigurableLoss):
def forward(self, net, state): def forward(self, net, state):
netD = self.env['discriminators'][self.opt['discriminator']] netD = self.env['discriminators'][self.opt['discriminator']]
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan', 'crossgan_lrref']:
if self.opt['gan_type'] == 'crossgan': if self.opt['gan_type'] == 'crossgan':
pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref']) pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref'])
elif self.opt['gan_type'] == 'crossgan_lrref':
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
else: else:
pred_g_fake = netD(state[self.opt['fake']]) pred_g_fake = netD(state[self.opt['fake']])
return self.criterion(pred_g_fake, True) return self.criterion(pred_g_fake, True)
@ -106,16 +108,22 @@ class DiscriminatorGanLoss(ConfigurableLoss):
mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0) mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0)
d_mismatch_real = net(state[self.opt['real']], mismatched_lq) d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq) d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
elif self.opt['gan_type'] == 'crossgan_lrref':
d_real = net(state[self.opt['real']], state['lq'])
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
else: else:
d_real = net(state[self.opt['real']]) d_real = net(state[self.opt['real']])
d_fake = net(state[self.opt['fake']].detach()) d_fake = net(state[self.opt['fake']].detach())
self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_fake", torch.mean(d_fake)))
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']: if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
l_real = self.criterion(d_real, True) l_real = self.criterion(d_real, True)
l_fake = self.criterion(d_fake, False) l_fake = self.criterion(d_fake, False)
l_total = l_real + l_fake l_total = l_real + l_fake
if self.opt['gan_type'] == 'crossgan': if 'crossgan' in self.opt['gan_type']:
l_mreal = self.criterion(d_mismatch_real, False) l_mreal = self.criterion(d_mismatch_real, False)
l_mfake = self.criterion(d_mismatch_fake, False) l_mfake = self.criterion(d_mismatch_fake, False)
l_total += l_mreal + l_mfake l_total += l_mreal + l_mfake