nightmare of spaghetti that might break compat; mechanism to increase RVQ bins of an existing model without retraining, keeps sampled proms/resps at max RVQ level and trim off excess levels according to what model receives them, some other things I already forgot (I really hope no one else has weights being baked right now)
This commit is contained in:
parent
f7f6d3bf6d
commit
2d1a9f10c0
|
@ -1,7 +1,14 @@
|
|||
dataset:
|
||||
training: []
|
||||
training: [
|
||||
# "./training/valle/data/LibriTTS/994/",
|
||||
]
|
||||
|
||||
validation: []
|
||||
validation: [
|
||||
# "./training/valle/data/Validation/1188/",
|
||||
]
|
||||
noise: [
|
||||
# "./training/valle/data/Other/noise/",
|
||||
]
|
||||
|
||||
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
|
||||
|
@ -12,7 +19,7 @@ dataset:
|
|||
workers: 4
|
||||
cache: True
|
||||
|
||||
phones_range: [4, 256]
|
||||
phones_range: [4, 512]
|
||||
duration_range: [1.0, 16.0]
|
||||
|
||||
random_utterance: 1.0
|
||||
|
@ -20,9 +27,11 @@ dataset:
|
|||
prompt_duration: 3.0
|
||||
|
||||
sample_type: speaker
|
||||
tasks_list: ["tts"] # do NOT change this until you're ready to train for SpeechX tasks # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"]
|
||||
|
||||
tasks_list: ["tts"] # ["tts", "ns", "sr", "tse", "cse", "nse"]
|
||||
|
||||
models:
|
||||
_max_levels: 8
|
||||
_models:
|
||||
- name: "ar"
|
||||
size: "full"
|
||||
|
@ -38,13 +47,14 @@ models:
|
|||
tasks: 8
|
||||
arch_type: "retnet"
|
||||
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 16
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_clipping: 100
|
||||
|
||||
optimizer: Adamw
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: AdamW
|
||||
learning_rate: 1.0e-4
|
||||
|
||||
scheduler_type: ""
|
||||
#scheduler_type: OneCycle
|
||||
|
@ -66,13 +76,13 @@ hyperparameters:
|
|||
# decay_mom_rate: 0.0
|
||||
|
||||
evaluation:
|
||||
batch_size: 32
|
||||
batch_size: 16
|
||||
frequency: 500
|
||||
size: 32
|
||||
size: 16
|
||||
|
||||
steps: 300
|
||||
ar_temperature: 1.0
|
||||
nar_temperature: 0.2
|
||||
ar_temperature: 0.95
|
||||
nar_temperature: 0.25
|
||||
|
||||
trainer:
|
||||
iterations: 1_000_000
|
||||
|
@ -106,4 +116,7 @@ inference:
|
|||
normalize: False # do NOT change this unless you know exactly what you are doing.
|
||||
|
||||
bitsandbytes:
|
||||
enabled: false
|
||||
enabled: False
|
||||
injects: True
|
||||
linear: True
|
||||
embedding: True
|
||||
|
|
|
@ -196,6 +196,8 @@ class Model:
|
|||
|
||||
@dataclass()
|
||||
class Models:
|
||||
_max_levels: int = 0
|
||||
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=1),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=1),
|
||||
|
@ -232,6 +234,10 @@ class Models:
|
|||
for model in self._models:
|
||||
tasks = max(tasks, model.tasks)
|
||||
return tasks
|
||||
|
||||
@property
|
||||
def max_levels(self):
|
||||
return self._max_levels if self._max_levels > 0 else self.prom_levels
|
||||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
|
@ -261,7 +267,8 @@ class DeepSpeed:
|
|||
use_compression_training: bool = False
|
||||
compression_bits: int = 8
|
||||
|
||||
def get_ds_cfg(self, model):
|
||||
@cached_property
|
||||
def ds_cfg(self):
|
||||
scheduler_params = {}
|
||||
for k in cfg.hyperparameters.scheduler_params:
|
||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||
|
@ -277,7 +284,7 @@ class DeepSpeed:
|
|||
"params": {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
},
|
||||
} if not cfg.hyperparameters.optimizer.endswith("-torch") else None,
|
||||
"scheduler": {
|
||||
"type": cfg.hyperparameters.scheduler_type,
|
||||
"params": scheduler_params,
|
||||
|
@ -351,8 +358,8 @@ class DeepSpeed:
|
|||
for k in null_keys:
|
||||
del ds_cfg[k]
|
||||
|
||||
if os.path.exists("./config/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
|
||||
if os.path.exists("./data/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
||||
|
||||
return ds_cfg
|
||||
|
||||
|
@ -404,8 +411,8 @@ class BitsAndBytes:
|
|||
enabled: bool = False
|
||||
injects: bool = False
|
||||
|
||||
linear: bool = False
|
||||
embedding: bool = False
|
||||
linear: bool = True
|
||||
embedding: bool = True
|
||||
|
||||
@dataclass()
|
||||
class Config(_Config):
|
||||
|
|
|
@ -63,11 +63,8 @@ def _get_quant_path(path):
|
|||
def _get_phone_path(path):
|
||||
return _replace_file_extension(path, ".phn.txt")
|
||||
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
path = _get_quant_path(path)
|
||||
return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16)
|
||||
|
||||
return torch.load(path)[0][:, :].t().to(torch.int16)
|
||||
|
||||
@cache
|
||||
def _get_phones(path, lang_marker="en"):
|
||||
|
@ -215,12 +212,12 @@ class Dataset(_Dataset):
|
|||
def _get_task_symmap(self):
|
||||
return get_task_symmap()
|
||||
|
||||
def get_task_token( self, token ):
|
||||
def get_task_token( self, token, levels=cfg.models.max_levels ):
|
||||
if not hasattr(self, "task_symmap"):
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
|
||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
||||
|
||||
def sample_noise(self):
|
||||
def sample_noise(self):
|
||||
paths = []
|
||||
for data_dir in cfg.dataset.noise:
|
||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
||||
|
@ -228,7 +225,7 @@ class Dataset(_Dataset):
|
|||
|
||||
if False and cfg.dataset.use_hdf5:
|
||||
key = f'/noise/{_get_hdf5_path(path)}'
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
return qnt
|
||||
|
@ -260,7 +257,7 @@ class Dataset(_Dataset):
|
|||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
|
||||
|
@ -293,7 +290,7 @@ class Dataset(_Dataset):
|
|||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
|
||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
@ -316,7 +313,7 @@ class Dataset(_Dataset):
|
|||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio(resps, noise, scale=[1, noise_scale], device="cpu")
|
||||
proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" )
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
|
@ -358,7 +355,7 @@ class Dataset(_Dataset):
|
|||
|
||||
if cfg.dataset.use_hdf5:
|
||||
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
|
||||
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) for path in sampled ]
|
||||
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :]).to(torch.int16) for path in sampled ]
|
||||
else:
|
||||
texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ]
|
||||
qnts = [ _load_quants(path) for path in sampled ]
|
||||
|
@ -394,15 +391,15 @@ class Dataset(_Dataset):
|
|||
|
||||
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
|
||||
# but it's noise, it's supposed to be random
|
||||
def noise_proms( proms ):
|
||||
def noise_proms( p ):
|
||||
# ignore if we turned it off
|
||||
if proms is None:
|
||||
if p is None:
|
||||
return None
|
||||
|
||||
# extend the noise to fill the target audio
|
||||
n = repeat_extend_audio(noise, proms.shape[0])
|
||||
n = repeat_extend_audio(noise, p.shape[0])
|
||||
# merge the noise over the utterance
|
||||
return merge_audio(proms, n, scale=[1, noise_scale], device="cpu")
|
||||
return merge_audio(p, n, scale=[1, noise_scale], device="cpu")
|
||||
|
||||
# apply noise to all pieces
|
||||
pre_prom = noise_proms( pre_prom )
|
||||
|
@ -426,6 +423,8 @@ class Dataset(_Dataset):
|
|||
[ edit_prom ] +
|
||||
([ post_prom ] if post_prom is not None else [])
|
||||
)
|
||||
else:
|
||||
raise f'Undefined task: {task}'
|
||||
|
||||
"""
|
||||
# emulate SVC
|
||||
|
@ -450,6 +449,10 @@ class Dataset(_Dataset):
|
|||
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||
"""
|
||||
|
||||
# trim to fit to requested prom/resps levels
|
||||
proms = proms[:, :cfg.models.prom_levels]
|
||||
resps = resps[:, :cfg.models.prom_levels]
|
||||
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
|
|
|
@ -21,16 +21,16 @@ except Exception as e:
|
|||
cfg.inference.use_vocos = False
|
||||
|
||||
@cache
|
||||
def _load_encodec_model(device="cuda"):
|
||||
def _load_encodec_model(device="cuda", levels=cfg.models.max_levels):
|
||||
# Instantiate a pretrained EnCodec model
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
if cfg.models.prom_levels == 2:
|
||||
if levels == 2:
|
||||
bandwidth_id = 1.5
|
||||
elif cfg.models.prom_levels == 4:
|
||||
elif levels == 4:
|
||||
bandwidth_id = 3.0
|
||||
elif cfg.models.prom_levels == 8:
|
||||
elif levels == 8:
|
||||
bandwidth_id = 6.0
|
||||
|
||||
model = EncodecModel.encodec_model_24khz().to(device)
|
||||
|
@ -43,18 +43,18 @@ def _load_encodec_model(device="cuda"):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_vocos_model(device="cuda"):
|
||||
def _load_vocos_model(device="cuda", levels=cfg.models.max_levels):
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
||||
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
|
||||
model = model.to(device)
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
if cfg.models.prom_levels == 2:
|
||||
if levels == 2:
|
||||
bandwidth_id = 0
|
||||
elif cfg.models.prom_levels == 4:
|
||||
elif levels == 4:
|
||||
bandwidth_id = 1
|
||||
elif cfg.models.prom_levels == 8:
|
||||
elif levels == 8:
|
||||
bandwidth_id = 2
|
||||
|
||||
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
||||
|
@ -64,11 +64,11 @@ def _load_vocos_model(device="cuda"):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", vocos=cfg.inference.use_vocos):
|
||||
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.models.max_levels):
|
||||
if vocos:
|
||||
model = _load_vocos_model(device)
|
||||
model = _load_vocos_model(device, levels=levels)
|
||||
else:
|
||||
model = _load_encodec_model(device)
|
||||
model = _load_encodec_model(device, levels=levels)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -78,7 +78,7 @@ def unload_model():
|
|||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda"):
|
||||
def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels):
|
||||
"""
|
||||
Args:
|
||||
codes: (b q t)
|
||||
|
@ -94,7 +94,7 @@ def decode(codes: Tensor, device="cuda"):
|
|||
codes = rearrange(codes, "t q -> 1 q t")
|
||||
|
||||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||
model = _load_model(device)
|
||||
model = _load_model(device, levels=levels)
|
||||
|
||||
# upcast so it won't whine
|
||||
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
||||
|
@ -115,8 +115,8 @@ def decode(codes: Tensor, device="cuda"):
|
|||
return wav, model.sample_rate
|
||||
|
||||
# huh
|
||||
def decode_to_wave(resps: Tensor, device="cuda"):
|
||||
return decode(resps, device=device)
|
||||
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.models.max_levels):
|
||||
return decode(resps, device=device, levels=levels)
|
||||
|
||||
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
||||
wavs, sr = decode(resps, device=device)
|
||||
|
@ -129,14 +129,14 @@ def _replace_file_extension(path, suffix):
|
|||
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int = 24_000, device="cuda"):
|
||||
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.models.max_levels):
|
||||
"""
|
||||
Args:
|
||||
wav: (t)
|
||||
sr: int
|
||||
"""
|
||||
|
||||
model = _load_encodec_model(device)
|
||||
model = _load_encodec_model(device, levels=levels)
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||
wav = wav.to(device)
|
||||
|
@ -203,16 +203,16 @@ def repeat_extend_audio( qnt, target ):
|
|||
|
||||
# merges two quantized audios together
|
||||
# I don't know if this works
|
||||
def merge_audio( *args, device="cpu", scale=[] ):
|
||||
def merge_audio( *args, device="cpu", scale=[], levels=cfg.models.max_levels ):
|
||||
qnts = [*args]
|
||||
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
|
||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||
|
||||
if len(scale) == len(decoded):
|
||||
for i in range(len(scale)):
|
||||
decoded[i] = decoded[i] * scale[i]
|
||||
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, 24_000, device="cpu")[0].t()
|
||||
return encode(combined, 24_000, device="cpu", levels=levels)[0].t()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
@ -39,6 +39,7 @@ import os
|
|||
from torch import Tensor
|
||||
from torch.distributed import all_reduce
|
||||
from typing import Any, Protocol
|
||||
from functools import cached_property
|
||||
|
||||
from .base import TrainFeeder
|
||||
|
||||
|
@ -50,6 +51,10 @@ if not distributed_initialized() and cfg.trainer.backend == "local":
|
|||
# A very naive engine implementation using barebones PyTorch
|
||||
class Engine():
|
||||
def __init__(self, *args, **kwargs):
|
||||
if '_cfg' in kwargs:
|
||||
self._cfg = kwargs['_cfg']
|
||||
kwargs.pop("_cfg")
|
||||
|
||||
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
@ -137,6 +142,10 @@ class Engine():
|
|||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
@cached_property
|
||||
def device(self):
|
||||
return next(self.module.parameters()).device
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module.forward(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -31,7 +31,11 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
|||
|
||||
class Engine(DeepSpeedEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
|
||||
if '_cfg' in kwargs:
|
||||
self._cfg = kwargs['_cfg']
|
||||
kwargs.pop("_cfg")
|
||||
|
||||
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
|
||||
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
||||
|
||||
super().__init__(None, *args, **kwargs)
|
||||
|
|
|
@ -8,10 +8,6 @@ from torch import Tensor
|
|||
from tqdm import trange
|
||||
|
||||
class AR(Base):
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
return cfg.models.ar.resp_levels
|
||||
|
||||
@property
|
||||
def causal(self):
|
||||
return True
|
||||
|
@ -32,6 +28,14 @@ class AR(Base):
|
|||
def n_prom_levels(self) -> int:
|
||||
return cfg.models.prom_levels
|
||||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
return cfg.models.ar.resp_levels
|
||||
|
||||
@property
|
||||
def n_max_levels(self) -> int:
|
||||
return cfg.models.max_levels
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
|
|
|
@ -113,6 +113,10 @@ class Base(nn.Module):
|
|||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_max_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
|
|
|
@ -7,10 +7,6 @@ from torch import Tensor
|
|||
from tqdm import trange
|
||||
|
||||
class NAR(Base):
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
return cfg.models.nar.resp_levels
|
||||
|
||||
@property
|
||||
def causal(self):
|
||||
return False
|
||||
|
@ -31,6 +27,14 @@ class NAR(Base):
|
|||
def n_prom_levels(self) -> int:
|
||||
return cfg.models.prom_levels
|
||||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
return cfg.models.nar.resp_levels
|
||||
|
||||
@property
|
||||
def n_max_levels(self) -> int:
|
||||
return cfg.models.max_levels
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
|
|
|
@ -23,7 +23,11 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def train_feeder(engine, batch):
|
||||
engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
|
||||
engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
||||
resps_list=batch["resps"]
|
||||
)
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
|
|
@ -88,19 +88,33 @@ def load_engines():
|
|||
|
||||
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
||||
n_prom_levels, n_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||
|
||||
# copy weights from the dict into the old portion
|
||||
model.proms_emb.weight.data[:n_prom_levels, :n_prom_tokens, :] = state['proms_emb.weight'].data[:n_prom_levels, :n_prom_tokens, :]
|
||||
model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :]
|
||||
# copy the full tensors back
|
||||
state['proms_emb.weight'] = model.proms_emb.weight
|
||||
|
||||
# extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||
if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]:
|
||||
o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape
|
||||
n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape
|
||||
|
||||
# copy weights from the dict into the old portion
|
||||
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
|
||||
# reuse additional levels, probably bad
|
||||
for n in range(o_resp_tokens, n_resp_tokens):
|
||||
model.resps_emb.weight.data[n] = model.resps_emb.weight.data[o_resp_tokens-1]
|
||||
# copy the full tensors back
|
||||
state['resps_emb.weight'] = model.resps_emb.weight
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
engines[name] = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
_cfg=model._cfg,
|
||||
)
|
||||
|
||||
engines = Engines(engines)
|
||||
|
|
Loading…
Reference in New Issue
Block a user