From 8f2840464503eea76a234a2ba565527f1875c9a7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 22 May 2022 21:32:43 -0600 Subject: [PATCH] another fix --- codes/models/audio/mel2vec.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 42e6ba31..b654ba4c 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -394,6 +394,7 @@ class Mel2Vec(nn.Module): self.mask_time_length = mask_time_length self.disable_custom_linear_init = disable_custom_linear_init self.linear_init_scale = linear_init_scale + self.dim_reduction_mult = dim_reduction_multiplier self.apply(self.init) def init(self, module): @@ -630,7 +631,7 @@ class ContrastiveTrainingWrapper(nn.Module): def forward(self, mel): mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math. - features_shape = (mel.shape[0], mel.shape[-1]//4) + features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult) mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length) sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device) mask_time_indices = torch.tensor(mask_time_indices, device=mel.device)