diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index fd72aef1..b526f5dc 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -1,3 +1,5 @@ +import functools + import torch import torch.nn as nn import torch.nn.functional as F @@ -32,6 +34,10 @@ class ConditioningEncoder(nn.Module): return h[:, :, 0] +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + class UnifiedGptVoice(nn.Module): """ Derived from GptTtsHf, but offers multiple modes of autoregressive operation: @@ -74,6 +80,10 @@ class UnifiedGptVoice(nn.Module): gradient_checkpointing=checkpointing, use_cache=not checkpointing) self.gpt = GPT2Model(self.gpt_config) + # Override the built in positional embeddings + del self.gpt.wpe + self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) @@ -143,7 +153,7 @@ class UnifiedGptVoice(nn.Module): mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) """ - assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' + assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}' assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}' @@ -187,7 +197,7 @@ class UnifiedGptVoice(nn.Module): """ Performs autoregressive modeling on only speech data. """ - assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' + assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)