diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 5be289d5..ff903da8 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -313,7 +313,6 @@ class Wav2Vec2Encoder(nn.Module): self, hidden_states, attention_mask=None, - output_attentions=False, output_hidden_states=False, ): all_hidden_states = () if output_hidden_states else None @@ -361,6 +360,8 @@ class Mel2Vec(nn.Module): layerdrop=0, mask_time_prob=.65, mask_time_length=10, + disable_custom_linear_init=False, + linear_init_scale=.02, ): super().__init__() self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), @@ -376,6 +377,8 @@ class Mel2Vec(nn.Module): self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop) self.mask_time_prob = mask_time_prob self.mask_time_length = mask_time_length + self.disable_custom_linear_init = disable_custom_linear_init + self.linear_init_scale = linear_init_scale self.apply(self.init) def init(self, module): @@ -393,7 +396,9 @@ class Mel2Vec(nn.Module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=.02) + if self.disable_custom_linear_init: + return + module.weight.data.normal_(mean=0.0, std=self.linear_init_scale) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): @@ -535,7 +540,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class ContrastiveTrainingWrapper(nn.Module): - def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100, + def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs): super().__init__() self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,