diff --git a/vall_e/config.py b/vall_e/config.py index 8ae3309..8185d47 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -212,6 +212,7 @@ class Model: frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training attention: str = "eager" # or flash_attention_2 audio_embedding_sums: bool = True + dropout: float = 0.1 # adjustable dropout value def get(self, name=None): return [ self ] if not name or self.name == name else [] @@ -687,6 +688,9 @@ class Config(_Config): if self.dataset.prompt_duration != 0: self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration] + if self.trainer.backend == "local" and self.distributed: + self.trainer.ddp = True + # Preserves the old behavior class NaiveTokenizer: def get_vocab( self ): diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 5979e44..59a86d7 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -10,6 +10,8 @@ def get_model(cfg, training=True): n_layers=cfg.layers, n_experts=cfg.experts, + p_dropout=cfg.dropout, + l_padding = cfg.input_alignment, training = training, diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 7d4b3b5..230e49b 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -151,6 +151,9 @@ def train( last_save_step = engines.global_step last_eval_step = 0 + if cfg.distributed: + train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths))) + # Training loop for batch in _make_infinite_epochs(train_dl): if engines.global_step >= cfg.trainer.iterations: