diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index ff903da8..e02768af 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -541,7 +541,9 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class ContrastiveTrainingWrapper(nn.Module): def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100, - max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs): + max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, + codebook_size=320, codebook_groups=2, + **kwargs): super().__init__() self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, mask_time_length=mask_time_length, **kwargs) @@ -551,7 +553,7 @@ class ContrastiveTrainingWrapper(nn.Module): self.max_gumbel_temperature = max_gumbel_temperature self.min_gumbel_temperature = min_gumbel_temperature self.gumbel_temperature_decay = gumbel_temperature_decay - self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim) + self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size) self.num_losses_record = [] # make sure that project_hid & project_q are initialized like normal linear layers