feverish cleanup

This commit is contained in:
mrq 2024-06-03 21:28:49 -05:00
parent 7feeb944a0
commit 934672252b
11 changed files with 343 additions and 278 deletions

View File

@ -1,3 +1,4 @@
experimental: False # should probably expand this into a dict of experimental flags
sample_rate: 24_000 # 44_000 for dac sample_rate: 24_000 # 44_000 for dac
audio_backend: "vocos" # or dac audio_backend: "vocos" # or dac
@ -11,9 +12,10 @@ models:
tones: 1 tones: 1
arch_type: llama arch_type: llama
training: True training: True
version: 4 version: 5
attention: flash_attention_2 attention: flash_attention_2
dropout: 0.1 dropout: 0.1
experimental: False
loss_factors: loss_factors:
text: 0.1 text: 0.1
@ -63,11 +65,10 @@ trainer:
keep_last_checkpoints: 4 keep_last_checkpoints: 4
aggressive_optimizations: False gradient_checkpointing: True
load_disabled_engines: False
#load_state_dict: True
strict_loading: False strict_loading: False
#load_state_dict: True
#load_tag: "9500" #load_tag: "9500"
#load_states: False #load_states: False
#restart_step_count: True #restart_step_count: True
@ -85,8 +86,6 @@ trainer:
amp: False amp: False
activation_checkpointing: True
load_webui: False load_webui: False
inference: inference:
@ -109,8 +108,6 @@ optimizations:
bitnet: False bitnet: False
fp8: False fp8: False
experimental: True # practically required now it seems
dataset: dataset:
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
speaker_group_getter: "lambda p: f'{p.parts[-3]}'" speaker_group_getter: "lambda p: f'{p.parts[-3]}'"

View File

@ -191,6 +191,7 @@ class Dataset:
def max_duration(self): def max_duration(self):
return self.duration_range[1] return self.duration_range[1]
# I really need to clean this up
@dataclass() @dataclass()
class Model: class Model:
_max_levels: int = 0 _max_levels: int = 0
@ -215,6 +216,7 @@ class Model:
dropout: float = 0.1 # adjustable dropout value dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
kv_heads: int = 0 kv_heads: int = 0
experimental: bool = False
def get(self, name=None): def get(self, name=None):
return [ self ] if not name or self.name == name else [] return [ self ] if not name or self.name == name else []
@ -306,6 +308,10 @@ class Model:
def activation_checkpointing(self): def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing return cfg.trainer.activation_checkpointing
@property
def gradient_checkpointing(self):
return cfg.trainer.gradient_checkpointing
@dataclass() @dataclass()
class Hyperparameters: class Hyperparameters:
batch_size: int = 8 batch_size: int = 8
@ -519,7 +525,8 @@ class Trainer:
load_module_only: bool = False load_module_only: bool = False
restart_step_count: bool = False restart_step_count: bool = False
activation_checkpointing: bool = True activation_checkpointing: bool | None = None # deprecated
gradient_checkpointing: bool = True
aggressive_optimizations: bool = False aggressive_optimizations: bool = False
check_for_oom: bool = True check_for_oom: bool = True
@ -728,6 +735,9 @@ class Config(_Config):
if self.inference.audio_backend != "" and self.audio_backend == "": if self.inference.audio_backend != "" and self.audio_backend == "":
self.audio_backend = self.inference.audio_backend self.audio_backend = self.inference.audio_backend
if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
# Preserves the old behavior # Preserves the old behavior
class NaiveTokenizer: class NaiveTokenizer:
def get_vocab( self ): def get_vocab( self ):

View File

@ -24,11 +24,144 @@ from typing import Any
from torch import Tensor from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system") # torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold_inputs(
text_list = [],
prom_list = [],
resp_list = [],
ignore_index = None,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor]):
l = list(map(len, x_list))
x = pad_sequence(x_list).t()
m = _create_mask(l, x_list[0].device)
m = m.to(x)
return x, m
device = text_list[0].device
batch_size = len(text_list)
input_ids = [ [] for _ in range(batch_size) ]
offset = 0
sep = torch.Tensor([ sep ])
stop = torch.Tensor([ stop ])
for i, text in enumerate(text_list):
seq = text.to("cpu", dtype=torch.int64)
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens
for i, prom in enumerate(prom_list):
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list):
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
return list_to_tensor(input_ids)
# unfold from one unified token ID space to separate token spaces
def unfold_outputs(
output_ids,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
device = output_ids.device
batch_size = output_ids.shape[0]
text_list = [ [] for _ in range(batch_size) ]
prom_list = [ [] for _ in range(batch_size) ]
resp_list = [ [] for _ in range(batch_size) ]
for i, batch in enumerate( output_ids ):
for idx, token in enumerate( batch ):
id = token.item()
if id == sep or id == stop:
continue
if 0 <= id and id < text_tokens:
text_list[i].append( id )
elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels):
prom_list[i].append( (id - text_tokens) % audio_tokens )
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens )
prom_len = len(prom_list[i])
if prom_len % audio_rvq_levels == 0 and False:
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( prom_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
resp_len = len(resp_list[i])
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( resp_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64)
return dict(
text_list=text_list,
prom_list=prom_list,
resp_list=resp_list
)
# to-do: clean up this symmap mess # to-do: clean up this symmap mess
def get_phone_symmap(): def get_phone_symmap():
return cfg.tokenizer.get_vocab() return cfg.tokenizer.get_vocab()

View File

@ -33,7 +33,7 @@ def load_engines(training=True):
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model._cfg.training inferencing = cfg.mode == "inferencing" or not model.hyper_config.training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp amp = cfg.inference.amp if inferencing else cfg.trainer.amp
@ -43,7 +43,7 @@ def load_engines(training=True):
engine_class = _Engine if backend == "local" or inferencing else Engine engine_class = _Engine if backend == "local" or inferencing else Engine
if inferencing: if inferencing:
model._cfg.training = False model.hyper_config.training = False
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
@ -83,7 +83,7 @@ def load_engines(training=True):
params.update(cfg.hyperparameters.optimizer_params) params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class( optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], [ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
**params, **params,
) )
@ -96,7 +96,7 @@ def load_engines(training=True):
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}') raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class( optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], [ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
lr = params['lr'], lr = params['lr'],
warmup_steps = cfg.hyperparameters.warmup_steps warmup_steps = cfg.hyperparameters.warmup_steps
) )
@ -144,7 +144,7 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
_cfg = model._cfg hyper_config = model.hyper_config
# wrap if DDP is requested # wrap if DDP is requested
if ddp: if ddp:
@ -161,7 +161,7 @@ def load_engines(training=True):
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
_cfg=_cfg, hyper_config=hyper_config,
stats=stats stats=stats
) )

View File

@ -52,9 +52,9 @@ if not distributed_initialized() and cfg.trainer.backend == "local": # and world
# A very naive engine implementation using barebones PyTorch # A very naive engine implementation using barebones PyTorch
class Engine(): class Engine():
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if '_cfg' in kwargs: if 'hyper_config' in kwargs:
self._cfg = kwargs['_cfg'] self.hyper_config = kwargs['hyper_config']
kwargs.pop("_cfg") kwargs.pop("hyper_config")
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype) self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
@ -72,11 +72,11 @@ class Engine():
def freeze(self, freeze_all=True): def freeze(self, freeze_all=True):
# set to freeze # set to freeze
if self._cfg is None or not hasattr(self._cfg, "frozen_params"): if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None") raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters(): for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False) param.requires_grad_(False)
self._frozen_params.add(param) self._frozen_params.add(param)
@ -87,9 +87,9 @@ class Engine():
@property @property
def _training(self): def _training(self):
if not hasattr(self, "_cfg"): if not hasattr(self, "hyper_config"):
return True return True
return self._cfg.training return self.hyper_config.training
@property @property
def global_step(self): def global_step(self):

View File

@ -32,10 +32,10 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
class Engine(DeepSpeedEngine): class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._cfg = None self.hyper_config = None
if '_cfg' in kwargs: if 'hyper_config' in kwargs:
self._cfg = kwargs['_cfg'] self.hyper_config = kwargs['hyper_config']
kwargs.pop("_cfg") kwargs.pop("hyper_config")
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
@ -63,11 +63,11 @@ class Engine(DeepSpeedEngine):
self.max_nan_losses = 8 self.max_nan_losses = 8
def freeze(self, freeze_all=True): def freeze(self, freeze_all=True):
if self._cfg is None or not hasattr(self._cfg, "frozen_params"): if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None") raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters(): for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False) param.requires_grad_(False)
self._frozen_params.add(param) self._frozen_params.add(param)
@ -78,7 +78,7 @@ class Engine(DeepSpeedEngine):
@property @property
def _training(self): def _training(self):
return self._cfg.training return self.hyper_config.training
@property @property
def global_step(self): def global_step(self):

View File

@ -1,8 +1,10 @@
from .ar_nar import AR_NAR from .ar_nar import AR_NAR
from .experimental import Model as Experimental
def get_model(cfg, training=True): def get_model(cfg, training=True):
name = cfg.name name = cfg.name
if not cfg.experimental:
model = AR_NAR( model = AR_NAR(
n_tokens=cfg.tokens, n_tokens=cfg.tokens,
d_model=cfg.dim, d_model=cfg.dim,
@ -18,6 +20,15 @@ def get_model(cfg, training=True):
config = cfg, config = cfg,
) )
model._cfg = cfg model._cfg = cfg
else:
model = Experimental(
d_model=cfg.dim,
n_layers=cfg.layers,
n_heads=cfg.heads,
p_dropout=cfg.dropout,
config = cfg,
)
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

View File

@ -64,7 +64,7 @@ try:
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor: def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
skip = x skip = x
for attn, ffn in zip(self.layers, self.ffn_layers): for attn, ffn in zip(self.layers, self.ffn_layers):
if x.requires_grad and self.activation_checkpointing: if x.requires_grad and self.gradient_checkpointing:
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False) x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
else: else:
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs) x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
@ -83,13 +83,13 @@ try:
num_tokens: int, num_tokens: int,
heads=8, heads=8,
ff_mult=4, ff_mult=4,
activation_checkpointing = True gradient_checkpointing = True
): ):
super().__init__() super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult ) self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim) self.norm = BitNetRMSNorm(dim)
self.transformer.activation_checkpointing = activation_checkpointing self.transformer.gradient_checkpointing = gradient_checkpointing
def forward(self, x): def forward(self, x):
x = self.transformer(x) x = self.transformer(x)
@ -431,9 +431,9 @@ class Base(nn.Module):
return -100 return -100
def loss_factor(self, k): def loss_factor(self, k):
if self.config is None: if self.hyper_config is None:
return 1.0 return 1.0
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0 return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0
def __init__( def __init__(
self, self,
@ -452,8 +452,8 @@ class Base(nn.Module):
): ):
super().__init__() super().__init__()
self.training = training self.training = training
self.config = config self.hyper_config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
self.n_tokens = n_tokens self.n_tokens = n_tokens
self.d_model = d_model self.d_model = d_model
@ -482,13 +482,13 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None, levels=self.n_prom_levels if self.version > 3 else None,
sums=self.config.audio_embedding_sums if self.config is not None else True, sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
) )
# [1025] + [1024] * 8 # [1025] + [1024] * 8
self.resps_emb = AudioEmbedding( self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None, levels=self.n_resp_levels if self.version > 3 else None,
sums=self.config.audio_embedding_sums if self.config is not None else True sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
) )
@ -502,20 +502,20 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way # ick, there has to be a better way
hf_attention = self.config.attention if self.config is not None else None hf_attention = self.hyper_config.attention if self.hyper_config is not None else None
if self.config.attention == "auto": if self.hyper_config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS: if "flash" in AVAILABLE_ATTENTIONS:
self.config.attention = "flash" self.hyper_config.attention = "flash"
elif "xformers" in AVAILABLE_ATTENTIONS: elif "xformers" in AVAILABLE_ATTENTIONS:
self.config.attention = "xformers" self.hyper_config.attention = "xformers"
else: else:
self.config.attention = "mem_efficient" self.hyper_config.attention = "mem_efficient"
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]: if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]:
hf_attention = None hf_attention = None
if self.config.attention not in AVAILABLE_ATTENTIONS: if self.hyper_config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.arch_type == "transformer": if self.arch_type == "transformer":
@ -538,12 +538,12 @@ class Base(nn.Module):
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0, attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
hidden_act="gelu", hidden_act="gelu",
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
attn_implementation=hf_attention, attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing, #gradient_checkpointing=self.gradient_checkpointing,
)) ))
else: else:
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
@ -554,7 +554,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0, attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
sliding_window=75 * 12, # 12 second context window sliding_window=75 * 12, # 12 second context window
output_router_logits=training, output_router_logits=training,
hidden_act="gelu", hidden_act="gelu",
@ -563,10 +563,10 @@ class Base(nn.Module):
num_local_experts=n_experts, num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts), num_experts_per_tok=min(2, n_experts),
attn_implementation=hf_attention, attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing, #gradient_checkpointing=self.gradient_checkpointing,
)) ))
if self.activation_checkpointing and not self.model.gradient_checkpointing: if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
@ -589,7 +589,7 @@ class Base(nn.Module):
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
attn_implementation=hf_attention, attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing, #gradient_checkpointing=self.gradient_checkpointing,
)) ))
else: else:
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
@ -609,10 +609,10 @@ class Base(nn.Module):
num_local_experts=n_experts, num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts), num_experts_per_tok=min(2, n_experts),
attn_implementation=hf_attention, attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing, #gradient_checkpointing=self.gradient_checkpointing,
)) ))
if self.activation_checkpointing and not self.model.gradient_checkpointing: if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
@ -628,7 +628,7 @@ class Base(nn.Module):
decoder_ffn_embed_dim=d_model * 4, decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers, decoder_layers=n_layers,
dropout=p_dropout if training else 0.0, dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing, checkpoint_activations=self.gradient_checkpointing,
activation_fn="gelu", activation_fn="gelu",
use_layernorm=self.version < 3, use_layernorm=self.version < 3,
use_biases=self.version < 3, use_biases=self.version < 3,
@ -660,7 +660,7 @@ class Base(nn.Module):
decoder_ffn_embed_dim=d_model * 4, decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers, decoder_layers=n_layers,
dropout=p_dropout if training else 0.0, dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing, checkpoint_activations=self.gradient_checkpointing,
activation_fn="gelu", activation_fn="gelu",
use_glu=False, # self.version >= 3, use_glu=False, # self.version >= 3,
@ -673,7 +673,7 @@ class Base(nn.Module):
self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs)) self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs))
if self.activation_checkpointing and not self.model.gradient_checkpointing: if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
@ -684,13 +684,13 @@ class Base(nn.Module):
depth=n_layers, depth=n_layers,
heads=n_heads, heads=n_heads,
ff_mult=4, ff_mult=4,
activation_checkpointing=self.activation_checkpointing, gradient_checkpointing=self.gradient_checkpointing,
) )
else: else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}') raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention ) self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.hyper_config.attention )
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -877,7 +877,7 @@ class Base(nn.Module):
quant_levels: Tensor | None = None, quant_levels: Tensor | None = None,
): ):
# old, "naive" way, no loss factoring # old, "naive" way, no loss factoring
if not self.config.loss_factors: if not self.hyper_config.loss_factors:
target_list = [] target_list = []
for batch in inputs: for batch in inputs:
target = [] target = []

View File

@ -1,9 +1,20 @@
"""
This is an experiment to:
* entertain a thought to try and abide by HF's transformers API (to benefit from caching better)
* conform to a single embedding (instead of a bunch of them) by folding/unfolding inputs
* stop trying to make a mixed AR+NAR model work since it seems lobotomized if I keep trying to enforce both recurrent and parallel inferencing (despite a penalty cost)
+ I will not cave and go with codebook patterns, not yet.
"""
from ..config import cfg from ..config import cfg
from ..data import fold_inputs, unfold_outputs
import torch import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torch import Tensor from torch import Tensor
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
import random import random
import math import math
@ -21,144 +32,40 @@ except Exception as e:
pass pass
try: try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward
AVAILABLE_ARCHES.append("mamba") AVAILABLE_ARCHES.append("mamba")
except Exception as e: except Exception as e:
print("Error importing `mamba` arch:", e) print("Error importing `mamba` arch:", e)
pass pass
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor]):
l = list(map(len, x_list))
x = pad_sequence(x_list).t()
m = _create_mask(l, x_list[0].device)
m = m.to(x)
return x, m
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold(
text_list = [],
proms_list = [],
resp_list = [],
ignore_index = None,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
device = text_list[0].device
batch_size = len(text_list)
input_ids = [ [] for _ in range(batch_size) ]
offset = 0
sep = torch.Tensor([ sep ])
stop = torch.Tensor([ stop ])
for i, text in enumerate(text_list):
seq = text.to("cpu", dtype=torch.int64)
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens
for i, prom in enumerate(proms_list):
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list):
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
return list_to_tensor(input_ids)
# unfold from one unified token ID space to separate token spaces
def unfold(
input_ids,
sep = 3,
stop = 3,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
):
device = input_ids.device
batch_size = input_ids.shape[0]
text_list = [ [] for _ in range(batch_size) ]
prom_list = [ [] for _ in range(batch_size) ]
resp_list = [ [] for _ in range(batch_size) ]
for i, batch in enumerate( input_ids ):
for idx, token in enumerate( batch ):
id = token.item()
if id == sep or id == stop:
continue
if 0 <= id and id < text_tokens:
text_list[i].append( id )
elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels):
prom_list[i].append( (id - text_tokens) % audio_tokens )
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens )
prom_len = len(prom_list[i])
if prom_len % audio_rvq_levels == 0 and False:
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( prom_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
resp_len = len(resp_list[i])
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( resp_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64)
return dict(
text_list=text_list,
prom_list=prom_list,
resp_list=resp_list
)
SELECTED_ARCH = cfg.model.arch_type SELECTED_ARCH = cfg.model.arch_type
if SELECTED_ARCH not in AVAILABLE_ARCHES: if SELECTED_ARCH not in AVAILABLE_ARCHES:
@ -179,9 +86,12 @@ class Model(LlmArchClass):
n_heads=16, n_heads=16,
p_dropout=0.1, p_dropout=0.1,
attention_backend=None, config = None,
activation_checkpointing=True,
): ):
self.hyper_config = config
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
if SELECTED_ARCH == "llama": if SELECTED_ARCH == "llama":
super().__init__(config=LlamaConfig( super().__init__(config=LlamaConfig(
@ -197,10 +107,10 @@ class Model(LlmArchClass):
hidden_act="gelu", hidden_act="gelu",
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
attn_implementation=attention_backend, attn_implementation=hf_attention,
)) ))
if activation_checkpointing: if gradient_checkpointing:
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
@ -209,9 +119,11 @@ class Model(LlmArchClass):
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
d_model=d_model, d_model=d_model,
n_layer=n_layers*2, n_layer=n_layers*2,
#ssm_cfg={"layer": "Mamba2"}, #ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
)) ))
self.backbone.gradient_checkpointing = gradient_checkpointing
def forward( def forward(
self, self,
@ -293,9 +205,9 @@ def example_usage():
proms_list = proms_list[:1] proms_list = proms_list[:1]
resps_list = resps_list[:1] resps_list = resps_list[:1]
input_ids, attention_mask = fold(text_list, proms_list, resps_list) input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list)
target_ids, target_attention_mask = fold(text_list, proms_list, resps_list, ignore_index=-100) target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100)
prefix_input_ids, prefix_attention_mask = fold(text_list, proms_list) prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list)
kwargs = {} kwargs = {}
model = Model(**kwargs).to(device) model = Model(**kwargs).to(device)
@ -373,7 +285,7 @@ def example_usage():
else: else:
output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False) output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
unfolded = unfold( output ) unfolded = unfold_outputs( output )
for i, batch in enumerate(unfolded["resp_list"]): for i, batch in enumerate(unfolded["resp_list"]):
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device) _ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)

View File

@ -15,7 +15,29 @@ from torch import Tensor, einsum, nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from ..utils import wrapper as ml from ..utils import wrapper as ml
from .adaln import AdaLN
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
super().__init__()
self.eps = eps
self.emb = nn.Embedding(n_levels, d_model * 2)
self.k = k
self.c = c
nn.init.zeros_(self.emb.weight)
def forward(self, x, l):
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
# performed worse than vanilla LayerNorm.
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
# Did they use AdaNorm inside AdaLN? (as follows)
h = self.c * (1 - (self.k * h).detach()) * h
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
y = logγ.exp() * h + β
return y
class SinusoidalEmbedding(nn.Module): class SinusoidalEmbedding(nn.Module):
def __init__(self, d_model): def __init__(self, d_model):

View File

@ -5,6 +5,7 @@ from .data import create_train_val_dataloader
from .emb import qnt from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .data import fold_inputs, unfold_outputs
import auraloss import auraloss
import json import json
@ -25,6 +26,23 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch): def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
if engine.hyper_config.experimental:
input_ids, attention_mask = fold_inputs(
text_list=batch["text"],
prom_list=batch["proms"],
resp_list=batch["resps"],
)
target_ids, target_attention_mask = fold_inputs(
text_list=batch["text"],
prom_list=batch["proms"],
resp_list=batch["resps"],
ignore_index=-100
)
engine(
input_ids=input_ids,
labels=target_ids
)
else:
engine( engine(
text_list=batch["text"], text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
@ -48,22 +66,6 @@ def train_feeder(engine, batch):
@torch.inference_mode() @torch.inference_mode()
def run_eval(engines, eval_name, dl): def run_eval(engines, eval_name, dl):
AR = None
NAR = None
AR_NAR = None
names = []
for name, engine in engines.items():
if name[:6] == "ar+nar":
AR_NAR = engine
elif name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
else:
continue
names.append(name)
stats = defaultdict(list) stats = defaultdict(list)
stats['loss'] = [] stats['loss'] = []
@ -101,42 +103,20 @@ def run_eval(engines, eval_name, dl):
batch: dict = to_device(next(iter(dl)), cfg.device) batch: dict = to_device(next(iter(dl)), cfg.device)
processed += len(batch["text"]) processed += len(batch["text"])
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
name = "+".join(names)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
process( name, batch, resps_list )
else:
for name in engines: for name in engines:
model = engines[name] engine = engines[name]
if name.startswith("ar+nar"): if engine.hyper_config.experimental:
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) input_ids, attention_mask = fold_inputs(
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
elif name.startswith("ar"):
resps_list = model(
text_list=batch["text"], text_list=batch["text"],
proms_list=batch["proms"], proms_list=batch["proms"],
lang_list=batch["lang"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.ar_temperature,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
elif name.startswith("nar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
lang_list=batch["lang"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.nar_temperature,
) )
output = engine.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
resps_list = unfold_outputs( output )["resp_list"]
else: else:
raise NotImplementedError(name) resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
process( name, batch, resps_list ) process( name, batch, resps_list )