suppress warning on exit about distributed not being cleaned up (because I updated my system)
This commit is contained in:
parent
682e4387dc
commit
06e948aec1
|
@ -28,7 +28,7 @@ def default_feeder(engine, batch):
|
|||
|
||||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
|
||||
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||
|
||||
import logging
|
||||
|
@ -452,6 +452,9 @@ class Engines(dict[str, Engine]):
|
|||
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
||||
return stats
|
||||
|
||||
def quit(self):
|
||||
cleanup_distributed()
|
||||
|
||||
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
||||
total_elapsed_time = 0
|
||||
|
||||
|
|
|
@ -400,7 +400,7 @@ def example_usage():
|
|||
from tqdm import tqdm
|
||||
|
||||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||
from ..engines import Engine
|
||||
from ..engines import Engine, Engines
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
import numpy as np
|
||||
|
@ -532,6 +532,9 @@ def example_usage():
|
|||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
engines = Engines({"ar+nar": engine})
|
||||
engines.setup()
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
|
@ -622,5 +625,7 @@ def example_usage():
|
|||
for task in tasks:
|
||||
sample("final", task=task)
|
||||
|
||||
engines.quit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
|
@ -28,6 +28,12 @@ def init_distributed( fn, *args, **kwargs ):
|
|||
def distributed_initialized():
|
||||
return _distributed_initialized
|
||||
|
||||
def cleanup_distributed():
|
||||
#if not _distributed_initialized:
|
||||
# return
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@cache
|
||||
def fix_unset_envs():
|
||||
envs = dict(
|
||||
|
|
|
@ -19,8 +19,10 @@ from tqdm import tqdm
|
|||
from typing import Protocol
|
||||
|
||||
from ..config import cfg
|
||||
from .distributed import init_distributed, distributed_initialized, world_size
|
||||
from .distributed import (
|
||||
init_distributed,
|
||||
distributed_initialized,
|
||||
world_size,
|
||||
global_leader_only,
|
||||
global_rank,
|
||||
is_global_leader,
|
||||
|
@ -116,7 +118,6 @@ def seed(seed):
|
|||
np.random.seed(seed + global_rank())
|
||||
torch.manual_seed(seed + global_rank())
|
||||
|
||||
|
||||
def train(
|
||||
train_dl: DataLoader,
|
||||
train_feeder: TrainFeeder = default_feeder,
|
||||
|
@ -141,6 +142,7 @@ def train(
|
|||
eval_fn(engines=engines)
|
||||
|
||||
if command in ["quit", "eval_quit"]:
|
||||
engines.quit()
|
||||
return
|
||||
|
||||
last_save_step = engines.global_step
|
||||
|
@ -250,4 +252,5 @@ def train(
|
|||
eval_fn(engines=engines)
|
||||
|
||||
if command in ["quit"]:
|
||||
engines.quit()
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue
Block a user