Remove RDB from srg2
Doesnt seem to work so great.
This commit is contained in:
parent
77d3765364
commit
510b2f887d
|
@ -313,9 +313,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=functools.partial(nn.Sequential, ResidualDenseBlock_5C(transformation_filters),
|
||||
ResidualDenseBlock_5C(transformation_filters)),
|
||||
transform_block=functools.partial(ResidualDenseBlock_5C, transformation_filters),
|
||||
pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
|
||||
transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
|
||||
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import models.archs.SwitchedResidualGenerator_arch as srg
|
|||
import models.archs.NestedSwitchGenerator as nsg
|
||||
import functools
|
||||
|
||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax, srg.Interpolate]
|
||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||
def install_forward_trace_hooks(module, id="base"):
|
||||
if type(module) in blacklisted_modules:
|
||||
return
|
||||
|
@ -96,11 +96,15 @@ if __name__ == "__main__":
|
|||
torch.randn(1, 3, 64, 64),
|
||||
device='cuda')
|
||||
'''
|
||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
|
||||
trans_counts=[8],
|
||||
trans_kernel_sizes=[3],
|
||||
trans_layers=[3],
|
||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
||||
switch_filters=[16,16,16,16,16],
|
||||
switch_growths=[32,32,32,32,32],
|
||||
switch_reductions=[1,1,1,1,1],
|
||||
switch_processing_layers=[5,5,5,5,5],
|
||||
trans_counts=[8,8,8,8,8],
|
||||
trans_kernel_sizes=[3,3,3,3,3],
|
||||
trans_layers=[3,3,3,3,3],
|
||||
transformation_filters=64,
|
||||
initial_temp=10),
|
||||
torch.randn(1, 3, 128, 128),
|
||||
torch.randn(1, 3, 64, 64),
|
||||
device='cuda')
|
||||
|
|
Loading…
Reference in New Issue
Block a user