diff --git a/vall_e/config.py b/vall_e/config.py index 2c773c5..2968992 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -614,7 +614,7 @@ class Trainer: amp: bool = False # automatic mixed precision ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested - scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload) + #scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload) load_webui: bool = False # not working, but loads the web UI to allow inferencing during training no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet @@ -634,14 +634,12 @@ class Trainer: return torch.float8_e4m3fn return torch.float32 - """ @cached_property def scale_loss(self): # currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways) - if self.backend != "local": - return False return self.dtype == torch.float16 """ + """ @dataclass() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index c70f0ff..18bea78 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -316,7 +316,7 @@ class AR_NAR(Base): def example_usage(): - cfg.trainer.backend = "local" + # cfg.trainer.backend = "local" cfg.hyperparameters.gradient_accumulation_steps = 1 if cfg.audio_backend == "dac": cfg.sample_rate = 44_100