shuffled VALL-E continuous as a task tts-c instead, logic fixes for it
This commit is contained in:
parent
2f06166ddd
commit
57db3ccfa8
|
@ -140,8 +140,6 @@ class Dataset:
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
|
||||||
continuous: bool = False # VALL-E continuous, as explained in the paper
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def min_phones(self):
|
def min_phones(self):
|
||||||
return self.phones_range[0]
|
return self.phones_range[0]
|
||||||
|
|
|
@ -312,12 +312,14 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
noise_scale = 0.25
|
noise_scale = 0.25
|
||||||
# text-to-speech
|
# text-to-speech
|
||||||
if task == "tts":
|
if task == "tts" or task == "tts-c":
|
||||||
trim_length = int(cfg.dataset.prompt_duration * 75)
|
trim_length = int(cfg.dataset.prompt_duration * 75)
|
||||||
|
continuous = task == "tts-c" and trim_length * 2 < resps.shape[0]
|
||||||
|
|
||||||
# VALL-E continuous
|
# VALL-E continuous
|
||||||
# ignore if target utterance is shorter than prompt duration
|
# ignore if target utterance is shorter than prompt duration
|
||||||
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
|
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
|
||||||
if cfg.dataset.continuous and trim_length > resps.shape[0]:
|
if continuous:
|
||||||
proms = resps[:trim_length, :]
|
proms = resps[:trim_length, :]
|
||||||
resps = resps[trim_length:, :]
|
resps = resps[trim_length:, :]
|
||||||
else:
|
else:
|
||||||
|
@ -440,7 +442,7 @@ class Dataset(_Dataset):
|
||||||
([ post_prom ] if post_prom is not None else [])
|
([ post_prom ] if post_prom is not None else [])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise f'Undefined task: {task}'
|
raise Exception(f'Undefined task: {task}')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# emulate SVC
|
# emulate SVC
|
||||||
|
|
|
@ -87,7 +87,7 @@ class Engine(DeepSpeedEngine):
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
with torch.autocast(self.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
self.forward(*args, **kwargs)
|
self.forward(*args, **kwargs)
|
||||||
losses = self.gather_attribute("loss")
|
losses = self.gather_attribute("loss")
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
|
@ -23,7 +23,7 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def train_feeder(engine, batch):
|
def train_feeder(engine, batch):
|
||||||
with torch.autocast(engine.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
engine(
|
engine(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
||||||
|
|
|
@ -15,8 +15,8 @@ def get_free_port():
|
||||||
|
|
||||||
|
|
||||||
_distributed_initialized = False
|
_distributed_initialized = False
|
||||||
def init_distributed( fn ):
|
def init_distributed( fn, *args, **kwargs ):
|
||||||
fn()
|
fn(*args, **kwargs)
|
||||||
_distributed_initialized = True
|
_distributed_initialized = True
|
||||||
|
|
||||||
def distributed_initialized():
|
def distributed_initialized():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user