another fix
This commit is contained in:
parent
41809a6330
commit
8f28404645
|
@ -394,6 +394,7 @@ class Mel2Vec(nn.Module):
|
||||||
self.mask_time_length = mask_time_length
|
self.mask_time_length = mask_time_length
|
||||||
self.disable_custom_linear_init = disable_custom_linear_init
|
self.disable_custom_linear_init = disable_custom_linear_init
|
||||||
self.linear_init_scale = linear_init_scale
|
self.linear_init_scale = linear_init_scale
|
||||||
|
self.dim_reduction_mult = dim_reduction_multiplier
|
||||||
self.apply(self.init)
|
self.apply(self.init)
|
||||||
|
|
||||||
def init(self, module):
|
def init(self, module):
|
||||||
|
@ -630,7 +631,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
def forward(self, mel):
|
def forward(self, mel):
|
||||||
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
||||||
|
|
||||||
features_shape = (mel.shape[0], mel.shape[-1]//4)
|
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
|
||||||
mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length)
|
mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length)
|
||||||
sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device)
|
sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device)
|
||||||
mask_time_indices = torch.tensor(mask_time_indices, device=mel.device)
|
mask_time_indices = torch.tensor(mask_time_indices, device=mel.device)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user