forked from mrq/DL-Art-School
Add in experiments hook
This commit is contained in:
parent
4f75cf0f02
commit
e2a146abc7
|
@ -25,7 +25,7 @@ def _get_paths_from_images(path):
|
|||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname):
|
||||
if is_image_file(fname) and 'ref.jpg' not in fname:
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
images.append(img_path)
|
||||
assert images, '{:s} has no valid image file'.format(path)
|
||||
|
|
|
@ -10,6 +10,7 @@ import models.lr_scheduler as lr_scheduler
|
|||
import models.networks as networks
|
||||
from models.base_model import BaseModel
|
||||
from models.steps.steps import ConfigurableStep
|
||||
from models.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
@ -37,6 +38,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
self.netsG = {}
|
||||
self.netsD = {}
|
||||
# Note that this is on the chopping block. It should be integrated into an injection point.
|
||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||
for name, net in opt['networks'].items():
|
||||
# Trainable is a required parameter, but the default is simply true. Set it here.
|
||||
|
@ -56,9 +58,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
new_net.eval()
|
||||
|
||||
# Initialize the train/eval steps
|
||||
self.step_names = []
|
||||
self.steps = []
|
||||
for step_name, step in opt['steps'].items():
|
||||
step = ConfigurableStep(step, self.env)
|
||||
self.step_names.append(step_name) # This could be an OrderedDict, but it's a PITA to integrate with AMP below.
|
||||
self.steps.append(step)
|
||||
|
||||
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
|
||||
|
@ -91,11 +95,12 @@ class ExtensibleTrainer(BaseModel):
|
|||
amp_opts = self.optimizers
|
||||
self.env['amp'] = False
|
||||
|
||||
# Unwrap steps & netF
|
||||
# Unwrap steps & netF & optimizers
|
||||
self.netF = amp_nets[len(total_nets)]
|
||||
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
||||
self.steps = amp_nets[len(total_nets)+1:]
|
||||
amp_nets = amp_nets[:len(total_nets)]
|
||||
self.optimizers = amp_opts
|
||||
|
||||
# DataParallel
|
||||
dnets = []
|
||||
|
@ -133,6 +138,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.print_network() # print network
|
||||
self.load() # load G and D if needed
|
||||
|
||||
# Load experiments
|
||||
self.experiments = []
|
||||
if 'experiments' in opt.keys():
|
||||
self.experiments = [get_experiment_for_name(e) for e in op['experiments']]
|
||||
|
||||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||
self.updated = True
|
||||
|
||||
|
@ -185,6 +195,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
p.requires_grad = False
|
||||
assert enabled == len(nets_to_train)
|
||||
|
||||
# Update experiments
|
||||
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
|
||||
|
||||
for o in s.get_optimizers():
|
||||
o.zero_grad()
|
||||
|
||||
|
@ -205,7 +218,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
state[k] = v
|
||||
|
||||
# And finally perform optimization.
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
s.do_step()
|
||||
[e.after_optimize(state) for e in self.experiments]
|
||||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['logger'].keys():
|
||||
|
@ -252,6 +267,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
for s in self.steps:
|
||||
log.update(s.get_metrics())
|
||||
|
||||
for e in self.experiments:
|
||||
log.update(e.get_log_data())
|
||||
|
||||
# Some generators can do their own metric logging.
|
||||
for net_name, net in self.networks.items():
|
||||
if hasattr(net.module, "get_debug_values"):
|
||||
|
|
83
codes/models/experiments/experiments.py
Normal file
83
codes/models/experiments/experiments.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
import torch
|
||||
|
||||
def get_experiment_for_name(name):
|
||||
return Experiment()
|
||||
|
||||
|
||||
# Experiments are ways to add hooks into the ExtensibleTrainer training process with the intent of reporting the
|
||||
# inner workings of the process in a custom manner that is unsuitable for addition elsewhere.
|
||||
class Experiment:
|
||||
def before_step(self, opt, step_name, env, nets_to_train, pre_state):
|
||||
pass
|
||||
|
||||
def before_optimize(self, state):
|
||||
pass
|
||||
|
||||
def after_optimize(self, state):
|
||||
pass
|
||||
|
||||
def get_log_data(self):
|
||||
pass
|
||||
|
||||
|
||||
class ModelParameterDepthTrackerMetrics(Experiment):
|
||||
# Subclasses should implement these two methods:
|
||||
def get_network_and_step_names(self):
|
||||
# Subclasses should return the network being debugged and the step name it is trained in. return: (net, stepname)
|
||||
pass
|
||||
def get_layers_to_debug(self, env, net, state):
|
||||
# Subclasses should populate self.layers with a list of per-layer nn.Modules here.
|
||||
pass
|
||||
|
||||
|
||||
def before_step(self, opt, step_name, env, nets_to_train, pre_state):
|
||||
self.net, step = self.get_network_and_step_names()
|
||||
self.activate = self.net in nets_to_train and step == step_name and self.step_num % opt['logger']['print_freq'] == 0
|
||||
if self.activate:
|
||||
layers = self.get_layers_to_debug(env, env['networks'][self.net], pre_state)
|
||||
self.params = []
|
||||
for l in layers:
|
||||
lparams = []
|
||||
for k, v in env['networks'][self.net].named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
lparams.append(v)
|
||||
self.params.append(lparams)
|
||||
|
||||
def before_optimize(self, state):
|
||||
self.cached_params = []
|
||||
for l in self.params:
|
||||
lparams = []
|
||||
for p in l:
|
||||
lparams.append(p.value().cpu())
|
||||
self.cached_params.append(lparams)
|
||||
|
||||
def after_optimize(self, state):
|
||||
# Compute the abs mean difference across the params.
|
||||
self.layer_means = []
|
||||
for l, lc in zip(self.params, self.cached_params):
|
||||
sum = torch.tensor(0)
|
||||
for p, pc in zip(l, lc):
|
||||
sum += torch.abs(pc - p.value().cpu())
|
||||
self.layer_means.append(sum / len(l))
|
||||
|
||||
def get_log_data(self):
|
||||
return {'%s_layer_update_means_histogram' % (self.net,): self.layer_means}
|
||||
|
||||
|
||||
class DiscriminatorParameterTracker(ModelParameterDepthTrackerMetrics):
|
||||
def get_network_and_step_names(self):
|
||||
return "feature_discriminator", "feature_discriminator"
|
||||
|
||||
def get_layers_to_debug(self, env, net, state):
|
||||
return [net.ref_head.conv0_0,
|
||||
net.ref_head.conv0_1,
|
||||
net.ref_head.conv1_0,
|
||||
net.ref_head.conv1_1,
|
||||
net.ref_head.conv2_0,
|
||||
net.ref_head.conv2_1,
|
||||
net.ref_head.conv3_0,
|
||||
net.ref_head.conv3_1,
|
||||
net.ref_head.conv4_0,
|
||||
net.ref_head.conv4_1,
|
||||
net.linear1,
|
||||
net.output_linears]
|
Loading…
Reference in New Issue
Block a user