From b2b37453dff130b9cef13f04b20ca23dd3a01570 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 20:58:56 -0600 Subject: [PATCH] make the codebook bigger --- codes/models/audio/mel2vec.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 827e347b..5901ba03 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -454,7 +454,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. """ - def __init__(self, proj_dim=1024, codevector_dim=256, num_codevector_groups=2, num_codevectors_per_group=320): + def __init__(self, proj_dim=1024, codevector_dim=512, num_codevector_groups=2, num_codevectors_per_group=320): super().__init__() self.codevector_dim = codevector_dim self.num_groups = num_codevector_groups @@ -501,7 +501,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) if self.training: - # sample code vector probs via gumbel in differentiateable way + # sample code vector probs via gumbel in differentiable way codevector_probs = nn.functional.gumbel_softmax( hidden_states.float(), tau=self.temperature, hard=True ).type_as(hidden_states) @@ -513,7 +513,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) else: # take argmax in non-differentiable way - # comptute hard codevector distribution (one hot) + # compute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0