forked from mrq/DL-Art-School
fiddle with init
This commit is contained in:
parent
208a703080
commit
efc2657b48
|
@ -313,7 +313,6 @@ class Wav2Vec2Encoder(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
@ -361,6 +360,8 @@ class Mel2Vec(nn.Module):
|
|||
layerdrop=0,
|
||||
mask_time_prob=.65,
|
||||
mask_time_length=10,
|
||||
disable_custom_linear_init=False,
|
||||
linear_init_scale=.02,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
|
||||
|
@ -376,6 +377,8 @@ class Mel2Vec(nn.Module):
|
|||
self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop)
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.disable_custom_linear_init = disable_custom_linear_init
|
||||
self.linear_init_scale = linear_init_scale
|
||||
self.apply(self.init)
|
||||
|
||||
def init(self, module):
|
||||
|
@ -393,7 +396,9 @@ class Mel2Vec(nn.Module):
|
|||
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
||||
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=.02)
|
||||
if self.disable_custom_linear_init:
|
||||
return
|
||||
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||
|
@ -535,7 +540,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
|
||||
|
||||
class ContrastiveTrainingWrapper(nn.Module):
|
||||
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100,
|
||||
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100,
|
||||
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs):
|
||||
super().__init__()
|
||||
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
||||
|
|
Loading…
Reference in New Issue
Block a user