This commit is contained in:
mrq 2024-05-12 07:30:59 -05:00
parent 3774fcbdee
commit 14709ac67f
3 changed files with 25 additions and 6 deletions

View File

@ -1034,13 +1034,29 @@ if __name__ == "__main__":
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ], "validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
} }
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
for k, v in samples.items(): for k, v in samples.items():
for i in range(len(v)): for i in range(len(v)):
del v[i]['proms'] for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
del v[i]['resps'] """
print(f'{k}:', v) try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
"""
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") for k, v in samples.items():
for i in range(len(v)):
print(f'{k}[{i}]:', v[i])
#train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
elif args.action == "tasks": elif args.action == "tasks":
index = 0 index = 0

View File

@ -212,12 +212,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dummy = False dummy = False
if metadata is None: if metadata is None:
metadata = dict( metadata = dict(
chunk_length=120, chunk_length= codes.shape[-1],
original_length=0, original_length=0,
input_db=-12, input_db=-12,
channels=1, channels=1,
sample_rate=model.sample_rate, sample_rate=model.sample_rate,
padding=False, padding=True,
dac_version='1.0.0', dac_version='1.0.0',
) )
dummy = True dummy = True

View File

@ -658,8 +658,11 @@ class Base(nn.Module):
else: else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}') raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
# Disabling for now, it might be broken
"""
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention ) self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
"""
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)