final tweaks (again) before training restarts

This commit is contained in:
mrq 2024-05-08 02:11:38 -05:00
parent 215800484d
commit c6e0f905b5
3 changed files with 40 additions and 14 deletions

View File

@ -64,6 +64,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
waveform, sample_rate = None, None waveform, sample_rate = None, None
language = metadata[filename]["language"] if "language" in metadata[filename] else "english" language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
if f'{dataset_name}/{speaker_id}' not in dataset:
dataset.append(f'{dataset_name}/{speaker_id}') dataset.append(f'{dataset_name}/{speaker_id}')
if len(metadata[filename]["segments"]) == 0 or not use_slices: if len(metadata[filename]["segments"]) == 0 or not use_slices:

View File

@ -763,7 +763,13 @@ def create_dataset_metadata( skip_existing=True ):
name = str(dir) name = str(dir)
name = name.replace(root, "") name = name.replace(root, "")
metadata_path = Path(f"{metadata_root}/{name}.json") # yucky
speaker_name = name
if "LbriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
@ -783,9 +789,9 @@ def create_dataset_metadata( skip_existing=True ):
if not audio_exists or not text_exists: if not audio_exists or not text_exists:
continue continue
key = f'{type}/{name}/{id}' key = f'{type}/{speaker_name}/{id}'
if skip_existing and key in metadata: if skip_existing and id in metadata:
continue continue
if id not in metadata: if id not in metadata:
@ -816,15 +822,18 @@ def create_dataset_metadata( skip_existing=True ):
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = json_metadata["phonemes"] content = json_metadata["phonemes"]
txt = json_metadata["text"] txt = json_metadata["text"]
lang = json_metadata["language"][:2]
else: else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
txt = "" txt = ""
lang = "en"
phn = cfg.tokenizer.encode("".join(content)) phn = cfg.tokenizer.encode("".join(content))
phn = np.array(phn).astype(np.uint8) phn = np.array(phn).astype(np.uint8)
metadata[id]["phones"] = len(phn) metadata[id]["phones"] = len(phn)
metadata[id]["transcription"] = txt metadata[id]["transcription"] = txt
metadata[id]["language"] = lang
except Exception as e: except Exception as e:
#raise e #raise e
print(id, e) print(id, e)
@ -849,20 +858,25 @@ def create_dataset_metadata( skip_existing=True ):
def create_dataset_hdf5( skip_existing=True ): def create_dataset_hdf5( skip_existing=True ):
cfg.dataset.use_hdf5 = True cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True) cfg.load_hdf5(write=True)
hf = cfg.hdf5
symmap = get_phone_symmap() symmap = get_phone_symmap()
root = str(cfg.data_dir) root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir) metadata_root = str(cfg.metadata_dir)
hf = cfg.hdf5
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
def add( dir, type="training", audios=True, texts=True ): def add( dir, type="training", audios=True, texts=True ):
name = str(dir) name = str(dir)
name = name.replace(root, "") name = name.replace(root, "")
metadata_path = Path(f"{metadata_root}/{name}.json") # yucky
speaker_name = name
if "LbriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
@ -882,7 +896,8 @@ def create_dataset_hdf5( skip_existing=True ):
if not audio_exists or not text_exists: if not audio_exists or not text_exists:
continue continue
key = f'{type}/{name}/{id}'
key = f'{type}/{speaker_name}/{id}'
""" """
if skip_existing and key in hf: if skip_existing and key in hf:
@ -893,7 +908,7 @@ def create_dataset_hdf5( skip_existing=True ):
group.attrs['id'] = id group.attrs['id'] = id
group.attrs['type'] = type group.attrs['type'] = type
group.attrs['speaker'] = name group.attrs['speaker'] = speaker_name
if id not in metadata: if id not in metadata:
metadata[id] = {} metadata[id] = {}
@ -930,9 +945,11 @@ def create_dataset_hdf5( skip_existing=True ):
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = json_metadata["phonemes"] content = json_metadata["phonemes"]
txt = json_metadata["text"] txt = json_metadata["text"]
lang = json_metadata["language"][:2]
else: else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
txt = "" txt = ""
lang = "en"
phn = cfg.tokenizer.encode("".join(content)) phn = cfg.tokenizer.encode("".join(content))
phn = np.array(phn).astype(np.uint8) phn = np.array(phn).astype(np.uint8)
@ -942,9 +959,11 @@ def create_dataset_hdf5( skip_existing=True ):
group.attrs['phonemes'] = len(phn) group.attrs['phonemes'] = len(phn)
group.attrs['transcription'] = txt group.attrs['transcription'] = txt
group.attrs['language'] = lang
metadata[id]["phones"] = len(phn) metadata[id]["phones"] = len(phn)
metadata[id]["transcription"] = txt metadata[id]["transcription"] = txt
metadata[id]["language"] = lang
else: else:
group.attrs['phonemes'] = 0 group.attrs['phonemes'] = 0
metadata[id]["phones"] = 0 metadata[id]["phones"] = 0
@ -958,15 +977,15 @@ def create_dataset_hdf5( skip_existing=True ):
# training # training
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
add( data_dir, type="training" ) add( data_dir, type="training" )
# validation # validation
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'): for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
add( data_dir, type="validation" ) add( data_dir, type="validation" )
# noise # noise
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
add( data_dir, type="noise", texts=False ) add( data_dir, type="noise", texts=False )
# write symmap # write symmap

View File

@ -272,11 +272,17 @@ def _replace_file_extension(path, suffix):
@torch.inference_mode() @torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
if cfg.inference.audio_backend == "dac": if cfg.inference.audio_backend == "dac":
model = _load_dac_model(device, levels=levels) model = _load_dac_model(device, levels=levels )
signal = AudioSignal(wav, sample_rate=sr) signal = AudioSignal(wav, sample_rate=sr)
artifact = model.compress(signal, 5.0, verbose=False, n_quantizers=levels if isinstance(levels, int) else None)
if not isinstance(levels, int):
levels = 8 if model.model_type == "24khz" else None
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=False): # or True for about 2x speed, not enabling by default for systems that do not have bfloat16
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
# trim to 8 codebooks if 24Khz # trim to 8 codebooks if 24Khz
# probably redundant with levels, should rewrite logic eventuall
if model.model_type == "24khz": if model.model_type == "24khz":
artifact.codes = artifact.codes[:, :8, :] artifact.codes = artifact.codes[:, :8, :]