logger broke for some reason, added flag to just tqdm.write instead, make cfg.bitsandbytes.bitnet==True yamls denoted since I'm sure they're not interoperable
This commit is contained in:
parent
35d78a2bb0
commit
0427d8d076
|
@ -189,6 +189,9 @@ class Model:
|
|||
else:
|
||||
name.append(self.arch_type.replace("/", "-"))
|
||||
|
||||
if cfg.bitsandbytes.bitnet:
|
||||
name.append("bitnet")
|
||||
|
||||
if self.interleave:
|
||||
name.append("interleaved")
|
||||
else:
|
||||
|
@ -488,6 +491,7 @@ class Trainer:
|
|||
amp: bool = False
|
||||
|
||||
load_webui: bool = False
|
||||
no_logger: bool = False
|
||||
|
||||
backend: str = "local"
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
|
|
|
@ -146,7 +146,10 @@ def run_eval(engines, eval_name, dl):
|
|||
}
|
||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
if cfg.trainer.no_logger:
|
||||
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
else:
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
|
||||
def train():
|
||||
|
|
|
@ -16,7 +16,7 @@ def get_free_port():
|
|||
|
||||
_distributed_initialized = False
|
||||
def init_distributed( fn, *args, **kwargs ):
|
||||
print("Initializing distributed...")
|
||||
#print("Initializing distributed...")
|
||||
fn(*args, **kwargs)
|
||||
_distributed_initialized = True
|
||||
|
||||
|
|
|
@ -175,7 +175,10 @@ def train(
|
|||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
metrics = json.dumps(stats)
|
||||
|
||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
if cfg.trainer.no_logger:
|
||||
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
else:
|
||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
|
||||
command = _non_blocking_input()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user