diff --git a/vall_e/demo.py b/vall_e/demo.py index ab3fdc2..21e1ea1 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -64,6 +64,10 @@ def main(): parser.add_argument("--mirostat-tau", type=float, default=0) parser.add_argument("--mirostat-eta", type=float, default=0) + + parser.add_argument("--dry-multiplier", type=float, default=0) + parser.add_argument("--dry-base", type=float, default=1.75) + parser.add_argument("--dry-allowed-length", type=int, default=2) parser.add_argument("--seed", type=int, default=None) @@ -99,6 +103,7 @@ def main(): length_penalty=args.length_penalty, beam_width=args.beam_width, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, )) ) # pull from provided samples diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index ed73888..c62cfc0 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -338,12 +338,6 @@ class Engines(dict[str, Engine]): lora = None save_path = cfg.ckpt_dir / name / f"fp32.{format}" config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config - - # coerce - if not isinstance(config, dict): - config = config.__dict__ - if not isinstance(config['experimental'], dict): - config['experimental'] = config['experimental'].__dict__ # safety for k, v in module.items(): @@ -363,7 +357,7 @@ class Engines(dict[str, Engine]): "tokens_processed": engine.tokens_processed, }, "userdata": userdata, - "config": config + "config": config.__dict__ | {"experimental": config.experimental.__dict__} # i hate implicit aliasing rules } if lora is None: diff --git a/vall_e/export.py b/vall_e/export.py index a807b22..934393a 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -98,8 +98,8 @@ def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dt tokens = 1025 if i == 0 else 1024 # trim per RVQ level (since level 0 has a stop token) - state_dict['module'][f'classifiers.proj.{i}.weight'] = state_dict['module']['classifier.weight'][:tokens, :] - state_dict['module'][f'classifiers.proj.{i}.bias'] = state_dict['module']['classifier.bias'][:tokens] + state_dict['module'][f'classifiers.proj.{i}.weight'] = state_dict['module']['classifier.weight'][:tokens, :].clone() + state_dict['module'][f'classifiers.proj.{i}.bias'] = state_dict['module']['classifier.bias'][:tokens].clone() # delete old weights del state_dict['module']['classifier.weight'] diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py new file mode 100644 index 0000000..9a861be --- /dev/null +++ b/vall_e/models/ar.py @@ -0,0 +1,516 @@ +""" +# an AR model that (should) handle: +* handling all RVQ levels, but does it in an autoregressive manner + +It's in a mess of a state, because I want this to be an interleaved model, but it just seems better to use the vall_e.models.experimental model. +""" +from .base import Base, list_to_tensor, Categorical +from ..config import cfg + +import torch +from torch.nn.utils.rnn import pad_sequence + +import random +import math +from einops import rearrange +from torch import Tensor +from tqdm import trange + +from ..emb.qnt import trim, encode_as_embedding + +from .lora import enable_lora + +def clamp(n, lo, hi): + return max(lo, min(n, hi)) + +class AR(Base): + def forward( + self, + text_list: list[Tensor], + proms_list: list[Tensor], + resps_list: list[Tensor] | None = None, + + task_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + + training: bool | None = None, + + max_steps: int = 1000, + max_levels: int = 0, + + sampling_temperature: float = 1.0, + sampling_min_temperature: float = -1.0, + sampling_top_k: int = -100, + sampling_top_p: float = 1.0, + sampling_repetition_penalty: float = 1.0, + sampling_repetition_penalty_decay: float = 0.0, + sampling_length_penalty: float = 0.0, + sampling_beam_width: int = 0, + sampling_mirostat_tau: float = 0.0, + sampling_mirostat_eta: float = 0.1, + sampling_dry_multiplier=0.0, + sampling_dry_base=1.75, + sampling_dry_allowed_length=2, + + disable_tqdm=False, + ): + device = text_list[0].device + batch_size = len(text_list) + + # generate task list if not provided + if task_list is None: + task_list = [ "tts" for _ in range(batch_size) ] + + # is training or NAR + if resps_list is not None: + n_levels_set = {r.shape[-1] for r in resps_list} + n_levels = next(iter(n_levels_set)) + + if training is None: + training = n_levels == self.n_resp_levels + + # is training + if training: + # specifies how to sample probabilities of which RVQ levels to train against + p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" + # determines which RVQ level to target per batch + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] + # rate to perform token dropout errors + token_dropout_error = self.config.experimental.token_dropout_error + # RVQ levels to apply token dropout on + token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + # implicitly set it to all levels + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, self.resp_levels - 1] + # allow passing a specific distribution of RVQ levels + p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else [] + if not p_rvq_levels: + lo, hi = quant_level_range[0], quant_level_range[1] + 1 + # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + if p_rvq_levels == "equal": + p_rvq_levels = [ i for i in range( lo, hi ) ] + else: + # yuck + p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + + # input RVQ levels + if not self.interleave: + quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] + # trim resps to only contain all levels below the target level + resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] + else: + quant_levels = [ 0 for i in range(batch_size) ] + + # tensor to cat for RVQ level 0 + # I hate python's value/reference semantics so much + for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): + # cap quant_level if it exceeds its corresponding resp/prom + if quant_level >= resps.shape[-1]: + quant_levels[i] = resps.shape[-1] - 1 + + # proms could be a Tensor, list[Tensor], or None + if isinstance( proms, torch.Tensor ): + if quant_level >= proms.shape[-1]: + quant_levels[i] = proms.shape[-1] - 1 + + elif isinstance( proms, list ): + for j, prom in enumerate( proms ): + if not isinstance( prom, torch.Tensor ): + continue + + if quant_level >= prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 + + # apply token dropout error compensation + if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): + steps = resps.shape[0] + for l in range( quant_level ): + for t in range( steps ): + token = resps[t, l].item() + + if random.random() < token_dropout_error: + offset = 1 * ( 1 if random.random() < 0.5 else -1 ) + resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + + # only apply stop token for RVQ level 0 + stop_sequence = torch.tensor([[self.stop_token] * resps.shape[-1]], device=device, dtype=torch.int16) + resps_list[i] = torch.cat([ resps, stop_sequence ]) + + + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + task_list=task_list, + ) + + return super().forward( + inputs=inputs, + ) + + # is AR + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( 0 ) ) + + sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] + stopped = torch.zeros(batch_size, device=device).bool() + + stop_token = self.stop_token + + + state = None + mirostat = [ + {"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} + ] * batch_size if sampling_mirostat_tau > 0.0 else None + + scores = [ 1.0 ] * sampling_beam_width + + # get next in sequence + for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): + resps_list = [x.unsqueeze(dim=-1) for x in sequence_list] + + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + task_list=task_list, + quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] + ) + + if state is not None: + logits, state = super().forward( + inputs=inputs, + state=state, + ) + else: + logits = super().forward( + inputs=inputs, + state=state, + ) + + r = super().sample( + logits=logits, + resps_list=resps_list, + + temperature=sampling_temperature, + min_temperature=sampling_min_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + length_penalty=sampling_length_penalty, + beam_width=sampling_beam_width, + + mirostat=mirostat, + + dry_multiplier=sampling_dry_multiplier, + dry_base=sampling_dry_base, + dry_allowed_length=sampling_dry_allowed_length, + ) + + if mirostat is not None: + # r is the state + mirostat = r + # extract token from state + r = [ state["token"] for state in mirostat ] + # we do it here because the sampler will already expand our logits list + elif sampling_beam_width > 0: + # expand tuple + r, s = r + # first step, expand batch + if batch_size == 1: + batch_size = sampling_beam_width + text_list = text_list * sampling_beam_width + proms_list = proms_list * sampling_beam_width + sequence_list = sequence_list * sampling_beam_width + stopped = torch.zeros(batch_size, device=device).bool() + + scores = [ scores[i] + score for i, score in enumerate(s) ] + + # append tokens + for i, ri in enumerate(r): + if stop_token in ri: + stopped[i] = True + sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) + + # stop token found + stopped |= r == stop_token + if stopped.all().item(): + break + + # pick the best scoring candidate + # desu this is always going to be candidate 0 + if sampling_beam_width: + sequence_list = [ sequence_list[0] ] + + sequence_list = [self._prune(r, stop_token) for r in sequence_list] + + for i, seq in enumerate( sequence_list ): + steps = seq.shape[0] // self.n_resp_levels + nearest_steps = steps * self.n_resp_levels + sequence_list[i] = seq[:nearest_steps].view(( steps, self.n_resp_levels )) + + return sequence_list + + +def example_usage(): + cfg.trainer.backend = "local" + cfg.hyperparameters.gradient_accumulation_steps = 1 + if cfg.audio_backend == "dac": + cfg.sample_rate = 44_100 + + from functools import partial + from einops import repeat + from tqdm import tqdm + + from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio + from ..engines import Engine, Engines + from ..utils import wrapper as ml + + import numpy as np + import re + + device = "cuda" + + # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) + """ + if "mamba" in cfg.model.arch_type: + cfg.model.resp_levels = 1 + """ + # cfg.model.loss_factors = {} + + def tokenize(content): + return torch.tensor( cfg.tokenizer.encode(content) ) + + def _load_quants(path) -> Tensor: + qnt = np.load(path, allow_pickle=True)[()] + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) + + qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + + text_list = [ + tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), + #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device), + ] + proms_list = [ + qnt[:cfg.dataset.frames_per_second, :].to(device), + #qnt[:cfg.dataset.frames_per_second, :].to(device), + ] + resps_list = [ + qnt[:, :].to(device), + #qnt[:cfg.dataset.frames_per_second, :].to(device), + ] + + text_list = text_list[:1] + proms_list = proms_list[:1] + resps_list = resps_list[:1] + + batch_size = len(text_list) + + # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise + kwargs = { + 'n_text_tokens': 256, + 'n_audio_tokens': 1024, + + 'd_model': 1024, # 256, # 1024, # 1536 + 'n_heads': 16, # 4, # 16, # 24 + 'n_layers': 12, # 32 + 'n_experts': 1, + + 'p_dropout': 0.1, + + 'l_padding': 8 if cfg.optimizations.fp8 else 0, + + 'config': cfg.model + } + + """ + try: + kwargs['config'] = cfg.model + except Exception as e: + pass + """ + + bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) + tasks = cfg.dataset.tasks_list + + model = AR(**kwargs).to(device) + steps = 75 * len(tasks) * cfg.model.experimental.causal_size + + optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" + scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" + learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None + + if cfg.optimizations.dadaptation: + # do not combine the two + if scheduler == "schedulefree": + scheduler = "" + + learning_rate = 1.0 + + if optimizer == "prodigy": + if learning_rate is None: + learning_rate = 1.0 + + optimizer = ml.Prodigy + elif optimizer == "adagrad": + if learning_rate is None: + learning_rate = 1.0e-2 + + optimizer = ml.Adagrad + elif optimizer == "adamw": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.AdamW + elif optimizer == "sdg": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.SGD + else: + raise ValueError(f"Unrecognized optimizer: {optimizer}") + + print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + + optimizer = optimizer(model.parameters(), lr=learning_rate) + + if scheduler == "schedulefree": + if isinstance(optimizer, ml.AdamW): + scheduler = ml.schedulefree.AdamWScheduleFree + elif isinstance(optimizer, ml.SGD): + scheduler = ml.schedulefree.SGDScheduleFree + else: + scheduler = None + + if scheduler is not None: + print("Scheduler:", scheduler) + optimizer = scheduler( model.parameters(), lr = learning_rate ) + + if cfg.optimizations.replace and cfg.optimizations.linear: + model = ml.replace_linear( model ) + + if cfg.optimizations.replace and cfg.optimizations.embedding: + model = ml.replace_embedding( model ) + + """ + cfg.optimizations.model_offloading = { + "devices": ["cuda:0", "cpu"], + # "limits": [ 0.9, -1 ], + "assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]], + # "limits": [ 256 * (1024 ** 2), -1 ] + } + """ + + engine = Engine(model=model, optimizer=optimizer) + engines = Engines({"ar": engine}) + engines.setup() + + """ + if cfg.optimizations.model_offloading: + model = ml.offload_model( model, policy=cfg.optimizations.model_offloading ) + """ + + """ + torch.save( { + 'module': model.state_dict() + }, f"./data/{cfg.model.arch_type}.pth" ) + """ + + print(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + @torch.no_grad() + def sample_data(task=None): + texts = [] + proms = [] + resps = [] + + for i in range(batch_size): + if task is None: + task = random.choice(tasks) + + text = text_list[i] + prom = proms_list[i] + resp = resps_list[i] + + # do nothing + if task == "tts": + ... + elif task == "tts-c": + trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) + + prom = resp[:trim_length] + resp = resp[trim_length:] + elif task == "ns" or task == "sr": + # extend the noise to fill the target audio + noise_ext = repeat_extend_audio( noise, resp.shape[0] ) + # create the input prompt by merging the target audio with the noise + prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) + # set the target to just be the noise if + if task == "sr": + resp = noise_ext + + # set the text prompt to empty to train without a guided text prompt + if random.random() < 0.5: + text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8) + + texts.append( text.to(device) ) + proms.append( prom.to(device) ) + resps.append( resp.to(device) ) + + return texts, proms, resps + + @torch.inference_mode() + def sample( name, steps=1000, task=None ): + engine.eval() + + texts, proms, resps = sample_data( task ) + + resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) + + for i, o in enumerate(resps): + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) + + unload_model() + + def train(): + engine.train() + t = trange(steps) + for i in t: + texts, proms, resps = sample_data() + + stats = {"step": i} + stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps) + stats |= {"grad_norm": engine.get_global_grad_norm()} + + tqdm.write(f"{stats}") + + """ + torch.save( { + 'module': model.state_dict() + }, f"./data/{cfg.model.arch_type}.pth" ) + """ + + #sample("init", 5) + train() + + """ + if cfg.optimizations.compile: + model = ml.compile_model(model, backend=cfg.optimizations.compile) + """ + + for task in tasks: + sample("final", task=task) + + engines.quit() + +if __name__ == "__main__": + example_usage() \ No newline at end of file diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index cf84924..4b5c404 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -531,6 +531,8 @@ def example_usage(): if "ar" in cfg.model.capabilities: resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) + else: + resps = [ resp[:, 0] for resp in resps ] if "nar" in cfg.model.capabilities: resps = engine( texts, proms, resps, sampling_temperature=0.2 ) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 895ab1a..4b43e85 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -73,6 +73,13 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): m = m.to(x) return x, m +def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ): + shape = (input[0].shape[0] * len(input), input[0].shape[dim] ) + return torch.concat( [ i.t() for i in input ] ).t().reshape( shape ) + +def _interleave_sequence_flatten( input: list[torch.Tensor] ): + return torch.concat( [ i.t() for i in input ] ).t().flatten() + # automagically parses a batch-list and returns it as a list """ class Embedding(nn.Embedding): @@ -158,6 +165,8 @@ class AudioEmbedding(nn.Module): token_dim: int, # dimensionality of the embedding sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings + + capabilities: list[str] | None = None, # helper shit ): super().__init__() # array of embeddings @@ -169,6 +178,7 @@ class AudioEmbedding(nn.Module): self.sums = sums self.external_mode = external_mode + self.capabilities = capabilities # set initial weights to zero if self.external_mode == "inclusive": @@ -213,7 +223,19 @@ class AudioEmbedding(nn.Module): return embedding - def internal_forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor: + def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor: + if offset is None: + # prom + if self.capabilities is None: + offset = 0 + # resp + elif "len" in self.capabilities: + offset = 1 + elif "nar" not in self.capabilities: + offset = 0 + elif quant_level > 0: + offset = 1 + if quant_level is None: quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 @@ -225,7 +247,7 @@ class AudioEmbedding(nn.Module): return x - def forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor: + def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor: x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None if self.external_mode and xi.shape[0] > 0: @@ -403,15 +425,22 @@ class Base(nn.Module): tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True + interleave = self.config.experimental.interleave if self.config is not None else False n_tasks = self.config.tasks if self.config is not None else 8 n_langs = self.config.langs if self.config is not None else 2 n_tones = self.config.tones if self.config is not None else 1 - if "len" not in self.capabilities: + # pure AR + if "nar" not in self.capabilities: + n_resp_tokens = n_audio_tokens + 1 + l_tokens = [n_resp_tokens] * self.n_resp_levels + # NAR-len model + elif "len" not in self.capabilities: # +1 to include the stop token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + # AR+NAR model else: n_resp_tokens = n_audio_tokens l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0)) @@ -423,6 +452,7 @@ class Base(nn.Module): """ self.unified_position_ids = unified_position_ids + self.interleave = interleave self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -455,11 +485,13 @@ class Base(nn.Module): [n_audio_tokens] * self.n_resp_levels, d_model, sums=audio_embedding_sums, external_mode=audio_embedding_mode, + capabilities=None, ) self.resps_emb = AudioEmbedding( l_tokens, d_model, sums=audio_embedding_sums, external_mode=audio_embedding_mode, + capabilities=self.capabilities, ) # useless since I actually removed using these with the input processing overhaul... @@ -893,7 +925,7 @@ class Base(nn.Module): if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) # insert RVQ level guidance token if the model is versioned for it - if self.rvq_l_emb is not None: + if self.rvq_l_emb is not None and not self.interleave: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert input audio prompt if proms_list is not None and proms_list[i] is not None: @@ -1007,7 +1039,15 @@ class Base(nn.Module): elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": - if "len" in self.capabilities and quant_level == 0: + if self.interleave: + embeddings = [ self.resps_emb( + input[:, :l+1], + offset = 0, + quant_level = l + ) for l in range( input.shape[-1] ) ] + + embedding = _interleave_sequence_reshape( embeddings ) + elif "len" in self.capabilities and quant_level == 0: if input_prom is not None: # fill with the prom as the initial condition repeat = (input.shape[0] // input_prom.shape[0]) + 1 @@ -1020,9 +1060,10 @@ class Base(nn.Module): ) else: # fill with "stop" token from the len layer for the NAR-only model + filler_token = 12 embedding = self.resps_emb( # self.dropout_token.repeat((input.shape[0], 1)), - torch.full_like(input if input.dim() == 1 else input[..., 0], 12), + torch.full_like(input if input.dim() == 1 else input[..., 0], filler_token), offset = 0, quant_level = 0, ) @@ -1035,9 +1076,17 @@ class Base(nn.Module): quant_level ) else: + offset = 0 + if "len" in self.capabilities: + offset = 1 + elif "nar" not in self.capabilities: + offset = 0 + elif quant_level > 0: + offset = 1 + embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], - offset = 1 if "len" in self.capabilities else (0 if quant_level == 0 else 1), + offset = offset, quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level ) @@ -1087,6 +1136,10 @@ class Base(nn.Module): if not isinstance(input, torch.Tensor): return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1 + # interleaved model + if self.interleave and name == "resp": + return input.shape[0] * input.shape[1] + # ending input will not have a separator later return input.shape[0] + (0 if name in ["resp", "len"] else 1) @@ -1142,7 +1195,10 @@ class Base(nn.Module): proms = [ input ] if isinstance(input, torch.Tensor) else input target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) ) elif name == "resp": - target.append( input if input.dim() == 1 else input[:, quant_level] ) + if self.interleave: + target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) ) + else: + target.append( input if input.dim() == 1 else input[:, quant_level] ) elif name in ["text", "quant_level", "lang", "tone", "len"]: target.append( input ) diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index 77c0698..d79e335 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -8,5 +8,6 @@ from .utils import ( tree_map, do_gc, set_seed, - passes_policy + passes_policy, + get_devices ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 4cc6133..5d5193d 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -379,6 +379,9 @@ def resize_weight( weight, target, dim=0, random=True ): return weight +def get_devices(): + return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] + # grabs the memory properties of a given device def get_device_properties( device ): if 'cuda' in device: @@ -416,7 +419,7 @@ def get_model_offload_policy(module, policy=None): policy["assign"] = [] if "devices" not in policy: - policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget + policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget # create initial device info devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ] diff --git a/vall_e/webui.py b/vall_e/webui.py index dd4888e..cf5eecc 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -13,6 +13,7 @@ from pathlib import Path from .inference import TTS, cfg from .train import train +from .utils import get_devices tts = None @@ -70,8 +71,11 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ): return yamls +def get_dtypes(): + return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] + #@gradio_wrapper(inputs=layout["settings"]["inputs"].keys()) -def load_model( yaml ): +def load_model( yaml, device, dtype ): gr.Info(f"Loading: {yaml}") try: init_tts( yaml=Path(yaml), restart=True ) @@ -79,7 +83,7 @@ def load_model( yaml ): raise gr.Error(e) gr.Info(f"Loaded model") -def init_tts(yaml=None, restart=False): +def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"): global tts if tts is not None: @@ -91,9 +95,9 @@ def init_tts(yaml=None, restart=False): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too - parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--device", type=str, default=device) parser.add_argument("--amp", action="store_true") - parser.add_argument("--dtype", type=str, default="auto") + parser.add_argument("--dtype", type=str, default=dtype) args, unknown = parser.parse_known_args() tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) @@ -307,7 +311,10 @@ with ui: with gr.Tab("Settings"): with gr.Row(): with gr.Column(scale=7): - layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model") + with gr.Row(): + layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model") + layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda", label="Device") + layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") with gr.Column(scale=1): layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")