ugh
This commit is contained in:
parent
d343bde09b
commit
19410a919e
|
@ -602,17 +602,17 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
state_dict = torch.load(path)
|
state_dict = torch.load(path)
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
state_dict = self.sampler.load_state(state_dict)
|
state_dict = self.sampler.set_state(state_dict)
|
||||||
else:
|
else:
|
||||||
for name, sampler in state_dict["samplers"].items():
|
for name, sampler in state_dict["samplers"].items():
|
||||||
if name not in self.samplers:
|
if name not in self.samplers:
|
||||||
continue
|
continue
|
||||||
self.samplers[name].load_state( sampler )
|
self.samplers[name].set_state( sampler )
|
||||||
|
|
||||||
for name, sampler in state_dict["spkr_samplers"].items():
|
for name, sampler in state_dict["spkr_samplers"].items():
|
||||||
if name not in self.spkr_samplers:
|
if name not in self.spkr_samplers:
|
||||||
continue
|
continue
|
||||||
self.spkr_samplers[name].load_state( sampler )
|
self.spkr_samplers[name].set_state( sampler )
|
||||||
|
|
||||||
def _get_phone_symmap(self):
|
def _get_phone_symmap(self):
|
||||||
return get_phone_symmap()
|
return get_phone_symmap()
|
||||||
|
|
|
@ -111,6 +111,7 @@ def load_engines(training=True):
|
||||||
|
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||||
print("DeepSpeed checkpoint missing, but weights found.")
|
print("DeepSpeed checkpoint missing, but weights found.")
|
||||||
loads_state_dict = True
|
loads_state_dict = True
|
||||||
|
|
|
@ -244,6 +244,15 @@ class AR_NAR(Base):
|
||||||
#mirostat=mirostat,
|
#mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# filter
|
||||||
|
"""
|
||||||
|
if self.arch_type in ["mamba2-hf"]:
|
||||||
|
for batch_index, resp in enumerate(resps_list):
|
||||||
|
for i, token in enumerate(resp):
|
||||||
|
if token >= 1024:
|
||||||
|
resps_list[batch_index][i] = 1023
|
||||||
|
"""
|
||||||
|
|
||||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||||
|
|
||||||
return prev_list
|
return prev_list
|
||||||
|
|
Loading…
Reference in New Issue
Block a user