unified more things with training the AR+NAR monolothic model
This commit is contained in:
parent
40ef34e1ca
commit
d07c63b9d8
|
@ -131,6 +131,7 @@ class Dataset:
|
|||
|
||||
phones_range: list[int] = field(default_factory=lambda: [4, 256])
|
||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
|
||||
min_utterances: int = 0
|
||||
|
||||
random_utterance: float = 1.0
|
||||
max_prompts: int = 3
|
||||
|
|
130
vall_e/data.py
130
vall_e/data.py
|
@ -59,24 +59,30 @@ def _get_quant_path(path):
|
|||
def _get_phone_path(path):
|
||||
return _replace_file_extension(path, ".phn.txt")
|
||||
|
||||
_total_durations = {}
|
||||
|
||||
@cfg.diskcache()
|
||||
def _calculate_durations( type="training" ):
|
||||
if type in _total_durations:
|
||||
return _total_durations[type]
|
||||
return 0
|
||||
|
||||
@cfg.diskcache()
|
||||
def _load_paths(dataset, type="training"):
|
||||
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
|
||||
"""
|
||||
def _load_paths_from_hdf5(dataset, type="training"):
|
||||
return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
|
||||
def _load_paths_from_disk(dataset, type="training"):
|
||||
return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
"""
|
||||
|
||||
def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||
|
||||
def _validate( entry ):
|
||||
if "phones" not in entry or "duration" not in entry:
|
||||
return False
|
||||
phones = entry['phones']
|
||||
duration = entry['duration']
|
||||
if type not in _total_durations:
|
||||
_total_durations[type] = 0
|
||||
_total_durations[type] += entry['duration']
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
metadata_path = data_dir / "metadata.json"
|
||||
|
@ -107,6 +113,9 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|||
def _validate(child):
|
||||
phones = child.attrs['phonemes']
|
||||
duration = child.attrs['duration']
|
||||
if type not in _total_durations:
|
||||
_total_durations[type] = 0
|
||||
_total_durations[type] += entry['duration']
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
||||
|
@ -172,6 +181,14 @@ class Dataset(_Dataset):
|
|||
self.dataset = cfg.dataset.training
|
||||
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||
|
||||
# cull speakers if they do not have enough utterances
|
||||
if cfg.dataset.min_utterances > 0:
|
||||
keys = list(self.paths_by_spkr_name.keys())
|
||||
for key in keys:
|
||||
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
||||
del self.paths_by_spkr_name[key]
|
||||
|
||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||
|
||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
|
@ -192,13 +209,8 @@ class Dataset(_Dataset):
|
|||
if len(self.paths) == 0 and training:
|
||||
raise ValueError("No valid path is found for training.")
|
||||
|
||||
# would be a better cost saving if we could fetch the duration during the validation pass but oh well
|
||||
self.duration = 0
|
||||
"""
|
||||
if cfg.dataset.use_hdf5:
|
||||
for path in tqdm(self.paths, desc="Calculating duration"):
|
||||
self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
|
||||
"""
|
||||
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
||||
self.duration = _calculate_durations(self.dataset_type)
|
||||
|
||||
@cached_property
|
||||
def phones(self):
|
||||
|
@ -663,57 +675,59 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
# grab IDs for every file
|
||||
ids = { ".".join(file.split(".")[:-2]) for file in files }
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
|
||||
try:
|
||||
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
|
||||
|
||||
if not audio_exists or not text_exists:
|
||||
continue
|
||||
|
||||
key = f'{type}/{name}/{id}'
|
||||
if key in hf:
|
||||
if skip_existing:
|
||||
if not audio_exists or not text_exists:
|
||||
continue
|
||||
del hf[key]
|
||||
|
||||
group = hf.create_group(key)
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = name
|
||||
key = f'{type}/{name}/{id}'
|
||||
if key in hf:
|
||||
if skip_existing:
|
||||
continue
|
||||
del hf[key]
|
||||
|
||||
metadata[id] = {}
|
||||
group = hf.create_group(key)
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = name
|
||||
|
||||
# audio
|
||||
if audios:
|
||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||
metadata[id] = {}
|
||||
|
||||
if "audio" in group:
|
||||
del group["audio"]
|
||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||
group.attrs['duration'] = qnt.shape[0] / 75
|
||||
metadata[id]["duration"] = qnt.shape[0] / 75
|
||||
else:
|
||||
group.attrs['duration'] = 0
|
||||
metadata[id]["duration"] = 0
|
||||
|
||||
# text
|
||||
if texts:
|
||||
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f:
|
||||
content = f.read().split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
|
||||
for s in set(phones):
|
||||
if s not in symmap:
|
||||
symmap[s] = len(symmap.keys())
|
||||
# audio
|
||||
if audios:
|
||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||
|
||||
phn = [ symmap[s] for s in phones ]
|
||||
if "audio" in group:
|
||||
del group["audio"]
|
||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||
group.attrs['duration'] = qnt.shape[0] / 75
|
||||
metadata[id]["duration"] = qnt.shape[0] / 75
|
||||
else:
|
||||
group.attrs['duration'] = 0
|
||||
metadata[id]["duration"] = 0
|
||||
|
||||
# text
|
||||
if texts:
|
||||
content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
|
||||
for s in set(phones):
|
||||
if s not in symmap:
|
||||
symmap[s] = len(symmap.keys())
|
||||
|
||||
if "text" in group:
|
||||
del group["text"]
|
||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||
group.attrs['phonemes'] = len(phn)
|
||||
metadata[id]["phones"] = len(phn)
|
||||
else:
|
||||
group.attrs['phonemes'] = 0
|
||||
metadata[id]["phones"] = 0
|
||||
phn = [ symmap[s] for s in phones ]
|
||||
|
||||
if "text" in group:
|
||||
del group["text"]
|
||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||
group.attrs['phonemes'] = len(phn)
|
||||
metadata[id]["phones"] = len(phn)
|
||||
else:
|
||||
group.attrs['phonemes'] = 0
|
||||
metadata[id]["phones"] = 0
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
|
|
@ -94,19 +94,10 @@ class AR_NAR(Base):
|
|||
|
||||
# is training
|
||||
if n_levels == self.n_resp_levels:
|
||||
if random.random() < cfg.models.ar_nar.p_ar_nar:
|
||||
quant_levels = None
|
||||
|
||||
targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||
resps_list = self._unsqueeze_list(targ_list)
|
||||
else:
|
||||
quant_levels = torch.randint(1, self.n_resp_levels, (batch_size,))
|
||||
|
||||
targ_list = [o[..., l] for o, l in zip(resps_list, quant_levels)]
|
||||
resps_list = [o[..., : l] for o, l in zip(resps_list, quant_levels)]
|
||||
|
||||
if quant_levels is not None:
|
||||
quant_levels.to(device=device)
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,))
|
||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)]
|
||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)]
|
||||
quant_levels.to(device=device)
|
||||
|
||||
return super().forward(
|
||||
text_list=text_list,
|
||||
|
@ -246,8 +237,6 @@ def example_usage():
|
|||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
print([ name for name, _ in model.named_parameters()])
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=600 ):
|
||||
|
|
|
@ -392,7 +392,6 @@ class Base(nn.Module):
|
|||
# compute loss if the target is given
|
||||
if targ_list is not None:
|
||||
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
||||
|
||||
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
||||
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
|
||||
# remake input sequence
|
||||
|
@ -401,23 +400,16 @@ class Base(nn.Module):
|
|||
# process each batch
|
||||
for i in range(len(text_prom_list)):
|
||||
# for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
|
||||
if quant_levels is None:
|
||||
if quant_levels is None or quant_levels[i] == 0:
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
targ_list[i] = targ_list[i].clone().roll(-1, dims=0)
|
||||
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
targ_list[i][-1] = self.stop_token
|
||||
# for the NAR, ignore completely computing the loss against the text prompt
|
||||
else:
|
||||
text_prom_list[i][:] = self.ignore_index
|
||||
|
||||
# adjust the target sequence if needed for the AR
|
||||
if quant_levels is None:
|
||||
# creates a copy because this is aliased against input response sequence
|
||||
targ_list = [*targ_list]
|
||||
# shift the target response into the future by 1, and mark the rolled back token / last token as a stop token
|
||||
# this prepares the AR to actually generate autoregressive sequences
|
||||
for i in range(len(targ_list)):
|
||||
targ_list[i] = targ_list[i].roll(-1, dims=0)
|
||||
targ_list[i][-1] = self.stop_token
|
||||
|
||||
# create the new target sequence to compute the loss against
|
||||
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )
|
||||
inputs = torch.cat( logits )
|
||||
|
|
Loading…
Reference in New Issue
Block a user