sanitizing
This commit is contained in:
parent
71e373064f
commit
0b6499601b
|
@ -212,6 +212,7 @@ class Model:
|
||||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||||
attention: str = "eager" # or flash_attention_2
|
attention: str = "eager" # or flash_attention_2
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
return [ self ] if not name or self.name == name else []
|
return [ self ] if not name or self.name == name else []
|
||||||
|
@ -687,6 +688,9 @@ class Config(_Config):
|
||||||
if self.dataset.prompt_duration != 0:
|
if self.dataset.prompt_duration != 0:
|
||||||
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
|
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
|
||||||
|
|
||||||
|
if self.trainer.backend == "local" and self.distributed:
|
||||||
|
self.trainer.ddp = True
|
||||||
|
|
||||||
# Preserves the old behavior
|
# Preserves the old behavior
|
||||||
class NaiveTokenizer:
|
class NaiveTokenizer:
|
||||||
def get_vocab( self ):
|
def get_vocab( self ):
|
||||||
|
|
|
@ -10,6 +10,8 @@ def get_model(cfg, training=True):
|
||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
n_experts=cfg.experts,
|
n_experts=cfg.experts,
|
||||||
|
|
||||||
|
p_dropout=cfg.dropout,
|
||||||
|
|
||||||
l_padding = cfg.input_alignment,
|
l_padding = cfg.input_alignment,
|
||||||
|
|
||||||
training = training,
|
training = training,
|
||||||
|
|
|
@ -151,6 +151,9 @@ def train(
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
last_eval_step = 0
|
last_eval_step = 0
|
||||||
|
|
||||||
|
if cfg.distributed:
|
||||||
|
train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths)))
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for batch in _make_infinite_epochs(train_dl):
|
for batch in _make_infinite_epochs(train_dl):
|
||||||
if engines.global_step >= cfg.trainer.iterations:
|
if engines.global_step >= cfg.trainer.iterations:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user