From 63cf3d31268b53377b881ca53585c439746cf883 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 29 Dec 2020 20:58:02 -0700 Subject: [PATCH] Injector auto-registration I love it! --- codes/models/transformers/igpt/gpt2.py | 2 +- codes/trainer/ExtensibleTrainer.py | 2 +- .../progressive_zoom.py | 2 +- .../stereoscopic.py | 2 +- .../tecogan_losses.py | 2 +- codes/trainer/inject.py | 64 ++++++++++ codes/trainer/injectors/__init__.py | 0 .../base_injectors.py} | 116 ++++-------------- codes/trainer/steps.py | 2 +- 9 files changed, 96 insertions(+), 96 deletions(-) create mode 100644 codes/trainer/inject.py create mode 100644 codes/trainer/injectors/__init__.py rename codes/trainer/{injectors.py => injectors/base_injectors.py} (76%) diff --git a/codes/models/transformers/igpt/gpt2.py b/codes/models/transformers/igpt/gpt2.py index 21a7e4fd..5e0d0eeb 100644 --- a/codes/models/transformers/igpt/gpt2.py +++ b/codes/models/transformers/igpt/gpt2.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -from trainer.injectors import Injector +from trainer.inject import Injector from trainer.networks import register_model from utils.util import checkpoint diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 8298091f..a3eb3a69 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -8,7 +8,7 @@ import torch.nn as nn import trainer.lr_scheduler as lr_scheduler import trainer.networks as networks from trainer.base_model import BaseModel -from trainer.injectors import create_injector +from trainer.inject import create_injector from trainer.steps import ConfigurableStep from trainer.experiments.experiments import get_experiment_for_name import torchvision.utils as utils diff --git a/codes/trainer/custom_training_components/progressive_zoom.py b/codes/trainer/custom_training_components/progressive_zoom.py index 8ef19f63..63f0f980 100644 --- a/codes/trainer/custom_training_components/progressive_zoom.py +++ b/codes/trainer/custom_training_components/progressive_zoom.py @@ -6,7 +6,7 @@ import torchvision from torch.cuda.amp import autocast from data.multiscale_dataset import build_multiscale_patch_index_map -from trainer.injectors import Injector +from trainer.inject import Injector from trainer.losses import extract_params_from_state import os.path as osp diff --git a/codes/trainer/custom_training_components/stereoscopic.py b/codes/trainer/custom_training_components/stereoscopic.py index 95b7945f..d2b923bf 100644 --- a/codes/trainer/custom_training_components/stereoscopic.py +++ b/codes/trainer/custom_training_components/stereoscopic.py @@ -2,7 +2,7 @@ import torch from torch.cuda.amp import autocast from models.flownet2.networks import Resample2d from models.flownet2 import flow2img -from trainer.injectors import Injector +from trainer.inject import Injector def create_stereoscopic_injector(opt, env): diff --git a/codes/trainer/custom_training_components/tecogan_losses.py b/codes/trainer/custom_training_components/tecogan_losses.py index d093240f..31154138 100644 --- a/codes/trainer/custom_training_components/tecogan_losses.py +++ b/codes/trainer/custom_training_components/tecogan_losses.py @@ -3,7 +3,7 @@ from torch.cuda.amp import autocast from models.stylegan.stylegan2_lucidrains import gradient_penalty from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.flownet2.networks import Resample2d -from trainer.injectors import Injector +from trainer.inject import Injector import torch import torch.nn.functional as F import os diff --git a/codes/trainer/inject.py b/codes/trainer/inject.py new file mode 100644 index 00000000..56ed4026 --- /dev/null +++ b/codes/trainer/inject.py @@ -0,0 +1,64 @@ +import importlib +import inspect +import pkgutil +import re +import sys + +import torch.nn + + +# Base class for all other injectors. +class Injector(torch.nn.Module): + def __init__(self, opt, env): + super(Injector, self).__init__() + self.opt = opt + self.env = env + if 'in' in opt.keys(): + self.input = opt['in'] + self.output = opt['out'] + + # This should return a dict of new state variables. + def forward(self, state): + raise NotImplementedError + + +def format_injector_name(name): + # Formats by converting from CamelCase to snake_case and removing trailing "_injector" + name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() + return name.replace("_injector", "") + + +# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__() +# field will be properly populated. +def find_registered_injectors(base_path="trainer/injectors"): + module_iter = pkgutil.walk_packages([base_path]) + results = {} + for mod in module_iter: + if mod.ispkg: + EXCLUSION_LIST = [] + if mod.name not in EXCLUSION_LIST: + results.update(find_registered_injectors(f'{base_path}/{mod.name}')) + else: + mod_name = f'{base_path}/{mod.name}'.replace('/', '.') + importlib.import_module(mod_name) + classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass) + for name, obj in classes: + if 'Injector' in [mro.__name__ for mro in inspect.getmro(obj)]: + results[format_injector_name(name)] = obj + return results + + +class CreateInjectorError(Exception): + def __init__(self, name, available): + super().__init__(f'Could not find the specified injector name: {name}. Available injectors:' + f'{available}') + + +# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. +def create_injector(opt_inject, env): + injectors = find_registered_injectors() + type = opt_inject['type'] + if type not in injectors.keys(): + raise CreateInjectorError(type, list(injectors.keys())) + return injectors[opt_inject['type']](opt_inject, env) \ No newline at end of file diff --git a/codes/trainer/injectors/__init__.py b/codes/trainer/injectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/trainer/injectors.py b/codes/trainer/injectors/base_injectors.py similarity index 76% rename from codes/trainer/injectors.py rename to codes/trainer/injectors/base_injectors.py index 98c2b43f..bdfdb272 100644 --- a/codes/trainer/injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -7,80 +7,14 @@ from torch.cuda.amp import autocast from utils.weight_scheduler import get_scheduler_for_opt from trainer.losses import extract_params_from_state +from trainer.inject import Injector -# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. -def create_injector(opt_inject, env): - type = opt_inject['type'] - if 'teco_' in type: - from trainer.custom_training_components import create_teco_injector - return create_teco_injector(opt_inject, env) - elif 'progressive_' in type: - from trainer.custom_training_components import create_progressive_zoom_injector - return create_progressive_zoom_injector(opt_inject, env) - elif 'stereoscopic_' in type: - from trainer.custom_training_components import create_stereoscopic_injector - return create_stereoscopic_injector(opt_inject, env) - elif 'igpt' in type: - from models.transformers.igpt import gpt2 - return gpt2.create_injector(opt_inject, env) - elif type == 'generator': - return ImageGeneratorInjector(opt_inject, env) - elif type == 'discriminator': - return DiscriminatorInjector(opt_inject, env) - elif type == 'scheduled_scalar': - return ScheduledScalarInjector(opt_inject, env) - elif type == 'add_noise': - return AddNoiseInjector(opt_inject, env) - elif type == 'greyscale': - return GreyInjector(opt_inject, env) - elif type == 'interpolate': - return InterpolateInjector(opt_inject, env) - elif type == 'image_patch': - return ImagePatchInjector(opt_inject, env) - elif type == 'concatenate': - return ConcatenateInjector(opt_inject, env) - elif type == 'margin_removal': - return MarginRemoval(opt_inject, env) - elif type == 'foreach': - return ForEachInjector(opt_inject, env) - elif type == 'constant': - return ConstantInjector(opt_inject, env) - elif type == 'extract_indices': - return IndicesExtractor(opt_inject, env) - elif type == 'random_shift': - return RandomShiftInjector(opt_inject, env) - elif type == 'batch_rotate': - return BatchRotateInjector(opt_inject, env) - elif type == 'sr_diffs': - return SrDiffsInjector(opt_inject, env) - elif type == 'multiframe_combiner': - return MultiFrameCombiner(opt_inject, env) - elif type == 'mix_and_label': - return MixAndLabelInjector(opt_inject, env) - elif type == 'save_images': - return SaveImages(opt_inject, env) - else: - raise NotImplementedError - - -class Injector(torch.nn.Module): - def __init__(self, opt, env): - super(Injector, self).__init__() - self.opt = opt - self.env = env - if 'in' in opt.keys(): - self.input = opt['in'] - self.output = opt['out'] - - # This should return a dict of new state variables. - def forward(self, state): - raise NotImplementedError # Uses a generator to synthesize an image from [in] and injects the results into [out] # Note that results are *not* detached. -class ImageGeneratorInjector(Injector): +class GeneratorInjector(Injector): def __init__(self, opt, env): - super(ImageGeneratorInjector, self).__init__(opt, env) + super(GeneratorInjector, self).__init__(opt, env) self.grad = opt['grad'] if 'grad' in opt.keys() else True def forward(self, state): @@ -209,22 +143,23 @@ class ImagePatchInjector(Injector): def __init__(self, opt, env): super(ImagePatchInjector, self).__init__(opt, env) self.patch_size = opt['patch_size'] - self.resize = opt['resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction. + self.resize = opt[ + 'resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction. def forward(self, state): im = state[self.opt['in']] if self.env['training']: - res = { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size], - '%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size], - '%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:], - '%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size], - '%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] } + res = {self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size], + '%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size], + '%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:], + '%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size], + '%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:]} else: - res = { self.opt['out']: im, - '%s_top_left' % (self.opt['out'],): im, - '%s_top_right' % (self.opt['out'],): im, - '%s_bottom_left' % (self.opt['out'],): im, - '%s_bottom_right' % (self.opt['out'],): im } + res = {self.opt['out']: im, + '%s_top_left' % (self.opt['out'],): im, + '%s_top_right' % (self.opt['out'],): im, + '%s_bottom_left' % (self.opt['out'],): im, + '%s_bottom_right' % (self.opt['out'],): im} if self.resize is not None: res2 = {} for k, v in res.items(): @@ -259,12 +194,12 @@ class MarginRemoval(Injector): for b in range(input.shape[0]): shiftleft = random.randint(-self.random_shift_max, self.random_shift_max) shifttop = random.randint(-self.random_shift_max, self.random_shift_max) - output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft), - self.margin+shifttop:-(self.margin-shifttop)]) + output.append(input[b, :, self.margin + shiftleft:-(self.margin - shiftleft), + self.margin + shifttop:-(self.margin - shifttop)]) output = torch.stack(output, dim=0) else: output = input[:, :, self.margin:-self.margin, - self.margin:-self.margin] + self.margin:-self.margin] return {self.opt['out']: output} @@ -292,7 +227,7 @@ class ForEachInjector(Injector): else: return {self.output: torch.stack(injs, dim=1)} - + class ConstantInjector(Injector): def __init__(self, opt, env): super(ConstantInjector, self).__init__(opt, env) @@ -305,7 +240,7 @@ class ConstantInjector(Injector): out = torch.zeros_like(like) else: raise NotImplementedError - return { self.opt['out']: out } + return {self.opt['out']: out} class IndicesExtractor(Injector): @@ -387,17 +322,17 @@ class MultiFrameCombiner(Injector): hq = state[self.in_hq_key] b, f, c, h, w = lq.shape center = f // 2 - center_img = lq[:,center,:,:,:] + center_img = lq[:, center, :, :, :] imgs = [center_img] with torch.no_grad(): for i in range(f): if i == center: continue - nimg = lq[:,i,:,:,:] + nimg = lq[:, i, :, :, :] flowfield = flow(torch.stack([center_img, nimg], dim=2).float()) nimg = self.resampler(nimg, flowfield) imgs.append(nimg) - hq_out = hq[:,center,:,:,:] + hq_out = hq[:, center, :, :, :] return {self.out_lq_key: torch.cat(imgs, dim=1), self.out_hq_key: hq_out, self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)} @@ -434,7 +369,7 @@ class MixAndLabelInjector(Injector): for b in range(bs): res.append(input_tensors[labels[b]][b, :, :, :]) output = torch.stack(res, dim=0) - return { self.out_labels: labels, self.output: output } + return {self.out_labels: labels, self.output: output} # Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering @@ -464,6 +399,7 @@ class SaveImages(Injector): torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg')) self.index += 1 elif self.rejectdir: - torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg')) + torchvision.utils.save_image(images[b], + os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg')) self.rindex += 1 return {} \ No newline at end of file diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 5f09038b..a0704d9f 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -6,7 +6,7 @@ import logging from trainer.losses import create_loss import torch from collections import OrderedDict -from trainer.injectors import create_injector +from trainer.inject import create_injector from utils.util import recursively_detach logger = logging.getLogger('base')