diff --git a/vall_e/config.py b/vall_e/config.py index 6751977..8474189 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -140,8 +140,6 @@ class Dataset: tasks_list: list[str] = field(default_factory=lambda: ["tts"]) - continuous: bool = False # VALL-E continuous, as explained in the paper - @property def min_phones(self): return self.phones_range[0] diff --git a/vall_e/data.py b/vall_e/data.py index b217fee..17b9ee0 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -312,12 +312,14 @@ class Dataset(_Dataset): noise_scale = 0.25 # text-to-speech - if task == "tts": + if task == "tts" or task == "tts-c": trim_length = int(cfg.dataset.prompt_duration * 75) + continuous = task == "tts-c" and trim_length * 2 < resps.shape[0] + # VALL-E continuous # ignore if target utterance is shorter than prompt duration # to-do: actually do this for the AR only as I don't think the paper trained the NAR for this - if cfg.dataset.continuous and trim_length > resps.shape[0]: + if continuous: proms = resps[:trim_length, :] resps = resps[trim_length:, :] else: @@ -440,7 +442,7 @@ class Dataset(_Dataset): ([ post_prom ] if post_prom is not None else []) ) else: - raise f'Undefined task: {task}' + raise Exception(f'Undefined task: {task}') """ # emulate SVC diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 0d88de7..6e2b760 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -87,7 +87,7 @@ class Engine(DeepSpeedEngine): print(str(e)) def traverse(self, *args, **kwargs): - with torch.autocast(self.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): self.forward(*args, **kwargs) losses = self.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() diff --git a/vall_e/train.py b/vall_e/train.py index 3268916..83f8f97 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -23,7 +23,7 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu") _logger = logging.getLogger(__name__) def train_feeder(engine, batch): - with torch.autocast(engine.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): engine( text_list=batch["text"], proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index e43c7e5..b6364dd 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -15,8 +15,8 @@ def get_free_port(): _distributed_initialized = False -def init_distributed( fn ): - fn() +def init_distributed( fn, *args, **kwargs ): + fn(*args, **kwargs) _distributed_initialized = True def distributed_initialized():