diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py new file mode 100644 index 00000000..3e7855f9 --- /dev/null +++ b/codes/utils/numeric_stability.py @@ -0,0 +1,110 @@ +import torch +from torch import nn +import models.archs.SwitchedResidualGenerator_arch as srg +import models.archs.NestedSwitchGenerator as nsg +import functools + +blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax] +def install_forward_trace_hooks(module, id="base"): + if type(module) in blacklisted_modules: + return + module.register_forward_hook(functools.partial(inject_input_shapes, mod_id=id)) + for name, m in module.named_children(): + cid = "%s:%s" % (id, name) + install_forward_trace_hooks(m, cid) + +def inject_input_shapes(module: nn.Module, inputs, outputs, mod_id: str): + if len(inputs) == 1 and isinstance(inputs[0], torch.Tensor): + # Only single tensor inputs currently supported. TODO: fix. + module._input_shape = inputs[0].shape + +def extract_input_shapes(module, id="base"): + shapes = {} + if hasattr(module, "_input_shape"): + shapes[id] = module._input_shape + for n, m in module.named_children(): + cid = "%s:%s" % (id, n) + shapes.update(extract_input_shapes(m, cid)) + return shapes + +def test_stability(mod_fn, dummy_inputs, device='cuda'): + base_module = mod_fn().to(device) + dummy_inputs = dummy_inputs.to(device) + install_forward_trace_hooks(base_module) + base_module(dummy_inputs) + input_shapes = extract_input_shapes(base_module) + + means = {} + stds = {} + for i in range(20): + mod = mod_fn().to(device) + t_means, t_stds = test_stability_per_module(mod, input_shapes, device) + for k in t_means.keys(): + if k not in means.keys(): + means[k] = [] + stds[k] = [] + means[k].extend(t_means[k]) + stds[k].extend(t_stds[k]) + + for k in means.keys(): + print("%s - mean: %f std: %f" % (k, torch.mean(torch.stack(means[k])), + torch.mean(torch.stack(stds[k])))) + +def test_stability_per_module(mod: nn.Module, input_shapes: dict, device='cuda', id="base"): + means = {} + stds = {} + if id in input_shapes.keys(): + format = input_shapes[id] + mean, std = test_numeric_stability(mod, format, 1, device) + means[id] = mean + stds[id] = std + for name, child in mod.named_children(): + cid = "%s:%s" % (id, name) + m, s = test_stability_per_module(child, input_shapes, device=device, id=cid) + means.update(m) + stds.update(s) + return means, stds + +def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda'): + x = torch.randn(format).to(device) + means = [] + stds = [] + with torch.no_grad(): + for i in range(iterations): + x = mod(x)[0] + measure = x + means.append(torch.mean(measure).detach()) + stds.append(torch.std(measure).detach()) + return torch.stack(means), torch.stack(stds) + +''' + def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, + trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False): + ''' +if __name__ == "__main__": + ''' + test_stability(functools.partial(nsg.NestedSwitchedGenerator, + switch_filters=64, + switch_reductions=[3,3,3,3,3], + switch_processing_layers=[1,1,1,1,1], + trans_counts=[3,3,3,3,3], + trans_kernel_sizes=[3,3,3,3,3], + trans_layers=[3,3,3,3,3], + transformation_filters=64, + initial_temp=10), + torch.randn(1, 3, 64, 64), + device='cuda') + ''' + test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2, + switch_filters=[16,16,16,16,16], + switch_growths=[32,32,32,32,32], + switch_reductions=[1,1,1,1,1], + switch_processing_layers=[5,5,5,5,5], + trans_counts=[8,8,8,8,8], + trans_kernel_sizes=[3,3,3,3,3], + trans_layers=[3,3,3,3,3], + transformation_filters=64, + initial_temp=10), + torch.randn(1, 3, 64, 64), + device='cuda')