maybe not
This commit is contained in:
parent
c658a7b440
commit
ed373957e2
|
@ -614,7 +614,7 @@ class Trainer:
|
||||||
|
|
||||||
amp: bool = False # automatic mixed precision
|
amp: bool = False # automatic mixed precision
|
||||||
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
||||||
scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
|
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
|
||||||
|
|
||||||
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
||||||
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
||||||
|
@ -634,14 +634,12 @@ class Trainer:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
"""
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def scale_loss(self):
|
def scale_loss(self):
|
||||||
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
||||||
if self.backend != "local":
|
|
||||||
return False
|
|
||||||
return self.dtype == torch.float16
|
return self.dtype == torch.float16
|
||||||
"""
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
|
|
@ -316,7 +316,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
cfg.trainer.backend = "local"
|
# cfg.trainer.backend = "local"
|
||||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
cfg.sample_rate = 44_100
|
cfg.sample_rate = 44_100
|
||||||
|
|
Loading…
Reference in New Issue
Block a user