DL-Art-School/codes/utils/numeric_stability.py
James Betker 5f2c722a10 SRG2 revival
Big update to SRG2 architecture to pull in a lot of things that have been learned:
- Use group norm instead of batch norm
- Initialize the weights on the transformations low like is done in RRDB rather than using the scalar. Models live or die by their early stages, and this ones early stage is pretty weak
- Transform multiplexer to use u-net like architecture.
- Just use one set of configuration variables instead of a list - flat networks performed fine in this regard.
2020-07-09 17:34:51 -06:00

132 lines
5.6 KiB
Python

import torch
from torch import nn
import models.archs.SRG1_arch as srg1
import models.archs.SwitchedResidualGenerator_arch as srg
import models.archs.NestedSwitchGenerator as nsg
import functools
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
def install_forward_trace_hooks(module, id="base"):
if type(module) in blacklisted_modules:
return
module.register_forward_hook(functools.partial(inject_input_shapes, mod_id=id))
for name, m in module.named_children():
cid = "%s:%s" % (id, name)
install_forward_trace_hooks(m, cid)
def inject_input_shapes(module: nn.Module, inputs, outputs, mod_id: str):
if len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
# Only single tensor inputs currently supported. TODO: fix.
module._input_shape = inputs[0].shape
def extract_input_shapes(module, id="base"):
shapes = {}
if hasattr(module, "_input_shape"):
shapes[id] = module._input_shape
for n, m in module.named_children():
cid = "%s:%s" % (id, n)
shapes.update(extract_input_shapes(m, cid))
return shapes
def test_stability(mod_fn, dummy_inputs, device='cuda'):
base_module = mod_fn().to(device)
dummy_inputs = dummy_inputs.to(device)
install_forward_trace_hooks(base_module)
base_module(dummy_inputs)
input_shapes = extract_input_shapes(base_module)
means = {}
stds = {}
for i in range(20):
mod = mod_fn().to(device)
t_means, t_stds = test_stability_per_module(mod, input_shapes, device)
for k in t_means.keys():
if k not in means.keys():
means[k] = []
stds[k] = []
means[k].extend(t_means[k])
stds[k].extend(t_stds[k])
for k in means.keys():
print("%s - mean: %f std: %f" % (k, torch.mean(torch.stack(means[k])),
torch.mean(torch.stack(stds[k]))))
def test_stability_per_module(mod: nn.Module, input_shapes: dict, device='cuda', id="base"):
means = {}
stds = {}
if id in input_shapes.keys():
format = input_shapes[id]
mean, std = test_numeric_stability(mod, format, 1, device)
means[id] = mean
stds[id] = std
for name, child in mod.named_children():
cid = "%s:%s" % (id, name)
m, s = test_stability_per_module(child, input_shapes, device=device, id=cid)
means.update(m)
stds.update(s)
return means, stds
def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda'):
x = torch.randn(format).to(device)
means = []
stds = []
with torch.no_grad():
for i in range(iterations):
x = mod(x)[0]
measure = x
means.append(torch.mean(measure).detach())
stds.append(torch.std(measure).detach())
return torch.stack(means), torch.stack(stds)
'''
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
'''
if __name__ == "__main__":
'''
test_stability(functools.partial(nsg.NestedSwitchedGenerator,
switch_filters=64,
switch_reductions=[3,3,3,3,3],
switch_processing_layers=[1,1,1,1,1],
trans_counts=[3,3,3,3,3],
trans_kernel_sizes=[3,3,3,3,3],
trans_layers=[3,3,3,3,3],
transformation_filters=64,
initial_temp=10),
torch.randn(1, 3, 64, 64),
device='cuda')
'''
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
switch_depth=4,
switch_filters=64,
switch_reductions=4,
switch_processing_layers=2,
trans_counts=8,
trans_kernel_sizes=3,
trans_layers=4,
transformation_filters=64,
upsample_factor=4),
torch.randn(1, 3, 64, 64),
device='cuda')
'''
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
switch_filters=[32,32,32,32],
switch_growths=[16,16,16,16],
switch_reductions=[4,3,2,1],
switch_processing_layers=[3,3,4,5],
trans_counts=[16,16,16,16,16],
trans_kernel_sizes=[3,3,3,3,3],
trans_layers=[3,3,3,3,3],
trans_filters_mid=[24,24,24,24,24],
initial_temp=10),
torch.randn(1, 3, 64, 64),
device='cuda')
'''
'''
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
64, 16),
torch.randn(1, 3, 64, 64),
device='cuda')
'''