diff --git a/vall_e/config.py b/vall_e/config.py index 8d11bd9..b870280 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -158,6 +158,8 @@ class Dataset: sample_type: str = "path" # path | speaker sample_order: str = "shuffle" # duration + sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable + # for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough tasks_list: list[str] = field(default_factory=lambda: ["tts"]) @@ -197,28 +199,29 @@ class Dataset: @dataclass() class Model: name: str = "" # vanity name for the model - version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding + version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings size: str | dict = "full" # preset string or explicitly defined dimensionality resp_levels: int = 1 # RVQ-bin levels this model targets for outputs prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt - tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") - langs: int = 1 # defined languages - tones: int = 1 # defined tones - experts: int = 1 - arch_type: str = "retnet" # or "transformer"" - training: bool = True # unneeded now + tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused) + langs: int = 1 # defined languages (semi-unused) + tones: int = 1 # defined tones (unsued) + experts: int = 1 # for mixtral / retnet-ts + arch_type: str = "llama" # underling LM architecture used + training: bool = True # I really need to attend to this interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results) - p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training - attention: str = "auto" - audio_embedding_sums: bool = True - split_classifiers: bool = False + attention: str = "auto" # for llama arch_types: attention used + audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level + split_classifiers: bool = False # experimental, but each RVQ level gets its own classifier / output proj / LM head dropout: float = 0.1 # adjustable dropout value #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good loss_factors: dict = field(default_factory=lambda: {}) capabilities: list = field(default_factory=lambda: ["ar", "nar"]) experimental: str | None = None # for now it sets things to be HF compatible kv_heads: int = 0 # MHA or GQA (for supported backends) + + p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range def get(self, name=None): @@ -338,8 +341,8 @@ class Model: class LoRA: name: str = "lora" # vanity name # to-do: find sane default values - rank: int = 8 # rank for the LoRA - alpha: int = 16 # rank for the LoRA + rank: int = 128 # rank for the LoRA + alpha: int = 128 # rank for the LoRA training: bool = True # parametrize: bool = False # rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA @@ -349,6 +352,7 @@ class LoRA: name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ] return "-".join(name) + # actually not needed anymore def active_level( self, level ): if not self.rvq_levels: return True @@ -360,10 +364,10 @@ class Hyperparameters: gradient_accumulation_steps: int = 32 gradient_clipping: int | float = 100 - optimizer: str = "Adamw" + optimizer: str = "Adamw" # should be 'Prodigyopt" now optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config - learning_rate: float = 3.25e-4 + learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt warmup_steps: int = 0 scheduler: str = "" @@ -384,18 +388,18 @@ class Evaluation: steps: int = 500 ar_temperature: float = 1.0 - nar_temperature: float = 0.2 + nar_temperature: float = 0.0 load_disabled_engines: bool = True @dataclass() class DeepSpeed: zero_optimization_level: int = 0 - use_compression_training: bool = False - compression_bits: int = 8 - inferencing: bool = False + use_compression_training: bool = False # cope + compression_bits: int = 8 # cope + inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead - amp: bool = False + amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently) config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config @@ -567,7 +571,7 @@ class Trainer: load_module_only: bool = False restart_step_count: bool = False - activation_checkpointing: bool | None = None # deprecated + activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing gradient_checkpointing: bool = True aggressive_optimizations: bool = False @@ -612,17 +616,13 @@ class Inference: amp: bool = False normalize: bool = False # do NOT enable this unless you know exactly what you're doing - audio_backend: str = "" # encodec, vocos, dac # legacy / backwards compat + audio_backend: str = "" # encodec, vocos, dac use_vocos: bool = True use_encodec: bool = True use_dac: bool = True - # shit that doesn't work - recurrent_chunk_size: int = 0 - recurrent_forward: bool = False - @cached_property def dtype(self): if self.weight_dtype == "float16": @@ -726,6 +726,7 @@ class Config(BaseConfig): print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e)) self.dataset.use_hdf5 = False + # to-do: prune unused keys def format( self, training=True ): if isinstance(self.dataset, type): self.dataset = dict() diff --git a/vall_e/data.py b/vall_e/data.py index 1800180..6ef6392 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -12,7 +12,7 @@ import itertools from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file -from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler +from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size from collections import defaultdict @@ -483,23 +483,29 @@ class Dataset(_Dataset): if self.sampler_order != "duration": continue - bucket = str(int(round(duration))) + bucket = int(round(duration)) if bucket not in self.duration_buckets: self.duration_buckets[bucket] = [] self.duration_buckets[bucket].append( ( Path(path), duration ) ) + # ensure they're ordered + self.duration_buckets = dict(sorted(self.duration_buckets.items())) + # sort by duration if self.sampler_order == "duration": + flattened = {} # sort and interleave for bucket in self.duration_buckets: # sort by duration self.duration_buckets[bucket].sort( key=lambda x: x[1] ) + # split to retain tuples + flattened[bucket] = self.duration_buckets[bucket] # replace with path - self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ] + flattened[bucket] = [ x[0] for x in flattened[bucket] ] # flatten by paths - self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)] + flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)] # flatten paths - self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values())) + self.paths = list(itertools.chain.from_iterable(flattened.values())) elif self.sampler_order == "shuffle": # just interleave self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] @@ -536,12 +542,14 @@ class Dataset(_Dataset): if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") - sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" if self.sampler_type == "path": - self.sampler = OrderedSampler( len(self) ) + if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: + self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size ) + else: + self.sampler = OrderedSampler( len(self) ) self.samplers = {} self.spkr_samplers = {} else: @@ -1001,17 +1009,23 @@ def _create_dataloader(dataset, training): shuffle = False """ + kwargs = dict( + shuffle=dataset.shuffle, + batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, + drop_last=training, + sampler=dataset.sampler, + ) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict( + batch_sampler=dataset.sampler, + ) + return DataLoader( dataset=dataset, - batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, - shuffle=dataset.shuffle, - drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=cfg.dataset.workers > 1, pin_memory=False, # True, worker_init_fn=_seed_worker, - sampler=dataset.sampler, + **kwargs, ) def create_datasets(): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ec38d38..72b9831 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -72,6 +72,12 @@ class AR_NAR(Base): if hasattr(self, "config") and self.config: return self.config.tasks return cfg.model.tasks + + @property + def p_rvq_levels(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.p_rvq_levels + return cfg.model.p_rvq_levels @property def n_langs(self) -> int: @@ -163,7 +169,10 @@ class AR_NAR(Base): # determines which RVQ level to target per batch quant_level_range = self.quant_level_range - if cfg.experimental: + if self.p_rvq_levels == "equal": + # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] + else: # if self.p_rvq_levels == "auto": # makes higher levels less likely def generate( lo=0, hi=8 ): index = lo @@ -174,9 +183,6 @@ class AR_NAR(Base): return int(index) quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] - else: - # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index b584bd5..d9ebe37 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -74,6 +74,53 @@ class OrderedSampler(Sampler): self.position = state["position"] self.length = state["length"] +# Like the above, but will batch based on token count +class BatchedOrderedSampler(Sampler): + def __init__( self, buckets, max_duration=0, max_batch_size=0 ): + self.position = 0 + self.batches = [] + + assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" + + current_batch = [] + current_size = 0 + current_index = 0 + for key, bucket in buckets.items(): + for path, duration in bucket: + # flush + should_flush = False + if max_duration > 0 and current_size + duration > max_duration: + should_flush = True + elif max_batch_size > 0 and len(current_batch) >= max_batch_size: + should_flush = True + + if should_flush and len(current_batch) > 0: + self.batches.append( current_batch ) + current_batch = [] + current_size = 0 + + current_batch.append( current_index ) + current_index += 1 + current_size += duration + + def __len__(self): + return len(self.batches) + + def __iter__(self): + if self.position >= len(self.batches): + self.position = 0 + + while self.position < len(self.batches): + yield self.batches[self.position] + self.position += 1 + + def get_state(self): + return { "position": self.position, "batches": self.batches } + + def set_state(self, state): + self.position = state["position"] + self.batches = state["batches"] + # Randomly samples indices from a given sequence from 0 to length # Allows saving and loading state class RandomSampler(Sampler):