From d3ace153af9945eed2b4007d1d7a99189ed17b05 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 6 Aug 2021 12:04:12 -0600 Subject: [PATCH] Add logic for performing inference using gpt_tts with dual-encoder modes --- codes/models/gpt_voice/gpt_tts.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index 70096c08..eb0b56ad 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -90,11 +90,26 @@ class GptTts(nn.Module): if len(mel_seq) >= self.max_mel_frames: print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") - # Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE. - mel_seq = mel_seq * (mel_seq < 512) - padding_needed = 3-(mel_seq.shape[1]%3) - mel_seq = F.pad(mel_seq, (0,padding_needed)) - return mel_seq + # Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE) + cleaned = [] + for j in range(mel_seq.shape[0]): + s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens. + gt = s >= 512 + l = (len(s)) // 3 + for i in reversed(range(l)): + if gt[i]: + l = i+1 + break + top = s[:l] + top = top + (top < 512) * 512 + bottom = s[l:l*3] + bottom = bottom * (bottom < 512) + combined = torch.cat([top,bottom], dim=0) + assert not torch.any(combined < 0) + combined = combined * (combined < 1024) + cleaned.append(combined) + + return torch.stack(cleaned) @register_model