Fix bugs in extensibletrainer
This commit is contained in:
parent
db52bec4ab
commit
f9b83176f1
|
@ -4,9 +4,10 @@ from models.networks import define_F
|
|||
from models.loss import GANLoss
|
||||
import random
|
||||
import functools
|
||||
import torchvision
|
||||
|
||||
|
||||
def create_generator_loss(opt_loss, env):
|
||||
def create_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
if type == 'pix':
|
||||
return PixLoss(opt_loss, env)
|
||||
|
@ -149,7 +150,6 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
import torchvision
|
||||
|
||||
class DiscriminatorGanLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from utils.loss_accumulator import LossAccumulator
|
||||
from torch.nn import Module
|
||||
import logging
|
||||
from models.steps.losses import create_generator_loss
|
||||
from models.steps.losses import create_loss
|
||||
import torch
|
||||
from apex import amp
|
||||
from collections import OrderedDict
|
||||
|
@ -34,7 +34,7 @@ class ConfigurableStep(Module):
|
|||
self.weights = {}
|
||||
if 'losses' in self.step_opt.keys():
|
||||
for loss_name, loss in self.step_opt['losses'].items():
|
||||
losses.append((loss_name, create_generator_loss(loss, env)))
|
||||
losses.append((loss_name, create_loss(loss, env)))
|
||||
self.weights[loss_name] = loss['weight']
|
||||
self.losses = OrderedDict(losses)
|
||||
|
||||
|
@ -96,6 +96,12 @@ class ConfigurableStep(Module):
|
|||
else:
|
||||
return [self.step_opt['training']]
|
||||
|
||||
def get_training_network_name(self):
|
||||
if isinstance(self.step_opt['training'], list):
|
||||
return self.step_opt['training'][0]
|
||||
else:
|
||||
return self.step_opt['training']
|
||||
|
||||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||
|
@ -145,7 +151,7 @@ class ConfigurableStep(Module):
|
|||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
for n, v in loss.extra_metrics():
|
||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'][0],), total_loss)
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
# Scale the loss down by the accumulation factor.
|
||||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user