Fix inference, always flow full text tokens through transformer
This commit is contained in:
parent
4c678172d6
commit
a2afb25e42
|
@ -57,7 +57,8 @@ class GptTtsCollater():
|
|||
|
||||
def __call__(self, batch):
|
||||
text_lens = [len(x[0]) for x in batch]
|
||||
max_text_len = max(text_lens)
|
||||
#max_text_len = max(text_lens)
|
||||
max_text_len = self.MAX_SYMBOLS_PER_PHRASE # This forces all outputs to have the full 200 characters. Testing if this makes a difference.
|
||||
mel_lens = [len(x[1]) for x in batch]
|
||||
max_mel_len = max(mel_lens)
|
||||
texts = []
|
||||
|
|
|
@ -71,11 +71,12 @@ class GptTts(nn.Module):
|
|||
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
|
||||
|
||||
def inference(self, text_inputs):
|
||||
b, _ = text_inputs.shape
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
|
||||
mel_seq = torch.full((text_emb.shape[0],1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||
stop_encountered = torch.zeros((text_emb.shape[0],), device=text_emb.device)
|
||||
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||
stop_encountered = torch.zeros((b,), device=text_emb.device)
|
||||
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
|
||||
mel_emb = self.mel_embedding(mel_seq)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
@ -91,25 +92,10 @@ class GptTts(nn.Module):
|
|||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||
cleaned = []
|
||||
for j in range(mel_seq.shape[0]):
|
||||
s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens.
|
||||
gt = s >= 512
|
||||
l = (len(s)) // 3
|
||||
for i in reversed(range(l)):
|
||||
if gt[i]:
|
||||
l = i+1
|
||||
break
|
||||
top = s[:l]
|
||||
top = top + (top < 512) * 512
|
||||
bottom = s[l:l*3]
|
||||
bottom = bottom * (bottom < 512)
|
||||
combined = torch.cat([top,bottom], dim=0)
|
||||
assert not torch.any(combined < 0)
|
||||
combined = combined * (combined < 1024)
|
||||
cleaned.append(combined)
|
||||
mel_seq = mel_seq[:, 1:-1] # Remove first and last tokens, which were artificially added for GPT
|
||||
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
|
||||
|
||||
return torch.stack(cleaned)
|
||||
return mel_seq
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
Loading…
Reference in New Issue
Block a user