From 70b17da193720c210632ea5bfc49780c8535e22f Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 8 Jan 2022 22:18:25 -0700 Subject: [PATCH] Alter unified_voice to use extensible transformer (still WIP) --- .../models/gpt_voice/transformer_builders.py | 70 +++++++++++++++++++ codes/models/gpt_voice/unified_voice.py | 24 +------ 2 files changed, 73 insertions(+), 21 deletions(-) create mode 100644 codes/models/gpt_voice/transformer_builders.py diff --git a/codes/models/gpt_voice/transformer_builders.py b/codes/models/gpt_voice/transformer_builders.py new file mode 100644 index 00000000..28d1b355 --- /dev/null +++ b/codes/models/gpt_voice/transformer_builders.py @@ -0,0 +1,70 @@ +""" +A list of functions that map a unified set of arguments to a fully built transformer. Also includes some testing +utilities for measuring parameter count, FLOPS, and general performance of each type. + +Every function contains the following arguments: + + layers: Net number of layers in the transformer. + model_dim: Hidden dimensionality of the model. + heads: Number of attention heads. + num_tokens: Number of possible tokens in the transformer's dictionary. Do not use this in future releases. + max_seq_len: Maximum sequence length to attend to. + checkpointing: Whether or not the underlying implementation should support gradient checkpointing. +""" +import functools +import torch + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + gpt_config = GPT2Config(vocab_size=num_tokens, + n_positions=max_seq_len, + n_ctx=max_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + return gpt + + +def build_lr_performer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): + """ + lucidrains Performer implementation, https://github.com/lucidrains/performer-pytorch + """ + pass + + +def build_lr_reformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): + """ + lucidrains Reformer implementation, https://github.com/lucidrains/reformer-pytorch + """ + pass + + +def build_lr_xformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): + """ + lucidrains x-transformer implementation, https://github.com/lucidrains/x-transformers + """ + pass + + +def test_all_performance(**kwargs): + transformer_builders = [build_hf_gpt_transformer, build_lr_performer, build_lr_reformer, build_lr_xformer] + for builder in transformer_builders: + model = builder(**kwargs) + + +if __name__ == '__main__': + test_all_performance(12, 512, 8, 8192, 1000, False) \ No newline at end of file diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 6e5548a4..1689d13e 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -8,6 +8,7 @@ from transformers import GPT2Model, GPT2Config from models.arch_util import AttentionBlock from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel from models.gpt_voice.gpt_asr_hf2 import ResBlock +from models.gpt_voice.transformer_builders import build_hf_gpt_transformer from models.tacotron2.text import symbols from trainer.networks import register_model from utils.util import opt_get @@ -59,22 +60,6 @@ class MelEncoder(nn.Module): return x.permute(0,2,1) -def null_position_embeddings(range, dim): - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): - gpt_config = GPT2Config(vocab_size=num_tokens, - n_positions=max_seq_len, - n_ctx=max_seq_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - return GPT2Model(gpt_config) - - class UnifiedGptVoice(nn.Module): """ Derived from GptTtsHf, but offers multiple modes of autoregressive operation: @@ -133,14 +118,11 @@ class UnifiedGptVoice(nn.Module): self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing) if train_solo_embeddings: - self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) - self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) else: self.mel_solo_embedding = 0 self.text_solo_embedding = 0 - # Override the built in positional embeddings - del self.gpt.wpe - self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) if not use_mel_codes_as_input: self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1)