ugh (ROCm seems to silently clamp any token value >= logits.shape[-1] for loss calculation, while cuda will throw an assert, making it hard to find this dumb fuckup)
This commit is contained in:
parent
943fe70c10
commit
ad7e290a5e
|
@ -1050,16 +1050,16 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
# insert timestep token
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
# cosine schedule
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
|
||||
|
||||
# store timestep information
|
||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
# store dropout mask
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
# insert the current output response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# store dropout mask
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
# Audio length prediction task
|
||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||
|
@ -1269,7 +1269,7 @@ class Base(nn.Module):
|
|||
at=None,
|
||||
):
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
if at is not None and batch_index != batch_index:
|
||||
if at is not None and batch_index != at:
|
||||
continue
|
||||
|
||||
for n, input in batch_input:
|
||||
|
@ -1420,6 +1420,16 @@ class Base(nn.Module):
|
|||
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
||||
|
||||
for batch_index, target in enumerate( target_list ):
|
||||
logit = logits[batch_index]
|
||||
|
||||
max_classes = logit.shape[-1]
|
||||
max_token = torch.max( target ).item()
|
||||
|
||||
if max_token > max_classes:
|
||||
task = self.get_input(inputs, "task", at=batch_index)
|
||||
print( batch_index, task, target, max_token, max_classes, inputs[batch_index] )
|
||||
|
||||
# see comments for the split-loss calc cross_entropy call
|
||||
if False:
|
||||
target = torch.cat( target_list )
|
||||
|
|
Loading…
Reference in New Issue
Block a user