Injector auto-registration

I love it!
This commit is contained in:
James Betker 2020-12-29 20:58:02 -07:00
parent a777c1e4f9
commit 63cf3d3126
9 changed files with 96 additions and 96 deletions

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from trainer.injectors import Injector from trainer.inject import Injector
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import trainer.lr_scheduler as lr_scheduler import trainer.lr_scheduler as lr_scheduler
import trainer.networks as networks import trainer.networks as networks
from trainer.base_model import BaseModel from trainer.base_model import BaseModel
from trainer.injectors import create_injector from trainer.inject import create_injector
from trainer.steps import ConfigurableStep from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name from trainer.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils import torchvision.utils as utils

View File

@ -6,7 +6,7 @@ import torchvision
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map from data.multiscale_dataset import build_multiscale_patch_index_map
from trainer.injectors import Injector from trainer.inject import Injector
from trainer.losses import extract_params_from_state from trainer.losses import extract_params_from_state
import os.path as osp import os.path as osp

View File

@ -2,7 +2,7 @@ import torch
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.flownet2.networks import Resample2d from models.flownet2.networks import Resample2d
from models.flownet2 import flow2img from models.flownet2 import flow2img
from trainer.injectors import Injector from trainer.inject import Injector
def create_stereoscopic_injector(opt, env): def create_stereoscopic_injector(opt, env):

View File

@ -3,7 +3,7 @@ from torch.cuda.amp import autocast
from models.stylegan.stylegan2_lucidrains import gradient_penalty from models.stylegan.stylegan2_lucidrains import gradient_penalty
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks import Resample2d from models.flownet2.networks import Resample2d
from trainer.injectors import Injector from trainer.inject import Injector
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import os import os

64
codes/trainer/inject.py Normal file
View File

@ -0,0 +1,64 @@
import importlib
import inspect
import pkgutil
import re
import sys
import torch.nn
# Base class for all other injectors.
class Injector(torch.nn.Module):
def __init__(self, opt, env):
super(Injector, self).__init__()
self.opt = opt
self.env = env
if 'in' in opt.keys():
self.input = opt['in']
self.output = opt['out']
# This should return a dict of new state variables.
def forward(self, state):
raise NotImplementedError
def format_injector_name(name):
# Formats by converting from CamelCase to snake_case and removing trailing "_injector"
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
return name.replace("_injector", "")
# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__()
# field will be properly populated.
def find_registered_injectors(base_path="trainer/injectors"):
module_iter = pkgutil.walk_packages([base_path])
results = {}
for mod in module_iter:
if mod.ispkg:
EXCLUSION_LIST = []
if mod.name not in EXCLUSION_LIST:
results.update(find_registered_injectors(f'{base_path}/{mod.name}'))
else:
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
importlib.import_module(mod_name)
classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass)
for name, obj in classes:
if 'Injector' in [mro.__name__ for mro in inspect.getmro(obj)]:
results[format_injector_name(name)] = obj
return results
class CreateInjectorError(Exception):
def __init__(self, name, available):
super().__init__(f'Could not find the specified injector name: {name}. Available injectors:'
f'{available}')
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
injectors = find_registered_injectors()
type = opt_inject['type']
if type not in injectors.keys():
raise CreateInjectorError(type, list(injectors.keys()))
return injectors[opt_inject['type']](opt_inject, env)

View File

View File

@ -7,80 +7,14 @@ from torch.cuda.amp import autocast
from utils.weight_scheduler import get_scheduler_for_opt from utils.weight_scheduler import get_scheduler_for_opt
from trainer.losses import extract_params_from_state from trainer.losses import extract_params_from_state
from trainer.inject import Injector
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
type = opt_inject['type']
if 'teco_' in type:
from trainer.custom_training_components import create_teco_injector
return create_teco_injector(opt_inject, env)
elif 'progressive_' in type:
from trainer.custom_training_components import create_progressive_zoom_injector
return create_progressive_zoom_injector(opt_inject, env)
elif 'stereoscopic_' in type:
from trainer.custom_training_components import create_stereoscopic_injector
return create_stereoscopic_injector(opt_inject, env)
elif 'igpt' in type:
from models.transformers.igpt import gpt2
return gpt2.create_injector(opt_inject, env)
elif type == 'generator':
return ImageGeneratorInjector(opt_inject, env)
elif type == 'discriminator':
return DiscriminatorInjector(opt_inject, env)
elif type == 'scheduled_scalar':
return ScheduledScalarInjector(opt_inject, env)
elif type == 'add_noise':
return AddNoiseInjector(opt_inject, env)
elif type == 'greyscale':
return GreyInjector(opt_inject, env)
elif type == 'interpolate':
return InterpolateInjector(opt_inject, env)
elif type == 'image_patch':
return ImagePatchInjector(opt_inject, env)
elif type == 'concatenate':
return ConcatenateInjector(opt_inject, env)
elif type == 'margin_removal':
return MarginRemoval(opt_inject, env)
elif type == 'foreach':
return ForEachInjector(opt_inject, env)
elif type == 'constant':
return ConstantInjector(opt_inject, env)
elif type == 'extract_indices':
return IndicesExtractor(opt_inject, env)
elif type == 'random_shift':
return RandomShiftInjector(opt_inject, env)
elif type == 'batch_rotate':
return BatchRotateInjector(opt_inject, env)
elif type == 'sr_diffs':
return SrDiffsInjector(opt_inject, env)
elif type == 'multiframe_combiner':
return MultiFrameCombiner(opt_inject, env)
elif type == 'mix_and_label':
return MixAndLabelInjector(opt_inject, env)
elif type == 'save_images':
return SaveImages(opt_inject, env)
else:
raise NotImplementedError
class Injector(torch.nn.Module):
def __init__(self, opt, env):
super(Injector, self).__init__()
self.opt = opt
self.env = env
if 'in' in opt.keys():
self.input = opt['in']
self.output = opt['out']
# This should return a dict of new state variables.
def forward(self, state):
raise NotImplementedError
# Uses a generator to synthesize an image from [in] and injects the results into [out] # Uses a generator to synthesize an image from [in] and injects the results into [out]
# Note that results are *not* detached. # Note that results are *not* detached.
class ImageGeneratorInjector(Injector): class GeneratorInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super(ImageGeneratorInjector, self).__init__(opt, env) super(GeneratorInjector, self).__init__(opt, env)
self.grad = opt['grad'] if 'grad' in opt.keys() else True self.grad = opt['grad'] if 'grad' in opt.keys() else True
def forward(self, state): def forward(self, state):
@ -209,22 +143,23 @@ class ImagePatchInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super(ImagePatchInjector, self).__init__(opt, env) super(ImagePatchInjector, self).__init__(opt, env)
self.patch_size = opt['patch_size'] self.patch_size = opt['patch_size']
self.resize = opt['resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction. self.resize = opt[
'resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction.
def forward(self, state): def forward(self, state):
im = state[self.opt['in']] im = state[self.opt['in']]
if self.env['training']: if self.env['training']:
res = { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size], res = {self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size], '%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:], '%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size], '%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] } '%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:]}
else: else:
res = { self.opt['out']: im, res = {self.opt['out']: im,
'%s_top_left' % (self.opt['out'],): im, '%s_top_left' % (self.opt['out'],): im,
'%s_top_right' % (self.opt['out'],): im, '%s_top_right' % (self.opt['out'],): im,
'%s_bottom_left' % (self.opt['out'],): im, '%s_bottom_left' % (self.opt['out'],): im,
'%s_bottom_right' % (self.opt['out'],): im } '%s_bottom_right' % (self.opt['out'],): im}
if self.resize is not None: if self.resize is not None:
res2 = {} res2 = {}
for k, v in res.items(): for k, v in res.items():
@ -259,12 +194,12 @@ class MarginRemoval(Injector):
for b in range(input.shape[0]): for b in range(input.shape[0]):
shiftleft = random.randint(-self.random_shift_max, self.random_shift_max) shiftleft = random.randint(-self.random_shift_max, self.random_shift_max)
shifttop = random.randint(-self.random_shift_max, self.random_shift_max) shifttop = random.randint(-self.random_shift_max, self.random_shift_max)
output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft), output.append(input[b, :, self.margin + shiftleft:-(self.margin - shiftleft),
self.margin+shifttop:-(self.margin-shifttop)]) self.margin + shifttop:-(self.margin - shifttop)])
output = torch.stack(output, dim=0) output = torch.stack(output, dim=0)
else: else:
output = input[:, :, self.margin:-self.margin, output = input[:, :, self.margin:-self.margin,
self.margin:-self.margin] self.margin:-self.margin]
return {self.opt['out']: output} return {self.opt['out']: output}
@ -292,7 +227,7 @@ class ForEachInjector(Injector):
else: else:
return {self.output: torch.stack(injs, dim=1)} return {self.output: torch.stack(injs, dim=1)}
class ConstantInjector(Injector): class ConstantInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super(ConstantInjector, self).__init__(opt, env) super(ConstantInjector, self).__init__(opt, env)
@ -305,7 +240,7 @@ class ConstantInjector(Injector):
out = torch.zeros_like(like) out = torch.zeros_like(like)
else: else:
raise NotImplementedError raise NotImplementedError
return { self.opt['out']: out } return {self.opt['out']: out}
class IndicesExtractor(Injector): class IndicesExtractor(Injector):
@ -387,17 +322,17 @@ class MultiFrameCombiner(Injector):
hq = state[self.in_hq_key] hq = state[self.in_hq_key]
b, f, c, h, w = lq.shape b, f, c, h, w = lq.shape
center = f // 2 center = f // 2
center_img = lq[:,center,:,:,:] center_img = lq[:, center, :, :, :]
imgs = [center_img] imgs = [center_img]
with torch.no_grad(): with torch.no_grad():
for i in range(f): for i in range(f):
if i == center: if i == center:
continue continue
nimg = lq[:,i,:,:,:] nimg = lq[:, i, :, :, :]
flowfield = flow(torch.stack([center_img, nimg], dim=2).float()) flowfield = flow(torch.stack([center_img, nimg], dim=2).float())
nimg = self.resampler(nimg, flowfield) nimg = self.resampler(nimg, flowfield)
imgs.append(nimg) imgs.append(nimg)
hq_out = hq[:,center,:,:,:] hq_out = hq[:, center, :, :, :]
return {self.out_lq_key: torch.cat(imgs, dim=1), return {self.out_lq_key: torch.cat(imgs, dim=1),
self.out_hq_key: hq_out, self.out_hq_key: hq_out,
self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)} self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)}
@ -434,7 +369,7 @@ class MixAndLabelInjector(Injector):
for b in range(bs): for b in range(bs):
res.append(input_tensors[labels[b]][b, :, :, :]) res.append(input_tensors[labels[b]][b, :, :, :])
output = torch.stack(res, dim=0) output = torch.stack(res, dim=0)
return { self.out_labels: labels, self.output: output } return {self.out_labels: labels, self.output: output}
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering # Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
@ -464,6 +399,7 @@ class SaveImages(Injector):
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg')) torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
self.index += 1 self.index += 1
elif self.rejectdir: elif self.rejectdir:
torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg')) torchvision.utils.save_image(images[b],
os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
self.rindex += 1 self.rindex += 1
return {} return {}

View File

@ -6,7 +6,7 @@ import logging
from trainer.losses import create_loss from trainer.losses import create_loss
import torch import torch
from collections import OrderedDict from collections import OrderedDict
from trainer.injectors import create_injector from trainer.inject import create_injector
from utils.util import recursively_detach from utils.util import recursively_detach
logger = logging.getLogger('base') logger = logging.getLogger('base')