forked from mrq/DL-Art-School
Support old school crossgan in extensibletrainer
This commit is contained in:
parent
b593d8e7c3
commit
cbd5e7a986
|
@ -76,9 +76,11 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan', 'crossgan_lrref']:
|
||||||
if self.opt['gan_type'] == 'crossgan':
|
if self.opt['gan_type'] == 'crossgan':
|
||||||
pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref'])
|
pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref'])
|
||||||
|
elif self.opt['gan_type'] == 'crossgan_lrref':
|
||||||
|
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
|
||||||
else:
|
else:
|
||||||
pred_g_fake = netD(state[self.opt['fake']])
|
pred_g_fake = netD(state[self.opt['fake']])
|
||||||
return self.criterion(pred_g_fake, True)
|
return self.criterion(pred_g_fake, True)
|
||||||
|
@ -106,16 +108,22 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0)
|
mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0)
|
||||||
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
||||||
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
||||||
|
elif self.opt['gan_type'] == 'crossgan_lrref':
|
||||||
|
d_real = net(state[self.opt['real']], state['lq'])
|
||||||
|
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
||||||
|
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
|
||||||
|
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
||||||
|
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
||||||
else:
|
else:
|
||||||
d_real = net(state[self.opt['real']])
|
d_real = net(state[self.opt['real']])
|
||||||
d_fake = net(state[self.opt['fake']].detach())
|
d_fake = net(state[self.opt['fake']].detach())
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']:
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
|
||||||
l_real = self.criterion(d_real, True)
|
l_real = self.criterion(d_real, True)
|
||||||
l_fake = self.criterion(d_fake, False)
|
l_fake = self.criterion(d_fake, False)
|
||||||
l_total = l_real + l_fake
|
l_total = l_real + l_fake
|
||||||
if self.opt['gan_type'] == 'crossgan':
|
if 'crossgan' in self.opt['gan_type']:
|
||||||
l_mreal = self.criterion(d_mismatch_real, False)
|
l_mreal = self.criterion(d_mismatch_real, False)
|
||||||
l_mfake = self.criterion(d_mismatch_fake, False)
|
l_mfake = self.criterion(d_mismatch_fake, False)
|
||||||
l_total += l_mreal + l_mfake
|
l_total += l_mreal + l_mfake
|
||||||
|
|
Loading…
Reference in New Issue
Block a user