small fix

This commit is contained in:
discus0434 2022-10-22 13:44:39 +00:00
parent 97749b7c7d
commit 6a4fa73a38

View File

@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout
if use_dropout:
p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
linears.append(torch.nn.Dropout(p=p))
# Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)