Fix up max lengths to save memory
This commit is contained in:
parent
9e47e64d5a
commit
729c1fd5a9
|
@ -66,6 +66,9 @@ class GptSegmentor(nn.Module):
|
|||
self.stop_head = nn.Linear(model_dim, 1)
|
||||
|
||||
def forward(self, mel_inputs, mel_lengths):
|
||||
max_len = mel_lengths.max() # This can be done in the dataset layer, but it is easier to do here.
|
||||
mel_inputs = mel_inputs[:, :, :max_len]
|
||||
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_lengths = mel_lengths // 4 # The encoder decimates the mel by a factor of 4.
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
|
|
Loading…
Reference in New Issue
Block a user