fixes, throw an exception when using NAR only model with non-unified position IDs, since for some reason it outputs garbage for the NAR

This commit is contained in:
mrq 2024-08-02 22:25:49 -05:00
parent 4456d3172b
commit 97c5241bef
4 changed files with 13 additions and 4 deletions

View File

@ -808,7 +808,8 @@ class Dataset(_Dataset):
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
metadata = cfg.hdf5[key].attrs
#metadata = cfg.hdf5[key].attrs
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
text = cfg.hdf5[key]["text"][:]
resps = cfg.hdf5[key]["audio"][:, :]
@ -915,7 +916,8 @@ class Dataset(_Dataset):
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
metadata = cfg.hdf5[key].attrs
# I need to do some weird coersion to a normal dict because it'll bitch about Hdf5 objects not being pickleable in worker processes
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
text = cfg.hdf5[key]["text"][:]
resps = cfg.hdf5[key]["audio"][:, :]

View File

@ -536,8 +536,10 @@ def example_usage():
engines = Engines({"ar+nar": engine})
engines.setup()
"""
if cfg.optimizations.model_offloading:
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
"""
"""
torch.save( {

View File

@ -438,6 +438,10 @@ class Base(nn.Module):
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
# there seems to be a problem with the NAR-only model with non-unified position IDs.............
if "len" in self.capabilities and not unified_position_ids:
raise Exception("ERROR: model instability for NAR-only model when not using unified position IDs.")
self.unified_position_ids = unified_position_ids
self.text_emb = Embedding(n_text_tokens, d_model)
@ -1081,11 +1085,12 @@ class Base(nn.Module):
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
# ending input will not have a separator later
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
for batch_index, batch_input in enumerate(inputs):
batch = torch.cat( [
torch.Tensor([*range(get_input_token_length(name, input))])
torch.Tensor([*range(get_input_token_length(name, input))]).to(dtype=torch.int32)
for name, input in batch_input if name != "task"
] )

View File

@ -390,7 +390,7 @@ def example_usage():
"""
model = NAR(**kwargs).to(device)
steps = 500
steps = 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""