Get rid of absolute positional embeddings in unifiedvoice

This commit is contained in:
James Betker 2021-12-26 00:10:24 -07:00
parent 6700f8851d
commit 8d01f7685c

View File

@ -1,3 +1,5 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -32,6 +34,10 @@ class ConditioningEncoder(nn.Module):
return h[:, :, 0]
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class UnifiedGptVoice(nn.Module):
"""
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
@ -74,6 +80,10 @@ class UnifiedGptVoice(nn.Module):
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
# Override the built in positional embeddings
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
@ -143,7 +153,7 @@ class UnifiedGptVoice(nn.Module):
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
"""
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
@ -187,7 +197,7 @@ class UnifiedGptVoice(nn.Module):
"""
Performs autoregressive modeling on only speech data.
"""
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)