added loss calcs against prom (requires the right settings for not shit results, disabled by default)

This commit is contained in:
mrq 2024-05-27 08:43:00 -05:00
parent 05cd8b797e
commit 5af6f41c94
2 changed files with 63 additions and 32 deletions

View File

@ -213,7 +213,7 @@ class Model:
attention: str = "auto"
audio_embedding_sums: bool = True
dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: {})
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.1, "resp": 1.0 })
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@ -495,11 +495,6 @@ class DeepSpeed:
return ds_cfg
@dataclass()
class LossFactor:
text: float = 1.0
resp: float = 1.0
@dataclass()
class Trainer:
iterations: int = 100_000

View File

@ -888,27 +888,38 @@ class Base(nn.Module):
def training_targets_split(
self,
inputs: list,
quant_levels: Tensor | None = None
):
text_lists = []
prom_lists = []
resp_lists = []
for bi in range(len(inputs)):
text_batch = []
prom_batch = []
resp_batch = []
for i in range(len(inputs[bi])):
name, input = inputs[bi][i]
device = input.device
if name in ["text", "lang" ]:
quant_level = quant_levels[bi] if quant_levels is not None else None
if name in ["text" ]:
text_batch.append( input )
elif name == "prom" and (quant_level is None or quant_level == 0 or not self.config.audio_embedding_sums):
prom_batch.append( input[:, quant_level] if quant_level is not None else input )
elif name == "targ":
resp_batch.append( input )
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
if text_batch:
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
if prom_batch:
prom_lists.append( _join( prom_batch, torch.tensor(self.ignore_index, device=device) ) )
if resp_batch:
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
return text_lists, resp_lists
return text_lists, prom_lists, resp_lists
def forward(
self,
@ -982,12 +993,32 @@ class Base(nn.Module):
# precision = self.precision_metric( inputs, target ),
)
# split our loss
# to-do: clean this up
else:
target_text_list, target_resp_list = self.training_targets_split( inputs )
target_text_list, target_prom_list, target_resp_list = self.training_targets_split( inputs, quant_levels )
logits_text = []
logits_prom = []
logits_resp = []
for i, logit in enumerate(logits):
text_len = target_text_list[i].shape[0] if target_text_list else 0
prom_len = target_prom_list[i].shape[0] if target_prom_list else 0
resp_len = target_resp_list[i].shape[0] if target_resp_list else 0
if target_text_list:
logit_text = logit[:text_len]
logits_text.append( logit_text )
# + 1 to include separator
if target_prom_list:
logit_prom = logit[text_len+1:text_len+1+prom_len]
logits_prom.append( logit_prom )
if target_resp_list:
logit_resp = logit[-resp_len:]
logits_resp.append( logit_resp )
# grab respective slice of logits
logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ]
logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ]
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_text_list)):
@ -996,25 +1027,37 @@ class Base(nn.Module):
# shift the target so that token n...
logits_text[i] = logits_text[i][..., :-1, :]
logits_prom[i] = logits_prom[i][..., :-1, :]
logits_resp[i] = logits_resp[i][..., :-1, :]
# predicts token n + 1
target_text_list[i] = target_text_list[i][..., 1:]
target_prom_list[i] = target_prom_list[i][..., 1:]
target_resp_list[i] = target_resp_list[i][..., 1:]
self.loss = dict()
self.stats = dict(acc = dict())
target_text = torch.cat( target_text_list ).long()
target_resp = torch.cat( target_resp_list ).long()
loss_factor_text = self.loss_factor("text")
if loss_factor_text > 0.0 and target_text_list:
target_text = torch.cat( target_text_list ).long()
inputs_text = torch.cat( logits_text )
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index )
self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text )
inputs_text = torch.cat( logits_text )
inputs_resp = torch.cat( logits_resp )
loss_factor_prom = self.loss_factor("prom")
if loss_factor_prom > 0.0 and target_prom_list:
target_prom = torch.cat( target_prom_list ).long()
inputs_prom = torch.cat( logits_prom )
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index )
self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom )
self.loss = dict(
text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ),
resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ),
)
for k in self.loss:
self.loss[k] *= self.loss_factor(k)
loss_factor_resp = self.loss_factor("resp")
if loss_factor_resp > 0.0 and target_resp_list:
target_resp = torch.cat( target_resp_list ).long()
inputs_resp = torch.cat( logits_resp )
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index )
self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp )
# to-do: compute loss per individual batch to scale per RVQ level
"""
@ -1023,13 +1066,6 @@ class Base(nn.Module):
...
"""
self.stats = dict(
acc = dict(
text = self.accuracy_metric( inputs_text, target_text ),
resp = self.accuracy_metric( inputs_resp, target_resp ),
),
)
# include any additional losses (for example: MoE router)
if aux_loss is not None:
self.loss["aux_loss"] = aux_loss