added total samples processed and tokens processed (len of text tokens + len of target response tokens)

This commit is contained in:
mrq 2023-08-28 11:02:45 -05:00
parent 87c4bfedba
commit 7f4388e591
5 changed files with 33 additions and 10 deletions

View File

@ -746,6 +746,12 @@ if __name__ == "__main__":
print(text, task, cfg.models.prom_levels)
print( proms.shape, resps.shape )
tokens = 0
tokens += sum([ text.shape[0] for text in batch["text"] ])
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
print( tokens )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
break

View File

@ -61,7 +61,8 @@ class Engine():
self.global_steps = 0
self.micro_steps = 0
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
self.global_samples = 0
self.tokens_processed = 0
def freeze(self):
for p in self.module.parameters():
@ -90,6 +91,10 @@ class Engine():
def batch_size(self):
return cfg.hyperparameters.batch_size
@property
def gradient_accumulation_steps(self):
return cfg.hyperparameters.gradient_accumulation_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
@ -100,11 +105,14 @@ class Engine():
save_path = save_dir / tag / "state.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"global_step": self.global_step,
"micro_step": self.micro_step,
"module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
"global_step": self.global_step,
"micro_step": self.micro_step,
"global_samples": self.global_samples,
"tokens_processed": self.tokens_processed,
}, save_path)
open(save_dir / "latest", 'w').write( tag )
@ -123,6 +131,8 @@ class Engine():
state = torch.load(load_path)
self.global_steps = state['global_step']
self.micro_steps = state['micro_step']
self.global_samples = state['global_samples']
self.tokens_processed = state['tokens_processed']
self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
@ -163,6 +173,7 @@ class Engine():
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
self.global_samples += self.batch_size
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
self.global_steps += 1
@ -235,9 +246,11 @@ class Engines(dict[str, Engine]):
for name, engine in self.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
state_dict = {
'module': engine.module.state_dict(),
"global_step": engine.global_step,
"micro_step": engine.micro_step,
'module': engine.module.state_dict(),
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
}
state_dict.update(userdata)
torch.save(state_dict, outpath)
@ -427,6 +440,8 @@ class Engines(dict[str, Engine]):
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed,
**engine_stats,
)
}
@ -435,9 +450,8 @@ class Engines(dict[str, Engine]):
self._update()
stats["batch_size"] = self.batch_size
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
#stats["micro_step"] = self.micro_step
return stats

View File

@ -41,6 +41,8 @@ class Engine(DeepSpeedEngine):
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
self.tokens_processed = 0
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
@ -62,7 +64,7 @@ class Engine(DeepSpeedEngine):
@property
def micro_step(self):
return self.micro_steps
return self.micro_steps
@property
def batch_size(self):

View File

@ -38,6 +38,9 @@ def train_feeder(engine, batch):
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
return loss, stats
@torch.inference_mode()

View File

@ -256,7 +256,7 @@ def train(
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
stats['batch'] = {
'size': stats['batch_size'],
'size': len(batch['text']),
'id': batch['spkr_id'],
'index': [ index for index in batch['index'] ],
'text_len': [ text.shape[0] for text in batch['text'] ],
@ -264,8 +264,6 @@ def train(
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
del stats['batch_size']
del stats['wall_time']
del stats['global_step']
elapsed_time = stats.get("elapsed_time", 0)