crammed in HF attention selection mechanisms for the AR

This commit is contained in:
mrq 2024-06-19 10:21:43 -05:00
parent e5136613f5
commit 268ba17485

View File

@ -12,6 +12,23 @@ from .arch_utils import AttentionBlock
from transformers import LogitsWarper
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash")
except Exception as e:
print("Error while querying for `flash_attn_2` support", e)
class TypicalLogitsWarper(LogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -262,18 +279,21 @@ class LearnedPositionEmbeddings(nn.Module):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"):
"""
GPT-2 implemented by the HuggingFace library.
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
use_cache=not checkpointing)
gpt_config = GPT2Config(
vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
use_cache=not checkpointing,
attention_implementation=attention_implementation
)
gpt = GPT2Model(gpt_config)
if checkpointing:
@ -332,7 +352,8 @@ class UnifiedVoice(nn.Module):
train_solo_embeddings=False,
use_mel_codes_as_input=True,
checkpointing=True,
types=1
types=1,
attention_implementation="auto",
):
"""
Args:
@ -354,7 +375,13 @@ class UnifiedVoice(nn.Module):
checkpointing:
"""
super().__init__()
if attention_implementation == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
attention_implementation = "flash_attention_2"
else:
attention_implementation = "mem_efficient"
self.attention_implementation = attention_implementation
self.number_text_tokens = number_text_tokens
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.stop_text_token = 0
@ -375,7 +402,7 @@ class UnifiedVoice(nn.Module):
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
@ -524,14 +551,17 @@ class UnifiedVoice(nn.Module):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True,
attn_implementation=self.attention_implementation,
)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding