ugh (ROCm seems to silently clamp any token value >= logits.shape[-1] for loss calculation, while cuda will throw an assert, making it hard to find this dumb fuckup)

This commit is contained in:
mrq 2024-11-09 19:40:02 -06:00
parent 943fe70c10
commit ad7e290a5e

View File

@ -1050,16 +1050,16 @@ class Base(nn.Module):
inputs[i].append( ( "tone", tone_list[i] ) )
# insert timestep token
if "len" in self.capabilities and quant_level == 0:
# cosine schedule
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
# store timestep information
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
# store dropout mask
inputs[i].append( ("dropout_mask", dropout_mask ) )
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# store dropout mask
if "len" in self.capabilities and quant_level == 0:
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
@ -1269,7 +1269,7 @@ class Base(nn.Module):
at=None,
):
for batch_index, batch_input in enumerate(inputs):
if at is not None and batch_index != batch_index:
if at is not None and batch_index != at:
continue
for n, input in batch_input:
@ -1420,6 +1420,16 @@ class Base(nn.Module):
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
target_list[i] = target_list[i][..., l:] # predicts token n + 1
for batch_index, target in enumerate( target_list ):
logit = logits[batch_index]
max_classes = logit.shape[-1]
max_token = torch.max( target ).item()
if max_token > max_classes:
task = self.get_input(inputs, "task", at=batch_index)
print( batch_index, task, target, max_token, max_classes, inputs[batch_index] )
# see comments for the split-loss calc cross_entropy call
if False:
target = torch.cat( target_list )