experimental implementation of STT (need to actually test on a model, test trainer seems to work)
This commit is contained in:
parent
d319d33368
commit
54547b74d8
|
@ -960,6 +960,19 @@ class NaiveTokenizer:
|
|||
# tokenize
|
||||
return [*map(symmap.get, phones)]
|
||||
|
||||
def decode( self, t ):
|
||||
s = ""
|
||||
symmap = self.get_vocab()
|
||||
reverse_symmap = {}
|
||||
for k, v in symmap.items():
|
||||
reverse_symmap[v] = k
|
||||
|
||||
for i, token in enumerate( t ):
|
||||
s += reverse_symmap[token]
|
||||
|
||||
return s
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
cfg = Config.from_cli()
|
||||
|
|
|
@ -428,6 +428,7 @@ def get_task_symmap():
|
|||
"<soe>": 5,
|
||||
"<mask>": 6,
|
||||
"<eoe>": 7,
|
||||
"<stt>": 8,
|
||||
|
||||
"<nse>": 6, # fake
|
||||
"<cse>": 6, # fake
|
||||
|
@ -1052,6 +1053,12 @@ class Dataset(_Dataset):
|
|||
task,
|
||||
]
|
||||
|
||||
# Base TTS (<resp> => <text>)
|
||||
elif task == "stt":
|
||||
# easier to just keep it instead of wrangling around trying to remove it
|
||||
# it might also help to provide a guidance prompt but who knows right now
|
||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
|
||||
# noise suppression (<text>? <resp+noise> => <resp>)
|
||||
# speech removal (<text>?<resp+noise> => <noise>)
|
||||
elif task == "ns" or task == "sr":
|
||||
|
|
|
@ -200,7 +200,7 @@ class AR(Base):
|
|||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
resps_list=resps_list,
|
||||
prev_list=resps_list,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
|
|
|
@ -61,15 +61,23 @@ class AR_NAR(Base):
|
|||
|
||||
disable_tqdm=False,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
text_task = [ "stt" ]
|
||||
|
||||
if text_list is not None:
|
||||
default_task = "tts"
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
else:
|
||||
default_task = "stt"
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
# generate task list if not provided
|
||||
if task_list is None:
|
||||
task_list = [ "tts" for _ in range(batch_size) ]
|
||||
task_list = [ default_task for _ in range(batch_size) ]
|
||||
|
||||
# is training or NAR
|
||||
if resps_list is not None:
|
||||
if resps_list is not None and text_list is not None:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
|
@ -102,12 +110,18 @@ class AR_NAR(Base):
|
|||
|
||||
# input RVQ levels
|
||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
||||
for i, task in enumerate( task_list ):
|
||||
if task in text_task:
|
||||
quant_levels[i] = 0 # self.n_resp_levels - 1
|
||||
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
# tensor to cat for RVQ level 0
|
||||
stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
@ -139,7 +153,11 @@ class AR_NAR(Base):
|
|||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0:
|
||||
# append stop tokens for AR
|
||||
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
||||
if task in text_task:
|
||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||
...
|
||||
else:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
||||
|
||||
inputs = self.inputs(
|
||||
|
@ -195,7 +213,7 @@ class AR_NAR(Base):
|
|||
|
||||
resps_list = super().sample(
|
||||
logits=logits,
|
||||
resps_list=prev_list,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
|
@ -220,11 +238,11 @@ class AR_NAR(Base):
|
|||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||
|
||||
# STT
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = self.stop_token
|
||||
|
||||
stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer
|
||||
|
||||
state = None
|
||||
mirostat = [
|
||||
|
@ -233,9 +251,17 @@ class AR_NAR(Base):
|
|||
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
# add <bos> to text for STT
|
||||
for i, sequence in enumerate( sequence_list ):
|
||||
if task_list[i] in text_task:
|
||||
sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)])
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
||||
if task_list[0] in text_task:
|
||||
text_list = [x for x in sequence_list]
|
||||
else:
|
||||
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -261,7 +287,7 @@ class AR_NAR(Base):
|
|||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
resps_list=resps_list,
|
||||
prev_list=resps_list,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
|
@ -398,10 +424,10 @@ def example_usage():
|
|||
"""
|
||||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
tasks = cfg.dataset.tasks_list
|
||||
available_tasks = ["tts", "stt"] # cfg.dataset.tasks_list
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 150 * len(tasks) # * cfg.model.experimental.causal_size
|
||||
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
@ -486,14 +512,14 @@ def example_usage():
|
|||
_logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_data(task=None):
|
||||
def sample_data(t=None):
|
||||
texts = []
|
||||
proms = []
|
||||
resps = []
|
||||
tasks = []
|
||||
|
||||
for i in range(batch_size):
|
||||
if task is None:
|
||||
task = random.choice(tasks)
|
||||
task = random.choice(available_tasks) if t is None else t
|
||||
|
||||
text = text_list[i]
|
||||
prom = proms_list[i]
|
||||
|
@ -502,6 +528,8 @@ def example_usage():
|
|||
# do nothing
|
||||
if task == "tts":
|
||||
...
|
||||
elif task == "stt":
|
||||
...
|
||||
elif task == "tts-c":
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
|
||||
|
@ -523,25 +551,35 @@ def example_usage():
|
|||
texts.append( text.to(device) )
|
||||
proms.append( prom.to(device) )
|
||||
resps.append( resp.to(device) )
|
||||
tasks.append( task )
|
||||
|
||||
return texts, proms, resps
|
||||
return texts, proms, resps, tasks
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000, task=None ):
|
||||
def sample( name, steps=500, task=None ):
|
||||
engine.eval()
|
||||
|
||||
texts, proms, resps = sample_data( task )
|
||||
texts, proms, resps, tasks = sample_data( task )
|
||||
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||
if tasks[0] == "stt":
|
||||
text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||
"""
|
||||
# to-do: STT for NAR
|
||||
text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||
"""
|
||||
text = [ cfg.tokenizer.decode( t ) for t in text ]
|
||||
|
||||
print( text )
|
||||
else:
|
||||
resps = [ resp[:, 0] for resp in resps ]
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||
else:
|
||||
resps = [ resp[:, 0] for resp in resps ]
|
||||
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
||||
for i, o in enumerate(resps):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
|
@ -549,10 +587,10 @@ def example_usage():
|
|||
engine.train()
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
texts, proms, resps = sample_data()
|
||||
texts, proms, resps, tasks = sample_data()
|
||||
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps)
|
||||
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
@ -571,7 +609,7 @@ def example_usage():
|
|||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||
"""
|
||||
|
||||
for task in tasks:
|
||||
for task in available_tasks:
|
||||
sample("final", task=task)
|
||||
|
||||
engines.quit()
|
||||
|
|
|
@ -259,7 +259,8 @@ class AudioEmbedding(nn.Module):
|
|||
return x
|
||||
|
||||
# per-level classification
|
||||
class AudioClassifier(nn.Module):
|
||||
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
|
||||
class Classifiers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
|
@ -783,10 +784,10 @@ class Base(nn.Module):
|
|||
self.metrics = None
|
||||
else:
|
||||
self.classifier = None
|
||||
self.classifiers = AudioClassifier( l_tokens, d_model )
|
||||
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model )
|
||||
self.accuracy_metric = None
|
||||
self.precision_metric = None
|
||||
self.metrics = Metrics( l_tokens )
|
||||
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
|
||||
|
||||
"""
|
||||
if tie_classifier_to_embedding:
|
||||
|
@ -907,6 +908,8 @@ class Base(nn.Module):
|
|||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
special_tasks = ["stt", "len"]
|
||||
|
||||
inputs = [ [] for _ in range(batch_size) ]
|
||||
for i in range(batch_size):
|
||||
quant_level = quant_levels[i] if quant_levels is not None else 0
|
||||
|
@ -921,7 +924,7 @@ class Base(nn.Module):
|
|||
# Base-line TTS task
|
||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||
if f'<{task_type}>' in get_task_symmap():
|
||||
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
@ -973,6 +976,21 @@ class Base(nn.Module):
|
|||
elif resps_list is not None and resps_list[i] is not None:
|
||||
# yes this could be encoded better
|
||||
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||
# Speech-to-Text prediction task
|
||||
# Sequence: <resp><sep><rvq lvl><sep><text>
|
||||
elif task_type == "stt":
|
||||
# insert the input response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert RVQ level guidance token if the model is versioned for it
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the output text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
else:
|
||||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
|
||||
|
@ -1010,6 +1028,8 @@ class Base(nn.Module):
|
|||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [1, self.resp_levels]
|
||||
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
|
||||
x_list = []
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = []
|
||||
|
@ -1071,7 +1091,13 @@ class Base(nn.Module):
|
|||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
embedding = sum([ self.resps_emb(
|
||||
input[:, :l+1],
|
||||
offset = 0 if l == 0 else 1, # or maybe set to 1
|
||||
quant_level = l
|
||||
) for l in range( input.shape[-1] - 1 ) ])
|
||||
else:
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
if self.version <= 4:
|
||||
|
@ -1171,7 +1197,9 @@ class Base(nn.Module):
|
|||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
):
|
||||
device = logits[0].device
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
||||
special_tasks = [ "len", "stt" ]
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
|
@ -1192,8 +1220,10 @@ class Base(nn.Module):
|
|||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
target = []
|
||||
task_type = "tts"
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
task_list.append( input )
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
|
@ -1201,6 +1231,8 @@ class Base(nn.Module):
|
|||
elif name == "resp":
|
||||
if self.interleave:
|
||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||
elif task_type in summed_embeddings_task:
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
else:
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||
|
@ -1273,7 +1305,12 @@ class Base(nn.Module):
|
|||
for name, input in batch:
|
||||
# do not use resp
|
||||
if name == "resp":
|
||||
input = input if input.dim() == 1 else input[:, quant_level]
|
||||
if self.interleave:
|
||||
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||
elif task_type in summed_embeddings_task:
|
||||
input = torch.full_like(input[..., 0], self.ignore_index)
|
||||
else:
|
||||
input = input if input.dim() == 1 else input[:, quant_level]
|
||||
# select prom level
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
|
@ -1383,7 +1420,8 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
if self.classifiers is not None:
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
||||
special_tasks = [ "len", "stt" ]
|
||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
x = self.classifiers(x, levels = classifier_quant_levels) * m
|
||||
|
||||
# Remove padding
|
||||
|
@ -1402,7 +1440,7 @@ class Base(nn.Module):
|
|||
def sample(
|
||||
self,
|
||||
logits: list[Tensor], # logit scores
|
||||
resps_list: list[Tensor], # previous tokens
|
||||
prev_list: list[Tensor], # previous tokens
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
# base sampling parameters
|
||||
temperature: float = 1.0,
|
||||
|
@ -1429,7 +1467,7 @@ class Base(nn.Module):
|
|||
# (NAR) return the entire generated response
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, prev_list)) ]
|
||||
# (AR chunkwise) return the last chunkwise piece
|
||||
elif self.causal:
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
@ -1439,22 +1477,22 @@ class Base(nn.Module):
|
|||
|
||||
# (NAR) disable stop token
|
||||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
# (AR-len) disable extraneous tokens
|
||||
if quant_levels is None and "len" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
return [ logit.argmax(dim=1) for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities:
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
if "len" not in self.capabilities and repetition_penalty != 1.0:
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, prev_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
|
@ -1469,7 +1507,7 @@ class Base(nn.Module):
|
|||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0:
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
|
|
|
@ -183,7 +183,7 @@ class NAR(Base):
|
|||
|
||||
resps_list = super().sample(
|
||||
logits=logits,
|
||||
resps_list=prev_list,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=1.0 if n == 0 else sampling_temperature,
|
||||
|
|
Loading…
Reference in New Issue
Block a user