Bug fixes
This commit is contained in:
parent
3c31bea1ac
commit
d4d4f85fc0
|
@ -155,7 +155,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
switches = []
|
switches = []
|
||||||
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
|
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
|
||||||
self.sw_conv = ConvBnLelu(transformation_filters, transformation_filters, lelu=False, bias=True)
|
self.sw_conv = ConvBnLelu(transformation_filters, transformation_filters, lelu=False, bias=True)
|
||||||
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, biasd=True)
|
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
|
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
|
||||||
|
|
|
@ -97,6 +97,19 @@ if __name__ == "__main__":
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
'''
|
||||||
|
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
||||||
|
switch_filters=[32,32,32,32],
|
||||||
|
switch_growths=[16,16,16,16],
|
||||||
|
switch_reductions=[4,3,2,1],
|
||||||
|
switch_processing_layers=[3,3,4,5],
|
||||||
|
trans_counts=[16,16,16,16,16],
|
||||||
|
trans_kernel_sizes=[3,3,3,3,3],
|
||||||
|
trans_layers=[3,3,3,3,3],
|
||||||
|
transformation_filters=64,
|
||||||
|
initial_temp=10),
|
||||||
|
torch.randn(1, 3, 64, 64),
|
||||||
|
device='cuda')
|
||||||
|
'''
|
||||||
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
||||||
switch_filters=[32,32,32,32],
|
switch_filters=[32,32,32,32],
|
||||||
switch_growths=[16,16,16,16],
|
switch_growths=[16,16,16,16],
|
||||||
|
@ -109,3 +122,4 @@ if __name__ == "__main__":
|
||||||
initial_temp=10),
|
initial_temp=10),
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
|
'''
|
||||||
|
|
Loading…
Reference in New Issue
Block a user