fiddle with init

This commit is contained in:
James Betker 2022-05-18 10:56:01 -06:00
parent 208a703080
commit efc2657b48

View File

@ -313,7 +313,6 @@ class Wav2Vec2Encoder(nn.Module):
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
):
all_hidden_states = () if output_hidden_states else None
@ -361,6 +360,8 @@ class Mel2Vec(nn.Module):
layerdrop=0,
mask_time_prob=.65,
mask_time_length=10,
disable_custom_linear_init=False,
linear_init_scale=.02,
):
super().__init__()
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
@ -376,6 +377,8 @@ class Mel2Vec(nn.Module):
self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop)
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.disable_custom_linear_init = disable_custom_linear_init
self.linear_init_scale = linear_init_scale
self.apply(self.init)
def init(self, module):
@ -393,7 +396,9 @@ class Mel2Vec(nn.Module):
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=.02)
if self.disable_custom_linear_init:
return
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
@ -535,7 +540,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
class ContrastiveTrainingWrapper(nn.Module):
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100,
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100,
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs):
super().__init__()
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,