forked from mrq/DL-Art-School
.
This commit is contained in:
parent
40f844657b
commit
2270c89fdc
|
@ -633,6 +633,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
|
||||
def forward(self, mel, inp_lengths=None):
|
||||
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
||||
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
|
||||
|
||||
# Frequency masking
|
||||
freq_mask_width = int(random.random() * self.freq_mask_percent * mel.shape[1])
|
||||
|
@ -641,13 +642,12 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
mel[:, freq_start:freq_start+freq_mask_width] = 0
|
||||
|
||||
# Build input masks from inp_lengths if possible.
|
||||
attention_mask = torch.ones_like(mel).long()
|
||||
attention_mask = torch.ones(features_shape, device=mel.device, dtype=torch.long)
|
||||
if inp_lengths is not None:
|
||||
inp_lengths = inp_lengths // self.inp_length_factor
|
||||
inp_lengths = inp_lengths // (self.inp_length_factor*self.m2v.dim_reduction_mult)
|
||||
for i, l in enumerate(inp_lengths):
|
||||
attention_mask[i, l:] = 0
|
||||
|
||||
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
|
||||
mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length, attention_mask=attention_mask)
|
||||
sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=mel.device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user