added option to specify frames per second for the given audio representation (Encodec is 75Hz, DAC is 41Hz (at 24K sources))

This commit is contained in:
mrq 2024-05-04 12:05:41 -05:00
parent c494894261
commit ffa200eec7
5 changed files with 22 additions and 108 deletions

View File

@ -156,6 +156,14 @@ class Dataset:
sample_type: str = "path" # path | speaker sample_type: str = "path" # path | speaker
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) tasks_list: list[str] = field(default_factory=lambda: ["tts"])
_frames_per_second: int = 0 # in encodec, each frame is 75 codes, in dac, each frame is 41
@cached_property
def frames_per_second(self):
if self._frames_per_second > 0:
return self._frames_per_second
return 41 if cfg.inference.audio_backend == "dac" else 75
@property @property
def min_phones(self): def min_phones(self):

View File

@ -63,10 +63,10 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension(): def _get_quant_extension():
return ".dac" return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
def _get_phone_extension(): def _get_phone_extension():
return ".json" return ".json" if cfg.inference.audio_backend == "dac" else ".phn.txt"
def _get_quant_path(path): def _get_quant_path(path):
return _replace_file_extension(path, _get_quant_extension()) return _replace_file_extension(path, _get_quant_extension())
@ -371,10 +371,10 @@ class Dataset(_Dataset):
# shuffle it up a bit # shuffle it up a bit
prom_length = 0 prom_length = 0
if cfg.experimental: if cfg.experimental:
trim_length = random.randint(75 * 3, 75 * 6) # [3 seconds, 6 seconds] trim_length = random.randint(cfg.dataset.frames_per_second * 3, cfg.dataset.frames_per_second * 6) # [3 seconds, 6 seconds]
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75)) #trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
else: else:
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)
for _ in range(cfg.dataset.max_prompts): for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices) path = random.choice(choices)
@ -470,7 +470,7 @@ class Dataset(_Dataset):
resps = torch.concat([ resps, qnt ]) resps = torch.concat([ resps, qnt ])
task = "tts" task = "tts"
trim_length = int(cfg.dataset.prompt_duration * 75) trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
@ -484,7 +484,7 @@ class Dataset(_Dataset):
task = "tts" task = "tts"
noise_scale = 0.25 noise_scale = 0.25
if task == "tts" or task == "tts-c": if task == "tts" or task == "tts-c":
trim_length = int(cfg.dataset.prompt_duration * 75) trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
# demote if the target is too short # demote if the target is too short
if task == "tts-c" and trim_length * 2 >= resps.shape[0]: if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
task = "tts" task = "tts"
@ -805,7 +805,7 @@ def create_dataset_metadata( skip_existing=True ):
} }
else: else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75 duration = qnt.shape[0] / cfg.dataset.frames_per_second
metadata[id]["duration"] = duration metadata[id]["duration"] = duration
else: else:
@ -912,7 +912,7 @@ def create_dataset_hdf5( skip_existing=True ):
} }
else: else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75 duration = qnt.shape[0] / cfg.dataset.frames_per_second
qnt = qnt.numpy().astype(np.int16) qnt = qnt.numpy().astype(np.int16)

View File

@ -115,7 +115,7 @@ class TTS():
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths]) res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
if trim_length: if trim_length:
res = trim( res, int( 75 * trim_length ) ) res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )
return res return res
@ -125,7 +125,7 @@ class TTS():
text, text,
references, references,
language="en", language="en",
max_ar_steps=6 * 75, max_ar_steps=6 * cfg.dataset.frames_per_second,
max_ar_context=-1, max_ar_context=-1,
max_nar_levels=7, max_nar_levels=7,
input_prompt_length=0.0, input_prompt_length=0.0,

View File

@ -150,7 +150,7 @@ class AR_NAR(Base):
""" """
if cfg.experimental: if cfg.experimental:
proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds proms_list = [ r if l == 0 else trim(r, cfg.dataset.frames_per_second * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
""" """
# append stop tokens for AR # append stop tokens for AR
@ -350,7 +350,7 @@ def example_usage():
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
] ]
proms_list = [ proms_list = [
qnt[:75, :].to(device), qnt[:cfg.dataset.frames_per_second, :].to(device),
] ]
resps_list = [ resps_list = [
qnt.to(device), qnt.to(device),

View File

@ -873,98 +873,4 @@ class Base(nn.Module):
return res, scores return res, scores
# and sample # and sample
return [ Categorical(logits=logit).sample() for logit in logits ] return [ Categorical(logits=logit).sample() for logit in logits ]
def example_usage():
from ..config import cfg
cfg.trainer.backend = "local"
cfg.trainer.check_for_oom = False
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine, Engines
from tqdm import tqdm, trange
from ..utils import wrapper as ml
from .ar import AR
from .nar import NAR
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
for name, model in models.items():
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
engines = Engines({ name: Engine(model=model, optimizer=ml.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
text_list = [
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
]
proms_list = [
qnt.to(device),
]
resps_list = [
qnt.to(device),
]
def sample( name, steps=600 ):
AR = None
NAR = None
engines.eval()
for name, engine in engines.items():
if name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
resps_list = [r.unsqueeze(-1) for r in resps_list]
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
if train:
sample("init", 15)
engines.train()
t = trange(500)
for i in t:
stats = {"step": i}
"""
for name, engine in engines.items():
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
"""
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list})
tqdm.write(f"{stats}")
else:
for name, engine in engines.items():
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
sample("final")
if __name__ == "__main__":
example_usage()