diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index a8648e8f..9b680fc0 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from tqdm import tqdm from models.arch_util import ConvGnSilu from models.tacotron2.taco_utils import get_mask_from_lengths @@ -9,6 +10,18 @@ from models.gpt_voice.min_gpt import GPT, GPTConfig from trainer.networks import register_model +# A Conv1d that masks out kernel elements ahead of the current location. +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.kernel_mask = torch.ones_like(self.weight) + self.kernel_mask[:, :, -(self.kernel_size[0]//2):] = 0 + + def forward(self, input): + self.kernel_mask = self.kernel_mask.to(input.device) + return self._conv_forward(input, self.weight * self.kernel_mask, self.bias) + + class GptTts(nn.Module): def __init__(self): super().__init__() @@ -18,21 +31,28 @@ class GptTts(nn.Module): max_mel_frames = 900 mel_dim=80 + self.model_dim = model_dim + self.max_mel_frames = max_mel_frames self.text_embedding = nn.Embedding(number_symbols, model_dim) - self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=3, convnd=nn.Conv1d), - ConvGnSilu(model_dim//2, model_dim, kernel_size=3, stride=2, convnd=nn.Conv1d)) + # Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias + # into the model which we cannot provide in inference. + self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), + ConvGnSilu(model_dim//2, model_dim, kernel_size=5, stride=2, convnd=CausalConv1d)) + # *_tags are additively applied to self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) + self.separator = nn.Parameter(torch.randn(1, 1, model_dim)) self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) - self.gpt = GPT(GPTConfig(max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8)) + self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8)) - self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d), + self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d), nn.Upsample(scale_factor=2, mode='nearest'), - ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d), + ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), + # No need for causal convolutions when kernel_size=1 nn.Conv1d(model_dim//2, 1, kernel_size=1)) - self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d), + self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d), nn.Upsample(scale_factor=2, mode='nearest'), - ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d), - ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=nn.Conv1d), + ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), + ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=CausalConv1d), ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d)) def forward(self, text_inputs, mel_targets, output_lengths): @@ -45,9 +65,11 @@ class GptTts(nn.Module): text_emb = text_emb + self.text_tags mel_emb = self.mel_encoder(mel_targets).permute(0,2,1) mel_emb = mel_emb + self.audio_tags - emb = torch.cat([text_emb, mel_emb], dim=1) + emb = torch.cat([text_emb, + self.separator.repeat(text_emb.shape[0],1,1), + mel_emb], dim=1) enc = self.gpt(emb) - mel_portion = enc[:, text_emb.shape[1]:].permute(0,2,1) + mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1) gates = self.gate_head(mel_portion).squeeze(1) mel_pred = self.mel_head(mel_portion) @@ -62,6 +84,53 @@ class GptTts(nn.Module): gates = gates[:, :-1] return mel_pred, gates + def test_guide(self, mel_guide, amount=50): + mel_guide = mel_guide[:,:,:amount] + mel_emb = self.mel_encoder(mel_guide).permute(0,2,1) + mel_emb = mel_emb + self.audio_tags + return mel_emb + + def inference(self, text_inputs, mel_guide): + MEL_HEAD_EXPANSION = 2 + GATE_THRESHOLD = .95 + + text_emb = self.text_embedding(text_inputs) + text_emb = text_emb + self.text_tags + b,s,c = text_emb.shape + emb = torch.cat([text_emb, + self.separator.repeat(text_emb.shape[0],1,1)], dim=1) + #self.test_guide(mel_guide)], dim=1) + completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool) + output = None + for i in tqdm(range(self.max_mel_frames)): + enc = self.gpt(emb) + inferred = enc[:,s:,:].permute(0,2,1) + # Create output frames. + inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:] + inferred_mel_frame = inferred_mel_frame * (~completed).float().view(b,1,1) + if output is None: + output = inferred_mel_frame + else: + output = torch.cat([output, inferred_mel_frame], dim=2) + + # Test termination condition + gate = F.sigmoid(self.gate_head(inferred)).max(dim=-1).values # TODO: accept single-frame terminations. + completed = completed.logical_or((gate > GATE_THRESHOLD).squeeze(1)) # This comprises a latch - but that may not be wise. + if torch.all(completed): + break + + # Apply inferred mel_frames to emb for next pass. + mel_emb = self.mel_encoder(output).permute(0,2,1) + mel_emb = mel_emb + self.audio_tags + emb = torch.cat([text_emb, + self.separator.repeat(text_emb.shape[0],1,1), + mel_emb], dim=1) + if i == self.max_mel_frames//2: + print("Warning! Inference hit mel frame cap without encountering a stop token.") + break + + return output + @register_model def register_gpt_tts(opt_net, opt): @@ -74,4 +143,9 @@ if __name__ == '__main__': torch.randn(2,80,747), torch.tensor([600,747])) print(m.shape) - print(g.shape) \ No newline at end of file + print(g.shape) + + o = gpt.infer(torch.randint(high=24, size=(2,60))) + print(o.shape) + + diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index 39fb6e40..f67ec2e6 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -40,7 +40,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt