make codebooks specifiable
This commit is contained in:
parent
efc2657b48
commit
10378fc37f
|
@ -541,7 +541,9 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
|
|
||||||
class ContrastiveTrainingWrapper(nn.Module):
|
class ContrastiveTrainingWrapper(nn.Module):
|
||||||
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100,
|
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100,
|
||||||
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs):
|
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
|
||||||
|
codebook_size=320, codebook_groups=2,
|
||||||
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
||||||
mask_time_length=mask_time_length, **kwargs)
|
mask_time_length=mask_time_length, **kwargs)
|
||||||
|
@ -551,7 +553,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
self.max_gumbel_temperature = max_gumbel_temperature
|
self.max_gumbel_temperature = max_gumbel_temperature
|
||||||
self.min_gumbel_temperature = min_gumbel_temperature
|
self.min_gumbel_temperature = min_gumbel_temperature
|
||||||
self.gumbel_temperature_decay = gumbel_temperature_decay
|
self.gumbel_temperature_decay = gumbel_temperature_decay
|
||||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim)
|
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size)
|
||||||
self.num_losses_record = []
|
self.num_losses_record = []
|
||||||
|
|
||||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
# make sure that project_hid & project_q are initialized like normal linear layers
|
||||||
|
|
Loading…
Reference in New Issue
Block a user