This commit is contained in:
mrq 2024-05-12 07:30:59 -05:00
parent 3774fcbdee
commit 14709ac67f
3 changed files with 25 additions and 6 deletions

View File

@ -1034,13 +1034,29 @@ if __name__ == "__main__":
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
for k, v in samples.items():
for i in range(len(v)):
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
"""
try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
"""
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items():
for i in range(len(v)):
print(f'{k}[{i}]:', v[i])
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
#train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
elif args.action == "tasks":
index = 0

View File

@ -212,12 +212,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dummy = False
if metadata is None:
metadata = dict(
chunk_length=120,
chunk_length= codes.shape[-1],
original_length=0,
input_db=-12,
channels=1,
sample_rate=model.sample_rate,
padding=False,
padding=True,
dac_version='1.0.0',
)
dummy = True

View File

@ -658,8 +658,11 @@ class Base(nn.Module):
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
# Disabling for now, it might be broken
"""
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
"""
self.classifier = nn.Linear(d_model, n_resp_tokens)