5f2c722a10
Big update to SRG2 architecture to pull in a lot of things that have been learned: - Use group norm instead of batch norm - Initialize the weights on the transformations low like is done in RRDB rather than using the scalar. Models live or die by their early stages, and this ones early stage is pretty weak - Transform multiplexer to use u-net like architecture. - Just use one set of configuration variables instead of a list - flat networks performed fine in this regard.
132 lines
5.6 KiB
Python
132 lines
5.6 KiB
Python
import torch
|
|
from torch import nn
|
|
import models.archs.SRG1_arch as srg1
|
|
import models.archs.SwitchedResidualGenerator_arch as srg
|
|
import models.archs.NestedSwitchGenerator as nsg
|
|
import functools
|
|
|
|
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
|
def install_forward_trace_hooks(module, id="base"):
|
|
if type(module) in blacklisted_modules:
|
|
return
|
|
module.register_forward_hook(functools.partial(inject_input_shapes, mod_id=id))
|
|
for name, m in module.named_children():
|
|
cid = "%s:%s" % (id, name)
|
|
install_forward_trace_hooks(m, cid)
|
|
|
|
def inject_input_shapes(module: nn.Module, inputs, outputs, mod_id: str):
|
|
if len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
|
|
# Only single tensor inputs currently supported. TODO: fix.
|
|
module._input_shape = inputs[0].shape
|
|
|
|
def extract_input_shapes(module, id="base"):
|
|
shapes = {}
|
|
if hasattr(module, "_input_shape"):
|
|
shapes[id] = module._input_shape
|
|
for n, m in module.named_children():
|
|
cid = "%s:%s" % (id, n)
|
|
shapes.update(extract_input_shapes(m, cid))
|
|
return shapes
|
|
|
|
def test_stability(mod_fn, dummy_inputs, device='cuda'):
|
|
base_module = mod_fn().to(device)
|
|
dummy_inputs = dummy_inputs.to(device)
|
|
install_forward_trace_hooks(base_module)
|
|
base_module(dummy_inputs)
|
|
input_shapes = extract_input_shapes(base_module)
|
|
|
|
means = {}
|
|
stds = {}
|
|
for i in range(20):
|
|
mod = mod_fn().to(device)
|
|
t_means, t_stds = test_stability_per_module(mod, input_shapes, device)
|
|
for k in t_means.keys():
|
|
if k not in means.keys():
|
|
means[k] = []
|
|
stds[k] = []
|
|
means[k].extend(t_means[k])
|
|
stds[k].extend(t_stds[k])
|
|
|
|
for k in means.keys():
|
|
print("%s - mean: %f std: %f" % (k, torch.mean(torch.stack(means[k])),
|
|
torch.mean(torch.stack(stds[k]))))
|
|
|
|
def test_stability_per_module(mod: nn.Module, input_shapes: dict, device='cuda', id="base"):
|
|
means = {}
|
|
stds = {}
|
|
if id in input_shapes.keys():
|
|
format = input_shapes[id]
|
|
mean, std = test_numeric_stability(mod, format, 1, device)
|
|
means[id] = mean
|
|
stds[id] = std
|
|
for name, child in mod.named_children():
|
|
cid = "%s:%s" % (id, name)
|
|
m, s = test_stability_per_module(child, input_shapes, device=device, id=cid)
|
|
means.update(m)
|
|
stds.update(s)
|
|
return means, stds
|
|
|
|
def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda'):
|
|
x = torch.randn(format).to(device)
|
|
means = []
|
|
stds = []
|
|
with torch.no_grad():
|
|
for i in range(iterations):
|
|
x = mod(x)[0]
|
|
measure = x
|
|
means.append(torch.mean(measure).detach())
|
|
stds.append(torch.std(measure).detach())
|
|
return torch.stack(means), torch.stack(stds)
|
|
|
|
'''
|
|
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
|
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
|
|
'''
|
|
if __name__ == "__main__":
|
|
'''
|
|
test_stability(functools.partial(nsg.NestedSwitchedGenerator,
|
|
switch_filters=64,
|
|
switch_reductions=[3,3,3,3,3],
|
|
switch_processing_layers=[1,1,1,1,1],
|
|
trans_counts=[3,3,3,3,3],
|
|
trans_kernel_sizes=[3,3,3,3,3],
|
|
trans_layers=[3,3,3,3,3],
|
|
transformation_filters=64,
|
|
initial_temp=10),
|
|
torch.randn(1, 3, 64, 64),
|
|
device='cuda')
|
|
'''
|
|
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
|
switch_depth=4,
|
|
switch_filters=64,
|
|
switch_reductions=4,
|
|
switch_processing_layers=2,
|
|
trans_counts=8,
|
|
trans_kernel_sizes=3,
|
|
trans_layers=4,
|
|
transformation_filters=64,
|
|
upsample_factor=4),
|
|
torch.randn(1, 3, 64, 64),
|
|
device='cuda')
|
|
|
|
'''
|
|
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
|
switch_filters=[32,32,32,32],
|
|
switch_growths=[16,16,16,16],
|
|
switch_reductions=[4,3,2,1],
|
|
switch_processing_layers=[3,3,4,5],
|
|
trans_counts=[16,16,16,16,16],
|
|
trans_kernel_sizes=[3,3,3,3,3],
|
|
trans_layers=[3,3,3,3,3],
|
|
trans_filters_mid=[24,24,24,24,24],
|
|
initial_temp=10),
|
|
torch.randn(1, 3, 64, 64),
|
|
device='cuda')
|
|
'''
|
|
'''
|
|
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
|
|
64, 16),
|
|
torch.randn(1, 3, 64, 64),
|
|
device='cuda')
|
|
''' |