From 510b2f887d6f28acd42eb58620eb2edd359f09c2 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 3 Jul 2020 22:31:20 -0600 Subject: [PATCH] Remove RDB from srg2 Doesnt seem to work so great. --- .../archs/SwitchedResidualGenerator_arch.py | 5 ++--- codes/utils/numeric_stability.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 47436bee..73ea56b5 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -313,9 +313,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count) switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=functools.partial(nn.Sequential, ResidualDenseBlock_5C(transformation_filters), - ResidualDenseBlock_5C(transformation_filters)), - transform_block=functools.partial(ResidualDenseBlock_5C, transformation_filters), + pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False), + transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01)) diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py index 30f4c4c9..3e7855f9 100644 --- a/codes/utils/numeric_stability.py +++ b/codes/utils/numeric_stability.py @@ -4,7 +4,7 @@ import models.archs.SwitchedResidualGenerator_arch as srg import models.archs.NestedSwitchGenerator as nsg import functools -blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax, srg.Interpolate] +blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax] def install_forward_trace_hooks(module, id="base"): if type(module) in blacklisted_modules: return @@ -96,11 +96,15 @@ if __name__ == "__main__": torch.randn(1, 3, 64, 64), device='cuda') ''' - test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3, - trans_counts=[8], - trans_kernel_sizes=[3], - trans_layers=[3], + test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2, + switch_filters=[16,16,16,16,16], + switch_growths=[32,32,32,32,32], + switch_reductions=[1,1,1,1,1], + switch_processing_layers=[5,5,5,5,5], + trans_counts=[8,8,8,8,8], + trans_kernel_sizes=[3,3,3,3,3], + trans_layers=[3,3,3,3,3], transformation_filters=64, initial_temp=10), - torch.randn(1, 3, 128, 128), + torch.randn(1, 3, 64, 64), device='cuda')