This commit is contained in:
James Betker 2022-05-23 08:27:54 -06:00
parent 9f16b25ce5
commit 40f844657b

View File

@ -641,7 +641,7 @@ class ContrastiveTrainingWrapper(nn.Module):
mel[:, freq_start:freq_start+freq_mask_width] = 0
# Build input masks from inp_lengths if possible.
attention_mask = torch.ones_like(mel)
attention_mask = torch.ones_like(mel).long()
if inp_lengths is not None:
inp_lengths = inp_lengths // self.inp_length_factor
for i, l in enumerate(inp_lengths):