Bug fixes

This commit is contained in:
James Betker 2020-07-06 22:25:40 -06:00
parent 3c31bea1ac
commit d4d4f85fc0
2 changed files with 15 additions and 1 deletions

View File

@ -155,7 +155,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
switches = []
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
self.sw_conv = ConvBnLelu(transformation_filters, transformation_filters, lelu=False, bias=True)
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, biasd=True)
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)

View File

@ -97,6 +97,19 @@ if __name__ == "__main__":
torch.randn(1, 3, 64, 64),
device='cuda')
'''
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
switch_filters=[32,32,32,32],
switch_growths=[16,16,16,16],
switch_reductions=[4,3,2,1],
switch_processing_layers=[3,3,4,5],
trans_counts=[16,16,16,16,16],
trans_kernel_sizes=[3,3,3,3,3],
trans_layers=[3,3,3,3,3],
transformation_filters=64,
initial_temp=10),
torch.randn(1, 3, 64, 64),
device='cuda')
'''
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
switch_filters=[32,32,32,32],
switch_growths=[16,16,16,16],
@ -109,3 +122,4 @@ if __name__ == "__main__":
initial_temp=10),
torch.randn(1, 3, 64, 64),
device='cuda')
'''