crammed in HF attention selection mechanisms for the AR
This commit is contained in:
parent
e5136613f5
commit
268ba17485
|
@ -12,6 +12,23 @@ from .arch_utils import AttentionBlock
|
||||||
|
|
||||||
from transformers import LogitsWarper
|
from transformers import LogitsWarper
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops import LowerTriangularMask
|
||||||
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS.append("xformers")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while importing `xformers`", e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
AVAILABLE_ATTENTIONS.append("flash")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while querying for `flash_attn_2` support", e)
|
||||||
|
|
||||||
class TypicalLogitsWarper(LogitsWarper):
|
class TypicalLogitsWarper(LogitsWarper):
|
||||||
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
@ -262,18 +279,21 @@ class LearnedPositionEmbeddings(nn.Module):
|
||||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"):
|
||||||
"""
|
"""
|
||||||
GPT-2 implemented by the HuggingFace library.
|
GPT-2 implemented by the HuggingFace library.
|
||||||
"""
|
"""
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
gpt_config = GPT2Config(
|
||||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
vocab_size=256, # Unused.
|
||||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||||
n_embd=model_dim,
|
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||||
n_layer=layers,
|
n_embd=model_dim,
|
||||||
n_head=heads,
|
n_layer=layers,
|
||||||
use_cache=not checkpointing)
|
n_head=heads,
|
||||||
|
use_cache=not checkpointing,
|
||||||
|
attention_implementation=attention_implementation
|
||||||
|
)
|
||||||
gpt = GPT2Model(gpt_config)
|
gpt = GPT2Model(gpt_config)
|
||||||
|
|
||||||
if checkpointing:
|
if checkpointing:
|
||||||
|
@ -332,7 +352,8 @@ class UnifiedVoice(nn.Module):
|
||||||
train_solo_embeddings=False,
|
train_solo_embeddings=False,
|
||||||
use_mel_codes_as_input=True,
|
use_mel_codes_as_input=True,
|
||||||
checkpointing=True,
|
checkpointing=True,
|
||||||
types=1
|
types=1,
|
||||||
|
attention_implementation="auto",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -354,7 +375,13 @@ class UnifiedVoice(nn.Module):
|
||||||
checkpointing:
|
checkpointing:
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if attention_implementation == "auto":
|
||||||
|
if "flash" in AVAILABLE_ATTENTIONS:
|
||||||
|
attention_implementation = "flash_attention_2"
|
||||||
|
else:
|
||||||
|
attention_implementation = "mem_efficient"
|
||||||
|
|
||||||
|
self.attention_implementation = attention_implementation
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
||||||
self.stop_text_token = 0
|
self.stop_text_token = 0
|
||||||
|
@ -375,7 +402,7 @@ class UnifiedVoice(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
|
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation)
|
||||||
if train_solo_embeddings:
|
if train_solo_embeddings:
|
||||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
|
@ -524,14 +551,17 @@ class UnifiedVoice(nn.Module):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
# TODO: Decouple gpt_config from this inference model.
|
# TODO: Decouple gpt_config from this inference model.
|
||||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
gpt_config = GPT2Config(
|
||||||
n_positions=seq_length,
|
vocab_size=self.max_mel_tokens,
|
||||||
n_ctx=seq_length,
|
n_positions=seq_length,
|
||||||
n_embd=self.model_dim,
|
n_ctx=seq_length,
|
||||||
n_layer=self.layers,
|
n_embd=self.model_dim,
|
||||||
n_head=self.heads,
|
n_layer=self.layers,
|
||||||
gradient_checkpointing=False,
|
n_head=self.heads,
|
||||||
use_cache=True)
|
gradient_checkpointing=False,
|
||||||
|
use_cache=True,
|
||||||
|
attn_implementation=self.attention_implementation,
|
||||||
|
)
|
||||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||||
self.gpt.wte = self.mel_embedding
|
self.gpt.wte = self.mel_embedding
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user