This commit is contained in:
mrq 2024-06-15 12:29:03 -05:00
parent d343bde09b
commit 19410a919e
3 changed files with 15 additions and 5 deletions

View File

@ -602,17 +602,17 @@ class Dataset(_Dataset):
state_dict = torch.load(path)
if self.sampler_type == "path":
state_dict = self.sampler.load_state(state_dict)
state_dict = self.sampler.set_state(state_dict)
else:
for name, sampler in state_dict["samplers"].items():
if name not in self.samplers:
continue
self.samplers[name].load_state( sampler )
self.samplers[name].set_state( sampler )
for name, sampler in state_dict["spkr_samplers"].items():
if name not in self.spkr_samplers:
continue
self.spkr_samplers[name].load_state( sampler )
self.spkr_samplers[name].set_state( sampler )
def _get_phone_symmap(self):
return get_phone_symmap()

View File

@ -111,10 +111,11 @@ def load_engines(training=True):
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
print("DeepSpeed checkpoint missing, but weights found.")
loads_state_dict = True
stats = None
if loads_state_dict:
state = torch.load(load_path, map_location=torch.device(cfg.device))

View File

@ -207,7 +207,7 @@ class AR_NAR(Base):
prev_list = resps_list
for n in trange( max_levels, desc="NAR" ):
for n in trange( max_levels, desc="NAR" ):
level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
@ -244,6 +244,15 @@ class AR_NAR(Base):
#mirostat=mirostat,
)
# filter
"""
if self.arch_type in ["mamba2-hf"]:
for batch_index, resp in enumerate(resps_list):
for i, token in enumerate(resp):
if token >= 1024:
resps_list[batch_index][i] = 1023
"""
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list