better way to compute per-segment losses
This commit is contained in:
parent
6c49ad06a3
commit
da473295b7
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -4,4 +4,5 @@ __pycache__
|
||||||
/venv
|
/venv
|
||||||
/*.egg-info
|
/*.egg-info
|
||||||
/vall_e/version.py
|
/vall_e/version.py
|
||||||
/.cache
|
/.cache
|
||||||
|
/voices
|
||||||
|
|
|
@ -213,7 +213,7 @@ class Model:
|
||||||
attention: str = "auto"
|
attention: str = "auto"
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.1, "resp": 1.0 })
|
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
return [ self ] if not name or self.name == name else []
|
return [ self ] if not name or self.name == name else []
|
||||||
|
|
|
@ -845,10 +845,10 @@ class Base(nn.Module):
|
||||||
quant_levels: Tensor | None = None
|
quant_levels: Tensor | None = None
|
||||||
):
|
):
|
||||||
x_list = []
|
x_list = []
|
||||||
for b_i in range(len(inputs)):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = []
|
batch = []
|
||||||
for i in range(len(inputs[b_i])):
|
quant_level = quant_levels[batch_index] if quant_levels is not None else None
|
||||||
name, input = inputs[b_i][i]
|
for name, input in batch_input:
|
||||||
embedding = None
|
embedding = None
|
||||||
if name == "text":
|
if name == "text":
|
||||||
embedding = self.text_emb( input )
|
embedding = self.text_emb( input )
|
||||||
|
@ -859,7 +859,7 @@ class Base(nn.Module):
|
||||||
elif name == "tone":
|
elif name == "tone":
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
embedding = self.resps_emb( input, quant_levels[b_i] if quant_levels is not None else None )
|
embedding = self.resps_emb( input, quant_level )
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -869,61 +869,101 @@ class Base(nn.Module):
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
|
|
||||||
def training_targets(
|
def calc_loss(
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
|
logits,
|
||||||
|
|
||||||
|
quant_levels: Tensor | None = None,
|
||||||
):
|
):
|
||||||
x_list = []
|
# old, "naive" way, no loss factoring
|
||||||
for bi in range(len(inputs)):
|
if not self.config.loss_factors:
|
||||||
batch = []
|
target_list = []
|
||||||
for i in range(len(inputs[bi])):
|
for batch in inputs:
|
||||||
name, input = inputs[bi][i]
|
target = []
|
||||||
device = input.device
|
for name, input in batch:
|
||||||
|
if name == "prom":
|
||||||
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
|
elif name in ["text", "lang", "tone", "targ"]:
|
||||||
|
target.append( input )
|
||||||
|
|
||||||
if name == "prom":
|
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||||
batch.append( torch.full_like(input[..., 0], self.ignore_index) )
|
|
||||||
elif name in ["text", "lang", "tone", "targ"]:
|
|
||||||
batch.append( input )
|
|
||||||
|
|
||||||
x_list.append( _join( batch, torch.tensor(self.ignore_index, device=device) ) )
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
|
for i in range(len(target_list)):
|
||||||
|
if quant_levels is not None and quant_levels[i] > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
return x_list
|
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
||||||
|
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
||||||
|
|
||||||
def training_targets_split(
|
target = torch.cat( target_list )
|
||||||
self,
|
inputs = torch.cat( logits )
|
||||||
inputs: list,
|
|
||||||
quant_levels: Tensor | None = None
|
|
||||||
):
|
|
||||||
text_lists = []
|
|
||||||
prom_lists = []
|
|
||||||
resp_lists = []
|
|
||||||
|
|
||||||
for bi in range(len(inputs)):
|
self.loss = dict(
|
||||||
text_batch = []
|
# "nll" was in the original implementation and should actually just be called something else
|
||||||
prom_batch = []
|
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||||
resp_batch = []
|
)
|
||||||
|
self.stats = dict(
|
||||||
|
acc = self.accuracy_metric( inputs, target ),
|
||||||
|
# precision = self.precision_metric( inputs, target ),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
for i in range(len(inputs[bi])):
|
self.loss = dict()
|
||||||
name, input = inputs[bi][i]
|
self.stats = dict(acc = dict())
|
||||||
device = input.device
|
|
||||||
|
|
||||||
quant_level = quant_levels[bi] if quant_levels is not None else None
|
info = {}
|
||||||
|
for i, batch in enumerate( inputs ):
|
||||||
|
quant_level = quant_levels[i] if quant_levels is not None else None
|
||||||
|
|
||||||
if name == "text":
|
it = 0
|
||||||
text_batch.append( input )
|
for name, input in batch:
|
||||||
elif name == "prom":
|
# do not use resp
|
||||||
prom_batch.append( input[:, quant_level] if quant_level is not None else input )
|
if name == "resp":
|
||||||
elif name == "targ":
|
continue
|
||||||
resp_batch.append( input )
|
# rename to resp
|
||||||
|
if name == "targ":
|
||||||
|
name = "resp"
|
||||||
|
# select prom level
|
||||||
|
elif name == "prom" and quant_level is not None:
|
||||||
|
input = input[:, quant_level]
|
||||||
|
|
||||||
if text_batch:
|
seq_len = input.shape[0]
|
||||||
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
|
logit = logits[i][it:it+seq_len]
|
||||||
if prom_batch:
|
it += seq_len + 1 # +1 to incorporate the separator
|
||||||
prom_lists.append( _join( prom_batch, torch.tensor(self.ignore_index, device=device) ) )
|
|
||||||
if resp_batch:
|
# for the AR, shift sequence so that it predicts the next token
|
||||||
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
|
if quant_level is None or quant_level == 0:
|
||||||
|
logit = logit[..., :-1, :] # get all but the final logit
|
||||||
|
input = input[..., 1:] # shift sequence to the right by one
|
||||||
|
|
||||||
return text_lists, prom_lists, resp_lists
|
if name not in info:
|
||||||
|
info[name] = {
|
||||||
|
"targets": [],
|
||||||
|
"logits": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
info[name]["targets"].append( input )
|
||||||
|
info[name]["logits"].append( logit )
|
||||||
|
|
||||||
|
for name, batch in info.items():
|
||||||
|
loss_factor = self.loss_factor(name)
|
||||||
|
if loss_factor == 0.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
targets = torch.cat( batch["targets"] ).long()
|
||||||
|
inputs = torch.cat( batch["logits"] )
|
||||||
|
|
||||||
|
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||||
|
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||||
|
|
||||||
|
# to-do: compute loss per individual batch to scale per RVQ level
|
||||||
|
"""
|
||||||
|
rvq_loss_factor = self.loss_factor("quant")
|
||||||
|
if isinstance( rvq_loss_factor, list ):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -974,93 +1014,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if training:
|
if training:
|
||||||
if not self.config.loss_factors:
|
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||||
target_list = self.training_targets( inputs )
|
|
||||||
|
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
|
||||||
for i in range(len(target_list)):
|
|
||||||
if quant_levels is not None and quant_levels[i] > 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
|
||||||
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
|
||||||
|
|
||||||
target = torch.cat( target_list )
|
|
||||||
inputs = torch.cat( logits )
|
|
||||||
|
|
||||||
self.loss = dict(
|
|
||||||
# "nll" was in the original implementation and should actually just be called something else
|
|
||||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
|
||||||
)
|
|
||||||
self.stats = dict(
|
|
||||||
acc = self.accuracy_metric( inputs, target ),
|
|
||||||
# precision = self.precision_metric( inputs, target ),
|
|
||||||
)
|
|
||||||
# split our loss
|
|
||||||
# to-do: clean this up
|
|
||||||
else:
|
|
||||||
target_text_list, target_prom_list, target_resp_list = self.training_targets_split( inputs, quant_levels )
|
|
||||||
|
|
||||||
logits_text = []
|
|
||||||
logits_prom = []
|
|
||||||
logits_resp = []
|
|
||||||
|
|
||||||
# trim logits to each section
|
|
||||||
for i, logit in enumerate(logits):
|
|
||||||
text_len = target_text_list[i].shape[0]
|
|
||||||
prom_len = target_prom_list[i].shape[0]
|
|
||||||
resp_len = target_resp_list[i].shape[0]
|
|
||||||
|
|
||||||
logits_text.append( logit[:text_len] )
|
|
||||||
logits_prom.append( logit[text_len+1:text_len+1+prom_len] ) # + 1 to include separator
|
|
||||||
logits_resp.append( logit[-resp_len:] )
|
|
||||||
|
|
||||||
|
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
|
||||||
for i in range(len(target_text_list)):
|
|
||||||
if quant_levels is not None and quant_levels[i] > 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# shift the target so that token n...
|
|
||||||
logits_text[i] = logits_text[i][..., :-1, :]
|
|
||||||
logits_prom[i] = logits_prom[i][..., :-1, :]
|
|
||||||
logits_resp[i] = logits_resp[i][..., :-1, :]
|
|
||||||
|
|
||||||
# predicts token n + 1
|
|
||||||
target_text_list[i] = target_text_list[i][..., 1:]
|
|
||||||
target_prom_list[i] = target_prom_list[i][..., 1:]
|
|
||||||
target_resp_list[i] = target_resp_list[i][..., 1:]
|
|
||||||
|
|
||||||
self.loss = dict()
|
|
||||||
self.stats = dict(acc = dict())
|
|
||||||
|
|
||||||
loss_factor_text = self.loss_factor("text")
|
|
||||||
if loss_factor_text > 0.0 and target_text_list:
|
|
||||||
target_text = torch.cat( target_text_list ).long()
|
|
||||||
inputs_text = torch.cat( logits_text )
|
|
||||||
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) * loss_factor_text
|
|
||||||
self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text )
|
|
||||||
|
|
||||||
loss_factor_prom = self.loss_factor("prom")
|
|
||||||
if loss_factor_prom > 0.0 and target_prom_list:
|
|
||||||
target_prom = torch.cat( target_prom_list ).long()
|
|
||||||
inputs_prom = torch.cat( logits_prom )
|
|
||||||
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) * loss_factor_prom
|
|
||||||
self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom )
|
|
||||||
|
|
||||||
loss_factor_resp = self.loss_factor("resp")
|
|
||||||
if loss_factor_resp > 0.0 and target_resp_list:
|
|
||||||
target_resp = torch.cat( target_resp_list ).long()
|
|
||||||
inputs_resp = torch.cat( logits_resp )
|
|
||||||
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) * loss_factor_resp
|
|
||||||
self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp )
|
|
||||||
|
|
||||||
# to-do: compute loss per individual batch to scale per RVQ level
|
|
||||||
"""
|
|
||||||
rvq_loss_factor = self.loss_factor("quant")
|
|
||||||
if isinstance( rvq_loss_factor, list ):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if aux_loss is not None:
|
if aux_loss is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user