final tweaks, hopefully

This commit is contained in:
mrq 2024-04-28 22:28:29 -05:00
parent ffc334cf58
commit caad7ee3c9
6 changed files with 220 additions and 81 deletions

View File

@ -0,0 +1,99 @@
import os
import json
import torch
import torchaudio
from tqdm.auto import tqdm
from pathlib import Path
input_dataset = "metadata"
output_dataset = "metadata-cleaned"
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
for dataset_name in os.listdir(f'./{input_dataset}/'):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'):
print("Is not dir:", f'./{input_dataset}/{dataset_name}/')
continue
for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc=f"Processing speaker: {dataset_name}"):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'):
print("Is not dir:", f'./{input_dataset}/{dataset_name}/{speaker_id}')
continue
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/whisper.json')
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
if not inpath.exists():
continue
if outpath.exists():
continue
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
try:
in_metadata = json.loads(open(inpath, 'r', encoding='utf-8').read())
except Exception as e:
print("Failed to open metadata file:", inpath)
continue
out_metadata = {}
speaker_metadatas = {}
for filename, result in in_metadata.items():
language = result["language"] if "language" in result else "en"
out_metadata[filename] = {
"segments": [],
"language": language,
"text": "",
"start": 0,
"end": 0,
}
segments = []
text = []
start = 0
end = 0
diarized = False
for segment in result["segments"]:
# diarize split
if "speaker" in segment:
diarized = True
speaker_id = segment["speaker"]
if speaker_id not in speaker_metadatas:
speaker_metadatas[speaker_id] = {}
if filename not in speaker_metadatas[speaker_id]:
speaker_metadatas[speaker_id][filename] = {
"segments": [],
"language": language,
"text": "",
"start": 0,
"end": 0,
}
speaker_metadatas[speaker_id][filename]["segments"].append( segment )
else:
segments.append( segment )
text.append( segment["text"] )
start = min( start, segment["start"] )
end = max( end, segment["end"] )
out_metadata[filename]["segments"] = segments
out_metadata[filename]["text"] = " ".join(text).strip()
out_metadata[filename]["start"] = start
out_metadata[filename]["end"] = end
if len(segments) == 0:
del out_metadata[filename]
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
for speaker_id, out_metadata in speaker_metadatas.items():
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))

View File

@ -8,26 +8,27 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voice" # things that could be args
input_audio = "voices"
input_metadata = "metadata" input_metadata = "metadata"
output_dataset = "training-24K" output_dataset = "training-24K"
device = "cuda"
slice = "auto"
missing = { missing = {
"transcription": [], "transcription": [],
"audio": [] "audio": []
} }
device = "cuda"
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
for dataset_name in os.listdir(f'./{input_audio}/'): for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/') print("Is not dir:", f'./{input_audio}/{dataset_name}/')
continue continue
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}') print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
continue continue
@ -36,24 +37,23 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json') metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json')
if not metadata_path.exists(): if not metadata_path.exists():
#print("Does not exist:", metadata_path)
missing["transcription"].append(str(metadata_path)) missing["transcription"].append(str(metadata_path))
continue continue
try: try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e: except Exception as e:
#print("Failed to load metadata:", metadata_path, e)
missing["transcription"].append(str(metadata_path)) missing["transcription"].append(str(metadata_path))
continue continue
txts = [] txts = []
wavs = [] wavs = []
for filename in metadata.keys(): use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"]
for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}') inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
if not inpath.exists(): if not inpath.exists():
#print("Does not exist:", inpath)
missing["audio"].append(str(inpath)) missing["audio"].append(str(inpath))
continue continue
@ -63,9 +63,8 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
waveform, sample_rate = None, None waveform, sample_rate = None, None
language = metadata[filename]["language"] if "language" in metadata[filename] else "english" language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
if len(metadata[filename]["segments"]) == 0: if len(metadata[filename]["segments"]) == 0 or not use_slices:
id = pad(0, 4) outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}')
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
text = metadata[filename]["text"] text = metadata[filename]["text"]
if len(text) == 0: if len(text) == 0:
@ -91,8 +90,10 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
sample_rate sample_rate
)) ))
else: else:
i = 0
for segment in metadata[filename]["segments"]: for segment in metadata[filename]["segments"]:
id = pad(segment['id'], 4) id = pad(i, 4)
i = i + 1
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists(): if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():

View File

@ -7,30 +7,35 @@ import whisperx
from tqdm.auto import tqdm from tqdm.auto import tqdm
from pathlib import Path from pathlib import Path
device = "cuda" # should be args
batch_size = 16 batch_size = 16
device = "cuda"
dtype = "float16" dtype = "float16"
model_size = "large-v2" model_name = "large-v3"
input_audio = "voice" input_audio = "voices"
output_dataset = "metadata" output_dataset = "metadata"
skip_existing = True skip_existing = True
diarize = False
model = whisperx.load_model(model_size, device, compute_type=dtype) #
model = whisperx.load_model(model_name, device, compute_type=dtype)
align_model, align_model_metadata, align_model_language = (None, None, None) align_model, align_model_metadata, align_model_language = (None, None, None)
if diarize:
diarize_model = whisperx.DiarizationPipeline(device=device)
else:
diarize_model = None
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
for dataset_name in os.listdir(f'./{input_audio}/'): for dataset_name in os.listdir(f'./{input_audio}/'):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
continue continue
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
continue continue
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json') outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
@ -46,18 +51,29 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
if skip_existing and filename in metadata: if skip_existing and filename in metadata:
continue continue
if ".json" in filename:
continue
inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}' inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}'
if os.path.isdir(inpath):
continue
metadata[filename] = { metadata[filename] = {
"segments": [], "segments": [],
"language": "", "language": "",
"text": [], "text": "",
"start": 0,
"end": 0,
} }
audio = whisperx.load_audio(inpath) audio = whisperx.load_audio(inpath)
result = model.transcribe(audio, batch_size=batch_size) result = model.transcribe(audio, batch_size=batch_size)
language = result["language"] language = result["language"]
if language[:2] not in ["ja"]:
language = "en"
if align_model_language != language: if align_model_language != language:
tqdm.write(f'Loading language: {language}') tqdm.write(f'Loading language: {language}')
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device) align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
@ -68,12 +84,20 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
metadata[filename]["segments"] = result["segments"] metadata[filename]["segments"] = result["segments"]
metadata[filename]["language"] = language metadata[filename]["language"] = language
if diarize_model is not None:
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
text = [] text = []
start = 0
end = 0
for segment in result["segments"]: for segment in result["segments"]:
id = len(text)
text.append( segment["text"] ) text.append( segment["text"] )
metadata[filename]["segments"][id]["id"] = id start = min( start, segment["start"] )
end = max( end, segment["end"] )
metadata[filename]["text"] = " ".join(text).strip() metadata[filename]["text"] = " ".join(text).strip()
metadata[filename]["start"] = start
metadata[filename]["end"] = end
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata)) open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))

View File

@ -33,6 +33,14 @@ class _Config:
def cache_dir(self): def cache_dir(self):
return self.relpath / ".cache" return self.relpath / ".cache"
@property
def data_dir(self):
return self.relpath / "data"
@property
def metadata_dir(self):
return self.relpath / "metadata"
@property @property
def ckpt_dir(self): def ckpt_dir(self):
return self.relpath / "ckpt" return self.relpath / "ckpt"

View File

@ -85,46 +85,45 @@ def _calculate_durations( type="training" ):
@cfg.diskcache() @cfg.diskcache()
def _load_paths(dataset, type="training"): def _load_paths(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
def _load_paths_from_metadata(dataset_name, type="training", validate=False):
data_dir = dataset_name if cfg.dataset.use_hdf5 else cfg.data_dir / dataset_name
def _load_paths_from_metadata(data_dir, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def _validate( entry ): def _validate( entry ):
if "phones" not in entry or "duration" not in entry: phones = entry['phones'] if "phones" in entry else 0
return False duration = entry['duration'] if "duration" in entry else 0
phones = entry['phones']
duration = entry['duration']
if type not in _total_durations: if type not in _total_durations:
_total_durations[type] = 0 _total_durations[type] = 0
_total_durations[type] += entry['duration'] _total_durations[type] += duration
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = data_dir / "metadata.json" metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
metadata = {} metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists(): if cfg.dataset.use_metadata and metadata_path.exists():
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
if len(metadata) == 0: if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate ) return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
def key( dir, id ): def key( dir, id ):
if not cfg.dataset.use_hdf5: if not cfg.dataset.use_hdf5:
return data_dir / id return data_dir / id
return f"/{type}{_get_hdf5_path(data_dir)}/{id}" return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ] return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
def _get_hdf5_path(path): def _get_hdf5_path(path):
path = str(path) # to-do: better validation
if path[:2] != "./": #print(path)
path = f'./{path}' return str(path)
res = path.replace(cfg.cfg_path, "")
return res
def _get_hdf5_paths( data_dir, type="training", validate=False ): def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir) data_dir = str(data_dir)
@ -137,7 +136,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
_total_durations[type] += child.attrs['duration'] _total_durations[type] += child.attrs['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
key = f"/{type}{_get_hdf5_path(data_dir)}" key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else [] return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ): def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
@ -427,6 +426,9 @@ class Dataset(_Dataset):
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path) key = _get_hdf5_path(path)
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
text = cfg.hdf5[key]["text"][:] text = cfg.hdf5[key]["text"][:]
resps = cfg.hdf5[key]["audio"][:, :] resps = cfg.hdf5[key]["audio"][:, :]
@ -752,6 +754,10 @@ def create_train_val_dataloader():
# parse dataset into better to sample metadata # parse dataset into better to sample metadata
def create_dataset_metadata(): def create_dataset_metadata():
# need to fix
if True:
return
cfg.dataset.validate = False cfg.dataset.validate = False
cfg.dataset.use_hdf5 = False cfg.dataset.use_hdf5 = False
@ -805,14 +811,19 @@ def create_dataset_hdf5( skip_existing=True ):
symmap = get_phone_symmap() symmap = get_phone_symmap()
root = cfg.cfg_path root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
hf = cfg.hdf5 hf = cfg.hdf5
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
def add( dir, type="training", audios=True, texts=True ): def add( dir, type="training", audios=True, texts=True ):
name = "./" + str(dir) name = str(dir)
name = name .replace(root, "") name = name.replace(root, "")
metadata = {}
metadata_path = Path(f"{metadata_root}/{name}.json")
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
if not os.path.isdir(f'{root}/{name}/'): if not os.path.isdir(f'{root}/{name}/'):
return return
@ -831,36 +842,38 @@ def create_dataset_hdf5( skip_existing=True ):
continue continue
key = f'{type}/{name}/{id}' key = f'{type}/{name}/{id}'
if key in hf:
if skip_existing:
continue
del hf[key]
group = hf.create_group(key) if skip_existing and key in hf:
continue
group = hf.create_group(key) if key not in hf else hf[key]
group.attrs['id'] = id group.attrs['id'] = id
group.attrs['type'] = type group.attrs['type'] = type
group.attrs['speaker'] = name group.attrs['speaker'] = name
if id not in metadata:
metadata[id] = {} metadata[id] = {}
# audio # audio
if audios: if audios:
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16)
if _get_quant_extension() == ".dac": if _get_quant_extension() == ".dac":
if "audio" in group: dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
del group["audio"] qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
duration = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
metadata[id]["metadata"] = { metadata[id]["metadata"] = {
"original_length": qnt["metadata"]["original_length"], "original_length": dac["metadata"]["original_length"],
"sample_rate": qnt["metadata"]["sample_rate"], "sample_rate": dac["metadata"]["sample_rate"],
} }
else: else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75 duration = qnt.shape[0] / 75
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf') qnt = qnt.numpy().astype(np.int16)
if "audio" not in group:
group.create_dataset('audio', data=qnt, compression='lzf')
group.attrs['duration'] = duration group.attrs['duration'] = duration
metadata[id]["duration"] = duration metadata[id]["duration"] = duration
@ -870,52 +883,46 @@ def create_dataset_hdf5( skip_existing=True ):
# text # text
if texts: if texts:
if _get_quant_extension() == ".json": if _get_phone_extension() == ".json":
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = json_metadata["phonemes"] content = json_metadata["phonemes"]
txt = json_metadata["text"]
else: else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
txt = ""
"""
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
"""
phn = cfg.tokenizer.encode("".join(content)) phn = cfg.tokenizer.encode("".join(content))
phn = np.array(phn).astype(np.uint8) phn = np.array(phn).astype(np.uint8)
if "text" in group: if "text" not in group:
del group["text"] group.create_dataset('text', data=phn, compression='lzf')
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
group.create_dataset('transcription', data=txt, compression='lzf', chunks=True)
group.attrs['phonemes'] = len(phn) group.attrs['phonemes'] = len(phn)
group.attrs['transcription'] = txt
metadata[id]["phones"] = len(phn) metadata[id]["phones"] = len(phn)
metadata[id]["transcription"] = txt
else: else:
group.attrs['phonemes'] = 0 group.attrs['phonemes'] = 0
metadata[id]["phones"] = 0 metadata[id]["phones"] = 0
except Exception as e: except Exception as e:
pass raise e
#pass
with open(dir / "metadata.json", "w", encoding="utf-8") as f: with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) ) f.write( json.dumps( metadata ) )
# training # training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
add( data_dir, type="training" ) add( data_dir, type="training" )
# validation # validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'): for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
add( data_dir, type="validation" ) add( data_dir, type="validation" )
# noise # noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'): for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
add( data_dir, type="noise", texts=False ) add( data_dir, type="noise", texts=False )
# write symmap # write symmap

View File

@ -340,7 +340,7 @@ def example_usage():
def _load_quants(path) -> Tensor: def _load_quants(path) -> Tensor:
if cfg.inference.audio_backend == "dac": if cfg.inference.audio_backend == "dac":
qnt = np.load(f'{path}.dac', allow_pickle=True)[()] qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16) return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
qnt = _load_quants("./data/qnt") qnt = _load_quants("./data/qnt")
@ -350,7 +350,7 @@ def example_usage():
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
] ]
proms_list = [ proms_list = [
qnt[:75*3, :].to(device), qnt.to(device),
] ]
resps_list = [ resps_list = [
qnt.to(device), qnt.to(device),
@ -407,7 +407,7 @@ def example_usage():
frozen_params = set() frozen_params = set()
for k in list(embeddings.keys()): for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k): if re.findall(r'_emb.', k):
frozen_params.add(k) frozen_params.add(k)
else: else:
del embeddings[k] del embeddings[k]