diff --git a/vall_e/config.py b/vall_e/config.py index 63306cb..2c773c5 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -480,6 +480,7 @@ class DeepSpeed: "fp16": { "enabled": cfg.trainer.weight_dtype.lower() == "float16", "auto_cast": True, # ??? + "loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0, }, "bf16": { "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", @@ -613,6 +614,7 @@ class Trainer: amp: bool = False # automatic mixed precision ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested + scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload) load_webui: bool = False # not working, but loads the web UI to allow inferencing during training no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet @@ -632,12 +634,14 @@ class Trainer: return torch.float8_e4m3fn return torch.float32 + """ @cached_property def scale_loss(self): # currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways) if self.backend != "local": return False return self.dtype == torch.float16 + """ @dataclass() diff --git a/vall_e/data.py b/vall_e/data.py index f3c4b4f..edcf618 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1540,6 +1540,7 @@ if __name__ == "__main__": for k, v in samples.items(): for i in range(len(v)): for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."): + """ """ try: decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" ) @@ -1549,7 +1550,6 @@ if __name__ == "__main__": decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" ) except Exception as e: print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e)) - """ v[i]['proms'][j] = v[i]['proms'][j].shape v[i]['resps'][j] = v[i]['resps'][j].shape