validated that SpeechX tasks cse and nse works, added a method to test each task by invoking python3 -m vall_e.data --action=tasks --tasks='sr,se,cse,nse'
This commit is contained in:
parent
6ca347e1e1
commit
f7f6d3bf6d
26
README.md
26
README.md
|
@ -63,7 +63,7 @@ If you're wanting to increase the `prom_levels` for a given model, or increase t
|
||||||
|
|
||||||
4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
||||||
|
|
||||||
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --create-hdf5 yaml='./data/config.yaml'`
|
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'`
|
||||||
|
|
||||||
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
|
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
|
||||||
|
|
||||||
|
@ -81,32 +81,18 @@ Two dataset formats are supported:
|
||||||
|
|
||||||
## Export
|
## Export
|
||||||
|
|
||||||
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run:
|
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
|
||||||
|
|
||||||
```
|
This will export the latest checkpoints under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth` to be loaded on any system with PyTorch.
|
||||||
python -m vall_e.export yaml=./data/config.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This will export the latest checkpoints.
|
|
||||||
|
|
||||||
## Synthesis
|
## Synthesis
|
||||||
|
|
||||||
To synthesize speech, invoke either (if exported the models):
|
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
|
||||||
|
|
||||||
```
|
|
||||||
python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt
|
|
||||||
```
|
|
||||||
|
|
||||||
or:
|
|
||||||
|
|
||||||
```
|
|
||||||
python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>
|
|
||||||
```
|
|
||||||
|
|
||||||
Some additional flags you can pass are:
|
Some additional flags you can pass are:
|
||||||
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
|
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
|
||||||
* `--ar-temp`: sampling temperature to use for the AR pass.
|
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output.
|
||||||
* `--nar-temp`: sampling temperature to use for the NAR pass.
|
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output.
|
||||||
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
||||||
|
|
||||||
## To-Do
|
## To-Do
|
||||||
|
|
|
@ -72,7 +72,7 @@ evaluation:
|
||||||
|
|
||||||
steps: 300
|
steps: 300
|
||||||
ar_temperature: 1.0
|
ar_temperature: 1.0
|
||||||
nar_temperature: 1.0
|
nar_temperature: 0.2
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
iterations: 1_000_000
|
iterations: 1_000_000
|
||||||
|
|
|
@ -446,6 +446,18 @@ class Config(_Config):
|
||||||
tmp = Config.from_yaml( config_path )
|
tmp = Config.from_yaml( config_path )
|
||||||
self.__dict__.update(tmp.__dict__)
|
self.__dict__.update(tmp.__dict__)
|
||||||
|
|
||||||
|
def load_hdf5( self, write=False ):
|
||||||
|
if hasattr(self, 'hdf5'):
|
||||||
|
self.hdf5.close()
|
||||||
|
|
||||||
|
if self.distributed:
|
||||||
|
self.dataset.hdf5_flag = "r"
|
||||||
|
try:
|
||||||
|
self.hdf5 = h5py.File(f'{self.cfg_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while opening HDF5 file:", f'{self.cfg_path}/{self.dataset.hdf5_name}', str(e))
|
||||||
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
def format( self ):
|
def format( self ):
|
||||||
self.dataset = Dataset(**self.dataset)
|
self.dataset = Dataset(**self.dataset)
|
||||||
self.models = Models(**self.models)
|
self.models = Models(**self.models)
|
||||||
|
@ -466,13 +478,7 @@ try:
|
||||||
|
|
||||||
# cached_property stopped working...
|
# cached_property stopped working...
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
if cfg.distributed:
|
cfg.load_hdf5()
|
||||||
cfg.dataset.hdf5_flag = "r"
|
|
||||||
try:
|
|
||||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
|
||||||
except Exception as e:
|
|
||||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
|
||||||
cfg.dataset.use_hdf5 = False
|
|
||||||
|
|
||||||
if not cfg.dataset.use_hdf5:
|
if not cfg.dataset.use_hdf5:
|
||||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||||
|
|
170
vall_e/data.py
170
vall_e/data.py
|
@ -300,6 +300,11 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
task = random.choice(self.tasks)
|
task = random.choice(self.tasks)
|
||||||
|
|
||||||
|
# ensure a speaker has at least four utterances
|
||||||
|
# default to tts if not
|
||||||
|
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
|
||||||
|
task = "tts"
|
||||||
|
|
||||||
noise_scale = 0.125
|
noise_scale = 0.125
|
||||||
# text-to-speech
|
# text-to-speech
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
|
@ -349,7 +354,7 @@ class Dataset(_Dataset):
|
||||||
# clean speech editing
|
# clean speech editing
|
||||||
elif task == "cse" or task == "nse":
|
elif task == "cse" or task == "nse":
|
||||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
|
||||||
sampled = random.choice([*choices], 4)
|
sampled = random.sample([*choices], 4)
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
|
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
|
||||||
|
@ -359,8 +364,8 @@ class Dataset(_Dataset):
|
||||||
qnts = [ _load_quants(path) for path in sampled ]
|
qnts = [ _load_quants(path) for path in sampled ]
|
||||||
|
|
||||||
# remove <s></s>
|
# remove <s></s>
|
||||||
for text in texts:
|
for i in range(len(texts)):
|
||||||
text = text[1:-1]
|
texts[i] = texts[i][1:-1]
|
||||||
|
|
||||||
pre_text, mid_text, post_text, edit_text = texts
|
pre_text, mid_text, post_text, edit_text = texts
|
||||||
pre_prom, mid_prom, post_prom, edit_prom = qnts
|
pre_prom, mid_prom, post_prom, edit_prom = qnts
|
||||||
|
@ -376,11 +381,11 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# create new text
|
# create new text
|
||||||
text = torch.cat(
|
text = torch.cat(
|
||||||
[ 1 ] + # <s>
|
[ torch.Tensor( [ 1 ] ).to(dtype=self.text_dtype) ] + # <s>
|
||||||
([ pre_text ] if pre_text is not None else []) +
|
([ pre_text, torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
|
||||||
[ edit_text ] +
|
[ edit_text ] + # 'edit text'
|
||||||
([ post_post ] if post_post is not None else []) +
|
([ torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
|
||||||
[ 2 ] # </s>
|
[ torch.Tensor( [ 2 ] ).to(dtype=self.text_dtype) ] # </s>
|
||||||
)
|
)
|
||||||
|
|
||||||
if task == "nse":
|
if task == "nse":
|
||||||
|
@ -397,7 +402,7 @@ class Dataset(_Dataset):
|
||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
n = repeat_extend_audio(noise, proms.shape[0])
|
n = repeat_extend_audio(noise, proms.shape[0])
|
||||||
# merge the noise over the utterance
|
# merge the noise over the utterance
|
||||||
return merge_audio(proms, noise, scale=[1, noise_scale], device="cpu")
|
return merge_audio(proms, n, scale=[1, noise_scale], device="cpu")
|
||||||
|
|
||||||
# apply noise to all pieces
|
# apply noise to all pieces
|
||||||
pre_prom = noise_proms( pre_prom )
|
pre_prom = noise_proms( pre_prom )
|
||||||
|
@ -649,14 +654,17 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
return train_dl, subtrain_dl, val_dl
|
return train_dl, subtrain_dl, val_dl
|
||||||
|
|
||||||
# parse yaml to create an hdf5 tile
|
# parse yaml to create an hdf5 file
|
||||||
def create_dataset_hdf5():
|
def create_dataset_hdf5():
|
||||||
|
cfg.dataset.use_hdf5 = True
|
||||||
|
cfg.load_hdf5(write=True)
|
||||||
|
|
||||||
symmap = get_phone_symmap()
|
symmap = get_phone_symmap()
|
||||||
|
|
||||||
root = cfg.cfg_path
|
root = cfg.cfg_path
|
||||||
hf = cfg.hdf5
|
hf = cfg.hdf5
|
||||||
|
|
||||||
def add( dir, type="training" ):
|
def add( dir, type="training", audios=True, texts=True ):
|
||||||
dir = "./" + str(dir)
|
dir = "./" + str(dir)
|
||||||
name = dir.replace(root, "")
|
name = dir.replace(root, "")
|
||||||
|
|
||||||
|
@ -670,7 +678,10 @@ def create_dataset_hdf5():
|
||||||
# grab IDs for every file
|
# grab IDs for every file
|
||||||
ids = { ".".join(file.split(".")[:-2]) for file in files }
|
ids = { ".".join(file.split(".")[:-2]) for file in files }
|
||||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||||
if not os.path.exists(f'{root}/{name}/{id}.qnt.pt') or not os.path.exists(f'{root}/{name}/{id}.phn.txt'):
|
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
|
||||||
|
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
|
||||||
|
|
||||||
|
if not audio_exists or not text_exists:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key = f'{type}/{name}/{id}'
|
key = f'{type}/{name}/{id}'
|
||||||
|
@ -681,27 +692,29 @@ def create_dataset_hdf5():
|
||||||
group = hf.create_group(key)
|
group = hf.create_group(key)
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
if audios:
|
||||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||||
|
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||||
|
|
||||||
# text
|
# text
|
||||||
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
|
if texts:
|
||||||
content = f.read()
|
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
|
||||||
split = content.split(" ")
|
content = f.read()
|
||||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
split = content.split(" ")
|
||||||
for s in set(phones):
|
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||||
if s not in symmap:
|
for s in set(phones):
|
||||||
symmap[s] = len(symmap.keys())
|
if s not in symmap:
|
||||||
phn = [ symmap[s] for s in phones ]
|
symmap[s] = len(symmap.keys())
|
||||||
|
phn = [ symmap[s] for s in phones ]
|
||||||
|
|
||||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||||
|
|
||||||
# metadata
|
# metadata
|
||||||
group.attrs['id'] = id
|
group.attrs['id'] = id
|
||||||
group.attrs['type'] = type
|
group.attrs['type'] = type
|
||||||
group.attrs['speaker'] = name
|
group.attrs['speaker'] = name
|
||||||
group.attrs['duration'] = qnt.shape[0] / 75
|
group.attrs['duration'] = qnt.shape[0] / 75
|
||||||
group.attrs['phonemes'] = len(phn)
|
group.attrs['phonemes'] = len(phn)
|
||||||
|
|
||||||
# training
|
# training
|
||||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||||
|
@ -713,10 +726,13 @@ def create_dataset_hdf5():
|
||||||
|
|
||||||
# noise
|
# noise
|
||||||
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
||||||
add( data_dir, type="noise" )
|
add( data_dir, type="noise", texts=False )
|
||||||
|
|
||||||
# write symmap
|
# write symmap
|
||||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
try:
|
||||||
|
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
hf.close()
|
hf.close()
|
||||||
|
|
||||||
|
@ -724,14 +740,15 @@ if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--task", type=str)
|
parser.add_argument("--action", type=str)
|
||||||
|
parser.add_argument("--tasks", type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
task = args.task
|
task = args.action
|
||||||
|
|
||||||
if args.task == "hdf5":
|
if args.action == "hdf5":
|
||||||
create_dataset_hdf5()
|
create_dataset_hdf5()
|
||||||
elif args.task == "sample":
|
elif args.action == "sample":
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
samples = {
|
samples = {
|
||||||
|
@ -745,77 +762,18 @@ if __name__ == "__main__":
|
||||||
del v[i]['proms']
|
del v[i]['proms']
|
||||||
del v[i]['resps']
|
del v[i]['resps']
|
||||||
print(f'{k}:', v)
|
print(f'{k}:', v)
|
||||||
"""
|
elif args.action == "tasks":
|
||||||
elif args.task == "tasks":
|
|
||||||
index = 0
|
index = 0
|
||||||
task = "ns"
|
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||||
|
|
||||||
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
batch = next(iter(train_dl))
|
||||||
|
|
||||||
train_dataset, val_dataset = create_datasets()
|
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
||||||
train_dataset.task_symmap = get_task_symmap()
|
if task not in cfg.dataset.tasks_list:
|
||||||
|
continue
|
||||||
|
|
||||||
if cfg.dataset.sample_type == "speaker":
|
print(text, task)
|
||||||
spkr_name = train_dataset.spkrs[index]
|
decode_to_file( proms, f"./.{task}.proms.wav", device="cpu" )
|
||||||
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
decode_to_file( resps, f"./.{task}.resps.wav", device="cpu" )
|
||||||
path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])])
|
break
|
||||||
else:
|
|
||||||
path = train_dataset.paths[index]
|
|
||||||
spkr_name = cfg.get_spkr(path)
|
|
||||||
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
|
||||||
key = _get_hdf5_path(path)
|
|
||||||
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
|
|
||||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
|
||||||
else:
|
|
||||||
text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
|
|
||||||
resps = _load_quants(path)
|
|
||||||
|
|
||||||
noise = None
|
|
||||||
if task == "ns" or task == "sr":
|
|
||||||
# sample random noise
|
|
||||||
noise = train_dataset.sample_noise()
|
|
||||||
|
|
||||||
decode_to_file( noise, "./.noise.wav", device="cpu" )
|
|
||||||
|
|
||||||
# extend the noise to fill the target audio
|
|
||||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
|
||||||
# create the input prompt by merging the target audio with the noise
|
|
||||||
proms = merge_audio(resps, noise, scale=[1, 0.125])
|
|
||||||
# set the target to just be the noise if <sr>
|
|
||||||
if task == "sr":
|
|
||||||
resps = noise
|
|
||||||
# prepend the task token
|
|
||||||
proms = torch.cat( [train_dataset.get_task_token(task), proms] )
|
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
|
||||||
if random.random() < 0.5:
|
|
||||||
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
|
||||||
# target speech extraction
|
|
||||||
elif task == "tse":
|
|
||||||
# sample a random, clean, utterance for the target speaker
|
|
||||||
clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
||||||
# sample a random, clean utterance from a different speaker
|
|
||||||
other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
|
|
||||||
# overlay the random speaker over the target audio
|
|
||||||
|
|
||||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
|
||||||
if other_proms.shape[0] == smallest_size:
|
|
||||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
|
|
||||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
|
||||||
else:
|
|
||||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
|
|
||||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
|
||||||
|
|
||||||
# stitch together the promps
|
|
||||||
proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
|
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
|
||||||
if random.random() < 0.5:
|
|
||||||
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
|
||||||
|
|
||||||
decode_to_file( proms, "./.proms.wav", device="cpu" )
|
|
||||||
decode_to_file( resps, "./.resps.wav", device="cpu" )
|
|
||||||
|
|
||||||
if noise is not None:
|
|
||||||
decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
|
|
||||||
"""
|
|
Loading…
Reference in New Issue
Block a user