Fix assertion error
This commit is contained in:
parent
8f92c0a088
commit
2636d3b620
|
@ -66,7 +66,7 @@ class SwitchComputer(nn.Module):
|
|||
tc = transform_count
|
||||
if self.enable_negative_transforms:
|
||||
tc = transform_count * 2
|
||||
assert filters > transform_count * 2
|
||||
assert filters > tc
|
||||
self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0)
|
||||
|
||||
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
||||
|
|
Loading…
Reference in New Issue
Block a user