added total samples processed and tokens processed (len of text tokens + len of target response tokens)
This commit is contained in:
parent
87c4bfedba
commit
7f4388e591
|
@ -746,6 +746,12 @@ if __name__ == "__main__":
|
|||
|
||||
print(text, task, cfg.models.prom_levels)
|
||||
print( proms.shape, resps.shape )
|
||||
|
||||
tokens = 0
|
||||
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
||||
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
print( tokens )
|
||||
|
||||
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||
break
|
|
@ -61,7 +61,8 @@ class Engine():
|
|||
|
||||
self.global_steps = 0
|
||||
self.micro_steps = 0
|
||||
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
|
||||
self.global_samples = 0
|
||||
self.tokens_processed = 0
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
|
@ -90,6 +91,10 @@ class Engine():
|
|||
def batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
|
||||
@property
|
||||
def gradient_accumulation_steps(self):
|
||||
return cfg.hyperparameters.gradient_accumulation_steps
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
||||
|
@ -100,11 +105,14 @@ class Engine():
|
|||
save_path = save_dir / tag / "state.pth"
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save({
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"module": self.module.state_dict(),
|
||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"global_samples": self.global_samples,
|
||||
"tokens_processed": self.tokens_processed,
|
||||
}, save_path)
|
||||
|
||||
open(save_dir / "latest", 'w').write( tag )
|
||||
|
@ -123,6 +131,8 @@ class Engine():
|
|||
state = torch.load(load_path)
|
||||
self.global_steps = state['global_step']
|
||||
self.micro_steps = state['micro_step']
|
||||
self.global_samples = state['global_samples']
|
||||
self.tokens_processed = state['tokens_processed']
|
||||
self.module.load_state_dict(state['module'])
|
||||
|
||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||
|
@ -163,6 +173,7 @@ class Engine():
|
|||
def step(self):
|
||||
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
||||
self.micro_steps += 1
|
||||
self.global_samples += self.batch_size
|
||||
|
||||
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
|
||||
self.global_steps += 1
|
||||
|
@ -235,9 +246,11 @@ class Engines(dict[str, Engine]):
|
|||
for name, engine in self.items():
|
||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state_dict = {
|
||||
'module': engine.module.state_dict(),
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
'module': engine.module.state_dict(),
|
||||
"global_samples": engine.global_samples,
|
||||
"tokens_processed": engine.tokens_processed,
|
||||
}
|
||||
state_dict.update(userdata)
|
||||
torch.save(state_dict, outpath)
|
||||
|
@ -427,6 +440,8 @@ class Engines(dict[str, Engine]):
|
|||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
samples_processed=engine.global_samples,
|
||||
tokens_processed=engine.tokens_processed,
|
||||
**engine_stats,
|
||||
)
|
||||
}
|
||||
|
@ -435,9 +450,8 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
self._update()
|
||||
|
||||
stats["batch_size"] = self.batch_size
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
stats["wall_time"] = time.time()
|
||||
stats["global_step"] = self.global_step
|
||||
#stats["micro_step"] = self.micro_step
|
||||
|
||||
return stats
|
||||
|
|
|
@ -41,6 +41,8 @@ class Engine(DeepSpeedEngine):
|
|||
super().__init__(None, *args, **kwargs)
|
||||
self._frozen_params = set()
|
||||
|
||||
self.tokens_processed = 0
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
|
|
|
@ -38,6 +38,9 @@ def train_feeder(engine, batch):
|
|||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
|
||||
return loss, stats
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -256,7 +256,7 @@ def train(
|
|||
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
|
||||
stats['batch'] = {
|
||||
'size': stats['batch_size'],
|
||||
'size': len(batch['text']),
|
||||
'id': batch['spkr_id'],
|
||||
'index': [ index for index in batch['index'] ],
|
||||
'text_len': [ text.shape[0] for text in batch['text'] ],
|
||||
|
@ -264,8 +264,6 @@ def train(
|
|||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||
}
|
||||
|
||||
del stats['batch_size']
|
||||
del stats['wall_time']
|
||||
del stats['global_step']
|
||||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user