crammed in HF attention selection mechanisms for the AR
This commit is contained in:
parent
e5136613f5
commit
268ba17485
|
@ -12,6 +12,23 @@ from .arch_utils import AttentionBlock
|
|||
|
||||
from transformers import LogitsWarper
|
||||
|
||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||
|
||||
try:
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
|
||||
try:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
AVAILABLE_ATTENTIONS.append("flash")
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attn_2` support", e)
|
||||
|
||||
class TypicalLogitsWarper(LogitsWarper):
|
||||
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
|
@ -262,18 +279,21 @@ class LearnedPositionEmbeddings(nn.Module):
|
|||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
use_cache=not checkpointing)
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
use_cache=not checkpointing,
|
||||
attention_implementation=attention_implementation
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
|
||||
if checkpointing:
|
||||
|
@ -332,7 +352,8 @@ class UnifiedVoice(nn.Module):
|
|||
train_solo_embeddings=False,
|
||||
use_mel_codes_as_input=True,
|
||||
checkpointing=True,
|
||||
types=1
|
||||
types=1,
|
||||
attention_implementation="auto",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -354,7 +375,13 @@ class UnifiedVoice(nn.Module):
|
|||
checkpointing:
|
||||
"""
|
||||
super().__init__()
|
||||
if attention_implementation == "auto":
|
||||
if "flash" in AVAILABLE_ATTENTIONS:
|
||||
attention_implementation = "flash_attention_2"
|
||||
else:
|
||||
attention_implementation = "mem_efficient"
|
||||
|
||||
self.attention_implementation = attention_implementation
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
||||
self.stop_text_token = 0
|
||||
|
@ -375,7 +402,7 @@ class UnifiedVoice(nn.Module):
|
|||
else:
|
||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
|
@ -524,14 +551,17 @@ class UnifiedVoice(nn.Module):
|
|||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||
if not hasattr(self, 'inference_model'):
|
||||
# TODO: Decouple gpt_config from this inference model.
|
||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True)
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
attn_implementation=self.attention_implementation,
|
||||
)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user