preparing for SpeechX extensions
This commit is contained in:
parent
ced31fd9b7
commit
2a71486cb6
|
@ -137,6 +137,8 @@ class Model:
|
||||||
name: str = ""
|
name: str = ""
|
||||||
size: str = "full"
|
size: str = "full"
|
||||||
resp_levels: int = 1
|
resp_levels: int = 1
|
||||||
|
prom_levels: int = 8
|
||||||
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "transformer"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -157,7 +159,7 @@ class Model:
|
||||||
if self.arch_type != "transformer":
|
if self.arch_type != "transformer":
|
||||||
name.append(self.arch_type.replace("/", "-"))
|
name.append(self.arch_type.replace("/", "-"))
|
||||||
|
|
||||||
name.append(f'{cfg.models.levels}')
|
name.append(f'{cfg.models.prom_levels}')
|
||||||
|
|
||||||
return "-".join(name)
|
return "-".join(name)
|
||||||
|
|
||||||
|
@ -192,8 +194,8 @@ class Model:
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Models:
|
class Models:
|
||||||
_models: list[Model] = field(default_factory=lambda: [
|
_models: list[Model] = field(default_factory=lambda: [
|
||||||
Model(name="ar", resp_levels=1),
|
Model(name="ar", resp_levels=1, prom_levels=8, tasks=1),
|
||||||
Model(name="nar", resp_levels=7),
|
Model(name="nar", resp_levels=7, prom_levels=8, tasks=1),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
@ -215,11 +217,19 @@ class Models:
|
||||||
return self.get("nar")
|
return self.get("nar")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def levels(self):
|
def prom_levels(self):
|
||||||
return self.prom_levels
|
prom_levels = 1
|
||||||
|
for model in self._models:
|
||||||
prom_levels: int = 8
|
prom_levels = max(prom_levels, model.prom_levels)
|
||||||
|
return prom_levels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tasks(self):
|
||||||
|
tasks = 1
|
||||||
|
for model in self._models:
|
||||||
|
tasks = max(tasks, model.tasks)
|
||||||
|
return tasks
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Hyperparameters:
|
class Hyperparameters:
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
@ -246,11 +256,9 @@ class Evaluation:
|
||||||
class DeepSpeed:
|
class DeepSpeed:
|
||||||
zero_optimization_level: int = 0
|
zero_optimization_level: int = 0
|
||||||
use_compression_training: bool = False
|
use_compression_training: bool = False
|
||||||
|
compression_bits: int = 8
|
||||||
|
|
||||||
def get_ds_cfg(self, model):
|
def get_ds_cfg(self, model):
|
||||||
weights = [ name[0] for name in model.named_parameters() ]
|
|
||||||
bits = 8
|
|
||||||
|
|
||||||
scheduler_params = {}
|
scheduler_params = {}
|
||||||
for k in cfg.hyperparameters.scheduler_params:
|
for k in cfg.hyperparameters.scheduler_params:
|
||||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||||
|
@ -298,30 +306,17 @@ class DeepSpeed:
|
||||||
"different_groups": {
|
"different_groups": {
|
||||||
"wq1": {
|
"wq1": {
|
||||||
"params": {
|
"params": {
|
||||||
"start_bits": bits,
|
"start_bits": self.compression_bits,
|
||||||
"target_bits": bits,
|
"target_bits": self.compression_bits,
|
||||||
"quantization_period": 0
|
"quantization_period": 0
|
||||||
},
|
},
|
||||||
"modules": weights
|
"modules": [
|
||||||
|
"blocks",
|
||||||
|
"retnet",
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"activation_quantization": {
|
|
||||||
"shared_parameters":{
|
|
||||||
"enabled": True,
|
|
||||||
"quantization_type": "symmetric",
|
|
||||||
"range_calibration": "dynamic",
|
|
||||||
"schedule_offset": 0
|
|
||||||
},
|
|
||||||
"different_groups": {
|
|
||||||
"aq1": {
|
|
||||||
"params": {
|
|
||||||
"bits": bits
|
|
||||||
},
|
|
||||||
"modules": weights
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} if self.use_compression_training else None,
|
} if self.use_compression_training else None,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": self.zero_optimization_level,
|
"stage": self.zero_optimization_level,
|
||||||
|
@ -467,6 +462,8 @@ try:
|
||||||
|
|
||||||
# cached_property stopped working...
|
# cached_property stopped working...
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
|
if cfg.distributed:
|
||||||
|
cfg.dataset.hdf5_flag = "r"
|
||||||
try:
|
try:
|
||||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -51,7 +51,7 @@ def _get_phone_path(path):
|
||||||
|
|
||||||
def _load_quants(path) -> Tensor:
|
def _load_quants(path) -> Tensor:
|
||||||
path = _get_quant_path(path)
|
path = _get_quant_path(path)
|
||||||
return torch.load(path)[0][:cfg.models.levels, :].t().to(torch.int16)
|
return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16)
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
|
@ -118,7 +118,6 @@ class Dataset(_Dataset):
|
||||||
max_duration=cfg.dataset.duration_range[1],
|
max_duration=cfg.dataset.duration_range[1],
|
||||||
training=False,
|
training=False,
|
||||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||||
sample_type=cfg.dataset.sample_type # path | speaker
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._head = None
|
self._head = None
|
||||||
|
@ -126,7 +125,7 @@ class Dataset(_Dataset):
|
||||||
self.max_phones = max_phones
|
self.max_phones = max_phones
|
||||||
self.min_duration = min_duration
|
self.min_duration = min_duration
|
||||||
self.max_duration = max_duration
|
self.max_duration = max_duration
|
||||||
self.sample_type = sample_type
|
self.sampler = None
|
||||||
|
|
||||||
if cfg.dataset.validate:
|
if cfg.dataset.validate:
|
||||||
self.paths = [
|
self.paths = [
|
||||||
|
@ -149,6 +148,9 @@ class Dataset(_Dataset):
|
||||||
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
|
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if cfg.dataset.sample_type == "path":
|
||||||
|
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
|
||||||
|
|
||||||
if len(self.paths) == 0 and training:
|
if len(self.paths) == 0 and training:
|
||||||
raise ValueError("No valid path is found for training.")
|
raise ValueError("No valid path is found for training.")
|
||||||
|
|
||||||
|
@ -227,7 +229,7 @@ class Dataset(_Dataset):
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path)
|
||||||
|
|
||||||
|
@ -255,7 +257,7 @@ class Dataset(_Dataset):
|
||||||
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
if cfg.dataset.sample_type == "speaker":
|
||||||
spkr_name = self.spkrs[index]
|
spkr_name = self.spkrs[index]
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||||
|
@ -267,7 +269,7 @@ class Dataset(_Dataset):
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
|
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
|
||||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16)
|
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||||
resps = _load_quants(path)
|
resps = _load_quants(path)
|
||||||
|
@ -321,11 +323,8 @@ class Dataset(_Dataset):
|
||||||
def training_(self, value):
|
def training_(self, value):
|
||||||
self.training = value
|
self.training = value
|
||||||
|
|
||||||
def interleaved_reorder_(self, fn):
|
|
||||||
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
if cfg.dataset.sample_type == "speaker":
|
||||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||||
return min(len(self.paths), self._head or len(self.paths))
|
return min(len(self.paths), self._head or len(self.paths))
|
||||||
|
|
||||||
|
@ -472,7 +471,6 @@ def create_datasets():
|
||||||
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataset.interleaved_reorder_(cfg.get_spkr)
|
|
||||||
val_dataset.head_(cfg.evaluation.size)
|
val_dataset.head_(cfg.evaluation.size)
|
||||||
|
|
||||||
return train_dataset, val_dataset
|
return train_dataset, val_dataset
|
||||||
|
@ -480,12 +478,10 @@ def create_datasets():
|
||||||
|
|
||||||
def create_train_val_dataloader():
|
def create_train_val_dataloader():
|
||||||
train_dataset, val_dataset = create_datasets()
|
train_dataset, val_dataset = create_datasets()
|
||||||
train_dataset.sample_type = cfg.dataset.sample_type #"speaker"
|
|
||||||
|
|
||||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||||
if subtrain_dataset.sample_type == "path":
|
if cfg.dataset.sample_type == "path":
|
||||||
subtrain_dataset.head_(cfg.evaluation.size)
|
subtrain_dataset.head_(cfg.evaluation.size)
|
||||||
subtrain_dataset.interleaved_reorder_(cfg.get_spkr)
|
|
||||||
|
|
||||||
train_dl = _create_dataloader(train_dataset, training=True)
|
train_dl = _create_dataloader(train_dataset, training=True)
|
||||||
val_dl = _create_dataloader(val_dataset, training=False)
|
val_dl = _create_dataloader(val_dataset, training=False)
|
||||||
|
|
|
@ -32,6 +32,10 @@ class AR(Base):
|
||||||
def n_prom_levels(self) -> int:
|
def n_prom_levels(self) -> int:
|
||||||
return cfg.models.prom_levels
|
return cfg.models.prom_levels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_tasks(self) -> int:
|
||||||
|
return cfg.models.tasks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resp_loss_only(self):
|
def resp_loss_only(self):
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -113,6 +113,10 @@ class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
def n_prom_levels(self) -> int:
|
def n_prom_levels(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_tasks(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resp_loss_only(self):
|
def resp_loss_only(self):
|
||||||
|
@ -120,7 +124,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_tokens: int,
|
n_tokens: int = 1024,
|
||||||
d_model: int = 512,
|
d_model: int = 512,
|
||||||
n_heads: int = 8,
|
n_heads: int = 8,
|
||||||
n_layers: int = 12,
|
n_layers: int = 12,
|
||||||
|
@ -132,16 +136,12 @@ class Base(nn.Module):
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
|
|
||||||
causal = self.causal
|
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
n_stop_tokens = 1 if self.use_stop_token else 0
|
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
|
||||||
n_resp_tokens = n_tokens + n_stop_tokens
|
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop
|
||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_tokens, d_model)
|
||||||
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
# Here I simply use all prom levels
|
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_tokens, d_model)
|
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
@ -152,14 +152,13 @@ class Base(nn.Module):
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_heads=n_heads,
|
n_heads=n_heads,
|
||||||
p_dropout=p_dropout,
|
p_dropout=p_dropout,
|
||||||
causal=causal,
|
causal=self.causal,
|
||||||
norm_type=self.norm_type,
|
norm_type=self.norm_type,
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
#tention="retention" if self.use_retnet else "attention"
|
|
||||||
) for _ in range(n_layers) ])
|
) for _ in range(n_layers) ])
|
||||||
|
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
self.retnet_config = RetNetConfig(
|
self.retnet = RetNetDecoder(RetNetConfig(
|
||||||
vocab_size=n_tokens,
|
vocab_size=n_tokens,
|
||||||
decoder_embed_dim=d_model,
|
decoder_embed_dim=d_model,
|
||||||
decoder_retention_heads=n_heads,
|
decoder_retention_heads=n_heads,
|
||||||
|
@ -169,13 +168,10 @@ class Base(nn.Module):
|
||||||
checkpoint_activations=True,
|
checkpoint_activations=True,
|
||||||
|
|
||||||
chunkwise_recurrent=self.causal,
|
chunkwise_recurrent=self.causal,
|
||||||
recurrent_chunkwise_size=128,
|
recurrent_chunkwise_size=64,
|
||||||
no_output_layer=True,
|
no_output_layer=True,
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
)
|
))
|
||||||
self.retnet = RetNetDecoder(
|
|
||||||
self.retnet_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
|
@ -281,13 +277,14 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
device = x.device
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
x = self.sin_emb.add_pe(x)
|
x = self.sin_emb.add_pe(x)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, m, quant_levels)
|
x = block(x, m, quant_levels)
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
|
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
|
||||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
state = self.retnet.get_incremental_state( state, 'prev_state' )
|
state = self.retnet.get_incremental_state( state, 'prev_state' )
|
||||||
|
|
||||||
|
@ -301,7 +298,7 @@ class Base(nn.Module):
|
||||||
if any([l == 0 for l in map(len, targ_list)]):
|
if any([l == 0 for l in map(len, targ_list)]):
|
||||||
raise ValueError("Cannot compute loss given empty targ_list.")
|
raise ValueError("Cannot compute loss given empty targ_list.")
|
||||||
|
|
||||||
ignore_sep = torch.tensor(self.ignore_index, device=x.device)
|
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
||||||
|
|
||||||
# ignore the prompt when computing loss
|
# ignore the prompt when computing loss
|
||||||
prom_list = [
|
prom_list = [
|
||||||
|
@ -348,11 +345,6 @@ class Base(nn.Module):
|
||||||
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
|
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||||
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
|
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||||
)
|
)
|
||||||
del targ_list
|
|
||||||
del prom_list
|
|
||||||
del text_prom_list
|
|
||||||
del y_list
|
|
||||||
|
|
||||||
|
|
||||||
# return the entire generated token string
|
# return the entire generated token string
|
||||||
if return_all:
|
if return_all:
|
||||||
|
@ -366,9 +358,6 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
logits = torch.stack([hi[-1] for hi in h_list])
|
logits = torch.stack([hi[-1] for hi in h_list])
|
||||||
ret = Categorical(logits=logits / sampling_temperature).sample()
|
ret = Categorical(logits=logits / sampling_temperature).sample()
|
||||||
|
|
||||||
del x_list
|
|
||||||
del h_list
|
|
||||||
|
|
||||||
return ret, state
|
return ret, state
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,10 @@ class NAR(Base):
|
||||||
def n_prom_levels(self) -> int:
|
def n_prom_levels(self) -> int:
|
||||||
return cfg.models.prom_levels
|
return cfg.models.prom_levels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_tasks(self) -> int:
|
||||||
|
return cfg.models.tasks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resp_loss_only(self):
|
def resp_loss_only(self):
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -87,8 +87,8 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
# pseudo loss calculation since we don't get the logits during eval
|
# pseudo loss calculation since we don't get the logits during eval
|
||||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||||
ref_audio = ref_audio[..., 0:min_length]
|
ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length]
|
||||||
hyp_audio = hyp_audio[..., 0:min_length]
|
hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length]
|
||||||
try:
|
try:
|
||||||
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -141,7 +141,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
iteration = engines.global_step
|
iteration = engines.global_step
|
||||||
engines_stats['it'] = iteration
|
engines_stats['it'] = iteration
|
||||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||||
|
|
||||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
|
||||||
|
|
2
vall_e/utils/sampler.py
Normal file
2
vall_e/utils/sampler.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
class Sampler():
|
||||||
|
...
|
|
@ -81,9 +81,21 @@ def load_engines():
|
||||||
if cfg.trainer.load_state_dict:
|
if cfg.trainer.load_state_dict:
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
state = torch.load(load_path)
|
state = torch.load(load_path)
|
||||||
|
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
||||||
|
# exporting with vall_e.export exports the state dict under .module
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state["module"]
|
state = state["module"]
|
||||||
model.load_state_dict(state)
|
|
||||||
|
print(model.proms_emb.weight.shape, state['proms_emb.weight'].shape)
|
||||||
|
|
||||||
|
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||||
|
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
||||||
|
n_prom_levels, n_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||||
|
|
||||||
|
model.proms_emb.weight.data[:n_prom_levels, :n_prom_tokens, :] = state['proms_emb.weight'].data[:n_prom_levels, :n_prom_tokens, :]
|
||||||
|
state['proms_emb.weight'] = model.proms_emb.weight
|
||||||
|
|
||||||
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
engines[name] = Engine(
|
engines[name] = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user