added loss calcs against prom (requires the right settings for not shit results, disabled by default)
This commit is contained in:
parent
05cd8b797e
commit
5af6f41c94
|
@ -213,7 +213,7 @@ class Model:
|
||||||
attention: str = "auto"
|
attention: str = "auto"
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
loss_factors: dict = field(default_factory=lambda: {})
|
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.1, "resp": 1.0 })
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
return [ self ] if not name or self.name == name else []
|
return [ self ] if not name or self.name == name else []
|
||||||
|
@ -495,11 +495,6 @@ class DeepSpeed:
|
||||||
|
|
||||||
return ds_cfg
|
return ds_cfg
|
||||||
|
|
||||||
@dataclass()
|
|
||||||
class LossFactor:
|
|
||||||
text: float = 1.0
|
|
||||||
resp: float = 1.0
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Trainer:
|
class Trainer:
|
||||||
iterations: int = 100_000
|
iterations: int = 100_000
|
||||||
|
|
|
@ -888,27 +888,38 @@ class Base(nn.Module):
|
||||||
def training_targets_split(
|
def training_targets_split(
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
|
quant_levels: Tensor | None = None
|
||||||
):
|
):
|
||||||
text_lists = []
|
text_lists = []
|
||||||
|
prom_lists = []
|
||||||
resp_lists = []
|
resp_lists = []
|
||||||
|
|
||||||
for bi in range(len(inputs)):
|
for bi in range(len(inputs)):
|
||||||
text_batch = []
|
text_batch = []
|
||||||
|
prom_batch = []
|
||||||
resp_batch = []
|
resp_batch = []
|
||||||
|
|
||||||
for i in range(len(inputs[bi])):
|
for i in range(len(inputs[bi])):
|
||||||
name, input = inputs[bi][i]
|
name, input = inputs[bi][i]
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
if name in ["text", "lang" ]:
|
quant_level = quant_levels[bi] if quant_levels is not None else None
|
||||||
|
|
||||||
|
if name in ["text" ]:
|
||||||
text_batch.append( input )
|
text_batch.append( input )
|
||||||
|
elif name == "prom" and (quant_level is None or quant_level == 0 or not self.config.audio_embedding_sums):
|
||||||
|
prom_batch.append( input[:, quant_level] if quant_level is not None else input )
|
||||||
elif name == "targ":
|
elif name == "targ":
|
||||||
resp_batch.append( input )
|
resp_batch.append( input )
|
||||||
|
|
||||||
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
|
if text_batch:
|
||||||
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
|
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||||
|
if prom_batch:
|
||||||
|
prom_lists.append( _join( prom_batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||||
|
if resp_batch:
|
||||||
|
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||||
|
|
||||||
return text_lists, resp_lists
|
return text_lists, prom_lists, resp_lists
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -982,12 +993,32 @@ class Base(nn.Module):
|
||||||
# precision = self.precision_metric( inputs, target ),
|
# precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
# split our loss
|
# split our loss
|
||||||
|
# to-do: clean this up
|
||||||
else:
|
else:
|
||||||
target_text_list, target_resp_list = self.training_targets_split( inputs )
|
target_text_list, target_prom_list, target_resp_list = self.training_targets_split( inputs, quant_levels )
|
||||||
|
|
||||||
|
logits_text = []
|
||||||
|
logits_prom = []
|
||||||
|
logits_resp = []
|
||||||
|
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
text_len = target_text_list[i].shape[0] if target_text_list else 0
|
||||||
|
prom_len = target_prom_list[i].shape[0] if target_prom_list else 0
|
||||||
|
resp_len = target_resp_list[i].shape[0] if target_resp_list else 0
|
||||||
|
|
||||||
|
if target_text_list:
|
||||||
|
logit_text = logit[:text_len]
|
||||||
|
logits_text.append( logit_text )
|
||||||
|
|
||||||
|
# + 1 to include separator
|
||||||
|
if target_prom_list:
|
||||||
|
logit_prom = logit[text_len+1:text_len+1+prom_len]
|
||||||
|
logits_prom.append( logit_prom )
|
||||||
|
|
||||||
|
if target_resp_list:
|
||||||
|
logit_resp = logit[-resp_len:]
|
||||||
|
logits_resp.append( logit_resp )
|
||||||
|
|
||||||
# grab respective slice of logits
|
|
||||||
logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ]
|
|
||||||
logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ]
|
|
||||||
|
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
for i in range(len(target_text_list)):
|
for i in range(len(target_text_list)):
|
||||||
|
@ -996,25 +1027,37 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# shift the target so that token n...
|
# shift the target so that token n...
|
||||||
logits_text[i] = logits_text[i][..., :-1, :]
|
logits_text[i] = logits_text[i][..., :-1, :]
|
||||||
|
logits_prom[i] = logits_prom[i][..., :-1, :]
|
||||||
logits_resp[i] = logits_resp[i][..., :-1, :]
|
logits_resp[i] = logits_resp[i][..., :-1, :]
|
||||||
|
|
||||||
# predicts token n + 1
|
# predicts token n + 1
|
||||||
target_text_list[i] = target_text_list[i][..., 1:]
|
target_text_list[i] = target_text_list[i][..., 1:]
|
||||||
|
target_prom_list[i] = target_prom_list[i][..., 1:]
|
||||||
target_resp_list[i] = target_resp_list[i][..., 1:]
|
target_resp_list[i] = target_resp_list[i][..., 1:]
|
||||||
|
|
||||||
|
self.loss = dict()
|
||||||
|
self.stats = dict(acc = dict())
|
||||||
|
|
||||||
target_text = torch.cat( target_text_list ).long()
|
loss_factor_text = self.loss_factor("text")
|
||||||
target_resp = torch.cat( target_resp_list ).long()
|
if loss_factor_text > 0.0 and target_text_list:
|
||||||
|
target_text = torch.cat( target_text_list ).long()
|
||||||
|
inputs_text = torch.cat( logits_text )
|
||||||
|
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index )
|
||||||
|
self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text )
|
||||||
|
|
||||||
inputs_text = torch.cat( logits_text )
|
loss_factor_prom = self.loss_factor("prom")
|
||||||
inputs_resp = torch.cat( logits_resp )
|
if loss_factor_prom > 0.0 and target_prom_list:
|
||||||
|
target_prom = torch.cat( target_prom_list ).long()
|
||||||
|
inputs_prom = torch.cat( logits_prom )
|
||||||
|
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index )
|
||||||
|
self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom )
|
||||||
|
|
||||||
self.loss = dict(
|
loss_factor_resp = self.loss_factor("resp")
|
||||||
text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ),
|
if loss_factor_resp > 0.0 and target_resp_list:
|
||||||
resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ),
|
target_resp = torch.cat( target_resp_list ).long()
|
||||||
)
|
inputs_resp = torch.cat( logits_resp )
|
||||||
|
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index )
|
||||||
for k in self.loss:
|
self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp )
|
||||||
self.loss[k] *= self.loss_factor(k)
|
|
||||||
|
|
||||||
# to-do: compute loss per individual batch to scale per RVQ level
|
# to-do: compute loss per individual batch to scale per RVQ level
|
||||||
"""
|
"""
|
||||||
|
@ -1023,13 +1066,6 @@ class Base(nn.Module):
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.stats = dict(
|
|
||||||
acc = dict(
|
|
||||||
text = self.accuracy_metric( inputs_text, target_text ),
|
|
||||||
resp = self.accuracy_metric( inputs_resp, target_resp ),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if aux_loss is not None:
|
if aux_loss is not None:
|
||||||
self.loss["aux_loss"] = aux_loss
|
self.loss["aux_loss"] = aux_loss
|
||||||
|
|
Loading…
Reference in New Issue
Block a user