forked from mrq/DL-Art-School
make the codebook bigger
This commit is contained in:
parent
9a9c3cafba
commit
b2b37453df
|
@ -454,7 +454,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, proj_dim=1024, codevector_dim=256, num_codevector_groups=2, num_codevectors_per_group=320):
|
||||
def __init__(self, proj_dim=1024, codevector_dim=512, num_codevector_groups=2, num_codevectors_per_group=320):
|
||||
super().__init__()
|
||||
self.codevector_dim = codevector_dim
|
||||
self.num_groups = num_codevector_groups
|
||||
|
@ -501,7 +501,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||||
|
||||
if self.training:
|
||||
# sample code vector probs via gumbel in differentiateable way
|
||||
# sample code vector probs via gumbel in differentiable way
|
||||
codevector_probs = nn.functional.gumbel_softmax(
|
||||
hidden_states.float(), tau=self.temperature, hard=True
|
||||
).type_as(hidden_states)
|
||||
|
@ -513,7 +513,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
||||
else:
|
||||
# take argmax in non-differentiable way
|
||||
# comptute hard codevector distribution (one hot)
|
||||
# compute hard codevector distribution (one hot)
|
||||
codevector_idx = hidden_states.argmax(dim=-1)
|
||||
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
||||
-1, codevector_idx.view(-1, 1), 1.0
|
||||
|
|
Loading…
Reference in New Issue
Block a user