Injector auto-registration
I love it!
This commit is contained in:
parent
a777c1e4f9
commit
63cf3d3126
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from trainer.injectors import Injector
|
||||
from trainer.inject import Injector
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
import trainer.lr_scheduler as lr_scheduler
|
||||
import trainer.networks as networks
|
||||
from trainer.base_model import BaseModel
|
||||
from trainer.injectors import create_injector
|
||||
from trainer.inject import create_injector
|
||||
from trainer.steps import ConfigurableStep
|
||||
from trainer.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
|
|
@ -6,7 +6,7 @@ import torchvision
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from data.multiscale_dataset import build_multiscale_patch_index_map
|
||||
from trainer.injectors import Injector
|
||||
from trainer.inject import Injector
|
||||
from trainer.losses import extract_params_from_state
|
||||
import os.path as osp
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch.cuda.amp import autocast
|
||||
from models.flownet2.networks import Resample2d
|
||||
from models.flownet2 import flow2img
|
||||
from trainer.injectors import Injector
|
||||
from trainer.inject import Injector
|
||||
|
||||
|
||||
def create_stereoscopic_injector(opt, env):
|
||||
|
|
|
@ -3,7 +3,7 @@ from torch.cuda.amp import autocast
|
|||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.flownet2.networks import Resample2d
|
||||
from trainer.injectors import Injector
|
||||
from trainer.inject import Injector
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
|
|
64
codes/trainer/inject.py
Normal file
64
codes/trainer/inject.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
||||
# Base class for all other injectors.
|
||||
class Injector(torch.nn.Module):
|
||||
def __init__(self, opt, env):
|
||||
super(Injector, self).__init__()
|
||||
self.opt = opt
|
||||
self.env = env
|
||||
if 'in' in opt.keys():
|
||||
self.input = opt['in']
|
||||
self.output = opt['out']
|
||||
|
||||
# This should return a dict of new state variables.
|
||||
def forward(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def format_injector_name(name):
|
||||
# Formats by converting from CamelCase to snake_case and removing trailing "_injector"
|
||||
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
|
||||
return name.replace("_injector", "")
|
||||
|
||||
|
||||
# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__()
|
||||
# field will be properly populated.
|
||||
def find_registered_injectors(base_path="trainer/injectors"):
|
||||
module_iter = pkgutil.walk_packages([base_path])
|
||||
results = {}
|
||||
for mod in module_iter:
|
||||
if mod.ispkg:
|
||||
EXCLUSION_LIST = []
|
||||
if mod.name not in EXCLUSION_LIST:
|
||||
results.update(find_registered_injectors(f'{base_path}/{mod.name}'))
|
||||
else:
|
||||
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
|
||||
importlib.import_module(mod_name)
|
||||
classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass)
|
||||
for name, obj in classes:
|
||||
if 'Injector' in [mro.__name__ for mro in inspect.getmro(obj)]:
|
||||
results[format_injector_name(name)] = obj
|
||||
return results
|
||||
|
||||
|
||||
class CreateInjectorError(Exception):
|
||||
def __init__(self, name, available):
|
||||
super().__init__(f'Could not find the specified injector name: {name}. Available injectors:'
|
||||
f'{available}')
|
||||
|
||||
|
||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||
def create_injector(opt_inject, env):
|
||||
injectors = find_registered_injectors()
|
||||
type = opt_inject['type']
|
||||
if type not in injectors.keys():
|
||||
raise CreateInjectorError(type, list(injectors.keys()))
|
||||
return injectors[opt_inject['type']](opt_inject, env)
|
0
codes/trainer/injectors/__init__.py
Normal file
0
codes/trainer/injectors/__init__.py
Normal file
|
@ -7,80 +7,14 @@ from torch.cuda.amp import autocast
|
|||
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
from trainer.losses import extract_params_from_state
|
||||
from trainer.inject import Injector
|
||||
|
||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||
def create_injector(opt_inject, env):
|
||||
type = opt_inject['type']
|
||||
if 'teco_' in type:
|
||||
from trainer.custom_training_components import create_teco_injector
|
||||
return create_teco_injector(opt_inject, env)
|
||||
elif 'progressive_' in type:
|
||||
from trainer.custom_training_components import create_progressive_zoom_injector
|
||||
return create_progressive_zoom_injector(opt_inject, env)
|
||||
elif 'stereoscopic_' in type:
|
||||
from trainer.custom_training_components import create_stereoscopic_injector
|
||||
return create_stereoscopic_injector(opt_inject, env)
|
||||
elif 'igpt' in type:
|
||||
from models.transformers.igpt import gpt2
|
||||
return gpt2.create_injector(opt_inject, env)
|
||||
elif type == 'generator':
|
||||
return ImageGeneratorInjector(opt_inject, env)
|
||||
elif type == 'discriminator':
|
||||
return DiscriminatorInjector(opt_inject, env)
|
||||
elif type == 'scheduled_scalar':
|
||||
return ScheduledScalarInjector(opt_inject, env)
|
||||
elif type == 'add_noise':
|
||||
return AddNoiseInjector(opt_inject, env)
|
||||
elif type == 'greyscale':
|
||||
return GreyInjector(opt_inject, env)
|
||||
elif type == 'interpolate':
|
||||
return InterpolateInjector(opt_inject, env)
|
||||
elif type == 'image_patch':
|
||||
return ImagePatchInjector(opt_inject, env)
|
||||
elif type == 'concatenate':
|
||||
return ConcatenateInjector(opt_inject, env)
|
||||
elif type == 'margin_removal':
|
||||
return MarginRemoval(opt_inject, env)
|
||||
elif type == 'foreach':
|
||||
return ForEachInjector(opt_inject, env)
|
||||
elif type == 'constant':
|
||||
return ConstantInjector(opt_inject, env)
|
||||
elif type == 'extract_indices':
|
||||
return IndicesExtractor(opt_inject, env)
|
||||
elif type == 'random_shift':
|
||||
return RandomShiftInjector(opt_inject, env)
|
||||
elif type == 'batch_rotate':
|
||||
return BatchRotateInjector(opt_inject, env)
|
||||
elif type == 'sr_diffs':
|
||||
return SrDiffsInjector(opt_inject, env)
|
||||
elif type == 'multiframe_combiner':
|
||||
return MultiFrameCombiner(opt_inject, env)
|
||||
elif type == 'mix_and_label':
|
||||
return MixAndLabelInjector(opt_inject, env)
|
||||
elif type == 'save_images':
|
||||
return SaveImages(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Injector(torch.nn.Module):
|
||||
def __init__(self, opt, env):
|
||||
super(Injector, self).__init__()
|
||||
self.opt = opt
|
||||
self.env = env
|
||||
if 'in' in opt.keys():
|
||||
self.input = opt['in']
|
||||
self.output = opt['out']
|
||||
|
||||
# This should return a dict of new state variables.
|
||||
def forward(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
||||
# Note that results are *not* detached.
|
||||
class ImageGeneratorInjector(Injector):
|
||||
class GeneratorInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImageGeneratorInjector, self).__init__(opt, env)
|
||||
super(GeneratorInjector, self).__init__(opt, env)
|
||||
self.grad = opt['grad'] if 'grad' in opt.keys() else True
|
||||
|
||||
def forward(self, state):
|
||||
|
@ -209,22 +143,23 @@ class ImagePatchInjector(Injector):
|
|||
def __init__(self, opt, env):
|
||||
super(ImagePatchInjector, self).__init__(opt, env)
|
||||
self.patch_size = opt['patch_size']
|
||||
self.resize = opt['resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction.
|
||||
self.resize = opt[
|
||||
'resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction.
|
||||
|
||||
def forward(self, state):
|
||||
im = state[self.opt['in']]
|
||||
if self.env['training']:
|
||||
res = { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
|
||||
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
|
||||
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
|
||||
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
|
||||
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] }
|
||||
res = {self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
|
||||
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
|
||||
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
|
||||
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
|
||||
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:]}
|
||||
else:
|
||||
res = { self.opt['out']: im,
|
||||
'%s_top_left' % (self.opt['out'],): im,
|
||||
'%s_top_right' % (self.opt['out'],): im,
|
||||
'%s_bottom_left' % (self.opt['out'],): im,
|
||||
'%s_bottom_right' % (self.opt['out'],): im }
|
||||
res = {self.opt['out']: im,
|
||||
'%s_top_left' % (self.opt['out'],): im,
|
||||
'%s_top_right' % (self.opt['out'],): im,
|
||||
'%s_bottom_left' % (self.opt['out'],): im,
|
||||
'%s_bottom_right' % (self.opt['out'],): im}
|
||||
if self.resize is not None:
|
||||
res2 = {}
|
||||
for k, v in res.items():
|
||||
|
@ -259,12 +194,12 @@ class MarginRemoval(Injector):
|
|||
for b in range(input.shape[0]):
|
||||
shiftleft = random.randint(-self.random_shift_max, self.random_shift_max)
|
||||
shifttop = random.randint(-self.random_shift_max, self.random_shift_max)
|
||||
output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft),
|
||||
self.margin+shifttop:-(self.margin-shifttop)])
|
||||
output.append(input[b, :, self.margin + shiftleft:-(self.margin - shiftleft),
|
||||
self.margin + shifttop:-(self.margin - shifttop)])
|
||||
output = torch.stack(output, dim=0)
|
||||
else:
|
||||
output = input[:, :, self.margin:-self.margin,
|
||||
self.margin:-self.margin]
|
||||
self.margin:-self.margin]
|
||||
|
||||
return {self.opt['out']: output}
|
||||
|
||||
|
@ -305,7 +240,7 @@ class ConstantInjector(Injector):
|
|||
out = torch.zeros_like(like)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return { self.opt['out']: out }
|
||||
return {self.opt['out']: out}
|
||||
|
||||
|
||||
class IndicesExtractor(Injector):
|
||||
|
@ -387,17 +322,17 @@ class MultiFrameCombiner(Injector):
|
|||
hq = state[self.in_hq_key]
|
||||
b, f, c, h, w = lq.shape
|
||||
center = f // 2
|
||||
center_img = lq[:,center,:,:,:]
|
||||
center_img = lq[:, center, :, :, :]
|
||||
imgs = [center_img]
|
||||
with torch.no_grad():
|
||||
for i in range(f):
|
||||
if i == center:
|
||||
continue
|
||||
nimg = lq[:,i,:,:,:]
|
||||
nimg = lq[:, i, :, :, :]
|
||||
flowfield = flow(torch.stack([center_img, nimg], dim=2).float())
|
||||
nimg = self.resampler(nimg, flowfield)
|
||||
imgs.append(nimg)
|
||||
hq_out = hq[:,center,:,:,:]
|
||||
hq_out = hq[:, center, :, :, :]
|
||||
return {self.out_lq_key: torch.cat(imgs, dim=1),
|
||||
self.out_hq_key: hq_out,
|
||||
self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)}
|
||||
|
@ -434,7 +369,7 @@ class MixAndLabelInjector(Injector):
|
|||
for b in range(bs):
|
||||
res.append(input_tensors[labels[b]][b, :, :, :])
|
||||
output = torch.stack(res, dim=0)
|
||||
return { self.out_labels: labels, self.output: output }
|
||||
return {self.out_labels: labels, self.output: output}
|
||||
|
||||
|
||||
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
|
||||
|
@ -464,6 +399,7 @@ class SaveImages(Injector):
|
|||
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
||||
self.index += 1
|
||||
elif self.rejectdir:
|
||||
torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
|
||||
torchvision.utils.save_image(images[b],
|
||||
os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
|
||||
self.rindex += 1
|
||||
return {}
|
|
@ -6,7 +6,7 @@ import logging
|
|||
from trainer.losses import create_loss
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from trainer.injectors import create_injector
|
||||
from trainer.inject import create_injector
|
||||
from utils.util import recursively_detach
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
|
Loading…
Reference in New Issue
Block a user