feverish cleanup

This commit is contained in:
mrq 2024-06-03 21:28:49 -05:00
parent 7feeb944a0
commit 934672252b
11 changed files with 343 additions and 278 deletions

View File

@ -1,3 +1,4 @@
experimental: False # should probably expand this into a dict of experimental flags
sample_rate: 24_000 # 44_000 for dac
audio_backend: "vocos" # or dac
@ -11,9 +12,10 @@ models:
tones: 1
arch_type: llama
training: True
version: 4
version: 5
attention: flash_attention_2
dropout: 0.1
experimental: False
loss_factors:
text: 0.1
@ -63,11 +65,10 @@ trainer:
keep_last_checkpoints: 4
aggressive_optimizations: False
load_disabled_engines: False
gradient_checkpointing: True
#load_state_dict: True
strict_loading: False
#load_state_dict: True
#load_tag: "9500"
#load_states: False
#restart_step_count: True
@ -85,8 +86,6 @@ trainer:
amp: False
activation_checkpointing: True
load_webui: False
inference:
@ -109,8 +108,6 @@ optimizations:
bitnet: False
fp8: False
experimental: True # practically required now it seems
dataset:
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
speaker_group_getter: "lambda p: f'{p.parts[-3]}'"

View File

@ -191,6 +191,7 @@ class Dataset:
def max_duration(self):
return self.duration_range[1]
# I really need to clean this up
@dataclass()
class Model:
_max_levels: int = 0
@ -215,6 +216,7 @@ class Model:
dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
kv_heads: int = 0
experimental: bool = False
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@ -306,6 +308,10 @@ class Model:
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@property
def gradient_checkpointing(self):
return cfg.trainer.gradient_checkpointing
@dataclass()
class Hyperparameters:
batch_size: int = 8
@ -519,7 +525,8 @@ class Trainer:
load_module_only: bool = False
restart_step_count: bool = False
activation_checkpointing: bool = True
activation_checkpointing: bool | None = None # deprecated
gradient_checkpointing: bool = True
aggressive_optimizations: bool = False
check_for_oom: bool = True
@ -728,6 +735,9 @@ class Config(_Config):
if self.inference.audio_backend != "" and self.audio_backend == "":
self.audio_backend = self.inference.audio_backend
if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
# Preserves the old behavior
class NaiveTokenizer:
def get_vocab( self ):

View File

@ -24,11 +24,144 @@ from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold_inputs(
text_list = [],
prom_list = [],
resp_list = [],
ignore_index = None,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor]):
l = list(map(len, x_list))
x = pad_sequence(x_list).t()
m = _create_mask(l, x_list[0].device)
m = m.to(x)
return x, m
device = text_list[0].device
batch_size = len(text_list)
input_ids = [ [] for _ in range(batch_size) ]
offset = 0
sep = torch.Tensor([ sep ])
stop = torch.Tensor([ stop ])
for i, text in enumerate(text_list):
seq = text.to("cpu", dtype=torch.int64)
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens
for i, prom in enumerate(prom_list):
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list):
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
return list_to_tensor(input_ids)
# unfold from one unified token ID space to separate token spaces
def unfold_outputs(
output_ids,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
device = output_ids.device
batch_size = output_ids.shape[0]
text_list = [ [] for _ in range(batch_size) ]
prom_list = [ [] for _ in range(batch_size) ]
resp_list = [ [] for _ in range(batch_size) ]
for i, batch in enumerate( output_ids ):
for idx, token in enumerate( batch ):
id = token.item()
if id == sep or id == stop:
continue
if 0 <= id and id < text_tokens:
text_list[i].append( id )
elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels):
prom_list[i].append( (id - text_tokens) % audio_tokens )
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens )
prom_len = len(prom_list[i])
if prom_len % audio_rvq_levels == 0 and False:
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( prom_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
resp_len = len(resp_list[i])
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( resp_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64)
return dict(
text_list=text_list,
prom_list=prom_list,
resp_list=resp_list
)
# to-do: clean up this symmap mess
def get_phone_symmap():
return cfg.tokenizer.get_vocab()

View File

@ -33,7 +33,7 @@ def load_engines(training=True):
optimizer = None
lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model._cfg.training
inferencing = cfg.mode == "inferencing" or not model.hyper_config.training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
@ -43,7 +43,7 @@ def load_engines(training=True):
engine_class = _Engine if backend == "local" or inferencing else Engine
if inferencing:
model._cfg.training = False
model.hyper_config.training = False
if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model )
@ -83,7 +83,7 @@ def load_engines(training=True):
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
**params,
)
@ -96,7 +96,7 @@ def load_engines(training=True):
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
lr = params['lr'],
warmup_steps = cfg.hyperparameters.warmup_steps
)
@ -144,7 +144,7 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
_cfg = model._cfg
hyper_config = model.hyper_config
# wrap if DDP is requested
if ddp:
@ -161,7 +161,7 @@ def load_engines(training=True):
optimizer=optimizer,
lr_scheduler=lr_scheduler,
_cfg=_cfg,
hyper_config=hyper_config,
stats=stats
)

View File

@ -52,9 +52,9 @@ if not distributed_initialized() and cfg.trainer.backend == "local": # and world
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
if '_cfg' in kwargs:
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
@ -72,11 +72,11 @@ class Engine():
def freeze(self, freeze_all=True):
# set to freeze
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
@ -87,9 +87,9 @@ class Engine():
@property
def _training(self):
if not hasattr(self, "_cfg"):
if not hasattr(self, "hyper_config"):
return True
return self._cfg.training
return self.hyper_config.training
@property
def global_step(self):

View File

@ -32,10 +32,10 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
self._cfg = None
if '_cfg' in kwargs:
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
self.hyper_config = None
if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
@ -63,11 +63,11 @@ class Engine(DeepSpeedEngine):
self.max_nan_losses = 8
def freeze(self, freeze_all=True):
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
@ -78,7 +78,7 @@ class Engine(DeepSpeedEngine):
@property
def _training(self):
return self._cfg.training
return self.hyper_config.training
@property
def global_step(self):

View File

@ -1,23 +1,34 @@
from .ar_nar import AR_NAR
from .experimental import Model as Experimental
def get_model(cfg, training=True):
name = cfg.name
model = AR_NAR(
n_tokens=cfg.tokens,
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
n_experts=cfg.experts,
if not cfg.experimental:
model = AR_NAR(
n_tokens=cfg.tokens,
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
n_experts=cfg.experts,
p_dropout=cfg.dropout,
p_dropout=cfg.dropout,
l_padding = cfg.input_alignment,
l_padding = cfg.input_alignment,
training = training,
config = cfg,
)
model._cfg = cfg
training = training,
config = cfg,
)
model._cfg = cfg
else:
model = Experimental(
d_model=cfg.dim,
n_layers=cfg.layers,
n_heads=cfg.heads,
p_dropout=cfg.dropout,
config = cfg,
)
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

View File

@ -64,7 +64,7 @@ try:
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
skip = x
for attn, ffn in zip(self.layers, self.ffn_layers):
if x.requires_grad and self.activation_checkpointing:
if x.requires_grad and self.gradient_checkpointing:
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
else:
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
@ -83,13 +83,13 @@ try:
num_tokens: int,
heads=8,
ff_mult=4,
activation_checkpointing = True
gradient_checkpointing = True
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
self.transformer.activation_checkpointing = activation_checkpointing
self.transformer.gradient_checkpointing = gradient_checkpointing
def forward(self, x):
x = self.transformer(x)
@ -431,9 +431,9 @@ class Base(nn.Module):
return -100
def loss_factor(self, k):
if self.config is None:
if self.hyper_config is None:
return 1.0
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0
def __init__(
self,
@ -452,8 +452,8 @@ class Base(nn.Module):
):
super().__init__()
self.training = training
self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
self.hyper_config = config
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
self.n_tokens = n_tokens
self.d_model = d_model
@ -482,13 +482,13 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None,
sums=self.config.audio_embedding_sums if self.config is not None else True,
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
)
# [1025] + [1024] * 8
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None,
sums=self.config.audio_embedding_sums if self.config is not None else True
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
)
@ -502,20 +502,20 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
hf_attention = self.config.attention if self.config is not None else None
hf_attention = self.hyper_config.attention if self.hyper_config is not None else None
if self.config.attention == "auto":
if self.hyper_config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
self.config.attention = "flash"
self.hyper_config.attention = "flash"
elif "xformers" in AVAILABLE_ATTENTIONS:
self.config.attention = "xformers"
self.hyper_config.attention = "xformers"
else:
self.config.attention = "mem_efficient"
self.hyper_config.attention = "mem_efficient"
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]:
hf_attention = None
if self.config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.hyper_config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.arch_type == "transformer":
@ -538,12 +538,12 @@ class Base(nn.Module):
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.gradient_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
@ -554,7 +554,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
@ -563,10 +563,10 @@ class Base(nn.Module):
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.gradient_checkpointing,
))
if self.activation_checkpointing and not self.model.gradient_checkpointing:
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
@ -589,7 +589,7 @@ class Base(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.gradient_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
@ -609,10 +609,10 @@ class Base(nn.Module):
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.gradient_checkpointing,
))
if self.activation_checkpointing and not self.model.gradient_checkpointing:
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
@ -628,7 +628,7 @@ class Base(nn.Module):
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing,
checkpoint_activations=self.gradient_checkpointing,
activation_fn="gelu",
use_layernorm=self.version < 3,
use_biases=self.version < 3,
@ -660,7 +660,7 @@ class Base(nn.Module):
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing,
checkpoint_activations=self.gradient_checkpointing,
activation_fn="gelu",
use_glu=False, # self.version >= 3,
@ -673,7 +673,7 @@ class Base(nn.Module):
self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs))
if self.activation_checkpointing and not self.model.gradient_checkpointing:
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
@ -684,13 +684,13 @@ class Base(nn.Module):
depth=n_layers,
heads=n_heads,
ff_mult=4,
activation_checkpointing=self.activation_checkpointing,
gradient_checkpointing=self.gradient_checkpointing,
)
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.hyper_config.attention )
self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -877,7 +877,7 @@ class Base(nn.Module):
quant_levels: Tensor | None = None,
):
# old, "naive" way, no loss factoring
if not self.config.loss_factors:
if not self.hyper_config.loss_factors:
target_list = []
for batch in inputs:
target = []

View File

@ -1,9 +1,20 @@
"""
This is an experiment to:
* entertain a thought to try and abide by HF's transformers API (to benefit from caching better)
* conform to a single embedding (instead of a bunch of them) by folding/unfolding inputs
* stop trying to make a mixed AR+NAR model work since it seems lobotomized if I keep trying to enforce both recurrent and parallel inferencing (despite a penalty cost)
+ I will not cave and go with codebook patterns, not yet.
"""
from ..config import cfg
from ..data import fold_inputs, unfold_outputs
import torch
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
import random
import math
@ -21,144 +32,40 @@ except Exception as e:
pass
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward
AVAILABLE_ARCHES.append("mamba")
except Exception as e:
print("Error importing `mamba` arch:", e)
pass
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor]):
l = list(map(len, x_list))
x = pad_sequence(x_list).t()
m = _create_mask(l, x_list[0].device)
m = m.to(x)
return x, m
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold(
text_list = [],
proms_list = [],
resp_list = [],
ignore_index = None,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
device = text_list[0].device
batch_size = len(text_list)
input_ids = [ [] for _ in range(batch_size) ]
offset = 0
sep = torch.Tensor([ sep ])
stop = torch.Tensor([ stop ])
for i, text in enumerate(text_list):
seq = text.to("cpu", dtype=torch.int64)
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens
for i, prom in enumerate(proms_list):
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list):
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
return list_to_tensor(input_ids)
# unfold from one unified token ID space to separate token spaces
def unfold(
input_ids,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
device = input_ids.device
batch_size = input_ids.shape[0]
text_list = [ [] for _ in range(batch_size) ]
prom_list = [ [] for _ in range(batch_size) ]
resp_list = [ [] for _ in range(batch_size) ]
for i, batch in enumerate( input_ids ):
for idx, token in enumerate( batch ):
id = token.item()
if id == sep or id == stop:
continue
if 0 <= id and id < text_tokens:
text_list[i].append( id )
elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels):
prom_list[i].append( (id - text_tokens) % audio_tokens )
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens )
prom_len = len(prom_list[i])
if prom_len % audio_rvq_levels == 0 and False:
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( prom_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
resp_len = len(resp_list[i])
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( resp_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64)
return dict(
text_list=text_list,
prom_list=prom_list,
resp_list=resp_list
)
SELECTED_ARCH = cfg.model.arch_type
if SELECTED_ARCH not in AVAILABLE_ARCHES:
@ -179,9 +86,12 @@ class Model(LlmArchClass):
n_heads=16,
p_dropout=0.1,
attention_backend=None,
activation_checkpointing=True,
config = None,
):
self.hyper_config = config
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
if SELECTED_ARCH == "llama":
super().__init__(config=LlamaConfig(
@ -197,10 +107,10 @@ class Model(LlmArchClass):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=attention_backend,
attn_implementation=hf_attention,
))
if activation_checkpointing:
if gradient_checkpointing:
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
@ -209,9 +119,11 @@ class Model(LlmArchClass):
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
d_model=d_model,
n_layer=n_layers*2,
#ssm_cfg={"layer": "Mamba2"},
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
))
self.backbone.gradient_checkpointing = gradient_checkpointing
def forward(
self,
@ -293,9 +205,9 @@ def example_usage():
proms_list = proms_list[:1]
resps_list = resps_list[:1]
input_ids, attention_mask = fold(text_list, proms_list, resps_list)
target_ids, target_attention_mask = fold(text_list, proms_list, resps_list, ignore_index=-100)
prefix_input_ids, prefix_attention_mask = fold(text_list, proms_list)
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list)
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100)
prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list)
kwargs = {}
model = Model(**kwargs).to(device)
@ -373,7 +285,7 @@ def example_usage():
else:
output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
unfolded = unfold( output )
unfolded = unfold_outputs( output )
for i, batch in enumerate(unfolded["resp_list"]):
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)

View File

@ -15,7 +15,29 @@ from torch import Tensor, einsum, nn
from torch.utils.checkpoint import checkpoint
from ..utils import wrapper as ml
from .adaln import AdaLN
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
super().__init__()
self.eps = eps
self.emb = nn.Embedding(n_levels, d_model * 2)
self.k = k
self.c = c
nn.init.zeros_(self.emb.weight)
def forward(self, x, l):
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
# performed worse than vanilla LayerNorm.
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
# Did they use AdaNorm inside AdaLN? (as follows)
h = self.c * (1 - (self.k * h).detach()) * h
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
y = logγ.exp() * h + β
return y
class SinusoidalEmbedding(nn.Module):
def __init__(self, d_model):

View File

@ -5,6 +5,7 @@ from .data import create_train_val_dataloader
from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .data import fold_inputs, unfold_outputs
import auraloss
import json
@ -25,12 +26,29 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
resps_list=batch["resps"],
lang_list=batch["lang"],
)
if engine.hyper_config.experimental:
input_ids, attention_mask = fold_inputs(
text_list=batch["text"],
prom_list=batch["proms"],
resp_list=batch["resps"],
)
target_ids, target_attention_mask = fold_inputs(
text_list=batch["text"],
prom_list=batch["proms"],
resp_list=batch["resps"],
ignore_index=-100
)
engine(
input_ids=input_ids,
labels=target_ids
)
else:
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
resps_list=batch["resps"],
lang_list=batch["lang"],
)
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
@ -48,22 +66,6 @@ def train_feeder(engine, batch):
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
AR = None
NAR = None
AR_NAR = None
names = []
for name, engine in engines.items():
if name[:6] == "ar+nar":
AR_NAR = engine
elif name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
else:
continue
names.append(name)
stats = defaultdict(list)
stats['loss'] = []
@ -101,44 +103,22 @@ def run_eval(engines, eval_name, dl):
batch: dict = to_device(next(iter(dl)), cfg.device)
processed += len(batch["text"])
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
name = "+".join(names)
for name in engines:
engine = engines[name]
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
if engine.hyper_config.experimental:
input_ids, attention_mask = fold_inputs(
text_list=batch["text"],
proms_list=batch["proms"],
)
output = engine.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
resps_list = unfold_outputs( output )["resp_list"]
else:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
process( name, batch, resps_list )
else:
for name in engines:
model = engines[name]
if name.startswith("ar+nar"):
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
elif name.startswith("ar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
lang_list=batch["lang"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.ar_temperature,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
elif name.startswith("nar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
lang_list=batch["lang"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.nar_temperature,
)
else:
raise NotImplementedError(name)
process( name, batch, resps_list )
stats = {k: sum(v) / len(v) for k, v in stats.items()}