change from chunk-slicing paths for distributed dataloader to instead interleave

This commit is contained in:
mrq 2024-06-29 10:10:35 -05:00
parent dd40463803
commit c4dd523b6f
3 changed files with 14 additions and 4 deletions

View File

@ -455,11 +455,15 @@ class Dataset(_Dataset):
# split dataset accordingly per GPU # split dataset accordingly per GPU
if cfg.distributed and self.training: if cfg.distributed and self.training:
"""
batches = len(self.paths) // world_size() batches = len(self.paths) // world_size()
start = batches * global_rank() start = batches * global_rank()
end = batches * (global_rank() + 1) end = batches * (global_rank() + 1)
self.paths = self.paths[start:end] self.paths = self.paths[start:end]
"""
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
# recreate paths_by_spkr_name # recreate paths_by_spkr_name
self.paths_by_spkr_name = {} self.paths_by_spkr_name = {}
@ -543,7 +547,7 @@ class Dataset(_Dataset):
if len(self.paths) == 0: if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}") raise ValueError(f"No valid path is found for {self.dataset_type}")
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" sampler_path = cfg.rel_path / self.sampler_state_dict_path
if self.sampler_type == "path": if self.sampler_type == "path":
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
@ -558,6 +562,10 @@ class Dataset(_Dataset):
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
self.load_state_dict() self.load_state_dict()
@cached_property
def sampler_state_dict_path(self):
return f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
def get_speaker(self, path): def get_speaker(self, path):
if isinstance(path, str): if isinstance(path, str):
@ -590,7 +598,7 @@ class Dataset(_Dataset):
def save_state_dict(self, path = None): def save_state_dict(self, path = None):
if path is None: if path is None:
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" path = cfg.rel_path / self.sampler_state_dict_path
if self.sampler_type == "path": if self.sampler_type == "path":
state_dict = self.sampler.get_state() state_dict = self.sampler.get_state()
@ -603,7 +611,7 @@ class Dataset(_Dataset):
def load_state_dict(self, path = None): def load_state_dict(self, path = None):
if path is None: if path is None:
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" path = cfg.rel_path / self.sampler_state_dict_path
if not path.exists(): if not path.exists():
return return

View File

@ -137,7 +137,7 @@ def load_engines(training=True):
stats = state["stats"] stats = state["stats"]
# do not load stats if we're training a LoRA # do not load stats if we're training a LoRA
if "lora" not in state: if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count:
stats = None stats = None
if "module" in state: if "module" in state:

View File

@ -5,6 +5,8 @@ import random
import torch import torch
from torch.utils.data import Sampler from torch.utils.data import Sampler
from .distributed import global_rank, local_rank, world_size
# Randomly picks an index from an array of indices # Randomly picks an index from an array of indices
class PoolSampler(): class PoolSampler():
def __init__( self, pool = [], keep_all = False ): def __init__( self, pool = [], keep_all = False ):