Get rid of absolute positional embeddings in unifiedvoice
This commit is contained in:
parent
6700f8851d
commit
8d01f7685c
|
@ -1,3 +1,5 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -32,6 +34,10 @@ class ConditioningEncoder(nn.Module):
|
|||
return h[:, :, 0]
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class UnifiedGptVoice(nn.Module):
|
||||
"""
|
||||
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
||||
|
@ -74,6 +80,10 @@ class UnifiedGptVoice(nn.Module):
|
|||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
self.gpt = GPT2Model(self.gpt_config)
|
||||
# Override the built in positional embeddings
|
||||
del self.gpt.wpe
|
||||
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||
|
@ -143,7 +153,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
"""
|
||||
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
||||
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
||||
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
|
||||
|
||||
|
@ -187,7 +197,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
"""
|
||||
Performs autoregressive modeling on only speech data.
|
||||
"""
|
||||
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
||||
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
||||
|
||||
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||
|
|
Loading…
Reference in New Issue
Block a user