added option to specify frames per second for the given audio representation (Encodec is 75Hz, DAC is 41Hz (at 24K sources))
This commit is contained in:
parent
c494894261
commit
ffa200eec7
|
@ -156,6 +156,14 @@ class Dataset:
|
|||
sample_type: str = "path" # path | speaker
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
|
||||
_frames_per_second: int = 0 # in encodec, each frame is 75 codes, in dac, each frame is 41
|
||||
|
||||
@cached_property
|
||||
def frames_per_second(self):
|
||||
if self._frames_per_second > 0:
|
||||
return self._frames_per_second
|
||||
return 41 if cfg.inference.audio_backend == "dac" else 75
|
||||
|
||||
@property
|
||||
def min_phones(self):
|
||||
|
|
|
@ -63,10 +63,10 @@ def _replace_file_extension(path, suffix):
|
|||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_quant_extension():
|
||||
return ".dac"
|
||||
return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
|
||||
|
||||
def _get_phone_extension():
|
||||
return ".json"
|
||||
return ".json" if cfg.inference.audio_backend == "dac" else ".phn.txt"
|
||||
|
||||
def _get_quant_path(path):
|
||||
return _replace_file_extension(path, _get_quant_extension())
|
||||
|
@ -371,10 +371,10 @@ class Dataset(_Dataset):
|
|||
# shuffle it up a bit
|
||||
prom_length = 0
|
||||
if cfg.experimental:
|
||||
trim_length = random.randint(75 * 3, 75 * 6) # [3 seconds, 6 seconds]
|
||||
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75))
|
||||
trim_length = random.randint(cfg.dataset.frames_per_second * 3, cfg.dataset.frames_per_second * 6) # [3 seconds, 6 seconds]
|
||||
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
|
||||
else:
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
|
@ -470,7 +470,7 @@ class Dataset(_Dataset):
|
|||
resps = torch.concat([ resps, qnt ])
|
||||
|
||||
task = "tts"
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75)
|
||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
|
||||
|
||||
|
@ -484,7 +484,7 @@ class Dataset(_Dataset):
|
|||
task = "tts"
|
||||
noise_scale = 0.25
|
||||
if task == "tts" or task == "tts-c":
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75)
|
||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
|
||||
# demote if the target is too short
|
||||
if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
|
||||
task = "tts"
|
||||
|
@ -805,7 +805,7 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
}
|
||||
else:
|
||||
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
|
||||
duration = qnt.shape[0] / 75
|
||||
duration = qnt.shape[0] / cfg.dataset.frames_per_second
|
||||
|
||||
metadata[id]["duration"] = duration
|
||||
else:
|
||||
|
@ -912,7 +912,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
}
|
||||
else:
|
||||
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
|
||||
duration = qnt.shape[0] / 75
|
||||
duration = qnt.shape[0] / cfg.dataset.frames_per_second
|
||||
|
||||
qnt = qnt.numpy().astype(np.int16)
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ class TTS():
|
|||
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
|
||||
|
||||
if trim_length:
|
||||
res = trim( res, int( 75 * trim_length ) )
|
||||
res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )
|
||||
|
||||
return res
|
||||
|
||||
|
@ -125,7 +125,7 @@ class TTS():
|
|||
text,
|
||||
references,
|
||||
language="en",
|
||||
max_ar_steps=6 * 75,
|
||||
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
||||
max_ar_context=-1,
|
||||
max_nar_levels=7,
|
||||
input_prompt_length=0.0,
|
||||
|
|
|
@ -150,7 +150,7 @@ class AR_NAR(Base):
|
|||
|
||||
"""
|
||||
if cfg.experimental:
|
||||
proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
||||
proms_list = [ r if l == 0 else trim(r, cfg.dataset.frames_per_second * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
||||
"""
|
||||
|
||||
# append stop tokens for AR
|
||||
|
@ -350,7 +350,7 @@ def example_usage():
|
|||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt[:75, :].to(device),
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
|
|
|
@ -873,98 +873,4 @@ class Base(nn.Module):
|
|||
return res, scores
|
||||
|
||||
# and sample
|
||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
||||
def example_usage():
|
||||
from ..config import cfg
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.trainer.check_for_oom = False
|
||||
|
||||
from functools import partial
|
||||
|
||||
from einops import repeat
|
||||
|
||||
from ..emb.qnt import decode_to_file
|
||||
from ..engines import Engine, Engines
|
||||
from tqdm import tqdm, trange
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
from .ar import AR
|
||||
from .nar import NAR
|
||||
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
def tokenize(content, lang_marker="en"):
|
||||
split = content.split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
||||
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
}
|
||||
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
|
||||
|
||||
for name, model in models.items():
|
||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
engines = Engines({ name: Engine(model=model, optimizer=ml.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||
|
||||
train = True
|
||||
|
||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
|
||||
text_list = [
|
||||
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
|
||||
]
|
||||
|
||||
proms_list = [
|
||||
qnt.to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
]
|
||||
|
||||
def sample( name, steps=600 ):
|
||||
AR = None
|
||||
NAR = None
|
||||
|
||||
engines.eval()
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
AR = engine
|
||||
elif name[:3] == "nar":
|
||||
NAR = engine
|
||||
|
||||
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
||||
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
|
||||
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
|
||||
|
||||
if train:
|
||||
sample("init", 15)
|
||||
|
||||
engines.train()
|
||||
t = trange(500)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
"""
|
||||
for name, engine in engines.items():
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
"""
|
||||
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list})
|
||||
tqdm.write(f"{stats}")
|
||||
else:
|
||||
for name, engine in engines.items():
|
||||
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
|
||||
|
||||
sample("final")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
Loading…
Reference in New Issue
Block a user