From c1bdb4f9a1cd256a2ce51f9537bc1ea972a02551 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 16:23:04 -0600 Subject: [PATCH] degrade gumbel softmax over time --- codes/models/audio/mel2vec.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 3f4637a5..2ca71877 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -533,7 +533,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class ContrastiveTrainingWrapper(nn.Module): - def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100, **kwargs): + def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100, + max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs): super().__init__() self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, mask_time_length=mask_time_length, **kwargs) @@ -541,6 +542,9 @@ class ContrastiveTrainingWrapper(nn.Module): self.num_negatives = num_negatives self.mask_time_prob = mask_time_prob self.mask_time_length = mask_time_length + self.max_gumbel_temperature = max_gumbel_temperature + self.min_gumbel_temperature = min_gumbel_temperature + self.gumbel_temperature_decay = gumbel_temperature_decay self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim) @@ -569,7 +573,15 @@ class ContrastiveTrainingWrapper(nn.Module): logits = logits / temperature return logits + def update_for_step(self, step, *args): + self.quantizer.temperature = max( + self.max_gumbel_temperature * self.gumbel_temperature_decay**step, + self.min_gumbel_temperature, + ) + def forward(self, mel): + mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math. + features_shape = (mel.shape[0], mel.shape[-1]//4) mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length) sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device)