Fix inference, always flow full text tokens through transformer

This commit is contained in:
James Betker 2021-08-07 20:11:10 -06:00
parent 4c678172d6
commit a2afb25e42
3 changed files with 9 additions and 22 deletions

View File

@ -57,7 +57,8 @@ class GptTtsCollater():
def __call__(self, batch):
text_lens = [len(x[0]) for x in batch]
max_text_len = max(text_lens)
#max_text_len = max(text_lens)
max_text_len = self.MAX_SYMBOLS_PER_PHRASE # This forces all outputs to have the full 200 characters. Testing if this makes a difference.
mel_lens = [len(x[1]) for x in batch]
max_mel_len = max(mel_lens)
texts = []

View File

@ -71,11 +71,12 @@ class GptTts(nn.Module):
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
def inference(self, text_inputs):
b, _ = text_inputs.shape
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_seq = torch.full((text_emb.shape[0],1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
stop_encountered = torch.zeros((text_emb.shape[0],), device=text_emb.device)
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
stop_encountered = torch.zeros((b,), device=text_emb.device)
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
mel_emb = self.mel_embedding(mel_seq)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
@ -91,25 +92,10 @@ class GptTts(nn.Module):
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
cleaned = []
for j in range(mel_seq.shape[0]):
s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens.
gt = s >= 512
l = (len(s)) // 3
for i in reversed(range(l)):
if gt[i]:
l = i+1
break
top = s[:l]
top = top + (top < 512) * 512
bottom = s[l:l*3]
bottom = bottom * (bottom < 512)
combined = torch.cat([top,bottom], dim=0)
assert not torch.any(combined < 0)
combined = combined * (combined < 1024)
cleaned.append(combined)
mel_seq = mel_seq[:, 1:-1] # Remove first and last tokens, which were artificially added for GPT
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
return torch.stack(cleaned)
return mel_seq
@register_model

View File

@ -54,7 +54,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt