Fix bugs in extensibletrainer

This commit is contained in:
James Betker 2020-09-28 22:09:42 -06:00
parent db52bec4ab
commit f9b83176f1
2 changed files with 11 additions and 5 deletions

View File

@ -4,9 +4,10 @@ from models.networks import define_F
from models.loss import GANLoss
import random
import functools
import torchvision
def create_generator_loss(opt_loss, env):
def create_loss(opt_loss, env):
type = opt_loss['type']
if type == 'pix':
return PixLoss(opt_loss, env)
@ -149,7 +150,6 @@ class GeneratorGanLoss(ConfigurableLoss):
else:
raise NotImplementedError
import torchvision
class DiscriminatorGanLoss(ConfigurableLoss):
def __init__(self, opt, env):

View File

@ -1,7 +1,7 @@
from utils.loss_accumulator import LossAccumulator
from torch.nn import Module
import logging
from models.steps.losses import create_generator_loss
from models.steps.losses import create_loss
import torch
from apex import amp
from collections import OrderedDict
@ -34,7 +34,7 @@ class ConfigurableStep(Module):
self.weights = {}
if 'losses' in self.step_opt.keys():
for loss_name, loss in self.step_opt['losses'].items():
losses.append((loss_name, create_generator_loss(loss, env)))
losses.append((loss_name, create_loss(loss, env)))
self.weights[loss_name] = loss['weight']
self.losses = OrderedDict(losses)
@ -96,6 +96,12 @@ class ConfigurableStep(Module):
else:
return [self.step_opt['training']]
def get_training_network_name(self):
if isinstance(self.step_opt['training'], list):
return self.step_opt['training'][0]
else:
return self.step_opt['training']
# Performs all forward and backward passes for this step given an input state. All input states are lists of
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks.
@ -145,7 +151,7 @@ class ConfigurableStep(Module):
self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'][0],), total_loss)
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
# Scale the loss down by the accumulation factor.
total_loss = total_loss / self.env['mega_batch_factor']