change from chunk-slicing paths for distributed dataloader to instead interleave
This commit is contained in:
parent
dd40463803
commit
c4dd523b6f
|
@ -455,11 +455,15 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# split dataset accordingly per GPU
|
# split dataset accordingly per GPU
|
||||||
if cfg.distributed and self.training:
|
if cfg.distributed and self.training:
|
||||||
|
"""
|
||||||
batches = len(self.paths) // world_size()
|
batches = len(self.paths) // world_size()
|
||||||
start = batches * global_rank()
|
start = batches * global_rank()
|
||||||
end = batches * (global_rank() + 1)
|
end = batches * (global_rank() + 1)
|
||||||
|
|
||||||
self.paths = self.paths[start:end]
|
self.paths = self.paths[start:end]
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||||
|
|
||||||
# recreate paths_by_spkr_name
|
# recreate paths_by_spkr_name
|
||||||
self.paths_by_spkr_name = {}
|
self.paths_by_spkr_name = {}
|
||||||
|
@ -543,7 +547,7 @@ class Dataset(_Dataset):
|
||||||
if len(self.paths) == 0:
|
if len(self.paths) == 0:
|
||||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
sampler_path = cfg.rel_path / self.sampler_state_dict_path
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||||
|
@ -558,6 +562,10 @@ class Dataset(_Dataset):
|
||||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
self.load_state_dict()
|
self.load_state_dict()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sampler_state_dict_path(self):
|
||||||
|
return f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
|
|
||||||
def get_speaker(self, path):
|
def get_speaker(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
|
@ -590,7 +598,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
def save_state_dict(self, path = None):
|
def save_state_dict(self, path = None):
|
||||||
if path is None:
|
if path is None:
|
||||||
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
path = cfg.rel_path / self.sampler_state_dict_path
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
state_dict = self.sampler.get_state()
|
state_dict = self.sampler.get_state()
|
||||||
|
@ -603,7 +611,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
def load_state_dict(self, path = None):
|
def load_state_dict(self, path = None):
|
||||||
if path is None:
|
if path is None:
|
||||||
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
path = cfg.rel_path / self.sampler_state_dict_path
|
||||||
|
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return
|
return
|
||||||
|
|
|
@ -137,7 +137,7 @@ def load_engines(training=True):
|
||||||
stats = state["stats"]
|
stats = state["stats"]
|
||||||
|
|
||||||
# do not load stats if we're training a LoRA
|
# do not load stats if we're training a LoRA
|
||||||
if "lora" not in state:
|
if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||||
stats = None
|
stats = None
|
||||||
|
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
|
|
|
@ -5,6 +5,8 @@ import random
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
from .distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
# Randomly picks an index from an array of indices
|
# Randomly picks an index from an array of indices
|
||||||
class PoolSampler():
|
class PoolSampler():
|
||||||
def __init__( self, pool = [], keep_all = False ):
|
def __init__( self, pool = [], keep_all = False ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user