forked from mrq/DL-Art-School
More ExtensibleTrainer work
It runs now, just need to debug it to reach performance parity with SRGAN. Sweet.
This commit is contained in:
parent
afdd93fbe9
commit
dffc15184d
|
@ -48,6 +48,8 @@ def get_scheduler_for_opt(opt):
|
||||||
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
|
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
|
||||||
elif opt['type'] == 'sinusoidal':
|
elif opt['type'] == 'sinusoidal':
|
||||||
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
|
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
# Do some testing.
|
# Do some testing.
|
||||||
|
|
|
@ -23,18 +23,16 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.rank = -1 # non dist training
|
self.rank = -1 # non dist training
|
||||||
train_opt = opt['train']
|
train_opt = opt['train']
|
||||||
self.mega_batch_factor = 1
|
|
||||||
|
|
||||||
# env is used as a global state to store things that subcomponents might need.
|
# env is used as a global state to store things that subcomponents might need.
|
||||||
env = {'device': self.device,
|
self.env = {'device': self.device,
|
||||||
'rank': self.rank,
|
'rank': self.rank,
|
||||||
'opt': opt}
|
'opt': opt,
|
||||||
|
'step': 0}
|
||||||
|
|
||||||
self.netsG = {}
|
self.netsG = {}
|
||||||
self.netsD = {}
|
self.netsD = {}
|
||||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||||
self.networks = []
|
|
||||||
self.visuals = {}
|
|
||||||
for name, net in opt['networks'].items():
|
for name, net in opt['networks'].items():
|
||||||
if net['type'] == 'generator':
|
if net['type'] == 'generator':
|
||||||
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
||||||
|
@ -44,18 +42,45 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.netsD[name] = new_net
|
self.netsD[name] = new_net
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Can only handle generators and discriminators")
|
raise NotImplementedError("Can only handle generators and discriminators")
|
||||||
self.networks.append(new_net)
|
|
||||||
|
# Initialize the train/eval steps
|
||||||
|
self.steps = []
|
||||||
|
for step_name, step in opt['steps'].items():
|
||||||
|
step = ConfigurableStep(step, self.env)
|
||||||
|
self.steps.append(step)
|
||||||
|
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||||
if self.mega_batch_factor is None:
|
if self.mega_batch_factor is None:
|
||||||
self.mega_batch_factor = 1
|
self.mega_batch_factor = 1
|
||||||
|
self.env['mega_batch_factor'] = self.mega_batch_factor
|
||||||
|
|
||||||
|
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
|
||||||
|
# yet.
|
||||||
|
self.env['generators'] = self.netsG
|
||||||
|
self.env['discriminators'] = self.netsD
|
||||||
|
|
||||||
|
# Define the optimizers from the steps
|
||||||
|
for s in self.steps:
|
||||||
|
s.define_optimizers()
|
||||||
|
self.optimizers.extend(s.get_optimizers())
|
||||||
|
|
||||||
|
# Find the optimizers that are using the default scheduler, then build them.
|
||||||
|
def_opt = []
|
||||||
|
for s in self.steps:
|
||||||
|
def_opt.extend(s.get_optimizers_with_default_scheduler())
|
||||||
|
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
||||||
|
|
||||||
# Initialize amp.
|
# Initialize amp.
|
||||||
amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
# self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use
|
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
||||||
# self.netG and self.netD for that.
|
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||||
self.networks = amp_nets
|
|
||||||
|
# Unwrap steps & netF
|
||||||
|
self.netF = amp_nets[len(total_nets)]
|
||||||
|
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
||||||
|
self.steps = amp_nets[len(total_nets)+1:]
|
||||||
|
amp_nets = amp_nets[:len(total_nets)]
|
||||||
|
|
||||||
# DataParallel
|
# DataParallel
|
||||||
dnets = []
|
dnets = []
|
||||||
|
@ -71,32 +96,24 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
dnet.eval()
|
dnet.eval()
|
||||||
dnets.append(dnet)
|
dnets.append(dnet)
|
||||||
|
if not opt['dist']:
|
||||||
|
self.netF = DataParallel(self.netF)
|
||||||
|
|
||||||
# Backpush the wrapped networks into the network dicts..
|
# Backpush the wrapped networks into the network dicts..
|
||||||
|
self.networks = {}
|
||||||
found = 0
|
found = 0
|
||||||
for dnet in dnets:
|
for dnet in dnets:
|
||||||
for net_dict in [self.netsD, self.netsG]:
|
for net_dict in [self.netsD, self.netsG]:
|
||||||
for k, v in net_dict.items():
|
for k, v in net_dict.items():
|
||||||
if v == dnet.module:
|
if v == dnet.module:
|
||||||
net_dict[k] = dnet
|
net_dict[k] = dnet
|
||||||
|
self.networks[k] = dnet
|
||||||
found += 1
|
found += 1
|
||||||
assert found == len(self.networks)
|
assert found == len(self.netsG) + len(self.netsD)
|
||||||
|
|
||||||
env['generators'] = self.netsG
|
# Replace the env networks with the wrapped networks
|
||||||
env['discriminators'] = self.netsD
|
self.env['generators'] = self.netsG
|
||||||
|
self.env['discriminators'] = self.netsD
|
||||||
# Initialize the training steps
|
|
||||||
self.steps = []
|
|
||||||
for step_name, step in opt['steps'].items():
|
|
||||||
step = ConfigurableStep(step, env)
|
|
||||||
self.steps.append(step)
|
|
||||||
self.optimizers.extend(step.get_optimizers())
|
|
||||||
|
|
||||||
# Find the optimizers that are using the default scheduler, then build them.
|
|
||||||
def_opt = []
|
|
||||||
for s in self.steps:
|
|
||||||
def_opt.extend(s.get_optimizers_with_default_scheduler())
|
|
||||||
lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
|
||||||
|
|
||||||
self.print_network() # print network
|
self.print_network() # print network
|
||||||
self.load() # load G and D if needed
|
self.load() # load G and D if needed
|
||||||
|
@ -105,30 +122,38 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.updated = True
|
self.updated = True
|
||||||
|
|
||||||
def feed_data(self, data):
|
def feed_data(self, data):
|
||||||
self.lq = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
|
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||||
|
|
||||||
def optimize_parameters(self, step):
|
def optimize_parameters(self, step):
|
||||||
|
self.env['step'] = step
|
||||||
|
|
||||||
# Some models need to make parametric adjustments per-step. Do that here.
|
# Some models need to make parametric adjustments per-step. Do that here.
|
||||||
for net in self.networks.values():
|
for net in self.networks.values():
|
||||||
if hasattr(net, "update_for_step"):
|
if hasattr(net, "update_for_step"):
|
||||||
net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||||
|
|
||||||
# Iterate through the steps, performing them one at a time.
|
# Iterate through the steps, performing them one at a time.
|
||||||
self.visuals = {}
|
|
||||||
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||||
for step_num, s in enumerate(self.steps):
|
for step_num, s in enumerate(self.steps):
|
||||||
# Only set requires_grad=True for the network being trained.
|
# Only set requires_grad=True for the network being trained.
|
||||||
nets_to_train = s.get_networks_trained()
|
nets_to_train = s.get_networks_trained()
|
||||||
|
enabled = 0
|
||||||
for name, net in self.networks.items():
|
for name, net in self.networks.items():
|
||||||
net_enabled = name in nets_to_train
|
net_enabled = name in nets_to_train
|
||||||
for p in self.netsG.parameters():
|
if net_enabled:
|
||||||
|
enabled += 1
|
||||||
|
for p in net.parameters():
|
||||||
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||||
p.requires_grad = net_enabled
|
p.requires_grad = net_enabled
|
||||||
else:
|
else:
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
assert enabled == len(nets_to_train)
|
||||||
|
|
||||||
|
for o in s.get_optimizers():
|
||||||
|
o.zero_grad()
|
||||||
|
|
||||||
# Now do a forward and backward pass for each gradient accumulation step.
|
# Now do a forward and backward pass for each gradient accumulation step.
|
||||||
new_states = {}
|
new_states = {}
|
||||||
|
@ -136,13 +161,13 @@ class ExtensibleTrainer(BaseModel):
|
||||||
ns = s.do_forward_backward(state, m, step_num)
|
ns = s.do_forward_backward(state, m, step_num)
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
if k not in new_states.keys():
|
if k not in new_states.keys():
|
||||||
new_states[k] = [v.detach()]
|
new_states[k] = [v]
|
||||||
else:
|
else:
|
||||||
new_states[k].append(v.detach())
|
new_states[k].append(v)
|
||||||
|
|
||||||
# Push the detached new state tensors into the state map for use with the next step.
|
# Push the detached new state tensors into the state map for use with the next step.
|
||||||
for k, v in new_states.items():
|
for k, v in new_states.items():
|
||||||
# Overwriting existing state keys is not supported.
|
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
||||||
assert k not in state.keys()
|
assert k not in state.keys()
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
|
@ -150,17 +175,14 @@ class ExtensibleTrainer(BaseModel):
|
||||||
s.do_step()
|
s.do_step()
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['train'].keys():
|
if 'visuals' in self.opt['logger'].keys():
|
||||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||||
for v in self.opt['train']['visuals']:
|
for v in self.opt['logger']['visuals']:
|
||||||
self.visuals[v] = state[v].detach().cpu()
|
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
if step % self.opt['train']['visual_debug_rate'] == 0:
|
for i, dbgv in enumerate(state[v]):
|
||||||
for i, dbgv in enumerate(self.visuals[v]):
|
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||||
|
|
||||||
# TODO: Do logging and image dumps
|
|
||||||
|
|
||||||
def compute_fea_loss(self, real, fake):
|
def compute_fea_loss(self, real, fake):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits_real = self.netF(real)
|
logits_real = self.netF(real)
|
||||||
|
@ -173,12 +195,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Iterate through the steps, performing them one at a time.
|
# Iterate through the steps, performing them one at a time.
|
||||||
self.visuals = {}
|
|
||||||
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||||
for step_num, s in enumerate(self.steps):
|
for step_num, s in enumerate(self.steps):
|
||||||
ns = s.do_forward_backward(state, 0, step_num, backward=False)
|
ns = s.do_forward_backward(state, 0, step_num, backward=False)
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
state[k] = [v.detach()]
|
state[k] = [v]
|
||||||
|
|
||||||
self.eval_state = state
|
self.eval_state = state
|
||||||
|
|
||||||
|
@ -192,7 +213,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
log.update(s.get_metrics())
|
log.update(s.get_metrics())
|
||||||
|
|
||||||
# Some generators can do their own metric logging.
|
# Some generators can do their own metric logging.
|
||||||
for net in self.networks:
|
for net in self.networks.values():
|
||||||
if hasattr(net.module, "get_debug_values"):
|
if hasattr(net.module, "get_debug_values"):
|
||||||
log.update(net.module.get_debug_values(step))
|
log.update(net.module.get_debug_values(step))
|
||||||
return log
|
return log
|
||||||
|
@ -204,17 +225,17 @@ class ExtensibleTrainer(BaseModel):
|
||||||
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||||
|
|
||||||
def print_network(self):
|
def print_network(self):
|
||||||
for net in self.networks:
|
for name, net in self.networks.items():
|
||||||
s, n = self.get_network_description(net)
|
s, n = self.get_network_description(net)
|
||||||
net_struc_str = '{}'.format(net.__class__.__name__)
|
net_struc_str = '{}'.format(net.__class__.__name__)
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
|
||||||
logger.info(s)
|
logger.info(s)
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
for netdict in [self.netsG, self.netsD]:
|
for netdict in [self.netsG, self.netsD]:
|
||||||
for name, net in netdict.items():
|
for name, net in netdict.items():
|
||||||
load_path = self.opt['path'][name]
|
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||||
if load_path is not None:
|
if load_path is not None:
|
||||||
logger.info('Loading model for [%s]' % (load_path))
|
logger.info('Loading model for [%s]' % (load_path))
|
||||||
self.load_network(load_path, net)
|
self.load_network(load_path, net)
|
||||||
|
@ -222,3 +243,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
def save(self, iter_step):
|
def save(self, iter_step):
|
||||||
for name, net in self.networks.items():
|
for name, net in self.networks.items():
|
||||||
self.save_network(net, name, iter_step)
|
self.save_network(net, name, iter_step)
|
||||||
|
|
||||||
|
def force_restore_swapout(self):
|
||||||
|
# Legacy method. Do nothing.
|
||||||
|
pass
|
|
@ -385,7 +385,7 @@ class SRGANModel(BaseModel):
|
||||||
print("Misc setup %f" % (time() - _t,))
|
print("Misc setup %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
if step >= self.D_init_iters:
|
if step >= self.init_iters:
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
self.fake_GenOut = []
|
self.fake_GenOut = []
|
||||||
self.fea_GenOut = []
|
self.fea_GenOut = []
|
||||||
|
|
|
@ -223,10 +223,10 @@ def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None
|
||||||
load_net_clean[k] = v
|
load_net_clean[k] = v
|
||||||
netF.load_state_dict(load_net_clean)
|
netF.load_state_dict(load_net_clean)
|
||||||
|
|
||||||
|
if not for_training:
|
||||||
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
||||||
netF.eval()
|
netF.eval()
|
||||||
for k, v in netF.named_parameters():
|
for k, v in netF.named_parameters():
|
||||||
v.requires_grad = False
|
v.requires_grad = False
|
||||||
netF.fdisc_weight = opt['weight']
|
|
||||||
|
|
||||||
return netF
|
return netF
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
|
from data.weight_scheduler import get_scheduler_for_opt
|
||||||
|
|
||||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||||
def create_injector(opt_inject, env):
|
def create_injector(opt_inject, env):
|
||||||
type = opt_inject['type']
|
type = opt_inject['type']
|
||||||
if type == 'img_grad':
|
if type == 'generator':
|
||||||
|
return ImageGeneratorInjector(opt_inject, env)
|
||||||
|
elif type == 'scheduled_scalar':
|
||||||
|
return ScheduledScalarInjector(opt_inject, env)
|
||||||
|
elif type == 'img_grad':
|
||||||
return ImageGradientInjector(opt_inject, env)
|
return ImageGradientInjector(opt_inject, env)
|
||||||
elif type == 'add_noise':
|
elif type == 'add_noise':
|
||||||
return AddNoiseInjector(opt_inject, env)
|
return AddNoiseInjector(opt_inject, env)
|
||||||
|
@ -19,7 +24,8 @@ class Injector(torch.nn.Module):
|
||||||
super(Injector, self).__init__()
|
super(Injector, self).__init__()
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.env = env
|
self.env = env
|
||||||
self.input = opt['in']
|
if 'in' in opt.keys():
|
||||||
|
self.input = opt['in']
|
||||||
self.output = opt['out']
|
self.output = opt['out']
|
||||||
|
|
||||||
# This should return a dict of new state variables.
|
# This should return a dict of new state variables.
|
||||||
|
@ -27,23 +33,59 @@ class Injector(torch.nn.Module):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
||||||
|
# Note that results are *not* detached.
|
||||||
|
class ImageGeneratorInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(ImageGeneratorInjector, self).__init__(opt, env)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
results = gen(state[self.input])
|
||||||
|
new_state = {}
|
||||||
|
if isinstance(self.output, list):
|
||||||
|
for i, k in enumerate(self.output):
|
||||||
|
new_state[k] = results[i]
|
||||||
|
else:
|
||||||
|
new_state[self.output] = results
|
||||||
|
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
|
||||||
# Creates an image gradient from [in] and injects it into [out]
|
# Creates an image gradient from [in] and injects it into [out]
|
||||||
class ImageGradientInjector(Injector):
|
class ImageGradientInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(ImageGradientInjector, self).__init__(opt, env)
|
super(ImageGradientInjector, self).__init__(opt, env)
|
||||||
self.img_grad_fn = ImageGradientNoPadding()
|
self.img_grad_fn = ImageGradientNoPadding().to(env['device'])
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
|
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
|
||||||
|
|
||||||
|
|
||||||
|
# Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence
|
||||||
|
# of something over time.
|
||||||
|
class ScheduledScalarInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(ScheduledScalarInjector, self).__init__(opt, env)
|
||||||
|
self.scheduler = get_scheduler_for_opt(opt['scheduler'])
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
return {self.opt['out']: self.scheduler.get_weight_for_step(self.env['step'])}
|
||||||
|
|
||||||
|
|
||||||
# Adds gaussian noise to [in], scales it to [0,[scale]] and injects into [out]
|
# Adds gaussian noise to [in], scales it to [0,[scale]] and injects into [out]
|
||||||
class AddNoiseInjector(Injector):
|
class AddNoiseInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(AddNoiseInjector, self).__init__(opt, env)
|
super(AddNoiseInjector, self).__init__(opt, env)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
noise = torch.randn_like(state[self.opt['in']]) * self.opt['scale']
|
# Scale can be a fixed float, or a state key (e.g. from ScheduledScalarInjector).
|
||||||
|
if isinstance(self.opt['scale'], str):
|
||||||
|
scale = state[self.opt['scale']]
|
||||||
|
else:
|
||||||
|
scale = self.opt['scale']
|
||||||
|
|
||||||
|
noise = torch.randn_like(state[self.opt['in']], device=self.env['device']) * scale
|
||||||
return {self.opt['out']: state[self.opt['in']] + noise}
|
return {self.opt['out']: state[self.opt['in']] + noise}
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,4 +98,4 @@ class GreyInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
|
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
|
||||||
mean = torch.repeat(mean, (-1, 3, -1, -1))
|
mean = torch.repeat(mean, (-1, 3, -1, -1))
|
||||||
return {self.opt['out']: mean}
|
return {self.opt['out']: mean}
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from models.networks import define_F
|
from models.networks import define_F
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
|
|
||||||
def create_generator_loss(opt_loss, env):
|
def create_generator_loss(opt_loss, env):
|
||||||
|
@ -23,10 +24,14 @@ class ConfigurableLoss(nn.Module):
|
||||||
super(ConfigurableLoss, self).__init__()
|
super(ConfigurableLoss, self).__init__()
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.env = env
|
self.env = env
|
||||||
|
self.metrics = []
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def extra_metrics(self):
|
||||||
|
return self.metrics
|
||||||
|
|
||||||
|
|
||||||
def get_basic_criterion_for_name(name, device):
|
def get_basic_criterion_for_name(name, device):
|
||||||
if name == 'l1':
|
if name == 'l1':
|
||||||
|
@ -53,6 +58,8 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||||
|
if not env['opt']['dist']:
|
||||||
|
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -66,18 +73,18 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
super(GeneratorGanLoss, self).__init__(opt, env)
|
super(GeneratorGanLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
self.netD = env['discriminators'][opt['discriminator']]
|
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||||
if self.opt['gan_type'] == 'crossgan':
|
if self.opt['gan_type'] == 'crossgan':
|
||||||
pred_g_fake = self.netD(state[self.opt['fake']], state['lq'])
|
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
|
||||||
else:
|
else:
|
||||||
pred_g_fake = self.netD(state[self.opt['fake']])
|
pred_g_fake = netD(state[self.opt['fake']])
|
||||||
return self.criterion(pred_g_fake, True)
|
return self.criterion(pred_g_fake, True)
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
pred_d_real = self.netD(state[self.opt['real']]).detach()
|
pred_d_real = netD(state[self.opt['real']]).detach()
|
||||||
pred_g_fake = self.netD(state[self.opt['fake']])
|
pred_g_fake = netD(state[self.opt['fake']])
|
||||||
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
else:
|
else:
|
||||||
|
@ -91,16 +98,33 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
self.metrics = []
|
||||||
|
|
||||||
|
if self.opt['gan_type'] == 'crossgan':
|
||||||
|
d_real = net(state[self.opt['real']], state['lq'])
|
||||||
|
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
||||||
|
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
|
||||||
|
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
||||||
|
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
||||||
|
else:
|
||||||
|
d_real = net(state[self.opt['real']])
|
||||||
|
d_fake = net(state[self.opt['fake']].detach())
|
||||||
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
|
|
||||||
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']:
|
||||||
|
l_real = self.criterion(d_real, True)
|
||||||
|
l_fake = self.criterion(d_fake, False)
|
||||||
|
l_total = l_real + l_fake
|
||||||
if self.opt['gan_type'] == 'crossgan':
|
if self.opt['gan_type'] == 'crossgan':
|
||||||
pred_g_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
l_mreal = self.criterion(d_mismatch_real, False)
|
||||||
else:
|
l_mfake = self.criterion(d_mismatch_fake, False)
|
||||||
pred_g_fake = net(state[self.opt['fake']].detach())
|
l_total += l_mreal + l_mfake
|
||||||
return self.criterion(pred_g_fake, False)
|
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
||||||
|
self.metrics.append(("l_fake", l_fake))
|
||||||
|
return l_total
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
pred_d_real = self.netD(state[self.opt['real']])
|
return (self.cri_gan(d_real - torch.mean(d_fake), True) +
|
||||||
pred_g_fake = self.netD(state[self.opt['fake']].detach())
|
self.cri_gan(d_fake - torch.mean(d_real), False))
|
||||||
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), True) +
|
|
||||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), False)) / 2
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,9 @@ class ConfigurableStep(Module):
|
||||||
self.step_opt = opt_step
|
self.step_opt = opt_step
|
||||||
self.env = env
|
self.env = env
|
||||||
self.opt = env['opt']
|
self.opt = env['opt']
|
||||||
self.gen = env['generators'][opt_step['generator']]
|
|
||||||
self.discs = env['discriminators']
|
|
||||||
self.gen_outputs = opt_step['generator_outputs']
|
self.gen_outputs = opt_step['generator_outputs']
|
||||||
self.training_net = env['generators'][opt_step['training']] if opt_step['training'] in env['generators'].keys() else env['discriminators'][opt_step['training']]
|
|
||||||
self.loss_accumulator = LossAccumulator()
|
self.loss_accumulator = LossAccumulator()
|
||||||
|
self.optimizers = None
|
||||||
|
|
||||||
self.injectors = []
|
self.injectors = []
|
||||||
if 'injectors' in self.step_opt.keys():
|
if 'injectors' in self.step_opt.keys():
|
||||||
|
@ -37,12 +35,13 @@ class ConfigurableStep(Module):
|
||||||
self.weights[loss_name] = loss['weight']
|
self.weights[loss_name] = loss['weight']
|
||||||
self.losses = OrderedDict(losses)
|
self.losses = OrderedDict(losses)
|
||||||
|
|
||||||
# Intentionally abstract so subclasses can have alternative optimizers.
|
|
||||||
self.define_optimizers()
|
|
||||||
|
|
||||||
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
||||||
# This default implementation defines a single optimizer for all Generator parameters.
|
# This default implementation defines a single optimizer for all Generator parameters.
|
||||||
|
# Must be called after networks are initialized and wrapped.
|
||||||
def define_optimizers(self):
|
def define_optimizers(self):
|
||||||
|
self.training_net = self.env['generators'][self.step_opt['training']] \
|
||||||
|
if self.step_opt['training'] in self.env['generators'].keys() \
|
||||||
|
else self.env['discriminators'][self.step_opt['training']]
|
||||||
optim_params = []
|
optim_params = []
|
||||||
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
|
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
|
@ -73,12 +72,7 @@ class ConfigurableStep(Module):
|
||||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True):
|
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True):
|
||||||
# First, do a forward pass with the generator.
|
|
||||||
results = self.gen(state[self.step_opt['generator_input']][grad_accum_step])
|
|
||||||
# Extract the resultants into a "new_state" dict per the configuration.
|
|
||||||
new_state = {}
|
new_state = {}
|
||||||
for i, gen_out in enumerate(self.gen_outputs):
|
|
||||||
new_state[gen_out] = results[i]
|
|
||||||
|
|
||||||
# Prepare a de-chunked state dict which will be used for the injectors & losses.
|
# Prepare a de-chunked state dict which will be used for the injectors & losses.
|
||||||
local_state = {}
|
local_state = {}
|
||||||
|
@ -97,17 +91,26 @@ class ConfigurableStep(Module):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for loss_name, loss in self.losses.items():
|
for loss_name, loss in self.losses.items():
|
||||||
l = loss(self.training_net, local_state)
|
l = loss(self.training_net, local_state)
|
||||||
self.loss_accumulator.add_loss(loss_name, l)
|
|
||||||
total_loss += l * self.weights[loss_name]
|
total_loss += l * self.weights[loss_name]
|
||||||
self.loss_accumulator.add_loss("total", total_loss)
|
# Record metrics.
|
||||||
|
self.loss_accumulator.add_loss(loss_name, l)
|
||||||
|
for n, v in loss.extra_metrics():
|
||||||
|
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||||
|
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'],), total_loss)
|
||||||
|
# Scale the loss down by the accumulation factor.
|
||||||
|
total_loss = total_loss / self.env['mega_batch_factor']
|
||||||
|
|
||||||
# Get dem grads!
|
# Get dem grads!
|
||||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
|
|
||||||
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
|
# we must release the gradients.
|
||||||
|
for k, v in new_state.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_state[k] = v.detach()
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
|
||||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||||
# all self.optimizers.
|
# all self.optimizers.
|
||||||
def do_step(self):
|
def do_step(self):
|
||||||
|
|
|
@ -112,14 +112,21 @@ def check_resume(opt, resume_iter):
|
||||||
'pretrain_model_D', None) is not None:
|
'pretrain_model_D', None) is not None:
|
||||||
logger.warning('pretrain_model path will be ignored when resuming training.')
|
logger.warning('pretrain_model path will be ignored when resuming training.')
|
||||||
|
|
||||||
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
if opt['model'] == 'extensibletrainer':
|
||||||
'{}_G.pth'.format(resume_iter))
|
for k in opt['networks'].keys():
|
||||||
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
pt_key = 'pretrain_model_%s' % (k,)
|
||||||
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
||||||
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
'{}_{}.pth'.format(resume_iter, k))
|
||||||
'{}_D.pth'.format(resume_iter))
|
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
||||||
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
else:
|
||||||
if 'spsr' in opt['model']:
|
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
||||||
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
'{}_G.pth'.format(resume_iter))
|
||||||
'{}_D_grad.pth'.format(resume_iter))
|
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
||||||
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
||||||
|
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
||||||
|
'{}_D.pth'.format(resume_iter))
|
||||||
|
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
||||||
|
if 'spsr' in opt['model']:
|
||||||
|
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
||||||
|
'{}_D_grad.pth'.format(resume_iter))
|
||||||
|
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
|
@ -161,7 +161,7 @@ def main():
|
||||||
current_step = resume_state['iter']
|
current_step = resume_state['iter']
|
||||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||||
else:
|
else:
|
||||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
current_step = 0 if 'start_step' not in opt.keys() else opt['start_step']
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
|
||||||
#### training
|
#### training
|
||||||
|
@ -215,7 +215,7 @@ def main():
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
#### validation
|
#### validation
|
||||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation
|
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||||
model.force_restore_swapout()
|
model.force_restore_swapout()
|
||||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||||
# does not support multi-GPU validation
|
# does not support multi-GPU validation
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
|
|
||||||
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
|
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
|
||||||
class LossAccumulator:
|
class LossAccumulator:
|
||||||
def __init__(self, buffer_sz=10):
|
def __init__(self, buffer_sz=50):
|
||||||
self.buffer_sz = buffer_sz
|
self.buffer_sz = buffer_sz
|
||||||
self.buffers = {}
|
self.buffers = {}
|
||||||
|
|
||||||
|
@ -15,6 +15,6 @@ class LossAccumulator:
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
result = {}
|
result = {}
|
||||||
for k, v in self.buffers:
|
for k, v in self.buffers.items():
|
||||||
result["loss_" + k] = torch.mean(v)
|
result["loss_" + k] = torch.mean(v[1])
|
||||||
return result
|
return result
|
Loading…
Reference in New Issue
Block a user