preparing for SpeechX extensions

This commit is contained in:
mrq 2023-08-18 20:58:07 -05:00
parent ced31fd9b7
commit 2a71486cb6
8 changed files with 77 additions and 73 deletions

View File

@ -137,6 +137,8 @@ class Model:
name: str = ""
size: str = "full"
resp_levels: int = 1
prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
arch_type: str = "transformer"
@property
@ -157,7 +159,7 @@ class Model:
if self.arch_type != "transformer":
name.append(self.arch_type.replace("/", "-"))
name.append(f'{cfg.models.levels}')
name.append(f'{cfg.models.prom_levels}')
return "-".join(name)
@ -192,8 +194,8 @@ class Model:
@dataclass()
class Models:
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1),
Model(name="nar", resp_levels=7),
Model(name="ar", resp_levels=1, prom_levels=8, tasks=1),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=1),
])
def get(self, name=None):
@ -215,11 +217,19 @@ class Models:
return self.get("nar")
@property
def levels(self):
return self.prom_levels
prom_levels: int = 8
def prom_levels(self):
prom_levels = 1
for model in self._models:
prom_levels = max(prom_levels, model.prom_levels)
return prom_levels
@property
def tasks(self):
tasks = 1
for model in self._models:
tasks = max(tasks, model.tasks)
return tasks
@dataclass()
class Hyperparameters:
batch_size: int = 8
@ -246,11 +256,9 @@ class Evaluation:
class DeepSpeed:
zero_optimization_level: int = 0
use_compression_training: bool = False
compression_bits: int = 8
def get_ds_cfg(self, model):
weights = [ name[0] for name in model.named_parameters() ]
bits = 8
scheduler_params = {}
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
@ -298,30 +306,17 @@ class DeepSpeed:
"different_groups": {
"wq1": {
"params": {
"start_bits": bits,
"target_bits": bits,
"start_bits": self.compression_bits,
"target_bits": self.compression_bits,
"quantization_period": 0
},
"modules": weights
"modules": [
"blocks",
"retnet",
]
}
}
},
"activation_quantization": {
"shared_parameters":{
"enabled": True,
"quantization_type": "symmetric",
"range_calibration": "dynamic",
"schedule_offset": 0
},
"different_groups": {
"aq1": {
"params": {
"bits": bits
},
"modules": weights
}
}
}
} if self.use_compression_training else None,
"zero_optimization": {
"stage": self.zero_optimization_level,
@ -467,6 +462,8 @@ try:
# cached_property stopped working...
if cfg.dataset.use_hdf5:
if cfg.distributed:
cfg.dataset.hdf5_flag = "r"
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:

View File

@ -51,7 +51,7 @@ def _get_phone_path(path):
def _load_quants(path) -> Tensor:
path = _get_quant_path(path)
return torch.load(path)[0][:cfg.models.levels, :].t().to(torch.int16)
return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16)
@cache
@ -118,7 +118,6 @@ class Dataset(_Dataset):
max_duration=cfg.dataset.duration_range[1],
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
sample_type=cfg.dataset.sample_type # path | speaker
):
super().__init__()
self._head = None
@ -126,7 +125,7 @@ class Dataset(_Dataset):
self.max_phones = max_phones
self.min_duration = min_duration
self.max_duration = max_duration
self.sample_type = sample_type
self.sampler = None
if cfg.dataset.validate:
self.paths = [
@ -149,6 +148,9 @@ class Dataset(_Dataset):
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
]
if cfg.dataset.sample_type == "path":
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
if len(self.paths) == 0 and training:
raise ValueError("No valid path is found for training.")
@ -227,7 +229,7 @@ class Dataset(_Dataset):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
qnt = _load_quants(path)
@ -255,7 +257,7 @@ class Dataset(_Dataset):
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
def __getitem__(self, index):
if hasattr(self, "sample_type") and self.sample_type == "speaker":
if cfg.dataset.sample_type == "speaker":
spkr_name = self.spkrs[index]
spkr_id = self.spkr_symmap[spkr_name]
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
@ -267,7 +269,7 @@ class Dataset(_Dataset):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
resps = _load_quants(path)
@ -321,11 +323,8 @@ class Dataset(_Dataset):
def training_(self, value):
self.training = value
def interleaved_reorder_(self, fn):
self.paths = [*_interleaved_reorder(self.paths, fn)]
def __len__(self):
if hasattr(self, "sample_type") and self.sample_type == "speaker":
if cfg.dataset.sample_type == "speaker":
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.paths), self._head or len(self.paths))
@ -472,7 +471,6 @@ def create_datasets():
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
val_dataset.interleaved_reorder_(cfg.get_spkr)
val_dataset.head_(cfg.evaluation.size)
return train_dataset, val_dataset
@ -480,12 +478,10 @@ def create_datasets():
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
train_dataset.sample_type = cfg.dataset.sample_type #"speaker"
subtrain_dataset = copy.deepcopy(train_dataset)
if subtrain_dataset.sample_type == "path":
if cfg.dataset.sample_type == "path":
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.interleaved_reorder_(cfg.get_spkr)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)

View File

@ -32,6 +32,10 @@ class AR(Base):
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_tasks(self) -> int:
return cfg.models.tasks
@property
def resp_loss_only(self):
return False

View File

@ -113,6 +113,10 @@ class Base(nn.Module):
@property
def n_prom_levels(self) -> int:
raise NotImplementedError
@property
def n_tasks(self) -> int:
raise NotImplementedError
@property
def resp_loss_only(self):
@ -120,7 +124,7 @@ class Base(nn.Module):
def __init__(
self,
n_tokens: int,
n_tokens: int = 1024,
d_model: int = 512,
n_heads: int = 8,
n_layers: int = 12,
@ -132,16 +136,12 @@ class Base(nn.Module):
self.n_heads = n_heads
self.n_layers = n_layers
causal = self.causal
# +1 to include the stop token
n_stop_tokens = 1 if self.use_stop_token else 0
n_resp_tokens = n_tokens + n_stop_tokens
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop
self.text_emb = Embedding(n_tokens, d_model)
# Here I simply use all prom levels
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
@ -152,14 +152,13 @@ class Base(nn.Module):
d_model=d_model,
n_heads=n_heads,
p_dropout=p_dropout,
causal=causal,
causal=self.causal,
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
#tention="retention" if self.use_retnet else "attention"
) for _ in range(n_layers) ])
elif self.arch_type == "retnet":
self.retnet_config = RetNetConfig(
self.retnet = RetNetDecoder(RetNetConfig(
vocab_size=n_tokens,
decoder_embed_dim=d_model,
decoder_retention_heads=n_heads,
@ -169,13 +168,10 @@ class Base(nn.Module):
checkpoint_activations=True,
chunkwise_recurrent=self.causal,
recurrent_chunkwise_size=128,
recurrent_chunkwise_size=64,
no_output_layer=True,
decoder_normalize_before=True,
)
self.retnet = RetNetDecoder(
self.retnet_config
)
))
self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -281,13 +277,14 @@ class Base(nn.Module):
)
x, m = list_to_tensor(x_list)
device = x.device
if self.arch_type == "transformer":
x = self.sin_emb.add_pe(x)
for block in self.blocks:
x = block(x, m, quant_levels)
elif self.arch_type == "retnet":
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
state = self.retnet.get_incremental_state( state, 'prev_state' )
@ -301,7 +298,7 @@ class Base(nn.Module):
if any([l == 0 for l in map(len, targ_list)]):
raise ValueError("Cannot compute loss given empty targ_list.")
ignore_sep = torch.tensor(self.ignore_index, device=x.device)
ignore_sep = torch.tensor(self.ignore_index, device=device)
# ignore the prompt when computing loss
prom_list = [
@ -348,11 +345,6 @@ class Base(nn.Module):
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
)
del targ_list
del prom_list
del text_prom_list
del y_list
# return the entire generated token string
if return_all:
@ -366,9 +358,6 @@ class Base(nn.Module):
else:
logits = torch.stack([hi[-1] for hi in h_list])
ret = Categorical(logits=logits / sampling_temperature).sample()
del x_list
del h_list
return ret, state

View File

@ -31,6 +31,10 @@ class NAR(Base):
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_tasks(self) -> int:
return cfg.models.tasks
@property
def resp_loss_only(self):
return True

View File

@ -87,8 +87,8 @@ def run_eval(engines, eval_name, dl):
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length]
hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length]
try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
except Exception as e:
@ -141,7 +141,7 @@ def run_eval(engines, eval_name, dl):
iteration = engines.global_step
engines_stats['it'] = iteration
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")

2
vall_e/utils/sampler.py Normal file
View File

@ -0,0 +1,2 @@
class Sampler():
...

View File

@ -81,9 +81,21 @@ def load_engines():
if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path)
# exporting the model from the zero_to_fp32.py exports the actual module's dict
# exporting with vall_e.export exports the state dict under .module
if "module" in state:
state = state["module"]
model.load_state_dict(state)
print(model.proms_emb.weight.shape, state['proms_emb.weight'].shape)
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
n_prom_levels, n_prom_tokens, d_model = state['proms_emb.weight'].shape
model.proms_emb.weight.data[:n_prom_levels, :n_prom_tokens, :] = state['proms_emb.weight'].data[:n_prom_levels, :n_prom_tokens, :]
state['proms_emb.weight'] = model.proms_emb.weight
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
engines[name] = Engine(
model=model,