suppress warning on exit about distributed not being cleaned up (because I updated my system)

This commit is contained in:
mrq 2024-07-25 16:50:47 -05:00
parent 682e4387dc
commit 06e948aec1
4 changed files with 21 additions and 4 deletions

View File

@ -28,7 +28,7 @@ def default_feeder(engine, batch):
from ..config import cfg from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
import logging import logging
@ -452,6 +452,9 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat })) stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats return stats
def quit(self):
cleanup_distributed()
def step(self, batch, feeder: TrainFeeder = default_feeder): def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0 total_elapsed_time = 0

View File

@ -400,7 +400,7 @@ def example_usage():
from tqdm import tqdm from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
from ..engines import Engine from ..engines import Engine, Engines
from ..utils import wrapper as ml from ..utils import wrapper as ml
import numpy as np import numpy as np
@ -532,6 +532,9 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
engines = Engines({"ar+nar": engine})
engines.setup()
""" """
torch.save( { torch.save( {
'module': model.state_dict() 'module': model.state_dict()
@ -622,5 +625,7 @@ def example_usage():
for task in tasks: for task in tasks:
sample("final", task=task) sample("final", task=task)
engines.quit()
if __name__ == "__main__": if __name__ == "__main__":
example_usage() example_usage()

View File

@ -28,6 +28,12 @@ def init_distributed( fn, *args, **kwargs ):
def distributed_initialized(): def distributed_initialized():
return _distributed_initialized return _distributed_initialized
def cleanup_distributed():
#if not _distributed_initialized:
# return
dist.barrier()
dist.destroy_process_group()
@cache @cache
def fix_unset_envs(): def fix_unset_envs():
envs = dict( envs = dict(

View File

@ -19,8 +19,10 @@ from tqdm import tqdm
from typing import Protocol from typing import Protocol
from ..config import cfg from ..config import cfg
from .distributed import init_distributed, distributed_initialized, world_size
from .distributed import ( from .distributed import (
init_distributed,
distributed_initialized,
world_size,
global_leader_only, global_leader_only,
global_rank, global_rank,
is_global_leader, is_global_leader,
@ -116,7 +118,6 @@ def seed(seed):
np.random.seed(seed + global_rank()) np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank()) torch.manual_seed(seed + global_rank())
def train( def train(
train_dl: DataLoader, train_dl: DataLoader,
train_feeder: TrainFeeder = default_feeder, train_feeder: TrainFeeder = default_feeder,
@ -141,6 +142,7 @@ def train(
eval_fn(engines=engines) eval_fn(engines=engines)
if command in ["quit", "eval_quit"]: if command in ["quit", "eval_quit"]:
engines.quit()
return return
last_save_step = engines.global_step last_save_step = engines.global_step
@ -250,4 +252,5 @@ def train(
eval_fn(engines=engines) eval_fn(engines=engines)
if command in ["quit"]: if command in ["quit"]:
engines.quit()
return return