ugh
This commit is contained in:
parent
d343bde09b
commit
19410a919e
|
@ -602,17 +602,17 @@ class Dataset(_Dataset):
|
|||
|
||||
state_dict = torch.load(path)
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.load_state(state_dict)
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
else:
|
||||
for name, sampler in state_dict["samplers"].items():
|
||||
if name not in self.samplers:
|
||||
continue
|
||||
self.samplers[name].load_state( sampler )
|
||||
self.samplers[name].set_state( sampler )
|
||||
|
||||
for name, sampler in state_dict["spkr_samplers"].items():
|
||||
if name not in self.spkr_samplers:
|
||||
continue
|
||||
self.spkr_samplers[name].load_state( sampler )
|
||||
self.spkr_samplers[name].set_state( sampler )
|
||||
|
||||
def _get_phone_symmap(self):
|
||||
return get_phone_symmap()
|
||||
|
|
|
@ -111,10 +111,11 @@ def load_engines(training=True):
|
|||
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||
print("DeepSpeed checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
|
||||
stats = None
|
||||
if loads_state_dict:
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
|
|
@ -207,7 +207,7 @@ class AR_NAR(Base):
|
|||
|
||||
prev_list = resps_list
|
||||
|
||||
for n in trange( max_levels, desc="NAR" ):
|
||||
for n in trange( max_levels, desc="NAR" ):
|
||||
level = prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
@ -244,6 +244,15 @@ class AR_NAR(Base):
|
|||
#mirostat=mirostat,
|
||||
)
|
||||
|
||||
# filter
|
||||
"""
|
||||
if self.arch_type in ["mamba2-hf"]:
|
||||
for batch_index, resp in enumerate(resps_list):
|
||||
for i, token in enumerate(resp):
|
||||
if token >= 1024:
|
||||
resps_list[batch_index][i] = 1023
|
||||
"""
|
||||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
|
Loading…
Reference in New Issue
Block a user