Add in experiments hook

This commit is contained in:
James Betker 2020-09-19 10:05:25 -06:00
parent 4f75cf0f02
commit e2a146abc7
3 changed files with 103 additions and 2 deletions

View File

@ -25,7 +25,7 @@ def _get_paths_from_images(path):
images = [] images = []
for dirpath, _, fnames in sorted(os.walk(path)): for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames): for fname in sorted(fnames):
if is_image_file(fname): if is_image_file(fname) and 'ref.jpg' not in fname:
img_path = os.path.join(dirpath, fname) img_path = os.path.join(dirpath, fname)
images.append(img_path) images.append(img_path)
assert images, '{:s} has no valid image file'.format(path) assert images, '{:s} has no valid image file'.format(path)

View File

@ -10,6 +10,7 @@ import models.lr_scheduler as lr_scheduler
import models.networks as networks import models.networks as networks
from models.base_model import BaseModel from models.base_model import BaseModel
from models.steps.steps import ConfigurableStep from models.steps.steps import ConfigurableStep
from models.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils import torchvision.utils as utils
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -37,6 +38,7 @@ class ExtensibleTrainer(BaseModel):
self.netsG = {} self.netsG = {}
self.netsD = {} self.netsD = {}
# Note that this is on the chopping block. It should be integrated into an injection point.
self.netF = networks.define_F().to(self.device) # Used to compute feature loss. self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
for name, net in opt['networks'].items(): for name, net in opt['networks'].items():
# Trainable is a required parameter, but the default is simply true. Set it here. # Trainable is a required parameter, but the default is simply true. Set it here.
@ -56,9 +58,11 @@ class ExtensibleTrainer(BaseModel):
new_net.eval() new_net.eval()
# Initialize the train/eval steps # Initialize the train/eval steps
self.step_names = []
self.steps = [] self.steps = []
for step_name, step in opt['steps'].items(): for step_name, step in opt['steps'].items():
step = ConfigurableStep(step, self.env) step = ConfigurableStep(step, self.env)
self.step_names.append(step_name) # This could be an OrderedDict, but it's a PITA to integrate with AMP below.
self.steps.append(step) self.steps.append(step)
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though # step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
@ -91,11 +95,12 @@ class ExtensibleTrainer(BaseModel):
amp_opts = self.optimizers amp_opts = self.optimizers
self.env['amp'] = False self.env['amp'] = False
# Unwrap steps & netF # Unwrap steps & netF & optimizers
self.netF = amp_nets[len(total_nets)] self.netF = amp_nets[len(total_nets)]
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:])) assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
self.steps = amp_nets[len(total_nets)+1:] self.steps = amp_nets[len(total_nets)+1:]
amp_nets = amp_nets[:len(total_nets)] amp_nets = amp_nets[:len(total_nets)]
self.optimizers = amp_opts
# DataParallel # DataParallel
dnets = [] dnets = []
@ -133,6 +138,11 @@ class ExtensibleTrainer(BaseModel):
self.print_network() # print network self.print_network() # print network
self.load() # load G and D if needed self.load() # load G and D if needed
# Load experiments
self.experiments = []
if 'experiments' in opt.keys():
self.experiments = [get_experiment_for_name(e) for e in op['experiments']]
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True self.updated = True
@ -185,6 +195,9 @@ class ExtensibleTrainer(BaseModel):
p.requires_grad = False p.requires_grad = False
assert enabled == len(nets_to_train) assert enabled == len(nets_to_train)
# Update experiments
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
for o in s.get_optimizers(): for o in s.get_optimizers():
o.zero_grad() o.zero_grad()
@ -205,7 +218,9 @@ class ExtensibleTrainer(BaseModel):
state[k] = v state[k] = v
# And finally perform optimization. # And finally perform optimization.
[e.before_optimize(state) for e in self.experiments]
s.do_step() s.do_step()
[e.after_optimize(state) for e in self.experiments]
# Record visual outputs for usage in debugging and testing. # Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys(): if 'visuals' in self.opt['logger'].keys():
@ -252,6 +267,9 @@ class ExtensibleTrainer(BaseModel):
for s in self.steps: for s in self.steps:
log.update(s.get_metrics()) log.update(s.get_metrics())
for e in self.experiments:
log.update(e.get_log_data())
# Some generators can do their own metric logging. # Some generators can do their own metric logging.
for net_name, net in self.networks.items(): for net_name, net in self.networks.items():
if hasattr(net.module, "get_debug_values"): if hasattr(net.module, "get_debug_values"):

View File

@ -0,0 +1,83 @@
import torch
def get_experiment_for_name(name):
return Experiment()
# Experiments are ways to add hooks into the ExtensibleTrainer training process with the intent of reporting the
# inner workings of the process in a custom manner that is unsuitable for addition elsewhere.
class Experiment:
def before_step(self, opt, step_name, env, nets_to_train, pre_state):
pass
def before_optimize(self, state):
pass
def after_optimize(self, state):
pass
def get_log_data(self):
pass
class ModelParameterDepthTrackerMetrics(Experiment):
# Subclasses should implement these two methods:
def get_network_and_step_names(self):
# Subclasses should return the network being debugged and the step name it is trained in. return: (net, stepname)
pass
def get_layers_to_debug(self, env, net, state):
# Subclasses should populate self.layers with a list of per-layer nn.Modules here.
pass
def before_step(self, opt, step_name, env, nets_to_train, pre_state):
self.net, step = self.get_network_and_step_names()
self.activate = self.net in nets_to_train and step == step_name and self.step_num % opt['logger']['print_freq'] == 0
if self.activate:
layers = self.get_layers_to_debug(env, env['networks'][self.net], pre_state)
self.params = []
for l in layers:
lparams = []
for k, v in env['networks'][self.net].named_parameters(): # can optimize for a part of the model
if v.requires_grad:
lparams.append(v)
self.params.append(lparams)
def before_optimize(self, state):
self.cached_params = []
for l in self.params:
lparams = []
for p in l:
lparams.append(p.value().cpu())
self.cached_params.append(lparams)
def after_optimize(self, state):
# Compute the abs mean difference across the params.
self.layer_means = []
for l, lc in zip(self.params, self.cached_params):
sum = torch.tensor(0)
for p, pc in zip(l, lc):
sum += torch.abs(pc - p.value().cpu())
self.layer_means.append(sum / len(l))
def get_log_data(self):
return {'%s_layer_update_means_histogram' % (self.net,): self.layer_means}
class DiscriminatorParameterTracker(ModelParameterDepthTrackerMetrics):
def get_network_and_step_names(self):
return "feature_discriminator", "feature_discriminator"
def get_layers_to_debug(self, env, net, state):
return [net.ref_head.conv0_0,
net.ref_head.conv0_1,
net.ref_head.conv1_0,
net.ref_head.conv1_1,
net.ref_head.conv2_0,
net.ref_head.conv2_1,
net.ref_head.conv3_0,
net.ref_head.conv3_1,
net.ref_head.conv4_0,
net.ref_head.conv4_1,
net.linear1,
net.output_linears]