Fix inference, always flow full text tokens through transformer
This commit is contained in:
parent
4c678172d6
commit
a2afb25e42
|
@ -57,7 +57,8 @@ class GptTtsCollater():
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
text_lens = [len(x[0]) for x in batch]
|
text_lens = [len(x[0]) for x in batch]
|
||||||
max_text_len = max(text_lens)
|
#max_text_len = max(text_lens)
|
||||||
|
max_text_len = self.MAX_SYMBOLS_PER_PHRASE # This forces all outputs to have the full 200 characters. Testing if this makes a difference.
|
||||||
mel_lens = [len(x[1]) for x in batch]
|
mel_lens = [len(x[1]) for x in batch]
|
||||||
max_mel_len = max(mel_lens)
|
max_mel_len = max(mel_lens)
|
||||||
texts = []
|
texts = []
|
||||||
|
|
|
@ -71,11 +71,12 @@ class GptTts(nn.Module):
|
||||||
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
|
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
|
||||||
|
|
||||||
def inference(self, text_inputs):
|
def inference(self, text_inputs):
|
||||||
|
b, _ = text_inputs.shape
|
||||||
text_emb = self.text_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
|
|
||||||
mel_seq = torch.full((text_emb.shape[0],1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||||
stop_encountered = torch.zeros((text_emb.shape[0],), device=text_emb.device)
|
stop_encountered = torch.zeros((b,), device=text_emb.device)
|
||||||
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
|
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
|
||||||
mel_emb = self.mel_embedding(mel_seq)
|
mel_emb = self.mel_embedding(mel_seq)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
|
@ -91,25 +92,10 @@ class GptTts(nn.Module):
|
||||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||||
|
|
||||||
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||||
cleaned = []
|
mel_seq = mel_seq[:, 1:-1] # Remove first and last tokens, which were artificially added for GPT
|
||||||
for j in range(mel_seq.shape[0]):
|
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
|
||||||
s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens.
|
|
||||||
gt = s >= 512
|
|
||||||
l = (len(s)) // 3
|
|
||||||
for i in reversed(range(l)):
|
|
||||||
if gt[i]:
|
|
||||||
l = i+1
|
|
||||||
break
|
|
||||||
top = s[:l]
|
|
||||||
top = top + (top < 512) * 512
|
|
||||||
bottom = s[l:l*3]
|
|
||||||
bottom = bottom * (bottom < 512)
|
|
||||||
combined = torch.cat([top,bottom], dim=0)
|
|
||||||
assert not torch.any(combined < 0)
|
|
||||||
combined = combined * (combined < 1024)
|
|
||||||
cleaned.append(combined)
|
|
||||||
|
|
||||||
return torch.stack(cleaned)
|
return mel_seq
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
want_metrics = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
|
Loading…
Reference in New Issue
Block a user