ughh
This commit is contained in:
parent
3774fcbdee
commit
14709ac67f
|
@ -1034,13 +1034,29 @@ if __name__ == "__main__":
|
|||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
|
||||
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
del v[i]['proms']
|
||||
del v[i]['resps']
|
||||
print(f'{k}:', v)
|
||||
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
|
||||
"""
|
||||
try:
|
||||
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
|
||||
try:
|
||||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
|
||||
"""
|
||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
print(f'{k}[{i}]:', v[i])
|
||||
|
||||
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||
#train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||
|
||||
elif args.action == "tasks":
|
||||
index = 0
|
||||
|
|
|
@ -212,12 +212,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
dummy = False
|
||||
if metadata is None:
|
||||
metadata = dict(
|
||||
chunk_length=120,
|
||||
chunk_length= codes.shape[-1],
|
||||
original_length=0,
|
||||
input_db=-12,
|
||||
channels=1,
|
||||
sample_rate=model.sample_rate,
|
||||
padding=False,
|
||||
padding=True,
|
||||
dac_version='1.0.0',
|
||||
)
|
||||
dummy = True
|
||||
|
|
|
@ -658,8 +658,11 @@ class Base(nn.Module):
|
|||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
# Disabling for now, it might be broken
|
||||
"""
|
||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
|
||||
"""
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user