Fix up max lengths to save memory

This commit is contained in:
James Betker 2021-08-15 21:29:28 -06:00
parent 9e47e64d5a
commit 729c1fd5a9

View File

@ -66,6 +66,9 @@ class GptSegmentor(nn.Module):
self.stop_head = nn.Linear(model_dim, 1) self.stop_head = nn.Linear(model_dim, 1)
def forward(self, mel_inputs, mel_lengths): def forward(self, mel_inputs, mel_lengths):
max_len = mel_lengths.max() # This can be done in the dataset layer, but it is easier to do here.
mel_inputs = mel_inputs[:, :, :max_len]
mel_emb = self.mel_encoder(mel_inputs) mel_emb = self.mel_encoder(mel_inputs)
mel_lengths = mel_lengths // 4 # The encoder decimates the mel by a factor of 4. mel_lengths = mel_lengths // 4 # The encoder decimates the mel by a factor of 4.
mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb.permute(0,2,1).contiguous()