Remove RDB from srg2
Doesnt seem to work so great.
This commit is contained in:
parent
77d3765364
commit
510b2f887d
|
@ -313,9 +313,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
||||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=functools.partial(nn.Sequential, ResidualDenseBlock_5C(transformation_filters),
|
pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
|
||||||
ResidualDenseBlock_5C(transformation_filters)),
|
transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
|
||||||
transform_block=functools.partial(ResidualDenseBlock_5C, transformation_filters),
|
|
||||||
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
||||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
|
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import models.archs.SwitchedResidualGenerator_arch as srg
|
||||||
import models.archs.NestedSwitchGenerator as nsg
|
import models.archs.NestedSwitchGenerator as nsg
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax, srg.Interpolate]
|
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||||
def install_forward_trace_hooks(module, id="base"):
|
def install_forward_trace_hooks(module, id="base"):
|
||||||
if type(module) in blacklisted_modules:
|
if type(module) in blacklisted_modules:
|
||||||
return
|
return
|
||||||
|
@ -96,11 +96,15 @@ if __name__ == "__main__":
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
'''
|
||||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
|
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
||||||
trans_counts=[8],
|
switch_filters=[16,16,16,16,16],
|
||||||
trans_kernel_sizes=[3],
|
switch_growths=[32,32,32,32,32],
|
||||||
trans_layers=[3],
|
switch_reductions=[1,1,1,1,1],
|
||||||
|
switch_processing_layers=[5,5,5,5,5],
|
||||||
|
trans_counts=[8,8,8,8,8],
|
||||||
|
trans_kernel_sizes=[3,3,3,3,3],
|
||||||
|
trans_layers=[3,3,3,3,3],
|
||||||
transformation_filters=64,
|
transformation_filters=64,
|
||||||
initial_temp=10),
|
initial_temp=10),
|
||||||
torch.randn(1, 3, 128, 128),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user