diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index da4ed187..afecd82b 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -4,9 +4,10 @@ from models.networks import define_F from models.loss import GANLoss import random import functools +import torchvision -def create_generator_loss(opt_loss, env): +def create_loss(opt_loss, env): type = opt_loss['type'] if type == 'pix': return PixLoss(opt_loss, env) @@ -149,7 +150,6 @@ class GeneratorGanLoss(ConfigurableLoss): else: raise NotImplementedError -import torchvision class DiscriminatorGanLoss(ConfigurableLoss): def __init__(self, opt, env): diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 4037fb09..f8e3b326 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -1,7 +1,7 @@ from utils.loss_accumulator import LossAccumulator from torch.nn import Module import logging -from models.steps.losses import create_generator_loss +from models.steps.losses import create_loss import torch from apex import amp from collections import OrderedDict @@ -34,7 +34,7 @@ class ConfigurableStep(Module): self.weights = {} if 'losses' in self.step_opt.keys(): for loss_name, loss in self.step_opt['losses'].items(): - losses.append((loss_name, create_generator_loss(loss, env))) + losses.append((loss_name, create_loss(loss, env))) self.weights[loss_name] = loss['weight'] self.losses = OrderedDict(losses) @@ -96,6 +96,12 @@ class ConfigurableStep(Module): else: return [self.step_opt['training']] + def get_training_network_name(self): + if isinstance(self.step_opt['training'], list): + return self.step_opt['training'][0] + else: + return self.step_opt['training'] + # Performs all forward and backward passes for this step given an input state. All input states are lists of # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later # steps might use. These tensors are automatically detached and accumulated into chunks. @@ -145,7 +151,7 @@ class ConfigurableStep(Module): self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) - self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'][0],), total_loss) + self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) # Scale the loss down by the accumulation factor. total_loss = total_loss / self.env['mega_batch_factor']