forked from mrq/DL-Art-School
lambda2
This commit is contained in:
parent
025a5867c4
commit
336f807c8e
|
@ -191,6 +191,17 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
nn.BatchNorm2d(breadth),
|
nn.BatchNorm2d(breadth),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
Conv2d(breadth, breadth, 1, stride=self.stride))
|
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||||
|
elif coupler_mode == 'lambda2':
|
||||||
|
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
||||||
|
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
|
||||||
|
nn.ReLU(),
|
||||||
|
LambdaLayer(dim=coupler_dim_in, dim_out=coupler_dim_in, r=23, dim_k=16, heads=2, dim_u=1),
|
||||||
|
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
|
||||||
|
nn.ReLU(),
|
||||||
|
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
||||||
|
nn.GroupNorm(num_groups=2, num_channels=breadth),
|
||||||
|
nn.ReLU(),
|
||||||
|
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||||
else:
|
else:
|
||||||
self.coupler = None
|
self.coupler = None
|
||||||
self.gate = HardRoutingGate(breadth, hard_en=hard_en)
|
self.gate = HardRoutingGate(breadth, hard_en=hard_en)
|
||||||
|
@ -240,7 +251,7 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
self.last_select = selector.detach().clone()
|
self.last_select = selector.detach().clone()
|
||||||
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
|
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
|
||||||
|
|
||||||
if True:
|
if False:
|
||||||
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
|
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
|
||||||
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -232,7 +232,7 @@ def register_vqvae3_hard_switch(opt_net, opt):
|
||||||
|
|
||||||
def performance_test():
|
def performance_test():
|
||||||
cfg = {
|
cfg = {
|
||||||
'mode': 'lambda',
|
'mode': 'lambda2',
|
||||||
'breadth': 8,
|
'breadth': 8,
|
||||||
'hard_enabled': True,
|
'hard_enabled': True,
|
||||||
'dropout': 0.4
|
'dropout': 0.4
|
||||||
|
|
Loading…
Reference in New Issue
Block a user