diff --git a/.gitignore b/.gitignore index dae9bf1..d806633 100755 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__ /venv /*.egg-info /vall_e/version.py -/.cache \ No newline at end of file +/.cache +/voices diff --git a/vall_e/config.py b/vall_e/config.py index 629d9fd..90bf6b8 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -213,7 +213,7 @@ class Model: attention: str = "auto" audio_embedding_sums: bool = True dropout: float = 0.1 # adjustable dropout value - loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.1, "resp": 1.0 }) + loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) def get(self, name=None): return [ self ] if not name or self.name == name else [] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 4d83976..f02db55 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -845,10 +845,10 @@ class Base(nn.Module): quant_levels: Tensor | None = None ): x_list = [] - for b_i in range(len(inputs)): + for batch_index, batch_input in enumerate(inputs): batch = [] - for i in range(len(inputs[b_i])): - name, input = inputs[b_i][i] + quant_level = quant_levels[batch_index] if quant_levels is not None else None + for name, input in batch_input: embedding = None if name == "text": embedding = self.text_emb( input ) @@ -859,7 +859,7 @@ class Base(nn.Module): elif name == "tone": embedding = self.tones_emb( input ) elif name == "resp": - embedding = self.resps_emb( input, quant_levels[b_i] if quant_levels is not None else None ) + embedding = self.resps_emb( input, quant_level ) else: continue @@ -869,61 +869,101 @@ class Base(nn.Module): return x_list - def training_targets( + def calc_loss( self, inputs: list, + logits, + + quant_levels: Tensor | None = None, ): - x_list = [] - for bi in range(len(inputs)): - batch = [] - for i in range(len(inputs[bi])): - name, input = inputs[bi][i] - device = input.device + # old, "naive" way, no loss factoring + if not self.config.loss_factors: + target_list = [] + for batch in inputs: + target = [] + for name, input in batch: + if name == "prom": + target.append( torch.full_like(input[..., 0], self.ignore_index) ) + elif name in ["text", "lang", "tone", "targ"]: + target.append( input ) - if name == "prom": - batch.append( torch.full_like(input[..., 0], self.ignore_index) ) - elif name in ["text", "lang", "tone", "targ"]: - batch.append( input ) + target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) ) - x_list.append( _join( batch, torch.tensor(self.ignore_index, device=device) ) ) + # modify only for the AR so it can properly behave like a transformer + for i in range(len(target_list)): + if quant_levels is not None and quant_levels[i] > 0: + continue - return x_list + logits[i] = logits[i][..., :-1, :] # shift the target so that token n... + target_list[i] = target_list[i][..., 1:] # predicts token n + 1 - def training_targets_split( - self, - inputs: list, - quant_levels: Tensor | None = None - ): - text_lists = [] - prom_lists = [] - resp_lists = [] + target = torch.cat( target_list ) + inputs = torch.cat( logits ) - for bi in range(len(inputs)): - text_batch = [] - prom_batch = [] - resp_batch = [] + self.loss = dict( + # "nll" was in the original implementation and should actually just be called something else + nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) + ) + self.stats = dict( + acc = self.accuracy_metric( inputs, target ), + # precision = self.precision_metric( inputs, target ), + ) + return - for i in range(len(inputs[bi])): - name, input = inputs[bi][i] - device = input.device + self.loss = dict() + self.stats = dict(acc = dict()) - quant_level = quant_levels[bi] if quant_levels is not None else None + info = {} + for i, batch in enumerate( inputs ): + quant_level = quant_levels[i] if quant_levels is not None else None - if name == "text": - text_batch.append( input ) - elif name == "prom": - prom_batch.append( input[:, quant_level] if quant_level is not None else input ) - elif name == "targ": - resp_batch.append( input ) + it = 0 + for name, input in batch: + # do not use resp + if name == "resp": + continue + # rename to resp + if name == "targ": + name = "resp" + # select prom level + elif name == "prom" and quant_level is not None: + input = input[:, quant_level] - if text_batch: - text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) ) - if prom_batch: - prom_lists.append( _join( prom_batch, torch.tensor(self.ignore_index, device=device) ) ) - if resp_batch: - resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) ) + seq_len = input.shape[0] + logit = logits[i][it:it+seq_len] + it += seq_len + 1 # +1 to incorporate the separator + + # for the AR, shift sequence so that it predicts the next token + if quant_level is None or quant_level == 0: + logit = logit[..., :-1, :] # get all but the final logit + input = input[..., 1:] # shift sequence to the right by one - return text_lists, prom_lists, resp_lists + if name not in info: + info[name] = { + "targets": [], + "logits": [], + } + + info[name]["targets"].append( input ) + info[name]["logits"].append( logit ) + + for name, batch in info.items(): + loss_factor = self.loss_factor(name) + if loss_factor == 0.0: + continue + + targets = torch.cat( batch["targets"] ).long() + inputs = torch.cat( batch["logits"] ) + + self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor + self.stats["acc"][name] = self.accuracy_metric( inputs, targets ) + + # to-do: compute loss per individual batch to scale per RVQ level + """ + rvq_loss_factor = self.loss_factor("quant") + if isinstance( rvq_loss_factor, list ): + ... + """ def forward( self, @@ -974,93 +1014,7 @@ class Base(nn.Module): # compute loss if the target is given if training: - if not self.config.loss_factors: - target_list = self.training_targets( inputs ) - - # modify only for the AR so it can properly behave like a transformer - for i in range(len(target_list)): - if quant_levels is not None and quant_levels[i] > 0: - continue - - logits[i] = logits[i][..., :-1, :] # shift the target so that token n... - target_list[i] = target_list[i][..., 1:] # predicts token n + 1 - - target = torch.cat( target_list ) - inputs = torch.cat( logits ) - - self.loss = dict( - # "nll" was in the original implementation and should actually just be called something else - nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) - ) - self.stats = dict( - acc = self.accuracy_metric( inputs, target ), - # precision = self.precision_metric( inputs, target ), - ) - # split our loss - # to-do: clean this up - else: - target_text_list, target_prom_list, target_resp_list = self.training_targets_split( inputs, quant_levels ) - - logits_text = [] - logits_prom = [] - logits_resp = [] - - # trim logits to each section - for i, logit in enumerate(logits): - text_len = target_text_list[i].shape[0] - prom_len = target_prom_list[i].shape[0] - resp_len = target_resp_list[i].shape[0] - - logits_text.append( logit[:text_len] ) - logits_prom.append( logit[text_len+1:text_len+1+prom_len] ) # + 1 to include separator - logits_resp.append( logit[-resp_len:] ) - - - # modify only for the AR so it can properly behave like a transformer - for i in range(len(target_text_list)): - if quant_levels is not None and quant_levels[i] > 0: - continue - - # shift the target so that token n... - logits_text[i] = logits_text[i][..., :-1, :] - logits_prom[i] = logits_prom[i][..., :-1, :] - logits_resp[i] = logits_resp[i][..., :-1, :] - - # predicts token n + 1 - target_text_list[i] = target_text_list[i][..., 1:] - target_prom_list[i] = target_prom_list[i][..., 1:] - target_resp_list[i] = target_resp_list[i][..., 1:] - - self.loss = dict() - self.stats = dict(acc = dict()) - - loss_factor_text = self.loss_factor("text") - if loss_factor_text > 0.0 and target_text_list: - target_text = torch.cat( target_text_list ).long() - inputs_text = torch.cat( logits_text ) - self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) * loss_factor_text - self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text ) - - loss_factor_prom = self.loss_factor("prom") - if loss_factor_prom > 0.0 and target_prom_list: - target_prom = torch.cat( target_prom_list ).long() - inputs_prom = torch.cat( logits_prom ) - self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) * loss_factor_prom - self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom ) - - loss_factor_resp = self.loss_factor("resp") - if loss_factor_resp > 0.0 and target_resp_list: - target_resp = torch.cat( target_resp_list ).long() - inputs_resp = torch.cat( logits_resp ) - self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) * loss_factor_resp - self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp ) - - # to-do: compute loss per individual batch to scale per RVQ level - """ - rvq_loss_factor = self.loss_factor("quant") - if isinstance( rvq_loss_factor, list ): - ... - """ + self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) # include any additional losses (for example: MoE router) if aux_loss is not None: