shuffled VALL-E continuous as a task tts-c instead, logic fixes for it
This commit is contained in:
parent
2f06166ddd
commit
57db3ccfa8
|
@ -140,8 +140,6 @@ class Dataset:
|
|||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
|
||||
continuous: bool = False # VALL-E continuous, as explained in the paper
|
||||
|
||||
@property
|
||||
def min_phones(self):
|
||||
return self.phones_range[0]
|
||||
|
|
|
@ -312,12 +312,14 @@ class Dataset(_Dataset):
|
|||
|
||||
noise_scale = 0.25
|
||||
# text-to-speech
|
||||
if task == "tts":
|
||||
if task == "tts" or task == "tts-c":
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75)
|
||||
continuous = task == "tts-c" and trim_length * 2 < resps.shape[0]
|
||||
|
||||
# VALL-E continuous
|
||||
# ignore if target utterance is shorter than prompt duration
|
||||
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
|
||||
if cfg.dataset.continuous and trim_length > resps.shape[0]:
|
||||
if continuous:
|
||||
proms = resps[:trim_length, :]
|
||||
resps = resps[trim_length:, :]
|
||||
else:
|
||||
|
@ -440,7 +442,7 @@ class Dataset(_Dataset):
|
|||
([ post_prom ] if post_prom is not None else [])
|
||||
)
|
||||
else:
|
||||
raise f'Undefined task: {task}'
|
||||
raise Exception(f'Undefined task: {task}')
|
||||
|
||||
"""
|
||||
# emulate SVC
|
||||
|
|
|
@ -87,7 +87,7 @@ class Engine(DeepSpeedEngine):
|
|||
print(str(e))
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
with torch.autocast(self.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
self.forward(*args, **kwargs)
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
|
|
@ -23,7 +23,7 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def train_feeder(engine, batch):
|
||||
with torch.autocast(engine.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
||||
|
|
|
@ -15,8 +15,8 @@ def get_free_port():
|
|||
|
||||
|
||||
_distributed_initialized = False
|
||||
def init_distributed( fn ):
|
||||
fn()
|
||||
def init_distributed( fn, *args, **kwargs ):
|
||||
fn(*args, **kwargs)
|
||||
_distributed_initialized = True
|
||||
|
||||
def distributed_initialized():
|
||||
|
|
Loading…
Reference in New Issue
Block a user