Add logic for performing inference using gpt_tts with dual-encoder modes
This commit is contained in:
parent
b43683b772
commit
d3ace153af
|
@ -90,11 +90,26 @@ class GptTts(nn.Module):
|
||||||
if len(mel_seq) >= self.max_mel_frames:
|
if len(mel_seq) >= self.max_mel_frames:
|
||||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||||
|
|
||||||
# Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE.
|
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||||
mel_seq = mel_seq * (mel_seq < 512)
|
cleaned = []
|
||||||
padding_needed = 3-(mel_seq.shape[1]%3)
|
for j in range(mel_seq.shape[0]):
|
||||||
mel_seq = F.pad(mel_seq, (0,padding_needed))
|
s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens.
|
||||||
return mel_seq
|
gt = s >= 512
|
||||||
|
l = (len(s)) // 3
|
||||||
|
for i in reversed(range(l)):
|
||||||
|
if gt[i]:
|
||||||
|
l = i+1
|
||||||
|
break
|
||||||
|
top = s[:l]
|
||||||
|
top = top + (top < 512) * 512
|
||||||
|
bottom = s[l:l*3]
|
||||||
|
bottom = bottom * (bottom < 512)
|
||||||
|
combined = torch.cat([top,bottom], dim=0)
|
||||||
|
assert not torch.any(combined < 0)
|
||||||
|
combined = combined * (combined < 1024)
|
||||||
|
cleaned.append(combined)
|
||||||
|
|
||||||
|
return torch.stack(cleaned)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user