added total samples processed and tokens processed (len of text tokens + len of target response tokens)
This commit is contained in:
parent
87c4bfedba
commit
7f4388e591
|
@ -746,6 +746,12 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print(text, task, cfg.models.prom_levels)
|
print(text, task, cfg.models.prom_levels)
|
||||||
print( proms.shape, resps.shape )
|
print( proms.shape, resps.shape )
|
||||||
|
|
||||||
|
tokens = 0
|
||||||
|
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
||||||
|
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||||
|
print( tokens )
|
||||||
|
|
||||||
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
||||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||||
break
|
break
|
|
@ -61,7 +61,8 @@ class Engine():
|
||||||
|
|
||||||
self.global_steps = 0
|
self.global_steps = 0
|
||||||
self.micro_steps = 0
|
self.micro_steps = 0
|
||||||
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
|
self.global_samples = 0
|
||||||
|
self.tokens_processed = 0
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
|
@ -90,6 +91,10 @@ class Engine():
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return cfg.hyperparameters.batch_size
|
return cfg.hyperparameters.batch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gradient_accumulation_steps(self):
|
||||||
|
return cfg.hyperparameters.gradient_accumulation_steps
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
def gather_attribute(self, *args, **kwargs):
|
||||||
return gather_attribute(self.module, *args, **kwargs)
|
return gather_attribute(self.module, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -100,11 +105,14 @@ class Engine():
|
||||||
save_path = save_dir / tag / "state.pth"
|
save_path = save_dir / tag / "state.pth"
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save({
|
torch.save({
|
||||||
"global_step": self.global_step,
|
|
||||||
"micro_step": self.micro_step,
|
|
||||||
"module": self.module.state_dict(),
|
"module": self.module.state_dict(),
|
||||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||||
|
|
||||||
|
"global_step": self.global_step,
|
||||||
|
"micro_step": self.micro_step,
|
||||||
|
"global_samples": self.global_samples,
|
||||||
|
"tokens_processed": self.tokens_processed,
|
||||||
}, save_path)
|
}, save_path)
|
||||||
|
|
||||||
open(save_dir / "latest", 'w').write( tag )
|
open(save_dir / "latest", 'w').write( tag )
|
||||||
|
@ -123,6 +131,8 @@ class Engine():
|
||||||
state = torch.load(load_path)
|
state = torch.load(load_path)
|
||||||
self.global_steps = state['global_step']
|
self.global_steps = state['global_step']
|
||||||
self.micro_steps = state['micro_step']
|
self.micro_steps = state['micro_step']
|
||||||
|
self.global_samples = state['global_samples']
|
||||||
|
self.tokens_processed = state['tokens_processed']
|
||||||
self.module.load_state_dict(state['module'])
|
self.module.load_state_dict(state['module'])
|
||||||
|
|
||||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||||
|
@ -163,6 +173,7 @@ class Engine():
|
||||||
def step(self):
|
def step(self):
|
||||||
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
||||||
self.micro_steps += 1
|
self.micro_steps += 1
|
||||||
|
self.global_samples += self.batch_size
|
||||||
|
|
||||||
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
|
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
|
||||||
self.global_steps += 1
|
self.global_steps += 1
|
||||||
|
@ -235,9 +246,11 @@ class Engines(dict[str, Engine]):
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
'module': engine.module.state_dict(),
|
||||||
"global_step": engine.global_step,
|
"global_step": engine.global_step,
|
||||||
"micro_step": engine.micro_step,
|
"micro_step": engine.micro_step,
|
||||||
'module': engine.module.state_dict(),
|
"global_samples": engine.global_samples,
|
||||||
|
"tokens_processed": engine.tokens_processed,
|
||||||
}
|
}
|
||||||
state_dict.update(userdata)
|
state_dict.update(userdata)
|
||||||
torch.save(state_dict, outpath)
|
torch.save(state_dict, outpath)
|
||||||
|
@ -427,6 +440,8 @@ class Engines(dict[str, Engine]):
|
||||||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
||||||
elapsed_time=elapsed_time,
|
elapsed_time=elapsed_time,
|
||||||
engine_step=engine.global_step,
|
engine_step=engine.global_step,
|
||||||
|
samples_processed=engine.global_samples,
|
||||||
|
tokens_processed=engine.tokens_processed,
|
||||||
**engine_stats,
|
**engine_stats,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -435,9 +450,8 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
self._update()
|
self._update()
|
||||||
|
|
||||||
stats["batch_size"] = self.batch_size
|
|
||||||
stats["elapsed_time"] = total_elapsed_time
|
stats["elapsed_time"] = total_elapsed_time
|
||||||
stats["wall_time"] = time.time()
|
|
||||||
stats["global_step"] = self.global_step
|
stats["global_step"] = self.global_step
|
||||||
|
#stats["micro_step"] = self.micro_step
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
|
@ -41,6 +41,8 @@ class Engine(DeepSpeedEngine):
|
||||||
super().__init__(None, *args, **kwargs)
|
super().__init__(None, *args, **kwargs)
|
||||||
self._frozen_params = set()
|
self._frozen_params = set()
|
||||||
|
|
||||||
|
self.tokens_processed = 0
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if p.requires_grad:
|
if p.requires_grad:
|
||||||
|
|
|
@ -38,6 +38,9 @@ def train_feeder(engine, batch):
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
stats |= {k: v.item() for k, v in stat.items()}
|
stats |= {k: v.item() for k, v in stat.items()}
|
||||||
|
|
||||||
|
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||||
|
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||||
|
|
||||||
return loss, stats
|
return loss, stats
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
|
@ -256,7 +256,7 @@ def train(
|
||||||
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||||
|
|
||||||
stats['batch'] = {
|
stats['batch'] = {
|
||||||
'size': stats['batch_size'],
|
'size': len(batch['text']),
|
||||||
'id': batch['spkr_id'],
|
'id': batch['spkr_id'],
|
||||||
'index': [ index for index in batch['index'] ],
|
'index': [ index for index in batch['index'] ],
|
||||||
'text_len': [ text.shape[0] for text in batch['text'] ],
|
'text_len': [ text.shape[0] for text in batch['text'] ],
|
||||||
|
@ -264,8 +264,6 @@ def train(
|
||||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||||
}
|
}
|
||||||
|
|
||||||
del stats['batch_size']
|
|
||||||
del stats['wall_time']
|
|
||||||
del stats['global_step']
|
del stats['global_step']
|
||||||
|
|
||||||
elapsed_time = stats.get("elapsed_time", 0)
|
elapsed_time = stats.get("elapsed_time", 0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user