Get rid of absolute positional embeddings in unifiedvoice
This commit is contained in:
parent
6700f8851d
commit
8d01f7685c
|
@ -1,3 +1,5 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -32,6 +34,10 @@ class ConditioningEncoder(nn.Module):
|
||||||
return h[:, :, 0]
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def null_position_embeddings(range, dim):
|
||||||
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
|
||||||
|
|
||||||
class UnifiedGptVoice(nn.Module):
|
class UnifiedGptVoice(nn.Module):
|
||||||
"""
|
"""
|
||||||
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
||||||
|
@ -74,6 +80,10 @@ class UnifiedGptVoice(nn.Module):
|
||||||
gradient_checkpointing=checkpointing,
|
gradient_checkpointing=checkpointing,
|
||||||
use_cache=not checkpointing)
|
use_cache=not checkpointing)
|
||||||
self.gpt = GPT2Model(self.gpt_config)
|
self.gpt = GPT2Model(self.gpt_config)
|
||||||
|
# Override the built in positional embeddings
|
||||||
|
del self.gpt.wpe
|
||||||
|
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
|
@ -143,7 +153,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
mel_inputs: long tensor, (b,m)
|
mel_inputs: long tensor, (b,m)
|
||||||
wav_lengths: long tensor, (b,)
|
wav_lengths: long tensor, (b,)
|
||||||
"""
|
"""
|
||||||
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
||||||
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||||
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
|
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
|
||||||
|
|
||||||
|
@ -187,7 +197,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
"""
|
"""
|
||||||
Performs autoregressive modeling on only speech data.
|
Performs autoregressive modeling on only speech data.
|
||||||
"""
|
"""
|
||||||
assert self.max_symbols_per_phrase >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
||||||
|
|
||||||
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user