fixes, throw an exception when using NAR only model with non-unified position IDs, since for some reason it outputs garbage for the NAR
This commit is contained in:
parent
4456d3172b
commit
97c5241bef
|
@ -808,7 +808,8 @@ class Dataset(_Dataset):
|
|||
if key not in cfg.hdf5:
|
||||
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||
|
||||
metadata = cfg.hdf5[key].attrs
|
||||
#metadata = cfg.hdf5[key].attrs
|
||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
resps = cfg.hdf5[key]["audio"][:, :]
|
||||
|
@ -915,7 +916,8 @@ class Dataset(_Dataset):
|
|||
if key not in cfg.hdf5:
|
||||
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||
|
||||
metadata = cfg.hdf5[key].attrs
|
||||
# I need to do some weird coersion to a normal dict because it'll bitch about Hdf5 objects not being pickleable in worker processes
|
||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
resps = cfg.hdf5[key]["audio"][:, :]
|
||||
|
|
|
@ -536,8 +536,10 @@ def example_usage():
|
|||
engines = Engines({"ar+nar": engine})
|
||||
engines.setup()
|
||||
|
||||
"""
|
||||
if cfg.optimizations.model_offloading:
|
||||
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
|
||||
"""
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
|
|
|
@ -438,6 +438,10 @@ class Base(nn.Module):
|
|||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
|
||||
# there seems to be a problem with the NAR-only model with non-unified position IDs.............
|
||||
if "len" in self.capabilities and not unified_position_ids:
|
||||
raise Exception("ERROR: model instability for NAR-only model when not using unified position IDs.")
|
||||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
|
@ -1081,11 +1085,12 @@ class Base(nn.Module):
|
|||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||
|
||||
# ending input will not have a separator later
|
||||
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
||||
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = torch.cat( [
|
||||
torch.Tensor([*range(get_input_token_length(name, input))])
|
||||
torch.Tensor([*range(get_input_token_length(name, input))]).to(dtype=torch.int32)
|
||||
for name, input in batch_input if name != "task"
|
||||
] )
|
||||
|
||||
|
|
|
@ -390,7 +390,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = NAR(**kwargs).to(device)
|
||||
steps = 500
|
||||
steps = 250
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
Loading…
Reference in New Issue
Block a user