make the codebook bigger

This commit is contained in:
James Betker 2022-05-17 20:58:56 -06:00
parent 9a9c3cafba
commit b2b37453df

View File

@ -454,7 +454,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
def __init__(self, proj_dim=1024, codevector_dim=256, num_codevector_groups=2, num_codevectors_per_group=320):
def __init__(self, proj_dim=1024, codevector_dim=512, num_codevector_groups=2, num_codevectors_per_group=320):
super().__init__()
self.codevector_dim = codevector_dim
self.num_groups = num_codevector_groups
@ -501,7 +501,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
# sample code vector probs via gumbel in differentiable way
codevector_probs = nn.functional.gumbel_softmax(
hidden_states.float(), tau=self.temperature, hard=True
).type_as(hidden_states)
@ -513,7 +513,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
# compute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0