experimental implementation of STT (need to actually test on a model, test trainer seems to work)
This commit is contained in:
parent
d319d33368
commit
54547b74d8
|
@ -960,6 +960,19 @@ class NaiveTokenizer:
|
||||||
# tokenize
|
# tokenize
|
||||||
return [*map(symmap.get, phones)]
|
return [*map(symmap.get, phones)]
|
||||||
|
|
||||||
|
def decode( self, t ):
|
||||||
|
s = ""
|
||||||
|
symmap = self.get_vocab()
|
||||||
|
reverse_symmap = {}
|
||||||
|
for k, v in symmap.items():
|
||||||
|
reverse_symmap[v] = k
|
||||||
|
|
||||||
|
for i, token in enumerate( t ):
|
||||||
|
s += reverse_symmap[token]
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
cfg = Config.from_cli()
|
cfg = Config.from_cli()
|
||||||
|
|
|
@ -428,6 +428,7 @@ def get_task_symmap():
|
||||||
"<soe>": 5,
|
"<soe>": 5,
|
||||||
"<mask>": 6,
|
"<mask>": 6,
|
||||||
"<eoe>": 7,
|
"<eoe>": 7,
|
||||||
|
"<stt>": 8,
|
||||||
|
|
||||||
"<nse>": 6, # fake
|
"<nse>": 6, # fake
|
||||||
"<cse>": 6, # fake
|
"<cse>": 6, # fake
|
||||||
|
@ -1052,6 +1053,12 @@ class Dataset(_Dataset):
|
||||||
task,
|
task,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Base TTS (<resp> => <text>)
|
||||||
|
elif task == "stt":
|
||||||
|
# easier to just keep it instead of wrangling around trying to remove it
|
||||||
|
# it might also help to provide a guidance prompt but who knows right now
|
||||||
|
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||||
|
|
||||||
# noise suppression (<text>? <resp+noise> => <resp>)
|
# noise suppression (<text>? <resp+noise> => <resp>)
|
||||||
# speech removal (<text>?<resp+noise> => <noise>)
|
# speech removal (<text>?<resp+noise> => <noise>)
|
||||||
elif task == "ns" or task == "sr":
|
elif task == "ns" or task == "sr":
|
||||||
|
|
|
@ -200,7 +200,7 @@ class AR(Base):
|
||||||
|
|
||||||
r = super().sample(
|
r = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
resps_list=resps_list,
|
prev_list=resps_list,
|
||||||
|
|
||||||
temperature=sampling_temperature,
|
temperature=sampling_temperature,
|
||||||
min_temperature=sampling_min_temperature,
|
min_temperature=sampling_min_temperature,
|
||||||
|
|
|
@ -61,15 +61,23 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
):
|
):
|
||||||
device = text_list[0].device
|
text_task = [ "stt" ]
|
||||||
batch_size = len(text_list)
|
|
||||||
|
if text_list is not None:
|
||||||
|
default_task = "tts"
|
||||||
|
device = text_list[0].device
|
||||||
|
batch_size = len(text_list)
|
||||||
|
else:
|
||||||
|
default_task = "stt"
|
||||||
|
device = resps_list[0].device
|
||||||
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
# generate task list if not provided
|
# generate task list if not provided
|
||||||
if task_list is None:
|
if task_list is None:
|
||||||
task_list = [ "tts" for _ in range(batch_size) ]
|
task_list = [ default_task for _ in range(batch_size) ]
|
||||||
|
|
||||||
# is training or NAR
|
# is training or NAR
|
||||||
if resps_list is not None:
|
if resps_list is not None and text_list is not None:
|
||||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||||
n_levels = next(iter(n_levels_set))
|
n_levels = next(iter(n_levels_set))
|
||||||
|
|
||||||
|
@ -102,12 +110,18 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# input RVQ levels
|
# input RVQ levels
|
||||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
||||||
|
for i, task in enumerate( task_list ):
|
||||||
|
if task in text_task:
|
||||||
|
quant_levels[i] = 0 # self.n_resp_levels - 1
|
||||||
|
|
||||||
# trim resps to only contain all levels below the target level
|
# trim resps to only contain all levels below the target level
|
||||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||||
|
|
||||||
# tensor to cat for RVQ level 0
|
# tensor to cat for RVQ level 0
|
||||||
stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
|
||||||
|
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||||
# I hate python's value/reference semantics so much
|
# I hate python's value/reference semantics so much
|
||||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
if quant_level >= resps.shape[-1]:
|
if quant_level >= resps.shape[-1]:
|
||||||
quant_levels[i] = resps.shape[-1] - 1
|
quant_levels[i] = resps.shape[-1] - 1
|
||||||
|
@ -139,7 +153,11 @@ class AR_NAR(Base):
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if quant_level <= 0:
|
if quant_level <= 0:
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
if task in text_task:
|
||||||
|
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||||
|
...
|
||||||
|
else:
|
||||||
|
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
|
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
|
@ -195,7 +213,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
resps_list = super().sample(
|
resps_list = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
resps_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
temperature=sampling_temperature,
|
temperature=sampling_temperature,
|
||||||
|
@ -220,11 +238,11 @@ class AR_NAR(Base):
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||||
|
|
||||||
|
# STT
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
stop_token = self.stop_token
|
stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer
|
||||||
|
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
mirostat = [
|
mirostat = [
|
||||||
|
@ -233,9 +251,17 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
|
|
||||||
|
# add <bos> to text for STT
|
||||||
|
for i, sequence in enumerate( sequence_list ):
|
||||||
|
if task_list[i] in text_task:
|
||||||
|
sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)])
|
||||||
|
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||||
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
if task_list[0] in text_task:
|
||||||
|
text_list = [x for x in sequence_list]
|
||||||
|
else:
|
||||||
|
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -261,7 +287,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
r = super().sample(
|
r = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
resps_list=resps_list,
|
prev_list=resps_list,
|
||||||
|
|
||||||
temperature=sampling_temperature,
|
temperature=sampling_temperature,
|
||||||
min_temperature=sampling_min_temperature,
|
min_temperature=sampling_min_temperature,
|
||||||
|
@ -398,10 +424,10 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||||
tasks = cfg.dataset.tasks_list
|
available_tasks = ["tts", "stt"] # cfg.dataset.tasks_list
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 150 * len(tasks) # * cfg.model.experimental.causal_size
|
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
@ -486,14 +512,14 @@ def example_usage():
|
||||||
_logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
_logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_data(task=None):
|
def sample_data(t=None):
|
||||||
texts = []
|
texts = []
|
||||||
proms = []
|
proms = []
|
||||||
resps = []
|
resps = []
|
||||||
|
tasks = []
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if task is None:
|
task = random.choice(available_tasks) if t is None else t
|
||||||
task = random.choice(tasks)
|
|
||||||
|
|
||||||
text = text_list[i]
|
text = text_list[i]
|
||||||
prom = proms_list[i]
|
prom = proms_list[i]
|
||||||
|
@ -502,6 +528,8 @@ def example_usage():
|
||||||
# do nothing
|
# do nothing
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
...
|
...
|
||||||
|
elif task == "stt":
|
||||||
|
...
|
||||||
elif task == "tts-c":
|
elif task == "tts-c":
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
|
@ -523,25 +551,35 @@ def example_usage():
|
||||||
texts.append( text.to(device) )
|
texts.append( text.to(device) )
|
||||||
proms.append( prom.to(device) )
|
proms.append( prom.to(device) )
|
||||||
resps.append( resp.to(device) )
|
resps.append( resp.to(device) )
|
||||||
|
tasks.append( task )
|
||||||
|
|
||||||
return texts, proms, resps
|
return texts, proms, resps, tasks
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=1000, task=None ):
|
def sample( name, steps=500, task=None ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
|
|
||||||
texts, proms, resps = sample_data( task )
|
texts, proms, resps, tasks = sample_data( task )
|
||||||
|
|
||||||
if "ar" in cfg.model.capabilities:
|
if tasks[0] == "stt":
|
||||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
"""
|
||||||
|
# to-do: STT for NAR
|
||||||
|
text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
"""
|
||||||
|
text = [ cfg.tokenizer.decode( t ) for t in text ]
|
||||||
|
|
||||||
|
print( text )
|
||||||
else:
|
else:
|
||||||
resps = [ resp[:, 0] for resp in resps ]
|
if "ar" in cfg.model.capabilities:
|
||||||
|
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
else:
|
||||||
|
resps = [ resp[:, 0] for resp in resps ]
|
||||||
|
|
||||||
if "nar" in cfg.model.capabilities:
|
if "nar" in cfg.model.capabilities:
|
||||||
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
||||||
|
for i, o in enumerate(resps):
|
||||||
for i, o in enumerate(resps):
|
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
|
||||||
|
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
|
@ -549,10 +587,10 @@ def example_usage():
|
||||||
engine.train()
|
engine.train()
|
||||||
t = trange(steps)
|
t = trange(steps)
|
||||||
for i in t:
|
for i in t:
|
||||||
texts, proms, resps = sample_data()
|
texts, proms, resps, tasks = sample_data()
|
||||||
|
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps)
|
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks)
|
||||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
@ -571,7 +609,7 @@ def example_usage():
|
||||||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for task in tasks:
|
for task in available_tasks:
|
||||||
sample("final", task=task)
|
sample("final", task=task)
|
||||||
|
|
||||||
engines.quit()
|
engines.quit()
|
||||||
|
|
|
@ -259,7 +259,8 @@ class AudioEmbedding(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# per-level classification
|
# per-level classification
|
||||||
class AudioClassifier(nn.Module):
|
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
|
||||||
|
class Classifiers(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||||
|
@ -783,10 +784,10 @@ class Base(nn.Module):
|
||||||
self.metrics = None
|
self.metrics = None
|
||||||
else:
|
else:
|
||||||
self.classifier = None
|
self.classifier = None
|
||||||
self.classifiers = AudioClassifier( l_tokens, d_model )
|
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model )
|
||||||
self.accuracy_metric = None
|
self.accuracy_metric = None
|
||||||
self.precision_metric = None
|
self.precision_metric = None
|
||||||
self.metrics = Metrics( l_tokens )
|
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if tie_classifier_to_embedding:
|
if tie_classifier_to_embedding:
|
||||||
|
@ -907,6 +908,8 @@ class Base(nn.Module):
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
|
||||||
|
special_tasks = ["stt", "len"]
|
||||||
|
|
||||||
inputs = [ [] for _ in range(batch_size) ]
|
inputs = [ [] for _ in range(batch_size) ]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
quant_level = quant_levels[i] if quant_levels is not None else 0
|
quant_level = quant_levels[i] if quant_levels is not None else 0
|
||||||
|
@ -921,7 +924,7 @@ class Base(nn.Module):
|
||||||
# Base-line TTS task
|
# Base-line TTS task
|
||||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||||
if f'<{task_type}>' in get_task_symmap():
|
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
@ -973,6 +976,21 @@ class Base(nn.Module):
|
||||||
elif resps_list is not None and resps_list[i] is not None:
|
elif resps_list is not None and resps_list[i] is not None:
|
||||||
# yes this could be encoded better
|
# yes this could be encoded better
|
||||||
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||||
|
# Speech-to-Text prediction task
|
||||||
|
# Sequence: <resp><sep><rvq lvl><sep><text>
|
||||||
|
elif task_type == "stt":
|
||||||
|
# insert the input response
|
||||||
|
if resps_list is not None and resps_list[i] is not None:
|
||||||
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
|
# insert lang token if we're trained for it
|
||||||
|
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||||
|
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||||
|
# insert RVQ level guidance token if the model is versioned for it
|
||||||
|
if self.rvq_l_emb is not None and not self.interleave:
|
||||||
|
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||||
|
# insert the output text prompt
|
||||||
|
if text_list is not None and text_list[i] is not None:
|
||||||
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Unrecognized task: {task_type}')
|
raise Exception(f'Unrecognized task: {task_type}')
|
||||||
|
|
||||||
|
@ -1010,6 +1028,8 @@ class Base(nn.Module):
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [1, self.resp_levels]
|
token_dropout_rvq_levels = [1, self.resp_levels]
|
||||||
|
|
||||||
|
summed_embeddings_task = [ "stt" ]
|
||||||
|
|
||||||
x_list = []
|
x_list = []
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = []
|
batch = []
|
||||||
|
@ -1071,7 +1091,13 @@ class Base(nn.Module):
|
||||||
offset = 0,
|
offset = 0,
|
||||||
quant_level = 0,
|
quant_level = 0,
|
||||||
)
|
)
|
||||||
|
# cheat-y way to handle performing STT across all levels
|
||||||
|
elif task_type in summed_embeddings_task:
|
||||||
|
embedding = sum([ self.resps_emb(
|
||||||
|
input[:, :l+1],
|
||||||
|
offset = 0 if l == 0 else 1, # or maybe set to 1
|
||||||
|
quant_level = l
|
||||||
|
) for l in range( input.shape[-1] - 1 ) ])
|
||||||
else:
|
else:
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
if self.version <= 4:
|
if self.version <= 4:
|
||||||
|
@ -1171,7 +1197,9 @@ class Base(nn.Module):
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
):
|
):
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
special_tasks = [ "len", "stt" ]
|
||||||
|
summed_embeddings_task = [ "stt" ]
|
||||||
|
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
@ -1192,8 +1220,10 @@ class Base(nn.Module):
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
quant_level = quant_levels[batch_index]
|
quant_level = quant_levels[batch_index]
|
||||||
target = []
|
target = []
|
||||||
|
task_type = "tts"
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "task":
|
if name == "task":
|
||||||
|
task_type = input
|
||||||
task_list.append( input )
|
task_list.append( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
@ -1201,6 +1231,8 @@ class Base(nn.Module):
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||||
|
elif task_type in summed_embeddings_task:
|
||||||
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
else:
|
else:
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||||
|
@ -1273,7 +1305,12 @@ class Base(nn.Module):
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
# do not use resp
|
# do not use resp
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
input = input if input.dim() == 1 else input[:, quant_level]
|
if self.interleave:
|
||||||
|
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||||
|
elif task_type in summed_embeddings_task:
|
||||||
|
input = torch.full_like(input[..., 0], self.ignore_index)
|
||||||
|
else:
|
||||||
|
input = input if input.dim() == 1 else input[:, quant_level]
|
||||||
# select prom level
|
# select prom level
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
@ -1383,7 +1420,8 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.classifiers is not None:
|
if self.classifiers is not None:
|
||||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
special_tasks = [ "len", "stt" ]
|
||||||
|
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
x = self.classifiers(x, levels = classifier_quant_levels) * m
|
x = self.classifiers(x, levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
|
@ -1402,7 +1440,7 @@ class Base(nn.Module):
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: list[Tensor], # logit scores
|
logits: list[Tensor], # logit scores
|
||||||
resps_list: list[Tensor], # previous tokens
|
prev_list: list[Tensor], # previous tokens
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
# base sampling parameters
|
# base sampling parameters
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
|
@ -1429,7 +1467,7 @@ class Base(nn.Module):
|
||||||
# (NAR) return the entire generated response
|
# (NAR) return the entire generated response
|
||||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
logits = [ logit[-l:] for logit, l in zip(logits, map(len, prev_list)) ]
|
||||||
# (AR chunkwise) return the last chunkwise piece
|
# (AR chunkwise) return the last chunkwise piece
|
||||||
elif self.causal:
|
elif self.causal:
|
||||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||||
|
@ -1439,22 +1477,22 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# (NAR) disable stop token
|
# (NAR) disable stop token
|
||||||
if quant_levels is not None and "ar" in self.capabilities:
|
if quant_levels is not None and "ar" in self.capabilities:
|
||||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
# (AR-len) disable extraneous tokens
|
# (AR-len) disable extraneous tokens
|
||||||
if quant_levels is None and "len" in self.capabilities:
|
if quant_levels is None and "len" in self.capabilities:
|
||||||
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
|
|
||||||
# argmax instead
|
# argmax instead
|
||||||
if temperature <= 0.0:
|
if temperature <= 0.0:
|
||||||
return [ logit.argmax(dim=1) for logit in logits ]
|
return [ logit.argmax(dim=1) for logit in logits ]
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
if "len" not in self.capabilities:
|
if "len" not in self.capabilities and repetition_penalty != 1.0:
|
||||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, prev_list ) ]
|
||||||
|
|
||||||
# (AR) perform length penalizing
|
# (AR) perform length penalizing
|
||||||
if quant_levels is None and self.causal:
|
if quant_levels is None and self.causal:
|
||||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
|
|
||||||
# perform top_k/top_p filtering of our logits
|
# perform top_k/top_p filtering of our logits
|
||||||
if top_k > 0 or top_p < 1.0:
|
if top_k > 0 or top_p < 1.0:
|
||||||
|
@ -1469,7 +1507,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# do DRY sampling
|
# do DRY sampling
|
||||||
if dry_multiplier > 0.0:
|
if dry_multiplier > 0.0:
|
||||||
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
|
||||||
|
|
||||||
# do mirostat sampling
|
# do mirostat sampling
|
||||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||||
|
|
|
@ -183,7 +183,7 @@ class NAR(Base):
|
||||||
|
|
||||||
resps_list = super().sample(
|
resps_list = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
resps_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
temperature=1.0 if n == 0 else sampling_temperature,
|
temperature=1.0 if n == 0 else sampling_temperature,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user