logger broke for some reason, added flag to just tqdm.write instead, make cfg.bitsandbytes.bitnet==True yamls denoted since I'm sure they're not interoperable

master
mrq 2024-03-01 10:32:35 +07:00
parent 35d78a2bb0
commit 0427d8d076
4 changed files with 13 additions and 3 deletions

@ -189,6 +189,9 @@ class Model:
else:
name.append(self.arch_type.replace("/", "-"))
if cfg.bitsandbytes.bitnet:
name.append("bitnet")
if self.interleave:
name.append("interleaved")
else:
@ -488,6 +491,7 @@ class Trainer:
amp: bool = False
load_webui: bool = False
no_logger: bool = False
backend: str = "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)

@ -146,7 +146,10 @@ def run_eval(engines, eval_name, dl):
}
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
if cfg.trainer.no_logger:
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
else:
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def train():

@ -16,7 +16,7 @@ def get_free_port():
_distributed_initialized = False
def init_distributed( fn, *args, **kwargs ):
print("Initializing distributed...")
#print("Initializing distributed...")
fn(*args, **kwargs)
_distributed_initialized = True

@ -175,7 +175,10 @@ def train(
elapsed_time = stats.get("elapsed_time", 0)
metrics = json.dumps(stats)
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
if cfg.trainer.no_logger:
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
else:
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
command = _non_blocking_input()