From 690d7e86d3d852e21735db09666ab7fbcd7f4e31 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 8 Aug 2021 11:38:52 -0600 Subject: [PATCH] Fix nv_tacotron_dataset bug which incorrectly mapped filenames dammit.. --- codes/data/audio/nv_tacotron_dataset.py | 3 ++- codes/models/gpt_voice/gpt_tts.py | 3 +++ codes/scripts/audio/generate_quantized_mels.py | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 0f0c7432..21c4b6ba 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -98,9 +98,11 @@ class TextMelCollate(): text_padded = torch.LongTensor(len(batch), max_input_len) text_padded.zero_() + filenames = [] for i in range(len(ids_sorted_decreasing)): text = batch[ids_sorted_decreasing[i]][0] text_padded[i, :text.size(0)] = text + filenames.append(batch[ids_sorted_decreasing[i]][2]) # Right zero-pad mel-spec num_mels = batch[0][1].size(0) @@ -121,7 +123,6 @@ class TextMelCollate(): gate_padded[i, mel.size(1)-1:] = 1 output_lengths[i] = mel.size(1) - filenames = [j[2] for j in batch] return { 'padded_text': text_padded, diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index a6e6cd2f..96ad2ce4 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -97,6 +97,9 @@ class GptTts(nn.Module): return mel_seq + def inference_beam(self, text_inputs): + pass + @register_model def register_gpt_tts(opt_net, opt): diff --git a/codes/scripts/audio/generate_quantized_mels.py b/codes/scripts/audio/generate_quantized_mels.py index 6adf7ea1..b331ef12 100644 --- a/codes/scripts/audio/generate_quantized_mels.py +++ b/codes/scripts/audio/generate_quantized_mels.py @@ -63,6 +63,6 @@ if __name__ == "__main__": wavfiles = data['filenames'] quantized = model.eval_state[opt['eval']['quantized_mels']][0] - for i, wavfile in enumerate(wavfiles): - qmelfile = wavfile.replace('wavs/', 'quantized_mels/') + '.pth' + for i, filename in enumerate(wavfiles): + qmelfile = filename.replace('wavs/', 'quantized_mels/') + '.pth' torch.save(quantized[i], os.path.join(outpath, qmelfile))