From 536511fc4b734f46b7d6c6d4e579438c0e59f691 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 22 Mar 2022 11:41:13 -0600 Subject: [PATCH] unified_voice: relative position encodings --- codes/models/audio/tts/transformer_builders.py | 11 +++++++++-- codes/models/audio/tts/unified_voice2.py | 8 +------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/codes/models/audio/tts/transformer_builders.py b/codes/models/audio/tts/transformer_builders.py index d117c932..e215ac21 100644 --- a/codes/models/audio/tts/transformer_builders.py +++ b/codes/models/audio/tts/transformer_builders.py @@ -21,6 +21,7 @@ Returns: """ import functools +import random from time import time import torch import torch.nn as nn @@ -32,15 +33,21 @@ def null_position_embeddings(range, dim): class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_dim, init=.02): + def __init__(self, seq_len, model_dim, init=.02, relative=True): super().__init__() self.emb = nn.Embedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) + self.relative = relative + self.seq_len = seq_len def forward(self, x): sl = x.shape[1] - return self.emb(torch.arange(0, sl, device=x.device)) + if self.relative: + start = random.randint(sl, self.seq_len) - sl + return self.emb(torch.arange(start, start+sl, device=x.device)) + else: + return self.emb(torch.arange(0, sl, device=x.device)) def get_fixed_embedding(self, ind, dev): return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 78beb2d0..0cd78aa2 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -560,22 +560,16 @@ class UnifiedVoice(nn.Module): return results * eos_token_mask - @register_model def register_unified_voice2(opt_net, opt): return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {})) if __name__ == '__main__': - ld = torch.load('attentions.pth') - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) - gpt.convert_attentions_to_aligned_codes(*ld) - ''' gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) l = gpt(torch.randn(2, 3, 80, 800), - torch.randint(high=len(symbols), size=(2,120)), + torch.randint(high=256, size=(2,120)), torch.tensor([32, 120]), torch.randint(high=8192, size=(2,250)), torch.tensor([250*256,195*256])) gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) - '''