This commit is contained in:
mrq 2024-12-22 15:19:41 -06:00
parent 0d4329d2e3
commit ddabcb65f5
3 changed files with 25 additions and 20 deletions

View File

@ -14,7 +14,7 @@
#include <unordered_map>
#include <iostream>
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
#if !LLAMA_CPP_EXTENDED

File diff suppressed because one or more lines are too long

View File

@ -1436,6 +1436,9 @@ class Base(nn.Module):
continue
# offset to flattened vocab ranges
if self.classifier is not None:
compute_acc = False
"""
if self.classifier is not None:
offsets = _get_offsets()
@ -1454,6 +1457,7 @@ class Base(nn.Module):
if t == self.ignore_index:
continue
token[i] += start
"""
if token.is_floating_point():
ignored = True
@ -1679,31 +1683,30 @@ class Base(nn.Module):
for i, state in enumerate( hidden_states ):
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# de-offset if needed
if self.classifier is not None:
offsets = _get_offsets()
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level == "stt":
k = "text"
elif classifier_level == "len":
k = "len"
else:
k = f'resps|{classifier_level}'
if k not in offsets:
continue
start, end = offsets[k]
logits[batch_index] = logits[batch_index][:, start:start+end]
if not training:
loss = None
stats = None
self.loss = None
self.stats = None
# de-offset if needed
if self.classifier is not None:
offsets = _get_offsets()
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level == "stt":
k = "text"
elif classifier_level == "len":
k = "len"
else:
k = f'resps|{classifier_level}'
if k not in offsets:
continue
start, end = offsets[k]
logits[batch_index] = logits[batch_index][:, start:start+end]
# compute loss if the target is given
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )