Remove RDB from srg2

Doesnt seem to work so great.
This commit is contained in:
James Betker 2020-07-03 22:31:20 -06:00
parent 77d3765364
commit 510b2f887d
2 changed files with 12 additions and 9 deletions

View File

@ -313,9 +313,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count) multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=functools.partial(nn.Sequential, ResidualDenseBlock_5C(transformation_filters), pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
ResidualDenseBlock_5C(transformation_filters)), transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
transform_block=functools.partial(ResidualDenseBlock_5C, transformation_filters),
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms, transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01)) add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))

View File

@ -4,7 +4,7 @@ import models.archs.SwitchedResidualGenerator_arch as srg
import models.archs.NestedSwitchGenerator as nsg import models.archs.NestedSwitchGenerator as nsg
import functools import functools
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax, srg.Interpolate] blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
def install_forward_trace_hooks(module, id="base"): def install_forward_trace_hooks(module, id="base"):
if type(module) in blacklisted_modules: if type(module) in blacklisted_modules:
return return
@ -96,11 +96,15 @@ if __name__ == "__main__":
torch.randn(1, 3, 64, 64), torch.randn(1, 3, 64, 64),
device='cuda') device='cuda')
''' '''
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3, test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
trans_counts=[8], switch_filters=[16,16,16,16,16],
trans_kernel_sizes=[3], switch_growths=[32,32,32,32,32],
trans_layers=[3], switch_reductions=[1,1,1,1,1],
switch_processing_layers=[5,5,5,5,5],
trans_counts=[8,8,8,8,8],
trans_kernel_sizes=[3,3,3,3,3],
trans_layers=[3,3,3,3,3],
transformation_filters=64, transformation_filters=64,
initial_temp=10), initial_temp=10),
torch.randn(1, 3, 128, 128), torch.randn(1, 3, 64, 64),
device='cuda') device='cuda')