add an option to avoid dying relu
This commit is contained in:
parent
dcb45dfecf
commit
fccba4729d
|
@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
|
||||||
|
|
||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
|
@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
# Add an activation func
|
# Add an activation func
|
||||||
if activation_func == "linear" or activation_func is None:
|
if activation_func == "linear" or activation_func is None:
|
||||||
pass
|
pass
|
||||||
|
# If ReLU, Skip adding it to the first layer to avoid dying ReLU
|
||||||
|
elif activation_func == "relu" and i < 1:
|
||||||
|
pass
|
||||||
elif activation_func in self.activation_dict:
|
elif activation_func in self.activation_dict:
|
||||||
linears.append(self.activation_dict[activation_func]())
|
linears.append(self.activation_dict[activation_func]())
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||||
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add dropout
|
# Add dropout
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
|
@ -166,8 +166,8 @@ class Hypernetwork:
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user