forked from mrq/DL-Art-School
Fix bugs in extensibletrainer
This commit is contained in:
parent
db52bec4ab
commit
f9b83176f1
|
@ -4,9 +4,10 @@ from models.networks import define_F
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss
|
||||||
import random
|
import random
|
||||||
import functools
|
import functools
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
def create_generator_loss(opt_loss, env):
|
def create_loss(opt_loss, env):
|
||||||
type = opt_loss['type']
|
type = opt_loss['type']
|
||||||
if type == 'pix':
|
if type == 'pix':
|
||||||
return PixLoss(opt_loss, env)
|
return PixLoss(opt_loss, env)
|
||||||
|
@ -149,7 +150,6 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
class DiscriminatorGanLoss(ConfigurableLoss):
|
class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from utils.loss_accumulator import LossAccumulator
|
from utils.loss_accumulator import LossAccumulator
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
import logging
|
import logging
|
||||||
from models.steps.losses import create_generator_loss
|
from models.steps.losses import create_loss
|
||||||
import torch
|
import torch
|
||||||
from apex import amp
|
from apex import amp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -34,7 +34,7 @@ class ConfigurableStep(Module):
|
||||||
self.weights = {}
|
self.weights = {}
|
||||||
if 'losses' in self.step_opt.keys():
|
if 'losses' in self.step_opt.keys():
|
||||||
for loss_name, loss in self.step_opt['losses'].items():
|
for loss_name, loss in self.step_opt['losses'].items():
|
||||||
losses.append((loss_name, create_generator_loss(loss, env)))
|
losses.append((loss_name, create_loss(loss, env)))
|
||||||
self.weights[loss_name] = loss['weight']
|
self.weights[loss_name] = loss['weight']
|
||||||
self.losses = OrderedDict(losses)
|
self.losses = OrderedDict(losses)
|
||||||
|
|
||||||
|
@ -96,6 +96,12 @@ class ConfigurableStep(Module):
|
||||||
else:
|
else:
|
||||||
return [self.step_opt['training']]
|
return [self.step_opt['training']]
|
||||||
|
|
||||||
|
def get_training_network_name(self):
|
||||||
|
if isinstance(self.step_opt['training'], list):
|
||||||
|
return self.step_opt['training'][0]
|
||||||
|
else:
|
||||||
|
return self.step_opt['training']
|
||||||
|
|
||||||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||||
|
@ -145,7 +151,7 @@ class ConfigurableStep(Module):
|
||||||
self.loss_accumulator.add_loss(loss_name, l)
|
self.loss_accumulator.add_loss(loss_name, l)
|
||||||
for n, v in loss.extra_metrics():
|
for n, v in loss.extra_metrics():
|
||||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||||
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'][0],), total_loss)
|
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||||
# Scale the loss down by the accumulation factor.
|
# Scale the loss down by the accumulation factor.
|
||||||
total_loss = total_loss / self.env['mega_batch_factor']
|
total_loss = total_loss / self.env['mega_batch_factor']
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user