forked from mrq/DL-Art-School
Support old school crossgan in extensibletrainer
This commit is contained in:
parent
b593d8e7c3
commit
cbd5e7a986
|
@ -76,9 +76,11 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
|
||||
def forward(self, net, state):
|
||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan', 'crossgan_lrref']:
|
||||
if self.opt['gan_type'] == 'crossgan':
|
||||
pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref'])
|
||||
elif self.opt['gan_type'] == 'crossgan_lrref':
|
||||
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
|
||||
else:
|
||||
pred_g_fake = netD(state[self.opt['fake']])
|
||||
return self.criterion(pred_g_fake, True)
|
||||
|
@ -106,16 +108,22 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0)
|
||||
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
||||
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
||||
elif self.opt['gan_type'] == 'crossgan_lrref':
|
||||
d_real = net(state[self.opt['real']], state['lq'])
|
||||
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
||||
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
|
||||
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
||||
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
||||
else:
|
||||
d_real = net(state[self.opt['real']])
|
||||
d_fake = net(state[self.opt['fake']].detach())
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']:
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
|
||||
l_real = self.criterion(d_real, True)
|
||||
l_fake = self.criterion(d_fake, False)
|
||||
l_total = l_real + l_fake
|
||||
if self.opt['gan_type'] == 'crossgan':
|
||||
if 'crossgan' in self.opt['gan_type']:
|
||||
l_mreal = self.criterion(d_mismatch_real, False)
|
||||
l_mfake = self.criterion(d_mismatch_fake, False)
|
||||
l_total += l_mreal + l_mfake
|
||||
|
|
Loading…
Reference in New Issue
Block a user