forked from mrq/DL-Art-School
Revert
This commit is contained in:
parent
6dec1f5968
commit
e7be4bdff3
|
@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
|
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
|
||||||
elif coupler_mode == 'lambda':
|
elif coupler_mode == 'lambda':
|
||||||
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
||||||
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
|
nn.BatchNorm2d(coupler_dim_in),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
||||||
nn.GroupNorm(num_groups=1, num_channels=breadth),
|
nn.BatchNorm2d(breadth),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
Conv2d(breadth, breadth, 1, stride=self.stride))
|
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||||
elif coupler_mode == 'lambda2':
|
elif coupler_mode == 'lambda2':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user