Add logic for performing inference using gpt_tts with dual-encoder modes
This commit is contained in:
parent
b43683b772
commit
d3ace153af
|
@ -90,11 +90,26 @@ class GptTts(nn.Module):
|
|||
if len(mel_seq) >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
# Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE.
|
||||
mel_seq = mel_seq * (mel_seq < 512)
|
||||
padding_needed = 3-(mel_seq.shape[1]%3)
|
||||
mel_seq = F.pad(mel_seq, (0,padding_needed))
|
||||
return mel_seq
|
||||
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||
cleaned = []
|
||||
for j in range(mel_seq.shape[0]):
|
||||
s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens.
|
||||
gt = s >= 512
|
||||
l = (len(s)) // 3
|
||||
for i in reversed(range(l)):
|
||||
if gt[i]:
|
||||
l = i+1
|
||||
break
|
||||
top = s[:l]
|
||||
top = top + (top < 512) * 512
|
||||
bottom = s[l:l*3]
|
||||
bottom = bottom * (bottom < 512)
|
||||
combined = torch.cat([top,bottom], dim=0)
|
||||
assert not torch.any(combined < 0)
|
||||
combined = combined * (combined < 1024)
|
||||
cleaned.append(combined)
|
||||
|
||||
return torch.stack(cleaned)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
Loading…
Reference in New Issue
Block a user