sort duration buckets to ensure that paths sorted-by-duration are actually sorted by duration (because i didnt know that python dicts can have non-strings as keys), added batching samples based on total duration to ensure best training throughput
This commit is contained in:
parent
5176ced35f
commit
83075c1505
|
@ -158,6 +158,8 @@ class Dataset:
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
sample_order: str = "shuffle" # duration
|
sample_order: str = "shuffle" # duration
|
||||||
|
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||||
|
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
|
||||||
|
@ -197,28 +199,29 @@ class Dataset:
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
name: str = "" # vanity name for the model
|
name: str = "" # vanity name for the model
|
||||||
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
|
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
|
||||||
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
||||||
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
|
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
|
||||||
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
|
||||||
langs: int = 1 # defined languages
|
langs: int = 1 # defined languages (semi-unused)
|
||||||
tones: int = 1 # defined tones
|
tones: int = 1 # defined tones (unsued)
|
||||||
experts: int = 1
|
experts: int = 1 # for mixtral / retnet-ts
|
||||||
arch_type: str = "retnet" # or "transformer""
|
arch_type: str = "llama" # underling LM architecture used
|
||||||
training: bool = True # unneeded now
|
training: bool = True # I really need to attend to this
|
||||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||||
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
|
||||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||||
attention: str = "auto"
|
attention: str = "auto" # for llama arch_types: attention used
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||||
split_classifiers: bool = False
|
split_classifiers: bool = False # experimental, but each RVQ level gets its own classifier / output proj / LM head
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||||
loss_factors: dict = field(default_factory=lambda: {})
|
loss_factors: dict = field(default_factory=lambda: {})
|
||||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||||
experimental: str | None = None # for now it sets things to be HF compatible
|
experimental: str | None = None # for now it sets things to be HF compatible
|
||||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||||
|
|
||||||
|
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
@ -338,8 +341,8 @@ class Model:
|
||||||
class LoRA:
|
class LoRA:
|
||||||
name: str = "lora" # vanity name
|
name: str = "lora" # vanity name
|
||||||
# to-do: find sane default values
|
# to-do: find sane default values
|
||||||
rank: int = 8 # rank for the LoRA
|
rank: int = 128 # rank for the LoRA
|
||||||
alpha: int = 16 # rank for the LoRA
|
alpha: int = 128 # rank for the LoRA
|
||||||
training: bool = True #
|
training: bool = True #
|
||||||
parametrize: bool = False #
|
parametrize: bool = False #
|
||||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||||
|
@ -349,6 +352,7 @@ class LoRA:
|
||||||
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
||||||
return "-".join(name)
|
return "-".join(name)
|
||||||
|
|
||||||
|
# actually not needed anymore
|
||||||
def active_level( self, level ):
|
def active_level( self, level ):
|
||||||
if not self.rvq_levels:
|
if not self.rvq_levels:
|
||||||
return True
|
return True
|
||||||
|
@ -360,10 +364,10 @@ class Hyperparameters:
|
||||||
gradient_accumulation_steps: int = 32
|
gradient_accumulation_steps: int = 32
|
||||||
gradient_clipping: int | float = 100
|
gradient_clipping: int | float = 100
|
||||||
|
|
||||||
optimizer: str = "Adamw"
|
optimizer: str = "Adamw" # should be 'Prodigyopt" now
|
||||||
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
learning_rate: float = 3.25e-4
|
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
||||||
warmup_steps: int = 0
|
warmup_steps: int = 0
|
||||||
|
|
||||||
scheduler: str = ""
|
scheduler: str = ""
|
||||||
|
@ -384,18 +388,18 @@ class Evaluation:
|
||||||
|
|
||||||
steps: int = 500
|
steps: int = 500
|
||||||
ar_temperature: float = 1.0
|
ar_temperature: float = 1.0
|
||||||
nar_temperature: float = 0.2
|
nar_temperature: float = 0.0
|
||||||
|
|
||||||
load_disabled_engines: bool = True
|
load_disabled_engines: bool = True
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class DeepSpeed:
|
class DeepSpeed:
|
||||||
zero_optimization_level: int = 0
|
zero_optimization_level: int = 0
|
||||||
use_compression_training: bool = False
|
use_compression_training: bool = False # cope
|
||||||
compression_bits: int = 8
|
compression_bits: int = 8 # cope
|
||||||
inferencing: bool = False
|
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||||
|
|
||||||
amp: bool = False
|
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||||
|
|
||||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
|
@ -567,7 +571,7 @@ class Trainer:
|
||||||
load_module_only: bool = False
|
load_module_only: bool = False
|
||||||
restart_step_count: bool = False
|
restart_step_count: bool = False
|
||||||
|
|
||||||
activation_checkpointing: bool | None = None # deprecated
|
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||||
gradient_checkpointing: bool = True
|
gradient_checkpointing: bool = True
|
||||||
|
|
||||||
aggressive_optimizations: bool = False
|
aggressive_optimizations: bool = False
|
||||||
|
@ -612,17 +616,13 @@ class Inference:
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
|
||||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||||
audio_backend: str = "" # encodec, vocos, dac
|
|
||||||
|
|
||||||
# legacy / backwards compat
|
# legacy / backwards compat
|
||||||
|
audio_backend: str = "" # encodec, vocos, dac
|
||||||
use_vocos: bool = True
|
use_vocos: bool = True
|
||||||
use_encodec: bool = True
|
use_encodec: bool = True
|
||||||
use_dac: bool = True
|
use_dac: bool = True
|
||||||
|
|
||||||
# shit that doesn't work
|
|
||||||
recurrent_chunk_size: int = 0
|
|
||||||
recurrent_forward: bool = False
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
if self.weight_dtype == "float16":
|
||||||
|
@ -726,6 +726,7 @@ class Config(BaseConfig):
|
||||||
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
|
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
|
||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
|
# to-do: prune unused keys
|
||||||
def format( self, training=True ):
|
def format( self, training=True ):
|
||||||
if isinstance(self.dataset, type):
|
if isinstance(self.dataset, type):
|
||||||
self.dataset = dict()
|
self.dataset = dict()
|
||||||
|
|
|
@ -12,7 +12,7 @@ import itertools
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||||
from .utils.distributed import global_rank, local_rank, world_size
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
@ -483,23 +483,29 @@ class Dataset(_Dataset):
|
||||||
if self.sampler_order != "duration":
|
if self.sampler_order != "duration":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bucket = str(int(round(duration)))
|
bucket = int(round(duration))
|
||||||
if bucket not in self.duration_buckets:
|
if bucket not in self.duration_buckets:
|
||||||
self.duration_buckets[bucket] = []
|
self.duration_buckets[bucket] = []
|
||||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||||
|
|
||||||
|
# ensure they're ordered
|
||||||
|
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
||||||
|
|
||||||
# sort by duration
|
# sort by duration
|
||||||
if self.sampler_order == "duration":
|
if self.sampler_order == "duration":
|
||||||
|
flattened = {}
|
||||||
# sort and interleave
|
# sort and interleave
|
||||||
for bucket in self.duration_buckets:
|
for bucket in self.duration_buckets:
|
||||||
# sort by duration
|
# sort by duration
|
||||||
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
||||||
|
# split to retain tuples
|
||||||
|
flattened[bucket] = self.duration_buckets[bucket]
|
||||||
# replace with path
|
# replace with path
|
||||||
self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ]
|
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||||
# flatten by paths
|
# flatten by paths
|
||||||
self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)]
|
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
||||||
# flatten paths
|
# flatten paths
|
||||||
self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values()))
|
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||||
elif self.sampler_order == "shuffle":
|
elif self.sampler_order == "shuffle":
|
||||||
# just interleave
|
# just interleave
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
@ -536,12 +542,14 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
if len(self.paths) == 0:
|
if len(self.paths) == 0:
|
||||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
|
|
||||||
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
self.sampler = OrderedSampler( len(self) )
|
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||||
|
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
|
||||||
|
else:
|
||||||
|
self.sampler = OrderedSampler( len(self) )
|
||||||
self.samplers = {}
|
self.samplers = {}
|
||||||
self.spkr_samplers = {}
|
self.spkr_samplers = {}
|
||||||
else:
|
else:
|
||||||
|
@ -1001,17 +1009,23 @@ def _create_dataloader(dataset, training):
|
||||||
shuffle = False
|
shuffle = False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
shuffle=dataset.shuffle,
|
||||||
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||||
|
drop_last=training,
|
||||||
|
sampler=dataset.sampler,
|
||||||
|
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
||||||
|
batch_sampler=dataset.sampler,
|
||||||
|
)
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
|
||||||
shuffle=dataset.shuffle,
|
|
||||||
drop_last=training,
|
|
||||||
num_workers=cfg.dataset.workers,
|
num_workers=cfg.dataset.workers,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
persistent_workers=cfg.dataset.workers > 1,
|
persistent_workers=cfg.dataset.workers > 1,
|
||||||
pin_memory=False, # True,
|
pin_memory=False, # True,
|
||||||
worker_init_fn=_seed_worker,
|
worker_init_fn=_seed_worker,
|
||||||
sampler=dataset.sampler,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_datasets():
|
def create_datasets():
|
||||||
|
|
|
@ -72,6 +72,12 @@ class AR_NAR(Base):
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
return self.config.tasks
|
return self.config.tasks
|
||||||
return cfg.model.tasks
|
return cfg.model.tasks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p_rvq_levels(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.p_rvq_levels
|
||||||
|
return cfg.model.p_rvq_levels
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_langs(self) -> int:
|
def n_langs(self) -> int:
|
||||||
|
@ -163,7 +169,10 @@ class AR_NAR(Base):
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.quant_level_range
|
quant_level_range = self.quant_level_range
|
||||||
|
|
||||||
if cfg.experimental:
|
if self.p_rvq_levels == "equal":
|
||||||
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
|
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||||
|
else: # if self.p_rvq_levels == "auto":
|
||||||
# makes higher levels less likely
|
# makes higher levels less likely
|
||||||
def generate( lo=0, hi=8 ):
|
def generate( lo=0, hi=8 ):
|
||||||
index = lo
|
index = lo
|
||||||
|
@ -174,9 +183,6 @@ class AR_NAR(Base):
|
||||||
return int(index)
|
return int(index)
|
||||||
|
|
||||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||||
else:
|
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
|
||||||
|
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,53 @@ class OrderedSampler(Sampler):
|
||||||
self.position = state["position"]
|
self.position = state["position"]
|
||||||
self.length = state["length"]
|
self.length = state["length"]
|
||||||
|
|
||||||
|
# Like the above, but will batch based on token count
|
||||||
|
class BatchedOrderedSampler(Sampler):
|
||||||
|
def __init__( self, buckets, max_duration=0, max_batch_size=0 ):
|
||||||
|
self.position = 0
|
||||||
|
self.batches = []
|
||||||
|
|
||||||
|
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||||
|
|
||||||
|
current_batch = []
|
||||||
|
current_size = 0
|
||||||
|
current_index = 0
|
||||||
|
for key, bucket in buckets.items():
|
||||||
|
for path, duration in bucket:
|
||||||
|
# flush
|
||||||
|
should_flush = False
|
||||||
|
if max_duration > 0 and current_size + duration > max_duration:
|
||||||
|
should_flush = True
|
||||||
|
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
|
||||||
|
should_flush = True
|
||||||
|
|
||||||
|
if should_flush and len(current_batch) > 0:
|
||||||
|
self.batches.append( current_batch )
|
||||||
|
current_batch = []
|
||||||
|
current_size = 0
|
||||||
|
|
||||||
|
current_batch.append( current_index )
|
||||||
|
current_index += 1
|
||||||
|
current_size += duration
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batches)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.position >= len(self.batches):
|
||||||
|
self.position = 0
|
||||||
|
|
||||||
|
while self.position < len(self.batches):
|
||||||
|
yield self.batches[self.position]
|
||||||
|
self.position += 1
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return { "position": self.position, "batches": self.batches }
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.position = state["position"]
|
||||||
|
self.batches = state["batches"]
|
||||||
|
|
||||||
# Randomly samples indices from a given sequence from 0 to length
|
# Randomly samples indices from a given sequence from 0 to length
|
||||||
# Allows saving and loading state
|
# Allows saving and loading state
|
||||||
class RandomSampler(Sampler):
|
class RandomSampler(Sampler):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user