shuffled VALL-E continuous as a task tts-c instead, logic fixes for it

This commit is contained in:
mrq 2023-09-02 12:23:40 -05:00
parent 2f06166ddd
commit 57db3ccfa8
5 changed files with 9 additions and 9 deletions

View File

@ -140,8 +140,6 @@ class Dataset:
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
continuous: bool = False # VALL-E continuous, as explained in the paper
@property
def min_phones(self):
return self.phones_range[0]

View File

@ -312,12 +312,14 @@ class Dataset(_Dataset):
noise_scale = 0.25
# text-to-speech
if task == "tts":
if task == "tts" or task == "tts-c":
trim_length = int(cfg.dataset.prompt_duration * 75)
continuous = task == "tts-c" and trim_length * 2 < resps.shape[0]
# VALL-E continuous
# ignore if target utterance is shorter than prompt duration
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
if cfg.dataset.continuous and trim_length > resps.shape[0]:
if continuous:
proms = resps[:trim_length, :]
resps = resps[trim_length:, :]
else:
@ -440,7 +442,7 @@ class Dataset(_Dataset):
([ post_prom ] if post_prom is not None else [])
)
else:
raise f'Undefined task: {task}'
raise Exception(f'Undefined task: {task}')
"""
# emulate SVC

View File

@ -87,7 +87,7 @@ class Engine(DeepSpeedEngine):
print(str(e))
def traverse(self, *args, **kwargs):
with torch.autocast(self.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()

View File

@ -23,7 +23,7 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
_logger = logging.getLogger(__name__)
def train_feeder(engine, batch):
with torch.autocast(engine.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level

View File

@ -15,8 +15,8 @@ def get_free_port():
_distributed_initialized = False
def init_distributed( fn ):
fn()
def init_distributed( fn, *args, **kwargs ):
fn(*args, **kwargs)
_distributed_initialized = True
def distributed_initialized():