fixed segfault from tts-c task token exceeding being too big (inserted it in the hypothetical svc task token because in reality that is never ever going to be a feasible task to train against)

This commit is contained in:
mrq 2023-09-02 19:25:43 -05:00
parent 4613781e23
commit 922404285c

View File

@ -45,8 +45,7 @@ def get_task_symmap():
"<soe>": start + 3, "<soe>": start + 3,
"<mask>": start + 4, "<mask>": start + 4,
"<eoe>": start + 5, "<eoe>": start + 5,
"<svc>": start + 6, "<tts-c>": start + 6,
"<tts-c>": start + 7,
} }
return symmap return symmap
@ -320,8 +319,6 @@ class Dataset(_Dataset):
if task == "tts-c" and trim_length * 2 >= resps.shape[0]: if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
task = "tts" task = "tts"
task = "tts"
# VALL-E continuous # VALL-E continuous
# ignore if target utterance is shorter than prompt duration # ignore if target utterance is shorter than prompt duration
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this # to-do: actually do this for the AR only as I don't think the paper trained the NAR for this