change from chunk-slicing paths for distributed dataloader to instead interleave
This commit is contained in:
parent
dd40463803
commit
c4dd523b6f
|
@ -455,11 +455,15 @@ class Dataset(_Dataset):
|
|||
|
||||
# split dataset accordingly per GPU
|
||||
if cfg.distributed and self.training:
|
||||
"""
|
||||
batches = len(self.paths) // world_size()
|
||||
start = batches * global_rank()
|
||||
end = batches * (global_rank() + 1)
|
||||
|
||||
self.paths = self.paths[start:end]
|
||||
"""
|
||||
|
||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||
|
||||
# recreate paths_by_spkr_name
|
||||
self.paths_by_spkr_name = {}
|
||||
|
@ -543,7 +547,7 @@ class Dataset(_Dataset):
|
|||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
sampler_path = cfg.rel_path / self.sampler_state_dict_path
|
||||
|
||||
if self.sampler_type == "path":
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
|
@ -559,6 +563,10 @@ class Dataset(_Dataset):
|
|||
|
||||
self.load_state_dict()
|
||||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
return f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
|
||||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
@ -590,7 +598,7 @@ class Dataset(_Dataset):
|
|||
|
||||
def save_state_dict(self, path = None):
|
||||
if path is None:
|
||||
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
path = cfg.rel_path / self.sampler_state_dict_path
|
||||
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.get_state()
|
||||
|
@ -603,7 +611,7 @@ class Dataset(_Dataset):
|
|||
|
||||
def load_state_dict(self, path = None):
|
||||
if path is None:
|
||||
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
path = cfg.rel_path / self.sampler_state_dict_path
|
||||
|
||||
if not path.exists():
|
||||
return
|
||||
|
|
|
@ -137,7 +137,7 @@ def load_engines(training=True):
|
|||
stats = state["stats"]
|
||||
|
||||
# do not load stats if we're training a LoRA
|
||||
if "lora" not in state:
|
||||
if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||
stats = None
|
||||
|
||||
if "module" in state:
|
||||
|
|
|
@ -5,6 +5,8 @@ import random
|
|||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .distributed import global_rank, local_rank, world_size
|
||||
|
||||
# Randomly picks an index from an array of indices
|
||||
class PoolSampler():
|
||||
def __init__( self, pool = [], keep_all = False ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user