From 06e948aec1d5a3370cfa6482907bd8d62ea2cfc0 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 25 Jul 2024 16:50:47 -0500 Subject: [PATCH] suppress warning on exit about distributed not being cleaned up (because I updated my system) --- vall_e/engines/base.py | 5 ++++- vall_e/models/ar_nar.py | 7 ++++++- vall_e/utils/distributed.py | 6 ++++++ vall_e/utils/trainer.py | 7 +++++-- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index e41ed4c..7313ab8 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -28,7 +28,7 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device -from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size +from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict import logging @@ -452,6 +452,9 @@ class Engines(dict[str, Engine]): stats.update(flatten_dict({ name.split("-")[0]: stat })) return stats + def quit(self): + cleanup_distributed() + def step(self, batch, feeder: TrainFeeder = default_feeder): total_elapsed_time = 0 diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0d15863..68d9282 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -400,7 +400,7 @@ def example_usage(): from tqdm import tqdm from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio - from ..engines import Engine + from ..engines import Engine, Engines from ..utils import wrapper as ml import numpy as np @@ -532,6 +532,9 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) + engines = Engines({"ar+nar": engine}) + engines.setup() + """ torch.save( { 'module': model.state_dict() @@ -622,5 +625,7 @@ def example_usage(): for task in tasks: sample("final", task=task) + engines.quit() + if __name__ == "__main__": example_usage() \ No newline at end of file diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index 167bda4..244c383 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -28,6 +28,12 @@ def init_distributed( fn, *args, **kwargs ): def distributed_initialized(): return _distributed_initialized +def cleanup_distributed(): + #if not _distributed_initialized: + # return + dist.barrier() + dist.destroy_process_group() + @cache def fix_unset_envs(): envs = dict( diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index d190125..19946c5 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -19,8 +19,10 @@ from tqdm import tqdm from typing import Protocol from ..config import cfg -from .distributed import init_distributed, distributed_initialized, world_size from .distributed import ( + init_distributed, + distributed_initialized, + world_size, global_leader_only, global_rank, is_global_leader, @@ -116,7 +118,6 @@ def seed(seed): np.random.seed(seed + global_rank()) torch.manual_seed(seed + global_rank()) - def train( train_dl: DataLoader, train_feeder: TrainFeeder = default_feeder, @@ -141,6 +142,7 @@ def train( eval_fn(engines=engines) if command in ["quit", "eval_quit"]: + engines.quit() return last_save_step = engines.global_step @@ -250,4 +252,5 @@ def train( eval_fn(engines=engines) if command in ["quit"]: + engines.quit() return