Fix assertion error

This commit is contained in:
James Betker 2020-07-06 09:23:53 -06:00
parent 8f92c0a088
commit 2636d3b620

View File

@ -66,7 +66,7 @@ class SwitchComputer(nn.Module):
tc = transform_count tc = transform_count
if self.enable_negative_transforms: if self.enable_negative_transforms:
tc = transform_count * 2 tc = transform_count * 2
assert filters > transform_count * 2 assert filters > tc
self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0) self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0)
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])