Add numeric stability computation script
This commit is contained in:
parent
c0bb123504
commit
ee6443ad7d
110
codes/utils/numeric_stability.py
Normal file
110
codes/utils/numeric_stability.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
import models.archs.NestedSwitchGenerator as nsg
|
||||
import functools
|
||||
|
||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||
def install_forward_trace_hooks(module, id="base"):
|
||||
if type(module) in blacklisted_modules:
|
||||
return
|
||||
module.register_forward_hook(functools.partial(inject_input_shapes, mod_id=id))
|
||||
for name, m in module.named_children():
|
||||
cid = "%s:%s" % (id, name)
|
||||
install_forward_trace_hooks(m, cid)
|
||||
|
||||
def inject_input_shapes(module: nn.Module, inputs, outputs, mod_id: str):
|
||||
if len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
|
||||
# Only single tensor inputs currently supported. TODO: fix.
|
||||
module._input_shape = inputs[0].shape
|
||||
|
||||
def extract_input_shapes(module, id="base"):
|
||||
shapes = {}
|
||||
if hasattr(module, "_input_shape"):
|
||||
shapes[id] = module._input_shape
|
||||
for n, m in module.named_children():
|
||||
cid = "%s:%s" % (id, n)
|
||||
shapes.update(extract_input_shapes(m, cid))
|
||||
return shapes
|
||||
|
||||
def test_stability(mod_fn, dummy_inputs, device='cuda'):
|
||||
base_module = mod_fn().to(device)
|
||||
dummy_inputs = dummy_inputs.to(device)
|
||||
install_forward_trace_hooks(base_module)
|
||||
base_module(dummy_inputs)
|
||||
input_shapes = extract_input_shapes(base_module)
|
||||
|
||||
means = {}
|
||||
stds = {}
|
||||
for i in range(20):
|
||||
mod = mod_fn().to(device)
|
||||
t_means, t_stds = test_stability_per_module(mod, input_shapes, device)
|
||||
for k in t_means.keys():
|
||||
if k not in means.keys():
|
||||
means[k] = []
|
||||
stds[k] = []
|
||||
means[k].extend(t_means[k])
|
||||
stds[k].extend(t_stds[k])
|
||||
|
||||
for k in means.keys():
|
||||
print("%s - mean: %f std: %f" % (k, torch.mean(torch.stack(means[k])),
|
||||
torch.mean(torch.stack(stds[k]))))
|
||||
|
||||
def test_stability_per_module(mod: nn.Module, input_shapes: dict, device='cuda', id="base"):
|
||||
means = {}
|
||||
stds = {}
|
||||
if id in input_shapes.keys():
|
||||
format = input_shapes[id]
|
||||
mean, std = test_numeric_stability(mod, format, 1, device)
|
||||
means[id] = mean
|
||||
stds[id] = std
|
||||
for name, child in mod.named_children():
|
||||
cid = "%s:%s" % (id, name)
|
||||
m, s = test_stability_per_module(child, input_shapes, device=device, id=cid)
|
||||
means.update(m)
|
||||
stds.update(s)
|
||||
return means, stds
|
||||
|
||||
def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda'):
|
||||
x = torch.randn(format).to(device)
|
||||
means = []
|
||||
stds = []
|
||||
with torch.no_grad():
|
||||
for i in range(iterations):
|
||||
x = mod(x)[0]
|
||||
measure = x
|
||||
means.append(torch.mean(measure).detach())
|
||||
stds.append(torch.std(measure).detach())
|
||||
return torch.stack(means), torch.stack(stds)
|
||||
|
||||
'''
|
||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
|
||||
'''
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
test_stability(functools.partial(nsg.NestedSwitchedGenerator,
|
||||
switch_filters=64,
|
||||
switch_reductions=[3,3,3,3,3],
|
||||
switch_processing_layers=[1,1,1,1,1],
|
||||
trans_counts=[3,3,3,3,3],
|
||||
trans_kernel_sizes=[3,3,3,3,3],
|
||||
trans_layers=[3,3,3,3,3],
|
||||
transformation_filters=64,
|
||||
initial_temp=10),
|
||||
torch.randn(1, 3, 64, 64),
|
||||
device='cuda')
|
||||
'''
|
||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
||||
switch_filters=[16,16,16,16,16],
|
||||
switch_growths=[32,32,32,32,32],
|
||||
switch_reductions=[1,1,1,1,1],
|
||||
switch_processing_layers=[5,5,5,5,5],
|
||||
trans_counts=[8,8,8,8,8],
|
||||
trans_kernel_sizes=[3,3,3,3,3],
|
||||
trans_layers=[3,3,3,3,3],
|
||||
transformation_filters=64,
|
||||
initial_temp=10),
|
||||
torch.randn(1, 3, 64, 64),
|
||||
device='cuda')
|
Loading…
Reference in New Issue
Block a user