ughh
This commit is contained in:
parent
3774fcbdee
commit
14709ac67f
|
@ -1034,13 +1034,29 @@ if __name__ == "__main__":
|
||||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for k, v in samples.items():
|
for k, v in samples.items():
|
||||||
for i in range(len(v)):
|
for i in range(len(v)):
|
||||||
del v[i]['proms']
|
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
|
||||||
del v[i]['resps']
|
"""
|
||||||
print(f'{k}:', v)
|
try:
|
||||||
|
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
|
||||||
|
try:
|
||||||
|
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
|
||||||
|
"""
|
||||||
|
v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||||
|
v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||||
|
|
||||||
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
for k, v in samples.items():
|
||||||
|
for i in range(len(v)):
|
||||||
|
print(f'{k}[{i}]:', v[i])
|
||||||
|
|
||||||
|
#train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||||
|
|
||||||
elif args.action == "tasks":
|
elif args.action == "tasks":
|
||||||
index = 0
|
index = 0
|
||||||
|
|
|
@ -212,12 +212,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
||||||
dummy = False
|
dummy = False
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = dict(
|
metadata = dict(
|
||||||
chunk_length=120,
|
chunk_length= codes.shape[-1],
|
||||||
original_length=0,
|
original_length=0,
|
||||||
input_db=-12,
|
input_db=-12,
|
||||||
channels=1,
|
channels=1,
|
||||||
sample_rate=model.sample_rate,
|
sample_rate=model.sample_rate,
|
||||||
padding=False,
|
padding=True,
|
||||||
dac_version='1.0.0',
|
dac_version='1.0.0',
|
||||||
)
|
)
|
||||||
dummy = True
|
dummy = True
|
||||||
|
|
|
@ -658,8 +658,11 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||||
|
|
||||||
|
# Disabling for now, it might be broken
|
||||||
|
"""
|
||||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||||
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
|
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
|
||||||
|
"""
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user