Add logic for performing inference using gpt_tts with dual-encoder modes

This commit is contained in:
James Betker 2021-08-06 12:04:12 -06:00
parent b43683b772
commit d3ace153af

View File

@ -90,11 +90,26 @@ class GptTts(nn.Module):
if len(mel_seq) >= self.max_mel_frames:
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
# Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE.
mel_seq = mel_seq * (mel_seq < 512)
padding_needed = 3-(mel_seq.shape[1]%3)
mel_seq = F.pad(mel_seq, (0,padding_needed))
return mel_seq
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
cleaned = []
for j in range(mel_seq.shape[0]):
s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens.
gt = s >= 512
l = (len(s)) // 3
for i in reversed(range(l)):
if gt[i]:
l = i+1
break
top = s[:l]
top = top + (top < 512) * 512
bottom = s[l:l*3]
bottom = bottom * (bottom < 512)
combined = torch.cat([top,bottom], dim=0)
assert not torch.any(combined < 0)
combined = combined * (combined < 1024)
cleaned.append(combined)
return torch.stack(cleaned)
@register_model