make codebooks specifiable

This commit is contained in:
James Betker 2022-05-18 11:07:12 -06:00
parent efc2657b48
commit 10378fc37f

View File

@ -541,7 +541,9 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
class ContrastiveTrainingWrapper(nn.Module): class ContrastiveTrainingWrapper(nn.Module):
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100, def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100,
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, **kwargs): max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
codebook_size=320, codebook_groups=2,
**kwargs):
super().__init__() super().__init__()
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
mask_time_length=mask_time_length, **kwargs) mask_time_length=mask_time_length, **kwargs)
@ -551,7 +553,7 @@ class ContrastiveTrainingWrapper(nn.Module):
self.max_gumbel_temperature = max_gumbel_temperature self.max_gumbel_temperature = max_gumbel_temperature
self.min_gumbel_temperature = min_gumbel_temperature self.min_gumbel_temperature = min_gumbel_temperature
self.gumbel_temperature_decay = gumbel_temperature_decay self.gumbel_temperature_decay = gumbel_temperature_decay
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim) self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size)
self.num_losses_record = [] self.num_losses_record = []
# make sure that project_hid & project_q are initialized like normal linear layers # make sure that project_hid & project_q are initialized like normal linear layers