Get rid of absolute positional embeddings in unifiedvoice

This commit is contained in:
James Betker 2021-12-26 00:10:24 -07:00
parent 6700f8851d
commit 8d01f7685c

View File

@ -1,3 +1,5 @@
import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -32,6 +34,10 @@ class ConditioningEncoder(nn.Module):
return h[:, :, 0] return h[:, :, 0]
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class UnifiedGptVoice(nn.Module): class UnifiedGptVoice(nn.Module):
""" """
Derived from GptTtsHf, but offers multiple modes of autoregressive operation: Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
@ -74,6 +80,10 @@ class UnifiedGptVoice(nn.Module):
gradient_checkpointing=checkpointing, gradient_checkpointing=checkpointing,
use_cache=not checkpointing) use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config) self.gpt = GPT2Model(self.gpt_config)
# Override the built in positional embeddings
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
@ -143,7 +153,7 @@ class UnifiedGptVoice(nn.Module):
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
""" """
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}' assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}' assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
@ -187,7 +197,7 @@ class UnifiedGptVoice(nn.Module):
""" """
Performs autoregressive modeling on only speech data. Performs autoregressive modeling on only speech data.
""" """
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)