sanitizing
This commit is contained in:
parent
71e373064f
commit
0b6499601b
|
@ -212,6 +212,7 @@ class Model:
|
|||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "eager" # or flash_attention_2
|
||||
audio_embedding_sums: bool = True
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
@ -687,6 +688,9 @@ class Config(_Config):
|
|||
if self.dataset.prompt_duration != 0:
|
||||
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
|
||||
|
||||
if self.trainer.backend == "local" and self.distributed:
|
||||
self.trainer.ddp = True
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
def get_vocab( self ):
|
||||
|
|
|
@ -10,6 +10,8 @@ def get_model(cfg, training=True):
|
|||
n_layers=cfg.layers,
|
||||
n_experts=cfg.experts,
|
||||
|
||||
p_dropout=cfg.dropout,
|
||||
|
||||
l_padding = cfg.input_alignment,
|
||||
|
||||
training = training,
|
||||
|
|
|
@ -151,6 +151,9 @@ def train(
|
|||
last_save_step = engines.global_step
|
||||
last_eval_step = 0
|
||||
|
||||
if cfg.distributed:
|
||||
train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths)))
|
||||
|
||||
# Training loop
|
||||
for batch in _make_infinite_epochs(train_dl):
|
||||
if engines.global_step >= cfg.trainer.iterations:
|
||||
|
|
Loading…
Reference in New Issue
Block a user