diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 3b11bac2..d10baf9a 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -641,7 +641,7 @@ class ContrastiveTrainingWrapper(nn.Module): mel[:, freq_start:freq_start+freq_mask_width] = 0 # Build input masks from inp_lengths if possible. - attention_mask = torch.ones_like(mel) + attention_mask = torch.ones_like(mel).long() if inp_lengths is not None: inp_lengths = inp_lengths // self.inp_length_factor for i, l in enumerate(inp_lengths):