sanitizing

This commit is contained in:
mrq 2024-05-11 16:31:05 -05:00
parent 71e373064f
commit 0b6499601b
3 changed files with 9 additions and 0 deletions

View File

@ -212,6 +212,7 @@ class Model:
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "eager" # or flash_attention_2
audio_embedding_sums: bool = True
dropout: float = 0.1 # adjustable dropout value
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@ -687,6 +688,9 @@ class Config(_Config):
if self.dataset.prompt_duration != 0:
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True
# Preserves the old behavior
class NaiveTokenizer:
def get_vocab( self ):

View File

@ -10,6 +10,8 @@ def get_model(cfg, training=True):
n_layers=cfg.layers,
n_experts=cfg.experts,
p_dropout=cfg.dropout,
l_padding = cfg.input_alignment,
training = training,

View File

@ -151,6 +151,9 @@ def train(
last_save_step = engines.global_step
last_eval_step = 0
if cfg.distributed:
train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths)))
# Training loop
for batch in _make_infinite_epochs(train_dl):
if engines.global_step >= cfg.trainer.iterations: