From 7f4388e59138cf3d57313c8261459d5ae0904ae5 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 28 Aug 2023 11:02:45 -0500 Subject: [PATCH] added total samples processed and tokens processed (len of text tokens + len of target response tokens) --- vall_e/data.py | 6 ++++++ vall_e/engines/base.py | 26 ++++++++++++++++++++------ vall_e/engines/deepspeed.py | 4 +++- vall_e/train.py | 3 +++ vall_e/utils/trainer.py | 4 +--- 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 5432662..3ed852d 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -746,6 +746,12 @@ if __name__ == "__main__": print(text, task, cfg.models.prom_levels) print( proms.shape, resps.shape ) + + tokens = 0 + tokens += sum([ text.shape[0] for text in batch["text"] ]) + tokens += sum([ resps.shape[0] for resps in batch["resps"] ]) + print( tokens ) + decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" ) decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" ) break \ No newline at end of file diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index a697b83..b8cca40 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -61,7 +61,8 @@ class Engine(): self.global_steps = 0 self.micro_steps = 0 - self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps + self.global_samples = 0 + self.tokens_processed = 0 def freeze(self): for p in self.module.parameters(): @@ -90,6 +91,10 @@ class Engine(): def batch_size(self): return cfg.hyperparameters.batch_size + @property + def gradient_accumulation_steps(self): + return cfg.hyperparameters.gradient_accumulation_steps + def gather_attribute(self, *args, **kwargs): return gather_attribute(self.module, *args, **kwargs) @@ -100,11 +105,14 @@ class Engine(): save_path = save_dir / tag / "state.pth" save_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ - "global_step": self.global_step, - "micro_step": self.micro_step, "module": self.module.state_dict(), "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, + + "global_step": self.global_step, + "micro_step": self.micro_step, + "global_samples": self.global_samples, + "tokens_processed": self.tokens_processed, }, save_path) open(save_dir / "latest", 'w').write( tag ) @@ -123,6 +131,8 @@ class Engine(): state = torch.load(load_path) self.global_steps = state['global_step'] self.micro_steps = state['micro_step'] + self.global_samples = state['global_samples'] + self.tokens_processed = state['tokens_processed'] self.module.load_state_dict(state['module']) load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state @@ -163,6 +173,7 @@ class Engine(): def step(self): with torch.set_grad_enabled(self.gradient_accumulation_steps > 1): self.micro_steps += 1 + self.global_samples += self.batch_size if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0: self.global_steps += 1 @@ -235,9 +246,11 @@ class Engines(dict[str, Engine]): for name, engine in self.items(): outpath = cfg.ckpt_dir / name / "fp32.pth" state_dict = { + 'module': engine.module.state_dict(), "global_step": engine.global_step, "micro_step": engine.micro_step, - 'module': engine.module.state_dict(), + "global_samples": engine.global_samples, + "tokens_processed": engine.tokens_processed, } state_dict.update(userdata) torch.save(state_dict, outpath) @@ -427,6 +440,8 @@ class Engines(dict[str, Engine]): grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation elapsed_time=elapsed_time, engine_step=engine.global_step, + samples_processed=engine.global_samples, + tokens_processed=engine.tokens_processed, **engine_stats, ) } @@ -435,9 +450,8 @@ class Engines(dict[str, Engine]): self._update() - stats["batch_size"] = self.batch_size stats["elapsed_time"] = total_elapsed_time - stats["wall_time"] = time.time() stats["global_step"] = self.global_step + #stats["micro_step"] = self.micro_step return stats diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 84b5243..5196afb 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -41,6 +41,8 @@ class Engine(DeepSpeedEngine): super().__init__(None, *args, **kwargs) self._frozen_params = set() + self.tokens_processed = 0 + def freeze(self): for p in self.module.parameters(): if p.requires_grad: @@ -62,7 +64,7 @@ class Engine(DeepSpeedEngine): @property def micro_step(self): - return self.micro_steps + return self.micro_steps @property def batch_size(self): diff --git a/vall_e/train.py b/vall_e/train.py index 135406e..c284c15 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -38,6 +38,9 @@ def train_feeder(engine, batch): stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in stat.items()} + engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) + engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ]) + return loss, stats @torch.inference_mode() diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index b58e2dc..15d3828 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -256,7 +256,7 @@ def train( stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) stats['batch'] = { - 'size': stats['batch_size'], + 'size': len(batch['text']), 'id': batch['spkr_id'], 'index': [ index for index in batch['index'] ], 'text_len': [ text.shape[0] for text in batch['text'] ], @@ -264,8 +264,6 @@ def train( 'resp_len': [ resp.shape[0] for resp in batch['resps'] ], } - del stats['batch_size'] - del stats['wall_time'] del stats['global_step'] elapsed_time = stats.get("elapsed_time", 0)