forked from mrq/DL-Art-School
make the codebook bigger
This commit is contained in:
parent
9a9c3cafba
commit
b2b37453df
|
@ -454,7 +454,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, proj_dim=1024, codevector_dim=256, num_codevector_groups=2, num_codevectors_per_group=320):
|
def __init__(self, proj_dim=1024, codevector_dim=512, num_codevector_groups=2, num_codevectors_per_group=320):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.codevector_dim = codevector_dim
|
self.codevector_dim = codevector_dim
|
||||||
self.num_groups = num_codevector_groups
|
self.num_groups = num_codevector_groups
|
||||||
|
@ -501,7 +501,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
# sample code vector probs via gumbel in differentiateable way
|
# sample code vector probs via gumbel in differentiable way
|
||||||
codevector_probs = nn.functional.gumbel_softmax(
|
codevector_probs = nn.functional.gumbel_softmax(
|
||||||
hidden_states.float(), tau=self.temperature, hard=True
|
hidden_states.float(), tau=self.temperature, hard=True
|
||||||
).type_as(hidden_states)
|
).type_as(hidden_states)
|
||||||
|
@ -513,7 +513,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
||||||
else:
|
else:
|
||||||
# take argmax in non-differentiable way
|
# take argmax in non-differentiable way
|
||||||
# comptute hard codevector distribution (one hot)
|
# compute hard codevector distribution (one hot)
|
||||||
codevector_idx = hidden_states.argmax(dim=-1)
|
codevector_idx = hidden_states.argmax(dim=-1)
|
||||||
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
||||||
-1, codevector_idx.view(-1, 1), 1.0
|
-1, codevector_idx.view(-1, 1), 1.0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user