sort duration buckets to ensure that paths sorted-by-duration are actually sorted by duration (because i didnt know that python dicts can have non-strings as keys), added batching samples based on total duration to ensure best training throughput
This commit is contained in:
parent
5176ced35f
commit
83075c1505
|
@ -158,6 +158,8 @@ class Dataset:
|
|||
|
||||
sample_type: str = "path" # path | speaker
|
||||
sample_order: str = "shuffle" # duration
|
||||
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
|
||||
|
@ -197,28 +199,29 @@ class Dataset:
|
|||
@dataclass()
|
||||
class Model:
|
||||
name: str = "" # vanity name for the model
|
||||
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
|
||||
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
|
||||
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
||||
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
|
||||
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
langs: int = 1 # defined languages
|
||||
tones: int = 1 # defined tones
|
||||
experts: int = 1
|
||||
arch_type: str = "retnet" # or "transformer""
|
||||
training: bool = True # unneeded now
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
|
||||
langs: int = 1 # defined languages (semi-unused)
|
||||
tones: int = 1 # defined tones (unsued)
|
||||
experts: int = 1 # for mixtral / retnet-ts
|
||||
arch_type: str = "llama" # underling LM architecture used
|
||||
training: bool = True # I really need to attend to this
|
||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "auto"
|
||||
audio_embedding_sums: bool = True
|
||||
split_classifiers: bool = False
|
||||
attention: str = "auto" # for llama arch_types: attention used
|
||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||
split_classifiers: bool = False # experimental, but each RVQ level gets its own classifier / output proj / LM head
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
experimental: str | None = None # for now it sets things to be HF compatible
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
|
||||
def get(self, name=None):
|
||||
|
@ -338,8 +341,8 @@ class Model:
|
|||
class LoRA:
|
||||
name: str = "lora" # vanity name
|
||||
# to-do: find sane default values
|
||||
rank: int = 8 # rank for the LoRA
|
||||
alpha: int = 16 # rank for the LoRA
|
||||
rank: int = 128 # rank for the LoRA
|
||||
alpha: int = 128 # rank for the LoRA
|
||||
training: bool = True #
|
||||
parametrize: bool = False #
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
|
@ -349,6 +352,7 @@ class LoRA:
|
|||
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
||||
return "-".join(name)
|
||||
|
||||
# actually not needed anymore
|
||||
def active_level( self, level ):
|
||||
if not self.rvq_levels:
|
||||
return True
|
||||
|
@ -360,10 +364,10 @@ class Hyperparameters:
|
|||
gradient_accumulation_steps: int = 32
|
||||
gradient_clipping: int | float = 100
|
||||
|
||||
optimizer: str = "Adamw"
|
||||
optimizer: str = "Adamw" # should be 'Prodigyopt" now
|
||||
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
learning_rate: float = 3.25e-4
|
||||
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
||||
warmup_steps: int = 0
|
||||
|
||||
scheduler: str = ""
|
||||
|
@ -384,18 +388,18 @@ class Evaluation:
|
|||
|
||||
steps: int = 500
|
||||
ar_temperature: float = 1.0
|
||||
nar_temperature: float = 0.2
|
||||
nar_temperature: float = 0.0
|
||||
|
||||
load_disabled_engines: bool = True
|
||||
|
||||
@dataclass()
|
||||
class DeepSpeed:
|
||||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
compression_bits: int = 8
|
||||
inferencing: bool = False
|
||||
use_compression_training: bool = False # cope
|
||||
compression_bits: int = 8 # cope
|
||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||
|
||||
amp: bool = False
|
||||
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
|
@ -567,7 +571,7 @@ class Trainer:
|
|||
load_module_only: bool = False
|
||||
restart_step_count: bool = False
|
||||
|
||||
activation_checkpointing: bool | None = None # deprecated
|
||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||
gradient_checkpointing: bool = True
|
||||
|
||||
aggressive_optimizations: bool = False
|
||||
|
@ -612,17 +616,13 @@ class Inference:
|
|||
amp: bool = False
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
audio_backend: str = "" # encodec, vocos, dac
|
||||
|
||||
# legacy / backwards compat
|
||||
audio_backend: str = "" # encodec, vocos, dac
|
||||
use_vocos: bool = True
|
||||
use_encodec: bool = True
|
||||
use_dac: bool = True
|
||||
|
||||
# shit that doesn't work
|
||||
recurrent_chunk_size: int = 0
|
||||
recurrent_forward: bool = False
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
|
@ -726,6 +726,7 @@ class Config(BaseConfig):
|
|||
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# to-do: prune unused keys
|
||||
def format( self, training=True ):
|
||||
if isinstance(self.dataset, type):
|
||||
self.dataset = dict()
|
||||
|
|
|
@ -12,7 +12,7 @@ import itertools
|
|||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
|
||||
from collections import defaultdict
|
||||
|
@ -483,23 +483,29 @@ class Dataset(_Dataset):
|
|||
if self.sampler_order != "duration":
|
||||
continue
|
||||
|
||||
bucket = str(int(round(duration)))
|
||||
bucket = int(round(duration))
|
||||
if bucket not in self.duration_buckets:
|
||||
self.duration_buckets[bucket] = []
|
||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||
|
||||
# ensure they're ordered
|
||||
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
||||
|
||||
# sort by duration
|
||||
if self.sampler_order == "duration":
|
||||
flattened = {}
|
||||
# sort and interleave
|
||||
for bucket in self.duration_buckets:
|
||||
# sort by duration
|
||||
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
||||
# split to retain tuples
|
||||
flattened[bucket] = self.duration_buckets[bucket]
|
||||
# replace with path
|
||||
self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ]
|
||||
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||
# flatten by paths
|
||||
self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)]
|
||||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values()))
|
||||
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||
elif self.sampler_order == "shuffle":
|
||||
# just interleave
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
@ -536,12 +542,14 @@ class Dataset(_Dataset):
|
|||
|
||||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
|
||||
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
|
||||
if self.sampler_type == "path":
|
||||
self.sampler = OrderedSampler( len(self) )
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) )
|
||||
self.samplers = {}
|
||||
self.spkr_samplers = {}
|
||||
else:
|
||||
|
@ -1001,17 +1009,23 @@ def _create_dataloader(dataset, training):
|
|||
shuffle = False
|
||||
"""
|
||||
|
||||
kwargs = dict(
|
||||
shuffle=dataset.shuffle,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
drop_last=training,
|
||||
sampler=dataset.sampler,
|
||||
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
||||
batch_sampler=dataset.sampler,
|
||||
)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
shuffle=dataset.shuffle,
|
||||
drop_last=training,
|
||||
num_workers=cfg.dataset.workers,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=cfg.dataset.workers > 1,
|
||||
pin_memory=False, # True,
|
||||
worker_init_fn=_seed_worker,
|
||||
sampler=dataset.sampler,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_datasets():
|
||||
|
|
|
@ -72,6 +72,12 @@ class AR_NAR(Base):
|
|||
if hasattr(self, "config") and self.config:
|
||||
return self.config.tasks
|
||||
return cfg.model.tasks
|
||||
|
||||
@property
|
||||
def p_rvq_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.p_rvq_levels
|
||||
return cfg.model.p_rvq_levels
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
|
@ -163,7 +169,10 @@ class AR_NAR(Base):
|
|||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.quant_level_range
|
||||
|
||||
if cfg.experimental:
|
||||
if self.p_rvq_levels == "equal":
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if self.p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
|
@ -174,9 +183,6 @@ class AR_NAR(Base):
|
|||
return int(index)
|
||||
|
||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
else:
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
|
|
|
@ -74,6 +74,53 @@ class OrderedSampler(Sampler):
|
|||
self.position = state["position"]
|
||||
self.length = state["length"]
|
||||
|
||||
# Like the above, but will batch based on token count
|
||||
class BatchedOrderedSampler(Sampler):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0 ):
|
||||
self.position = 0
|
||||
self.batches = []
|
||||
|
||||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||
|
||||
current_batch = []
|
||||
current_size = 0
|
||||
current_index = 0
|
||||
for key, bucket in buckets.items():
|
||||
for path, duration in bucket:
|
||||
# flush
|
||||
should_flush = False
|
||||
if max_duration > 0 and current_size + duration > max_duration:
|
||||
should_flush = True
|
||||
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
|
||||
should_flush = True
|
||||
|
||||
if should_flush and len(current_batch) > 0:
|
||||
self.batches.append( current_batch )
|
||||
current_batch = []
|
||||
current_size = 0
|
||||
|
||||
current_batch.append( current_index )
|
||||
current_index += 1
|
||||
current_size += duration
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches)
|
||||
|
||||
def __iter__(self):
|
||||
if self.position >= len(self.batches):
|
||||
self.position = 0
|
||||
|
||||
while self.position < len(self.batches):
|
||||
yield self.batches[self.position]
|
||||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "batches": self.batches }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.batches = state["batches"]
|
||||
|
||||
# Randomly samples indices from a given sequence from 0 to length
|
||||
# Allows saving and loading state
|
||||
class RandomSampler(Sampler):
|
||||
|
|
Loading…
Reference in New Issue
Block a user