ugh
This commit is contained in:
parent
0d4329d2e3
commit
5f289db275
|
@ -14,7 +14,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
|
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
|
||||||
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
|
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
|
||||||
|
|
||||||
#if !LLAMA_CPP_EXTENDED
|
#if !LLAMA_CPP_EXTENDED
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -105,7 +105,8 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
|
||||||
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
|
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
|
||||||
|
|
||||||
if not split_classifiers:
|
if not split_classifiers:
|
||||||
classifier.weight[:] = state_dict['module']['classifier.weight'][:]
|
src = state_dict['module']['classifier.weight'][:]
|
||||||
|
classifier.weight[:src.shape[0], ] = src
|
||||||
|
|
||||||
# update ranges
|
# update ranges
|
||||||
start = 0
|
start = 0
|
||||||
|
|
|
@ -523,7 +523,7 @@ class Base(nn.Module):
|
||||||
classifier_l_tokens += [ 11 ]
|
classifier_l_tokens += [ 11 ]
|
||||||
classifier_l_names += ["len"]
|
classifier_l_names += ["len"]
|
||||||
|
|
||||||
n_vocab = 17701 if not split_classifiers else n_resp_tokens + 1
|
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||||
|
|
||||||
self.n_vocab = n_vocab
|
self.n_vocab = n_vocab
|
||||||
self.unified_position_ids = unified_position_ids
|
self.unified_position_ids = unified_position_ids
|
||||||
|
@ -1436,6 +1436,9 @@ class Base(nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# offset to flattened vocab ranges
|
# offset to flattened vocab ranges
|
||||||
|
if self.classifier is not None:
|
||||||
|
compute_acc = False
|
||||||
|
"""
|
||||||
if self.classifier is not None:
|
if self.classifier is not None:
|
||||||
offsets = _get_offsets()
|
offsets = _get_offsets()
|
||||||
|
|
||||||
|
@ -1454,6 +1457,7 @@ class Base(nn.Module):
|
||||||
if t == self.ignore_index:
|
if t == self.ignore_index:
|
||||||
continue
|
continue
|
||||||
token[i] += start
|
token[i] += start
|
||||||
|
"""
|
||||||
|
|
||||||
if token.is_floating_point():
|
if token.is_floating_point():
|
||||||
ignored = True
|
ignored = True
|
||||||
|
@ -1511,7 +1515,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# perofrm loss calculation on the entire sequence
|
# perofrm loss calculation on the entire sequence
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
target = _join( target, torch.tensor(self.ignore_index if self.classifier is None else 17685, device=target[-1].device) )
|
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||||
logit = logits[batch_index]
|
logit = logits[batch_index]
|
||||||
|
|
||||||
# shift if causal
|
# shift if causal
|
||||||
|
@ -1679,31 +1683,30 @@ class Base(nn.Module):
|
||||||
for i, state in enumerate( hidden_states ):
|
for i, state in enumerate( hidden_states ):
|
||||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||||
|
|
||||||
|
# de-offset if needed
|
||||||
|
if self.classifier is not None:
|
||||||
|
offsets = _get_offsets()
|
||||||
|
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||||
|
if classifier_level == "stt":
|
||||||
|
k = "text"
|
||||||
|
elif classifier_level == "len":
|
||||||
|
k = "len"
|
||||||
|
else:
|
||||||
|
k = f'resps|{classifier_level}'
|
||||||
|
|
||||||
|
if k not in offsets:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start, end = offsets[k]
|
||||||
|
|
||||||
|
logits[batch_index] = logits[batch_index][:, start:end]
|
||||||
|
|
||||||
if not training:
|
if not training:
|
||||||
loss = None
|
loss = None
|
||||||
stats = None
|
stats = None
|
||||||
|
|
||||||
self.loss = None
|
self.loss = None
|
||||||
self.stats = None
|
self.stats = None
|
||||||
|
|
||||||
# de-offset if needed
|
|
||||||
if self.classifier is not None:
|
|
||||||
offsets = _get_offsets()
|
|
||||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
|
||||||
if classifier_level == "stt":
|
|
||||||
k = "text"
|
|
||||||
elif classifier_level == "len":
|
|
||||||
k = "len"
|
|
||||||
else:
|
|
||||||
k = f'resps|{classifier_level}'
|
|
||||||
|
|
||||||
if k not in offsets:
|
|
||||||
continue
|
|
||||||
|
|
||||||
start, end = offsets[k]
|
|
||||||
|
|
||||||
logits[batch_index] = logits[batch_index][:, start:start+end]
|
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
else:
|
else:
|
||||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||||
|
|
Loading…
Reference in New Issue
Block a user