From 268ba17485b30dfbc337d55330c0279a5216acb8 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 10:21:43 -0500 Subject: [PATCH] crammed in HF attention selection mechanisms for the AR --- tortoise_tts/models/unified_voice.py | 66 ++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/tortoise_tts/models/unified_voice.py b/tortoise_tts/models/unified_voice.py index 7b08a5b..7828ad4 100644 --- a/tortoise_tts/models/unified_voice.py +++ b/tortoise_tts/models/unified_voice.py @@ -12,6 +12,23 @@ from .arch_utils import AttentionBlock from transformers import LogitsWarper +AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] + +try: + from xformers.ops import LowerTriangularMask + from xformers.ops.fmha import memory_efficient_attention + + AVAILABLE_ATTENTIONS.append("xformers") +except Exception as e: + print("Error while importing `xformers`", e) + +try: + from transformers.utils import is_flash_attn_2_available + + if is_flash_attn_2_available(): + AVAILABLE_ATTENTIONS.append("flash") +except Exception as e: + print("Error while querying for `flash_attn_2` support", e) class TypicalLogitsWarper(LogitsWarper): def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -262,18 +279,21 @@ class LearnedPositionEmbeddings(nn.Module): return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) -def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"): """ GPT-2 implemented by the HuggingFace library. """ from transformers import GPT2Config, GPT2Model - gpt_config = GPT2Config(vocab_size=256, # Unused. - n_positions=max_mel_seq_len+max_text_seq_len, - n_ctx=max_mel_seq_len+max_text_seq_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - use_cache=not checkpointing) + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len+max_text_seq_len, + n_ctx=max_mel_seq_len+max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + use_cache=not checkpointing, + attention_implementation=attention_implementation + ) gpt = GPT2Model(gpt_config) if checkpointing: @@ -332,7 +352,8 @@ class UnifiedVoice(nn.Module): train_solo_embeddings=False, use_mel_codes_as_input=True, checkpointing=True, - types=1 + types=1, + attention_implementation="auto", ): """ Args: @@ -354,7 +375,13 @@ class UnifiedVoice(nn.Module): checkpointing: """ super().__init__() + if attention_implementation == "auto": + if "flash" in AVAILABLE_ATTENTIONS: + attention_implementation = "flash_attention_2" + else: + attention_implementation = "mem_efficient" + self.attention_implementation = attention_implementation self.number_text_tokens = number_text_tokens self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token self.stop_text_token = 0 @@ -375,7 +402,7 @@ class UnifiedVoice(nn.Module): else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ - build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing) + build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) @@ -524,14 +551,17 @@ class UnifiedVoice(nn.Module): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 if not hasattr(self, 'inference_model'): # TODO: Decouple gpt_config from this inference model. - gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=self.model_dim, - n_layer=self.layers, - n_head=self.heads, - gradient_checkpointing=False, - use_cache=True) + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + attn_implementation=self.attention_implementation, + ) self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) self.gpt.wte = self.mel_embedding