forked from mrq/DL-Art-School
Add in experiments hook
This commit is contained in:
parent
4f75cf0f02
commit
e2a146abc7
|
@ -25,7 +25,7 @@ def _get_paths_from_images(path):
|
||||||
images = []
|
images = []
|
||||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||||
for fname in sorted(fnames):
|
for fname in sorted(fnames):
|
||||||
if is_image_file(fname):
|
if is_image_file(fname) and 'ref.jpg' not in fname:
|
||||||
img_path = os.path.join(dirpath, fname)
|
img_path = os.path.join(dirpath, fname)
|
||||||
images.append(img_path)
|
images.append(img_path)
|
||||||
assert images, '{:s} has no valid image file'.format(path)
|
assert images, '{:s} has no valid image file'.format(path)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import models.lr_scheduler as lr_scheduler
|
||||||
import models.networks as networks
|
import models.networks as networks
|
||||||
from models.base_model import BaseModel
|
from models.base_model import BaseModel
|
||||||
from models.steps.steps import ConfigurableStep
|
from models.steps.steps import ConfigurableStep
|
||||||
|
from models.experiments.experiments import get_experiment_for_name
|
||||||
import torchvision.utils as utils
|
import torchvision.utils as utils
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -37,6 +38,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
self.netsG = {}
|
self.netsG = {}
|
||||||
self.netsD = {}
|
self.netsD = {}
|
||||||
|
# Note that this is on the chopping block. It should be integrated into an injection point.
|
||||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||||
for name, net in opt['networks'].items():
|
for name, net in opt['networks'].items():
|
||||||
# Trainable is a required parameter, but the default is simply true. Set it here.
|
# Trainable is a required parameter, but the default is simply true. Set it here.
|
||||||
|
@ -56,9 +58,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
new_net.eval()
|
new_net.eval()
|
||||||
|
|
||||||
# Initialize the train/eval steps
|
# Initialize the train/eval steps
|
||||||
|
self.step_names = []
|
||||||
self.steps = []
|
self.steps = []
|
||||||
for step_name, step in opt['steps'].items():
|
for step_name, step in opt['steps'].items():
|
||||||
step = ConfigurableStep(step, self.env)
|
step = ConfigurableStep(step, self.env)
|
||||||
|
self.step_names.append(step_name) # This could be an OrderedDict, but it's a PITA to integrate with AMP below.
|
||||||
self.steps.append(step)
|
self.steps.append(step)
|
||||||
|
|
||||||
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
|
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
|
||||||
|
@ -91,11 +95,12 @@ class ExtensibleTrainer(BaseModel):
|
||||||
amp_opts = self.optimizers
|
amp_opts = self.optimizers
|
||||||
self.env['amp'] = False
|
self.env['amp'] = False
|
||||||
|
|
||||||
# Unwrap steps & netF
|
# Unwrap steps & netF & optimizers
|
||||||
self.netF = amp_nets[len(total_nets)]
|
self.netF = amp_nets[len(total_nets)]
|
||||||
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
||||||
self.steps = amp_nets[len(total_nets)+1:]
|
self.steps = amp_nets[len(total_nets)+1:]
|
||||||
amp_nets = amp_nets[:len(total_nets)]
|
amp_nets = amp_nets[:len(total_nets)]
|
||||||
|
self.optimizers = amp_opts
|
||||||
|
|
||||||
# DataParallel
|
# DataParallel
|
||||||
dnets = []
|
dnets = []
|
||||||
|
@ -133,6 +138,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.print_network() # print network
|
self.print_network() # print network
|
||||||
self.load() # load G and D if needed
|
self.load() # load G and D if needed
|
||||||
|
|
||||||
|
# Load experiments
|
||||||
|
self.experiments = []
|
||||||
|
if 'experiments' in opt.keys():
|
||||||
|
self.experiments = [get_experiment_for_name(e) for e in op['experiments']]
|
||||||
|
|
||||||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||||
self.updated = True
|
self.updated = True
|
||||||
|
|
||||||
|
@ -185,6 +195,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
assert enabled == len(nets_to_train)
|
assert enabled == len(nets_to_train)
|
||||||
|
|
||||||
|
# Update experiments
|
||||||
|
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
|
||||||
|
|
||||||
for o in s.get_optimizers():
|
for o in s.get_optimizers():
|
||||||
o.zero_grad()
|
o.zero_grad()
|
||||||
|
|
||||||
|
@ -205,7 +218,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
# And finally perform optimization.
|
# And finally perform optimization.
|
||||||
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
s.do_step()
|
s.do_step()
|
||||||
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys():
|
if 'visuals' in self.opt['logger'].keys():
|
||||||
|
@ -252,6 +267,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
for s in self.steps:
|
for s in self.steps:
|
||||||
log.update(s.get_metrics())
|
log.update(s.get_metrics())
|
||||||
|
|
||||||
|
for e in self.experiments:
|
||||||
|
log.update(e.get_log_data())
|
||||||
|
|
||||||
# Some generators can do their own metric logging.
|
# Some generators can do their own metric logging.
|
||||||
for net_name, net in self.networks.items():
|
for net_name, net in self.networks.items():
|
||||||
if hasattr(net.module, "get_debug_values"):
|
if hasattr(net.module, "get_debug_values"):
|
||||||
|
|
83
codes/models/experiments/experiments.py
Normal file
83
codes/models/experiments/experiments.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def get_experiment_for_name(name):
|
||||||
|
return Experiment()
|
||||||
|
|
||||||
|
|
||||||
|
# Experiments are ways to add hooks into the ExtensibleTrainer training process with the intent of reporting the
|
||||||
|
# inner workings of the process in a custom manner that is unsuitable for addition elsewhere.
|
||||||
|
class Experiment:
|
||||||
|
def before_step(self, opt, step_name, env, nets_to_train, pre_state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_optimize(self, state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_optimize(self, state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_log_data(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelParameterDepthTrackerMetrics(Experiment):
|
||||||
|
# Subclasses should implement these two methods:
|
||||||
|
def get_network_and_step_names(self):
|
||||||
|
# Subclasses should return the network being debugged and the step name it is trained in. return: (net, stepname)
|
||||||
|
pass
|
||||||
|
def get_layers_to_debug(self, env, net, state):
|
||||||
|
# Subclasses should populate self.layers with a list of per-layer nn.Modules here.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def before_step(self, opt, step_name, env, nets_to_train, pre_state):
|
||||||
|
self.net, step = self.get_network_and_step_names()
|
||||||
|
self.activate = self.net in nets_to_train and step == step_name and self.step_num % opt['logger']['print_freq'] == 0
|
||||||
|
if self.activate:
|
||||||
|
layers = self.get_layers_to_debug(env, env['networks'][self.net], pre_state)
|
||||||
|
self.params = []
|
||||||
|
for l in layers:
|
||||||
|
lparams = []
|
||||||
|
for k, v in env['networks'][self.net].named_parameters(): # can optimize for a part of the model
|
||||||
|
if v.requires_grad:
|
||||||
|
lparams.append(v)
|
||||||
|
self.params.append(lparams)
|
||||||
|
|
||||||
|
def before_optimize(self, state):
|
||||||
|
self.cached_params = []
|
||||||
|
for l in self.params:
|
||||||
|
lparams = []
|
||||||
|
for p in l:
|
||||||
|
lparams.append(p.value().cpu())
|
||||||
|
self.cached_params.append(lparams)
|
||||||
|
|
||||||
|
def after_optimize(self, state):
|
||||||
|
# Compute the abs mean difference across the params.
|
||||||
|
self.layer_means = []
|
||||||
|
for l, lc in zip(self.params, self.cached_params):
|
||||||
|
sum = torch.tensor(0)
|
||||||
|
for p, pc in zip(l, lc):
|
||||||
|
sum += torch.abs(pc - p.value().cpu())
|
||||||
|
self.layer_means.append(sum / len(l))
|
||||||
|
|
||||||
|
def get_log_data(self):
|
||||||
|
return {'%s_layer_update_means_histogram' % (self.net,): self.layer_means}
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorParameterTracker(ModelParameterDepthTrackerMetrics):
|
||||||
|
def get_network_and_step_names(self):
|
||||||
|
return "feature_discriminator", "feature_discriminator"
|
||||||
|
|
||||||
|
def get_layers_to_debug(self, env, net, state):
|
||||||
|
return [net.ref_head.conv0_0,
|
||||||
|
net.ref_head.conv0_1,
|
||||||
|
net.ref_head.conv1_0,
|
||||||
|
net.ref_head.conv1_1,
|
||||||
|
net.ref_head.conv2_0,
|
||||||
|
net.ref_head.conv2_1,
|
||||||
|
net.ref_head.conv3_0,
|
||||||
|
net.ref_head.conv3_1,
|
||||||
|
net.ref_head.conv4_0,
|
||||||
|
net.ref_head.conv4_1,
|
||||||
|
net.linear1,
|
||||||
|
net.output_linears]
|
Loading…
Reference in New Issue
Block a user