degrade gumbel softmax over time
This commit is contained in:
parent
3853f37257
commit
c1bdb4f9a1
|
@ -533,7 +533,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ContrastiveTrainingWrapper(nn.Module):
|
class ContrastiveTrainingWrapper(nn.Module):
|
||||||
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100, **kwargs):
|
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100,
|
||||||
|
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
||||||
mask_time_length=mask_time_length, **kwargs)
|
mask_time_length=mask_time_length, **kwargs)
|
||||||
|
@ -541,6 +542,9 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
self.num_negatives = num_negatives
|
self.num_negatives = num_negatives
|
||||||
self.mask_time_prob = mask_time_prob
|
self.mask_time_prob = mask_time_prob
|
||||||
self.mask_time_length = mask_time_length
|
self.mask_time_length = mask_time_length
|
||||||
|
self.max_gumbel_temperature = max_gumbel_temperature
|
||||||
|
self.min_gumbel_temperature = min_gumbel_temperature
|
||||||
|
self.gumbel_temperature_decay = gumbel_temperature_decay
|
||||||
|
|
||||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim)
|
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim)
|
||||||
|
|
||||||
|
@ -569,7 +573,15 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def update_for_step(self, step, *args):
|
||||||
|
self.quantizer.temperature = max(
|
||||||
|
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
|
||||||
|
self.min_gumbel_temperature,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, mel):
|
def forward(self, mel):
|
||||||
|
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
||||||
|
|
||||||
features_shape = (mel.shape[0], mel.shape[-1]//4)
|
features_shape = (mel.shape[0], mel.shape[-1]//4)
|
||||||
mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length)
|
mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length)
|
||||||
sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device)
|
sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user