forked from mrq/DL-Art-School
Fix nv_tacotron_dataset bug which incorrectly mapped filenames
dammit..
This commit is contained in:
parent
a2afb25e42
commit
690d7e86d3
|
@ -98,9 +98,11 @@ class TextMelCollate():
|
||||||
|
|
||||||
text_padded = torch.LongTensor(len(batch), max_input_len)
|
text_padded = torch.LongTensor(len(batch), max_input_len)
|
||||||
text_padded.zero_()
|
text_padded.zero_()
|
||||||
|
filenames = []
|
||||||
for i in range(len(ids_sorted_decreasing)):
|
for i in range(len(ids_sorted_decreasing)):
|
||||||
text = batch[ids_sorted_decreasing[i]][0]
|
text = batch[ids_sorted_decreasing[i]][0]
|
||||||
text_padded[i, :text.size(0)] = text
|
text_padded[i, :text.size(0)] = text
|
||||||
|
filenames.append(batch[ids_sorted_decreasing[i]][2])
|
||||||
|
|
||||||
# Right zero-pad mel-spec
|
# Right zero-pad mel-spec
|
||||||
num_mels = batch[0][1].size(0)
|
num_mels = batch[0][1].size(0)
|
||||||
|
@ -121,7 +123,6 @@ class TextMelCollate():
|
||||||
gate_padded[i, mel.size(1)-1:] = 1
|
gate_padded[i, mel.size(1)-1:] = 1
|
||||||
output_lengths[i] = mel.size(1)
|
output_lengths[i] = mel.size(1)
|
||||||
|
|
||||||
filenames = [j[2] for j in batch]
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'padded_text': text_padded,
|
'padded_text': text_padded,
|
||||||
|
|
|
@ -97,6 +97,9 @@ class GptTts(nn.Module):
|
||||||
|
|
||||||
return mel_seq
|
return mel_seq
|
||||||
|
|
||||||
|
def inference_beam(self, text_inputs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_gpt_tts(opt_net, opt):
|
def register_gpt_tts(opt_net, opt):
|
||||||
|
|
|
@ -63,6 +63,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
wavfiles = data['filenames']
|
wavfiles = data['filenames']
|
||||||
quantized = model.eval_state[opt['eval']['quantized_mels']][0]
|
quantized = model.eval_state[opt['eval']['quantized_mels']][0]
|
||||||
for i, wavfile in enumerate(wavfiles):
|
for i, filename in enumerate(wavfiles):
|
||||||
qmelfile = wavfile.replace('wavs/', 'quantized_mels/') + '.pth'
|
qmelfile = filename.replace('wavs/', 'quantized_mels/') + '.pth'
|
||||||
torch.save(quantized[i], os.path.join(outpath, qmelfile))
|
torch.save(quantized[i], os.path.join(outpath, qmelfile))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user