143 lines
5.3 KiB
Python
143 lines
5.3 KiB
Python
import functools
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
import dlas.models.image_generation.discriminator_vgg_arch as disc
|
|
|
|
blacklisted_modules = [nn.Conv2d, nn.ReLU,
|
|
nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
|
|
|
|
|
def install_forward_trace_hooks(module, id="base"):
|
|
if type(module) in blacklisted_modules:
|
|
return
|
|
module.register_forward_hook(
|
|
functools.partial(inject_input_shapes, mod_id=id))
|
|
for name, m in module.named_children():
|
|
cid = "%s:%s" % (id, name)
|
|
install_forward_trace_hooks(m, cid)
|
|
|
|
|
|
def inject_input_shapes(module: nn.Module, inputs, outputs, mod_id: str):
|
|
if len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
|
|
# Only single tensor inputs currently supported. TODO: fix.
|
|
module._input_shape = inputs[0].shape
|
|
|
|
|
|
def extract_input_shapes(module, id="base"):
|
|
shapes = {}
|
|
if hasattr(module, "_input_shape"):
|
|
shapes[id] = module._input_shape
|
|
for n, m in module.named_children():
|
|
cid = "%s:%s" % (id, n)
|
|
shapes.update(extract_input_shapes(m, cid))
|
|
return shapes
|
|
|
|
|
|
def test_stability(mod_fn, dummy_inputs, device='cuda'):
|
|
base_module = mod_fn().to(device)
|
|
dummy_inputs = dummy_inputs.to(device)
|
|
install_forward_trace_hooks(base_module)
|
|
base_module(dummy_inputs)
|
|
input_shapes = extract_input_shapes(base_module)
|
|
|
|
means = {}
|
|
stds = {}
|
|
for i in range(20):
|
|
mod = mod_fn().to(device)
|
|
t_means, t_stds = test_stability_per_module(mod, input_shapes, device)
|
|
for k in t_means.keys():
|
|
if k not in means.keys():
|
|
means[k] = []
|
|
stds[k] = []
|
|
means[k].extend(t_means[k])
|
|
stds[k].extend(t_stds[k])
|
|
|
|
for k in means.keys():
|
|
print("%s - mean: %f std: %f" % (k, torch.mean(torch.stack(means[k])),
|
|
torch.mean(torch.stack(stds[k]))))
|
|
|
|
|
|
def test_stability_per_module(mod: nn.Module, input_shapes: dict, device='cuda', id="base"):
|
|
means = {}
|
|
stds = {}
|
|
if id in input_shapes.keys():
|
|
format = input_shapes[id]
|
|
mean, std = test_numeric_stability(mod, format, 1, device)
|
|
means[id] = mean
|
|
stds[id] = std
|
|
for name, child in mod.named_children():
|
|
cid = "%s:%s" % (id, name)
|
|
m, s = test_stability_per_module(
|
|
child, input_shapes, device=device, id=cid)
|
|
means.update(m)
|
|
stds.update(s)
|
|
return means, stds
|
|
|
|
|
|
def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda'):
|
|
x = torch.randn(format).to(device)
|
|
means = []
|
|
stds = []
|
|
with torch.no_grad():
|
|
for i in range(iterations):
|
|
x = mod(x)[0]
|
|
measure = x
|
|
means.append(torch.mean(measure).detach())
|
|
stds.append(torch.std(measure).detach())
|
|
return torch.stack(means), torch.stack(stds)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
'''
|
|
test_stability(functools.partial(nsg.NestedSwitchedGenerator,
|
|
switch_filters=64,
|
|
switch_reductions=[3,3,3,3,3],
|
|
switch_processing_layers=[1,1,1,1,1],
|
|
trans_counts=[3,3,3,3,3],
|
|
trans_kernel_sizes=[3,3,3,3,3],
|
|
trans_layers=[3,3,3,3,3],
|
|
transformation_filters=64,
|
|
initial_temp=10),
|
|
torch.randn(1, 3, 64, 64),
|
|
device='cuda')
|
|
'''
|
|
'''
|
|
test_stability(functools.partial(srg.DualOutputSRG,
|
|
switch_depth=4,
|
|
switch_filters=64,
|
|
switch_reductions=4,
|
|
switch_processing_layers=2,
|
|
trans_counts=8,
|
|
trans_kernel_sizes=3,
|
|
trans_layers=4,
|
|
transformation_filters=64,
|
|
upsample_factor=4),
|
|
torch.randn(1, 3, 32, 32),
|
|
device='cpu')
|
|
'''
|
|
'''
|
|
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
|
switch_filters=[32,32,32,32],
|
|
switch_growths=[16,16,16,16],
|
|
switch_reductions=[4,3,2,1],
|
|
switch_processing_layers=[3,3,4,5],
|
|
trans_counts=[16,16,16,16,16],
|
|
trans_kernel_sizes=[3,3,3,3,3],
|
|
trans_layers=[3,3,3,3,3],
|
|
trans_filters_mid=[24,24,24,24,24],
|
|
initial_temp=10),
|
|
torch.randn(1, 3, 64, 64),
|
|
device='cuda')
|
|
'''
|
|
'''
|
|
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
|
|
64, 16),
|
|
torch.randn(1, 3, 64, 64),
|
|
device='cuda')
|
|
'''
|
|
test_stability(functools.partial(disc.Discriminator_UNet_FeaOut, 3, 64),
|
|
torch.randn(1, 3, 128, 128),
|
|
device='cpu')
|