Fix assertion error
This commit is contained in:
parent
8f92c0a088
commit
2636d3b620
|
@ -66,7 +66,7 @@ class SwitchComputer(nn.Module):
|
||||||
tc = transform_count
|
tc = transform_count
|
||||||
if self.enable_negative_transforms:
|
if self.enable_negative_transforms:
|
||||||
tc = transform_count * 2
|
tc = transform_count * 2
|
||||||
assert filters > transform_count * 2
|
assert filters > tc
|
||||||
self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0)
|
self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0)
|
||||||
|
|
||||||
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user