Injector auto-registration

I love it!
This commit is contained in:
James Betker 2020-12-29 20:58:02 -07:00
parent a777c1e4f9
commit 63cf3d3126
9 changed files with 96 additions and 96 deletions

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from trainer.injectors import Injector
from trainer.inject import Injector
from trainer.networks import register_model
from utils.util import checkpoint

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import trainer.lr_scheduler as lr_scheduler
import trainer.networks as networks
from trainer.base_model import BaseModel
from trainer.injectors import create_injector
from trainer.inject import create_injector
from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils

View File

@ -6,7 +6,7 @@ import torchvision
from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map
from trainer.injectors import Injector
from trainer.inject import Injector
from trainer.losses import extract_params_from_state
import os.path as osp

View File

@ -2,7 +2,7 @@ import torch
from torch.cuda.amp import autocast
from models.flownet2.networks import Resample2d
from models.flownet2 import flow2img
from trainer.injectors import Injector
from trainer.inject import Injector
def create_stereoscopic_injector(opt, env):

View File

@ -3,7 +3,7 @@ from torch.cuda.amp import autocast
from models.stylegan.stylegan2_lucidrains import gradient_penalty
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks import Resample2d
from trainer.injectors import Injector
from trainer.inject import Injector
import torch
import torch.nn.functional as F
import os

64
codes/trainer/inject.py Normal file
View File

@ -0,0 +1,64 @@
import importlib
import inspect
import pkgutil
import re
import sys
import torch.nn
# Base class for all other injectors.
class Injector(torch.nn.Module):
def __init__(self, opt, env):
super(Injector, self).__init__()
self.opt = opt
self.env = env
if 'in' in opt.keys():
self.input = opt['in']
self.output = opt['out']
# This should return a dict of new state variables.
def forward(self, state):
raise NotImplementedError
def format_injector_name(name):
# Formats by converting from CamelCase to snake_case and removing trailing "_injector"
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
return name.replace("_injector", "")
# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__()
# field will be properly populated.
def find_registered_injectors(base_path="trainer/injectors"):
module_iter = pkgutil.walk_packages([base_path])
results = {}
for mod in module_iter:
if mod.ispkg:
EXCLUSION_LIST = []
if mod.name not in EXCLUSION_LIST:
results.update(find_registered_injectors(f'{base_path}/{mod.name}'))
else:
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
importlib.import_module(mod_name)
classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass)
for name, obj in classes:
if 'Injector' in [mro.__name__ for mro in inspect.getmro(obj)]:
results[format_injector_name(name)] = obj
return results
class CreateInjectorError(Exception):
def __init__(self, name, available):
super().__init__(f'Could not find the specified injector name: {name}. Available injectors:'
f'{available}')
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
injectors = find_registered_injectors()
type = opt_inject['type']
if type not in injectors.keys():
raise CreateInjectorError(type, list(injectors.keys()))
return injectors[opt_inject['type']](opt_inject, env)

View File

View File

@ -7,80 +7,14 @@ from torch.cuda.amp import autocast
from utils.weight_scheduler import get_scheduler_for_opt
from trainer.losses import extract_params_from_state
from trainer.inject import Injector
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
type = opt_inject['type']
if 'teco_' in type:
from trainer.custom_training_components import create_teco_injector
return create_teco_injector(opt_inject, env)
elif 'progressive_' in type:
from trainer.custom_training_components import create_progressive_zoom_injector
return create_progressive_zoom_injector(opt_inject, env)
elif 'stereoscopic_' in type:
from trainer.custom_training_components import create_stereoscopic_injector
return create_stereoscopic_injector(opt_inject, env)
elif 'igpt' in type:
from models.transformers.igpt import gpt2
return gpt2.create_injector(opt_inject, env)
elif type == 'generator':
return ImageGeneratorInjector(opt_inject, env)
elif type == 'discriminator':
return DiscriminatorInjector(opt_inject, env)
elif type == 'scheduled_scalar':
return ScheduledScalarInjector(opt_inject, env)
elif type == 'add_noise':
return AddNoiseInjector(opt_inject, env)
elif type == 'greyscale':
return GreyInjector(opt_inject, env)
elif type == 'interpolate':
return InterpolateInjector(opt_inject, env)
elif type == 'image_patch':
return ImagePatchInjector(opt_inject, env)
elif type == 'concatenate':
return ConcatenateInjector(opt_inject, env)
elif type == 'margin_removal':
return MarginRemoval(opt_inject, env)
elif type == 'foreach':
return ForEachInjector(opt_inject, env)
elif type == 'constant':
return ConstantInjector(opt_inject, env)
elif type == 'extract_indices':
return IndicesExtractor(opt_inject, env)
elif type == 'random_shift':
return RandomShiftInjector(opt_inject, env)
elif type == 'batch_rotate':
return BatchRotateInjector(opt_inject, env)
elif type == 'sr_diffs':
return SrDiffsInjector(opt_inject, env)
elif type == 'multiframe_combiner':
return MultiFrameCombiner(opt_inject, env)
elif type == 'mix_and_label':
return MixAndLabelInjector(opt_inject, env)
elif type == 'save_images':
return SaveImages(opt_inject, env)
else:
raise NotImplementedError
class Injector(torch.nn.Module):
def __init__(self, opt, env):
super(Injector, self).__init__()
self.opt = opt
self.env = env
if 'in' in opt.keys():
self.input = opt['in']
self.output = opt['out']
# This should return a dict of new state variables.
def forward(self, state):
raise NotImplementedError
# Uses a generator to synthesize an image from [in] and injects the results into [out]
# Note that results are *not* detached.
class ImageGeneratorInjector(Injector):
class GeneratorInjector(Injector):
def __init__(self, opt, env):
super(ImageGeneratorInjector, self).__init__(opt, env)
super(GeneratorInjector, self).__init__(opt, env)
self.grad = opt['grad'] if 'grad' in opt.keys() else True
def forward(self, state):
@ -209,22 +143,23 @@ class ImagePatchInjector(Injector):
def __init__(self, opt, env):
super(ImagePatchInjector, self).__init__(opt, env)
self.patch_size = opt['patch_size']
self.resize = opt['resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction.
self.resize = opt[
'resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction.
def forward(self, state):
im = state[self.opt['in']]
if self.env['training']:
res = { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] }
res = {self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:]}
else:
res = { self.opt['out']: im,
'%s_top_left' % (self.opt['out'],): im,
'%s_top_right' % (self.opt['out'],): im,
'%s_bottom_left' % (self.opt['out'],): im,
'%s_bottom_right' % (self.opt['out'],): im }
res = {self.opt['out']: im,
'%s_top_left' % (self.opt['out'],): im,
'%s_top_right' % (self.opt['out'],): im,
'%s_bottom_left' % (self.opt['out'],): im,
'%s_bottom_right' % (self.opt['out'],): im}
if self.resize is not None:
res2 = {}
for k, v in res.items():
@ -259,12 +194,12 @@ class MarginRemoval(Injector):
for b in range(input.shape[0]):
shiftleft = random.randint(-self.random_shift_max, self.random_shift_max)
shifttop = random.randint(-self.random_shift_max, self.random_shift_max)
output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft),
self.margin+shifttop:-(self.margin-shifttop)])
output.append(input[b, :, self.margin + shiftleft:-(self.margin - shiftleft),
self.margin + shifttop:-(self.margin - shifttop)])
output = torch.stack(output, dim=0)
else:
output = input[:, :, self.margin:-self.margin,
self.margin:-self.margin]
self.margin:-self.margin]
return {self.opt['out']: output}
@ -305,7 +240,7 @@ class ConstantInjector(Injector):
out = torch.zeros_like(like)
else:
raise NotImplementedError
return { self.opt['out']: out }
return {self.opt['out']: out}
class IndicesExtractor(Injector):
@ -387,17 +322,17 @@ class MultiFrameCombiner(Injector):
hq = state[self.in_hq_key]
b, f, c, h, w = lq.shape
center = f // 2
center_img = lq[:,center,:,:,:]
center_img = lq[:, center, :, :, :]
imgs = [center_img]
with torch.no_grad():
for i in range(f):
if i == center:
continue
nimg = lq[:,i,:,:,:]
nimg = lq[:, i, :, :, :]
flowfield = flow(torch.stack([center_img, nimg], dim=2).float())
nimg = self.resampler(nimg, flowfield)
imgs.append(nimg)
hq_out = hq[:,center,:,:,:]
hq_out = hq[:, center, :, :, :]
return {self.out_lq_key: torch.cat(imgs, dim=1),
self.out_hq_key: hq_out,
self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)}
@ -434,7 +369,7 @@ class MixAndLabelInjector(Injector):
for b in range(bs):
res.append(input_tensors[labels[b]][b, :, :, :])
output = torch.stack(res, dim=0)
return { self.out_labels: labels, self.output: output }
return {self.out_labels: labels, self.output: output}
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
@ -464,6 +399,7 @@ class SaveImages(Injector):
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
self.index += 1
elif self.rejectdir:
torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
torchvision.utils.save_image(images[b],
os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
self.rindex += 1
return {}

View File

@ -6,7 +6,7 @@ import logging
from trainer.losses import create_loss
import torch
from collections import OrderedDict
from trainer.injectors import create_injector
from trainer.inject import create_injector
from utils.util import recursively_detach
logger = logging.getLogger('base')