From c4dd523b6f9305142b1a68565606926ac9f0d8d1 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 29 Jun 2024 10:10:35 -0500 Subject: [PATCH] change from chunk-slicing paths for distributed dataloader to instead interleave --- vall_e/data.py | 14 +++++++++++--- vall_e/engines/__init__.py | 2 +- vall_e/utils/sampler.py | 2 ++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 767d5f7..1e2d773 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -455,11 +455,15 @@ class Dataset(_Dataset): # split dataset accordingly per GPU if cfg.distributed and self.training: + """ batches = len(self.paths) // world_size() start = batches * global_rank() end = batches * (global_rank() + 1) self.paths = self.paths[start:end] + """ + + self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ] # recreate paths_by_spkr_name self.paths_by_spkr_name = {} @@ -543,7 +547,7 @@ class Dataset(_Dataset): if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") - sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" + sampler_path = cfg.rel_path / self.sampler_state_dict_path if self.sampler_type == "path": if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: @@ -558,6 +562,10 @@ class Dataset(_Dataset): self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } self.load_state_dict() + + @cached_property + def sampler_state_dict_path(self): + return f"sampler.{self.sampler_type}.rank{global_rank()}.pt" def get_speaker(self, path): if isinstance(path, str): @@ -590,7 +598,7 @@ class Dataset(_Dataset): def save_state_dict(self, path = None): if path is None: - path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" + path = cfg.rel_path / self.sampler_state_dict_path if self.sampler_type == "path": state_dict = self.sampler.get_state() @@ -603,7 +611,7 @@ class Dataset(_Dataset): def load_state_dict(self, path = None): if path is None: - path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" + path = cfg.rel_path / self.sampler_state_dict_path if not path.exists(): return diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 8a4141e..f686e3e 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -137,7 +137,7 @@ def load_engines(training=True): stats = state["stats"] # do not load stats if we're training a LoRA - if "lora" not in state: + if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count: stats = None if "module" in state: diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index d9ebe37..9e5dc1a 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -5,6 +5,8 @@ import random import torch from torch.utils.data import Sampler +from .distributed import global_rank, local_rank, world_size + # Randomly picks an index from an array of indices class PoolSampler(): def __init__( self, pool = [], keep_all = False ):