fix issue with sft and shared tensors...
This commit is contained in:
parent
23f3b56fda
commit
3a65cc4b22
|
@ -64,6 +64,10 @@ def main():
|
|||
|
||||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||
|
||||
parser.add_argument("--dry-multiplier", type=float, default=0)
|
||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
|
@ -99,6 +103,7 @@ def main():
|
|||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
)) )
|
||||
|
||||
# pull from provided samples
|
||||
|
|
|
@ -338,12 +338,6 @@ class Engines(dict[str, Engine]):
|
|||
lora = None
|
||||
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
|
||||
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||
|
||||
# coerce
|
||||
if not isinstance(config, dict):
|
||||
config = config.__dict__
|
||||
if not isinstance(config['experimental'], dict):
|
||||
config['experimental'] = config['experimental'].__dict__
|
||||
|
||||
# safety
|
||||
for k, v in module.items():
|
||||
|
@ -363,7 +357,7 @@ class Engines(dict[str, Engine]):
|
|||
"tokens_processed": engine.tokens_processed,
|
||||
},
|
||||
"userdata": userdata,
|
||||
"config": config
|
||||
"config": config.__dict__ | {"experimental": config.experimental.__dict__} # i hate implicit aliasing rules
|
||||
}
|
||||
|
||||
if lora is None:
|
||||
|
|
|
@ -98,8 +98,8 @@ def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dt
|
|||
tokens = 1025 if i == 0 else 1024
|
||||
|
||||
# trim per RVQ level (since level 0 has a stop token)
|
||||
state_dict['module'][f'classifiers.proj.{i}.weight'] = state_dict['module']['classifier.weight'][:tokens, :]
|
||||
state_dict['module'][f'classifiers.proj.{i}.bias'] = state_dict['module']['classifier.bias'][:tokens]
|
||||
state_dict['module'][f'classifiers.proj.{i}.weight'] = state_dict['module']['classifier.weight'][:tokens, :].clone()
|
||||
state_dict['module'][f'classifiers.proj.{i}.bias'] = state_dict['module']['classifier.bias'][:tokens].clone()
|
||||
|
||||
# delete old weights
|
||||
del state_dict['module']['classifier.weight']
|
||||
|
|
516
vall_e/models/ar.py
Normal file
516
vall_e/models/ar.py
Normal file
|
@ -0,0 +1,516 @@
|
|||
"""
|
||||
# an AR model that (should) handle:
|
||||
* handling all RVQ levels, but does it in an autoregressive manner
|
||||
|
||||
It's in a mess of a state, because I want this to be an interleaved model, but it just seems better to use the vall_e.models.experimental model.
|
||||
"""
|
||||
from .base import Base, list_to_tensor, Categorical
|
||||
from ..config import cfg
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import random
|
||||
import math
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
|
||||
from ..emb.qnt import trim, encode_as_embedding
|
||||
|
||||
from .lora import enable_lora
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
class AR(Base):
|
||||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | None = None,
|
||||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 0,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
sampling_dry_multiplier=0.0,
|
||||
sampling_dry_base=1.75,
|
||||
sampling_dry_allowed_length=2,
|
||||
|
||||
disable_tqdm=False,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
# generate task list if not provided
|
||||
if task_list is None:
|
||||
task_list = [ "tts" for _ in range(batch_size) ]
|
||||
|
||||
# is training or NAR
|
||||
if resps_list is not None:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
if training is None:
|
||||
training = n_levels == self.n_resp_levels
|
||||
|
||||
# is training
|
||||
if training:
|
||||
# specifies how to sample probabilities of which RVQ levels to train against
|
||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||
# rate to perform token dropout errors
|
||||
token_dropout_error = self.config.experimental.token_dropout_error
|
||||
# RVQ levels to apply token dropout on
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# implicitly set it to all levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||
if not p_rvq_levels:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
if p_rvq_levels == "equal":
|
||||
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
||||
else:
|
||||
# yuck
|
||||
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||
|
||||
# input RVQ levels
|
||||
if not self.interleave:
|
||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
else:
|
||||
quant_levels = [ 0 for i in range(batch_size) ]
|
||||
|
||||
# tensor to cat for RVQ level 0
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
||||
# proms could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms, torch.Tensor ):
|
||||
if quant_level >= proms.shape[-1]:
|
||||
quant_levels[i] = proms.shape[-1] - 1
|
||||
|
||||
elif isinstance( proms, list ):
|
||||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
# apply token dropout error compensation
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
steps = resps.shape[0]
|
||||
for l in range( quant_level ):
|
||||
for t in range( steps ):
|
||||
token = resps[t, l].item()
|
||||
|
||||
if random.random() < token_dropout_error:
|
||||
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
stop_sequence = torch.tensor([[self.stop_token] * resps.shape[-1]], device=device, dtype=torch.int16)
|
||||
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
||||
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
task_list=task_list,
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# is AR
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = self.stop_token
|
||||
|
||||
|
||||
state = None
|
||||
mirostat = [
|
||||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
task_list=task_list,
|
||||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||
)
|
||||
|
||||
if state is not None:
|
||||
logits, state = super().forward(
|
||||
inputs=inputs,
|
||||
state=state,
|
||||
)
|
||||
else:
|
||||
logits = super().forward(
|
||||
inputs=inputs,
|
||||
state=state,
|
||||
)
|
||||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
resps_list=resps_list,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
beam_width=sampling_beam_width,
|
||||
|
||||
mirostat=mirostat,
|
||||
|
||||
dry_multiplier=sampling_dry_multiplier,
|
||||
dry_base=sampling_dry_base,
|
||||
dry_allowed_length=sampling_dry_allowed_length,
|
||||
)
|
||||
|
||||
if mirostat is not None:
|
||||
# r is the state
|
||||
mirostat = r
|
||||
# extract token from state
|
||||
r = [ state["token"] for state in mirostat ]
|
||||
# we do it here because the sampler will already expand our logits list
|
||||
elif sampling_beam_width > 0:
|
||||
# expand tuple
|
||||
r, s = r
|
||||
# first step, expand batch
|
||||
if batch_size == 1:
|
||||
batch_size = sampling_beam_width
|
||||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
sequence_list = sequence_list * sampling_beam_width
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
if sampling_beam_width:
|
||||
sequence_list = [ sequence_list[0] ]
|
||||
|
||||
sequence_list = [self._prune(r, stop_token) for r in sequence_list]
|
||||
|
||||
for i, seq in enumerate( sequence_list ):
|
||||
steps = seq.shape[0] // self.n_resp_levels
|
||||
nearest_steps = steps * self.n_resp_levels
|
||||
sequence_list[i] = seq[:nearest_steps].view(( steps, self.n_resp_levels ))
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
def example_usage():
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_100
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||
from ..engines import Engine, Engines
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
device = "cuda"
|
||||
|
||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
"""
|
||||
if "mamba" in cfg.model.arch_type:
|
||||
cfg.model.resp_levels = 1
|
||||
"""
|
||||
# cfg.model.loss_factors = {}
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt[:, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
|
||||
text_list = text_list[:1]
|
||||
proms_list = proms_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
||||
kwargs = {
|
||||
'n_text_tokens': 256,
|
||||
'n_audio_tokens': 1024,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1,
|
||||
|
||||
'p_dropout': 0.1,
|
||||
|
||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||
|
||||
'config': cfg.model
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
kwargs['config'] = cfg.model
|
||||
except Exception as e:
|
||||
pass
|
||||
"""
|
||||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
tasks = cfg.dataset.tasks_list
|
||||
|
||||
model = AR(**kwargs).to(device)
|
||||
steps = 75 * len(tasks) * cfg.model.experimental.causal_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
||||
|
||||
if cfg.optimizations.dadaptation:
|
||||
# do not combine the two
|
||||
if scheduler == "schedulefree":
|
||||
scheduler = ""
|
||||
|
||||
learning_rate = 1.0
|
||||
|
||||
if optimizer == "prodigy":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0
|
||||
|
||||
optimizer = ml.Prodigy
|
||||
elif optimizer == "adagrad":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-2
|
||||
|
||||
optimizer = ml.Adagrad
|
||||
elif optimizer == "adamw":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.AdamW
|
||||
elif optimizer == "sdg":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.SGD
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
if scheduler == "schedulefree":
|
||||
if isinstance(optimizer, ml.AdamW):
|
||||
scheduler = ml.schedulefree.AdamWScheduleFree
|
||||
elif isinstance(optimizer, ml.SGD):
|
||||
scheduler = ml.schedulefree.SGDScheduleFree
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
print("Scheduler:", scheduler)
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model = ml.replace_linear( model )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||
model = ml.replace_embedding( model )
|
||||
|
||||
"""
|
||||
cfg.optimizations.model_offloading = {
|
||||
"devices": ["cuda:0", "cpu"],
|
||||
# "limits": [ 0.9, -1 ],
|
||||
"assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]],
|
||||
# "limits": [ 256 * (1024 ** 2), -1 ]
|
||||
}
|
||||
"""
|
||||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
engines = Engines({"ar": engine})
|
||||
engines.setup()
|
||||
|
||||
"""
|
||||
if cfg.optimizations.model_offloading:
|
||||
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
|
||||
"""
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_data(task=None):
|
||||
texts = []
|
||||
proms = []
|
||||
resps = []
|
||||
|
||||
for i in range(batch_size):
|
||||
if task is None:
|
||||
task = random.choice(tasks)
|
||||
|
||||
text = text_list[i]
|
||||
prom = proms_list[i]
|
||||
resp = resps_list[i]
|
||||
|
||||
# do nothing
|
||||
if task == "tts":
|
||||
...
|
||||
elif task == "tts-c":
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
|
||||
prom = resp[:trim_length]
|
||||
resp = resp[trim_length:]
|
||||
elif task == "ns" or task == "sr":
|
||||
# extend the noise to fill the target audio
|
||||
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resp = noise_ext
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
|
||||
|
||||
texts.append( text.to(device) )
|
||||
proms.append( prom.to(device) )
|
||||
resps.append( resp.to(device) )
|
||||
|
||||
return texts, proms, resps
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000, task=None ):
|
||||
engine.eval()
|
||||
|
||||
texts, proms, resps = sample_data( task )
|
||||
|
||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
||||
for i, o in enumerate(resps):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
texts, proms, resps = sample_data()
|
||||
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
#sample("init", 5)
|
||||
train()
|
||||
|
||||
"""
|
||||
if cfg.optimizations.compile:
|
||||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||
"""
|
||||
|
||||
for task in tasks:
|
||||
sample("final", task=task)
|
||||
|
||||
engines.quit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
|
@ -531,6 +531,8 @@ def example_usage():
|
|||
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||
else:
|
||||
resps = [ resp[:, 0] for resp in resps ]
|
||||
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
||||
|
|
|
@ -73,6 +73,13 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
|||
m = m.to(x)
|
||||
return x, m
|
||||
|
||||
def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
|
||||
shape = (input[0].shape[0] * len(input), input[0].shape[dim] )
|
||||
return torch.concat( [ i.t() for i in input ] ).t().reshape( shape )
|
||||
|
||||
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
||||
return torch.concat( [ i.t() for i in input ] ).t().flatten()
|
||||
|
||||
# automagically parses a batch-list and returns it as a list
|
||||
"""
|
||||
class Embedding(nn.Embedding):
|
||||
|
@ -158,6 +165,8 @@ class AudioEmbedding(nn.Module):
|
|||
token_dim: int, # dimensionality of the embedding
|
||||
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings
|
||||
|
||||
capabilities: list[str] | None = None, # helper shit
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
|
@ -169,6 +178,7 @@ class AudioEmbedding(nn.Module):
|
|||
self.sums = sums
|
||||
|
||||
self.external_mode = external_mode
|
||||
self.capabilities = capabilities
|
||||
|
||||
# set initial weights to zero
|
||||
if self.external_mode == "inclusive":
|
||||
|
@ -213,7 +223,19 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return embedding
|
||||
|
||||
def internal_forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
|
||||
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
|
||||
if offset is None:
|
||||
# prom
|
||||
if self.capabilities is None:
|
||||
offset = 0
|
||||
# resp
|
||||
elif "len" in self.capabilities:
|
||||
offset = 1
|
||||
elif "nar" not in self.capabilities:
|
||||
offset = 0
|
||||
elif quant_level > 0:
|
||||
offset = 1
|
||||
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
|
@ -225,7 +247,7 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
|
||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
|
||||
x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||
|
||||
if self.external_mode and xi.shape[0] > 0:
|
||||
|
@ -403,15 +425,22 @@ class Base(nn.Module):
|
|||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
|
||||
n_tasks = self.config.tasks if self.config is not None else 8
|
||||
n_langs = self.config.langs if self.config is not None else 2
|
||||
n_tones = self.config.tones if self.config is not None else 1
|
||||
|
||||
if "len" not in self.capabilities:
|
||||
# pure AR
|
||||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" not in self.capabilities:
|
||||
# +1 to include the stop token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
# AR+NAR model
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens
|
||||
l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0))
|
||||
|
@ -423,6 +452,7 @@ class Base(nn.Module):
|
|||
"""
|
||||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -455,11 +485,13 @@ class Base(nn.Module):
|
|||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=audio_embedding_mode,
|
||||
capabilities=None,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
l_tokens, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=audio_embedding_mode,
|
||||
capabilities=self.capabilities,
|
||||
)
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
|
@ -893,7 +925,7 @@ class Base(nn.Module):
|
|||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert RVQ level guidance token if the model is versioned for it
|
||||
if self.rvq_l_emb is not None:
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert input audio prompt
|
||||
if proms_list is not None and proms_list[i] is not None:
|
||||
|
@ -1007,7 +1039,15 @@ class Base(nn.Module):
|
|||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
if self.interleave:
|
||||
embeddings = [ self.resps_emb(
|
||||
input[:, :l+1],
|
||||
offset = 0,
|
||||
quant_level = l
|
||||
) for l in range( input.shape[-1] ) ]
|
||||
|
||||
embedding = _interleave_sequence_reshape( embeddings )
|
||||
elif "len" in self.capabilities and quant_level == 0:
|
||||
if input_prom is not None:
|
||||
# fill with the prom as the initial condition
|
||||
repeat = (input.shape[0] // input_prom.shape[0]) + 1
|
||||
|
@ -1020,9 +1060,10 @@ class Base(nn.Module):
|
|||
)
|
||||
else:
|
||||
# fill with "stop" token from the len layer for the NAR-only model
|
||||
filler_token = 12
|
||||
embedding = self.resps_emb(
|
||||
# self.dropout_token.repeat((input.shape[0], 1)),
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], 12),
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], filler_token),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
|
@ -1035,9 +1076,17 @@ class Base(nn.Module):
|
|||
quant_level
|
||||
)
|
||||
else:
|
||||
offset = 0
|
||||
if "len" in self.capabilities:
|
||||
offset = 1
|
||||
elif "nar" not in self.capabilities:
|
||||
offset = 0
|
||||
elif quant_level > 0:
|
||||
offset = 1
|
||||
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
offset = 1 if "len" in self.capabilities else (0 if quant_level == 0 else 1),
|
||||
offset = offset,
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
)
|
||||
|
||||
|
@ -1087,6 +1136,10 @@ class Base(nn.Module):
|
|||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1
|
||||
|
||||
# interleaved model
|
||||
if self.interleave and name == "resp":
|
||||
return input.shape[0] * input.shape[1]
|
||||
|
||||
# ending input will not have a separator later
|
||||
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
||||
|
||||
|
@ -1142,7 +1195,10 @@ class Base(nn.Module):
|
|||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||
elif name == "resp":
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||
if self.interleave:
|
||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||
else:
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||
target.append( input )
|
||||
|
||||
|
|
|
@ -8,5 +8,6 @@ from .utils import (
|
|||
tree_map,
|
||||
do_gc,
|
||||
set_seed,
|
||||
passes_policy
|
||||
passes_policy,
|
||||
get_devices
|
||||
)
|
|
@ -379,6 +379,9 @@ def resize_weight( weight, target, dim=0, random=True ):
|
|||
|
||||
return weight
|
||||
|
||||
def get_devices():
|
||||
return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu']
|
||||
|
||||
# grabs the memory properties of a given device
|
||||
def get_device_properties( device ):
|
||||
if 'cuda' in device:
|
||||
|
@ -416,7 +419,7 @@ def get_model_offload_policy(module, policy=None):
|
|||
policy["assign"] = []
|
||||
|
||||
if "devices" not in policy:
|
||||
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
|
||||
policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget
|
||||
|
||||
# create initial device info
|
||||
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
||||
|
|
|
@ -13,6 +13,7 @@ from pathlib import Path
|
|||
|
||||
from .inference import TTS, cfg
|
||||
from .train import train
|
||||
from .utils import get_devices
|
||||
|
||||
tts = None
|
||||
|
||||
|
@ -70,8 +71,11 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
|||
|
||||
return yamls
|
||||
|
||||
def get_dtypes():
|
||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||
|
||||
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
||||
def load_model( yaml ):
|
||||
def load_model( yaml, device, dtype ):
|
||||
gr.Info(f"Loading: {yaml}")
|
||||
try:
|
||||
init_tts( yaml=Path(yaml), restart=True )
|
||||
|
@ -79,7 +83,7 @@ def load_model( yaml ):
|
|||
raise gr.Error(e)
|
||||
gr.Info(f"Loaded model")
|
||||
|
||||
def init_tts(yaml=None, restart=False):
|
||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
|
||||
global tts
|
||||
|
||||
if tts is not None:
|
||||
|
@ -91,9 +95,9 @@ def init_tts(yaml=None, restart=False):
|
|||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--device", type=str, default=device)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default="auto")
|
||||
parser.add_argument("--dtype", type=str, default=dtype)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||
|
@ -307,7 +311,10 @@ with ui:
|
|||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda", label="Device")
|
||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||
with gr.Column(scale=1):
|
||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user