This commit is contained in:
mrq 2024-12-22 15:19:41 -06:00
parent 0d4329d2e3
commit ddabcb65f5
3 changed files with 25 additions and 20 deletions

View File

@ -14,7 +14,7 @@
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <iostream>
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions #define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch #define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
#if !LLAMA_CPP_EXTENDED #if !LLAMA_CPP_EXTENDED

File diff suppressed because one or more lines are too long

View File

@ -1436,6 +1436,9 @@ class Base(nn.Module):
continue continue
# offset to flattened vocab ranges # offset to flattened vocab ranges
if self.classifier is not None:
compute_acc = False
"""
if self.classifier is not None: if self.classifier is not None:
offsets = _get_offsets() offsets = _get_offsets()
@ -1454,6 +1457,7 @@ class Base(nn.Module):
if t == self.ignore_index: if t == self.ignore_index:
continue continue
token[i] += start token[i] += start
"""
if token.is_floating_point(): if token.is_floating_point():
ignored = True ignored = True
@ -1679,13 +1683,6 @@ class Base(nn.Module):
for i, state in enumerate( hidden_states ): for i, state in enumerate( hidden_states ):
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
if not training:
loss = None
stats = None
self.loss = None
self.stats = None
# de-offset if needed # de-offset if needed
if self.classifier is not None: if self.classifier is not None:
offsets = _get_offsets() offsets = _get_offsets()
@ -1704,6 +1701,12 @@ class Base(nn.Module):
logits[batch_index] = logits[batch_index][:, start:start+end] logits[batch_index] = logits[batch_index][:, start:start+end]
if not training:
loss = None
stats = None
self.loss = None
self.stats = None
# compute loss if the target is given # compute loss if the target is given
else: else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )