suppress warning on exit about distributed not being cleaned up (because I updated my system)
This commit is contained in:
parent
682e4387dc
commit
06e948aec1
|
@ -28,7 +28,7 @@ def default_feeder(engine, batch):
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
|
||||||
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -452,6 +452,9 @@ class Engines(dict[str, Engine]):
|
||||||
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
def quit(self):
|
||||||
|
cleanup_distributed()
|
||||||
|
|
||||||
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
||||||
total_elapsed_time = 0
|
total_elapsed_time = 0
|
||||||
|
|
||||||
|
|
|
@ -400,7 +400,7 @@ def example_usage():
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||||
from ..engines import Engine
|
from ..engines import Engine, Engines
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -532,6 +532,9 @@ def example_usage():
|
||||||
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
|
engines = Engines({"ar+nar": engine})
|
||||||
|
engines.setup()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
|
@ -622,5 +625,7 @@ def example_usage():
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
sample("final", task=task)
|
sample("final", task=task)
|
||||||
|
|
||||||
|
engines.quit()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
example_usage()
|
example_usage()
|
|
@ -28,6 +28,12 @@ def init_distributed( fn, *args, **kwargs ):
|
||||||
def distributed_initialized():
|
def distributed_initialized():
|
||||||
return _distributed_initialized
|
return _distributed_initialized
|
||||||
|
|
||||||
|
def cleanup_distributed():
|
||||||
|
#if not _distributed_initialized:
|
||||||
|
# return
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def fix_unset_envs():
|
def fix_unset_envs():
|
||||||
envs = dict(
|
envs = dict(
|
||||||
|
|
|
@ -19,8 +19,10 @@ from tqdm import tqdm
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from .distributed import init_distributed, distributed_initialized, world_size
|
|
||||||
from .distributed import (
|
from .distributed import (
|
||||||
|
init_distributed,
|
||||||
|
distributed_initialized,
|
||||||
|
world_size,
|
||||||
global_leader_only,
|
global_leader_only,
|
||||||
global_rank,
|
global_rank,
|
||||||
is_global_leader,
|
is_global_leader,
|
||||||
|
@ -116,7 +118,6 @@ def seed(seed):
|
||||||
np.random.seed(seed + global_rank())
|
np.random.seed(seed + global_rank())
|
||||||
torch.manual_seed(seed + global_rank())
|
torch.manual_seed(seed + global_rank())
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_dl: DataLoader,
|
train_dl: DataLoader,
|
||||||
train_feeder: TrainFeeder = default_feeder,
|
train_feeder: TrainFeeder = default_feeder,
|
||||||
|
@ -141,6 +142,7 @@ def train(
|
||||||
eval_fn(engines=engines)
|
eval_fn(engines=engines)
|
||||||
|
|
||||||
if command in ["quit", "eval_quit"]:
|
if command in ["quit", "eval_quit"]:
|
||||||
|
engines.quit()
|
||||||
return
|
return
|
||||||
|
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
@ -250,4 +252,5 @@ def train(
|
||||||
eval_fn(engines=engines)
|
eval_fn(engines=engines)
|
||||||
|
|
||||||
if command in ["quit"]:
|
if command in ["quit"]:
|
||||||
|
engines.quit()
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue
Block a user