logger broke for some reason, added flag to just tqdm.write instead, make cfg.bitsandbytes.bitnet==True yamls denoted since I'm sure they're not interoperable
This commit is contained in:
parent
35d78a2bb0
commit
0427d8d076
|
@ -189,6 +189,9 @@ class Model:
|
||||||
else:
|
else:
|
||||||
name.append(self.arch_type.replace("/", "-"))
|
name.append(self.arch_type.replace("/", "-"))
|
||||||
|
|
||||||
|
if cfg.bitsandbytes.bitnet:
|
||||||
|
name.append("bitnet")
|
||||||
|
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
name.append("interleaved")
|
name.append("interleaved")
|
||||||
else:
|
else:
|
||||||
|
@ -488,6 +491,7 @@ class Trainer:
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
|
||||||
load_webui: bool = False
|
load_webui: bool = False
|
||||||
|
no_logger: bool = False
|
||||||
|
|
||||||
backend: str = "local"
|
backend: str = "local"
|
||||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||||
|
|
|
@ -146,7 +146,10 @@ def run_eval(engines, eval_name, dl):
|
||||||
}
|
}
|
||||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||||
|
|
||||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
if cfg.trainer.no_logger:
|
||||||
|
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
else:
|
||||||
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
|
|
@ -16,7 +16,7 @@ def get_free_port():
|
||||||
|
|
||||||
_distributed_initialized = False
|
_distributed_initialized = False
|
||||||
def init_distributed( fn, *args, **kwargs ):
|
def init_distributed( fn, *args, **kwargs ):
|
||||||
print("Initializing distributed...")
|
#print("Initializing distributed...")
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
_distributed_initialized = True
|
_distributed_initialized = True
|
||||||
|
|
||||||
|
|
|
@ -175,7 +175,10 @@ def train(
|
||||||
elapsed_time = stats.get("elapsed_time", 0)
|
elapsed_time = stats.get("elapsed_time", 0)
|
||||||
metrics = json.dumps(stats)
|
metrics = json.dumps(stats)
|
||||||
|
|
||||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
if cfg.trainer.no_logger:
|
||||||
|
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
|
||||||
|
else:
|
||||||
|
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||||
|
|
||||||
command = _non_blocking_input()
|
command = _non_blocking_input()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user