Fix nv_tacotron_dataset bug which incorrectly mapped filenames

dammit..
This commit is contained in:
James Betker 2021-08-08 11:38:52 -06:00
parent a2afb25e42
commit 690d7e86d3
3 changed files with 7 additions and 3 deletions

View File

@ -98,9 +98,11 @@ class TextMelCollate():
text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_()
filenames = []
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]][0]
text_padded[i, :text.size(0)] = text
filenames.append(batch[ids_sorted_decreasing[i]][2])
# Right zero-pad mel-spec
num_mels = batch[0][1].size(0)
@ -121,7 +123,6 @@ class TextMelCollate():
gate_padded[i, mel.size(1)-1:] = 1
output_lengths[i] = mel.size(1)
filenames = [j[2] for j in batch]
return {
'padded_text': text_padded,

View File

@ -97,6 +97,9 @@ class GptTts(nn.Module):
return mel_seq
def inference_beam(self, text_inputs):
pass
@register_model
def register_gpt_tts(opt_net, opt):

View File

@ -63,6 +63,6 @@ if __name__ == "__main__":
wavfiles = data['filenames']
quantized = model.eval_state[opt['eval']['quantized_mels']][0]
for i, wavfile in enumerate(wavfiles):
qmelfile = wavfile.replace('wavs/', 'quantized_mels/') + '.pth'
for i, filename in enumerate(wavfiles):
qmelfile = filename.replace('wavs/', 'quantized_mels/') + '.pth'
torch.save(quantized[i], os.path.join(outpath, qmelfile))