This commit is contained in:
mrq 2025-02-12 00:18:24 -06:00
parent 1c0ed6abac
commit 04fef5dad5
8 changed files with 797 additions and 425 deletions

View File

@ -90,9 +90,11 @@ However, because this codec relies on FSQ (Finite Scalar Quantization) rather th
Proposed architectures may include:
* independent NAR-demasking for *all* levels, rather than FSQ level 0.
* little additional code is required, as existing NAR-demasking training/inference code can be repurposed for additional levels.
* this also has the best backwards compat with vall_e.cpp, as no extra model code is required.
* parallel decoding for *all* levels in one pass, rather than separate passes for each level.
* some extra code would be required for orchestrating the additional decoding heads in parallel.
* the decoding heads may simply be a single `nn.Linear` classifier, or additional transformer blocks.
* the former yields bad results when overfitting, the latter without an output projection head allows for overfitting.
## `transcribe.py`

View File

@ -18,55 +18,32 @@ from vall_e.config import cfg
from vall_e.emb.g2p import encode as phonemize
from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def load_audio( path, device="cuda" ):
waveform, sample_rate = torchaudio.load(path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform = convert_audio(waveform, sample_rate, cfg.sample_rate, 1)
return waveform.to(device=device), cfg.sample_rate
from vall_e.emb.process import pad, load_audio, process_items, process_jobs
def process(
audio_backend="encodec",
input_audio="Emilia",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
device="cuda",
dtype="float16",
amp=False,
):
# encodec / vocos
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
cfg.sample_rate = 24_000
cfg.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
sample_rate = 48_000
audio_extension = ".dec"
cfg.model.resp_levels = 8 # ?
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
audio_backend="encodec",
input_audio="Emilia",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
batch_size=1,
low_memory=False,
device="cuda",
dtype="float16",
amp=False,
):
# prepare from args
cfg.audio_backend = audio_backend # "encodec"
cfg.device = device
cfg.set_audio_backend(audio_backend)
audio_extension = cfg.audio_backend_extension
cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False
dtype = cfg.inference.dtype if not amp else None
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
@ -145,58 +122,11 @@ def process(
if waveform is None:
waveform, sample_rate = load_audio(inpath)
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
jobs.append(( outpath, waveform, sample_rate, text, language.lower() ))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = phonemize(text, language=f'{language}'.lower())
qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions:
raise e
continue
# processes audio files one at a time
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
jobs = []
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
@ -214,6 +144,8 @@ def main():
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--batch-size", type=int, default=0)
args = parser.parse_args()
@ -232,6 +164,8 @@ def main():
stride=args.stride,
stride_offset=args.stride_offset,
slice=args.slice,
batch_size=args.batch_size,
low_memory=args.low_memory,
device=args.device,
dtype=args.dtype,

View File

@ -15,52 +15,36 @@ from pathlib import Path
from vall_e.config import cfg
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
from vall_e.emb.g2p import encode as phonemize
from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
from vall_e.emb.process import pad, load_audio, process_items, process_jobs
def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process(
audio_backend="encodec",
input_audio="LibriTTS_R",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
device="cuda",
dtype="float16",
amp=False,
):
# encodec / vocos
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
cfg.sample_rate = 24_000
cfg.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
sample_rate = 48_000
audio_extension = ".dec"
cfg.model.resp_levels = 8 # ?
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
audio_backend="encodec",
input_audio="LibriTTS_R",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
batch_size=1,
low_memory=False,
device="cuda",
dtype="float16",
amp=False,
):
# prepare from args
cfg.audio_backend = audio_backend # "encodec"
cfg.device = device
cfg.set_audio_backend(audio_backend)
audio_extension = cfg.audio_backend_extension
cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False
# import after because we've overriden the config above
# need to validate if this is even necessary anymore
from vall_e.emb.g2p import encode as phonemize
from vall_e.emb.qnt import encode as quantize, _replace_file_extension
dtype = cfg.inference.dtype if not amp else None
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
@ -140,62 +124,19 @@ def process(
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform, sample_rate = load_audio(inpath)
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
jobs.append(( outpath, waveform, sample_rate, text, language ))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = phonemize(text, language=language)
qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions:
raise e
continue
# processes audio files one at a time
if low_memory:
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
jobs = []
# processes all audio files for a given speaker
if not low_memory:
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
jobs = []
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
@ -213,6 +154,8 @@ def main():
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--batch-size", type=int, default=0)
args = parser.parse_args()
@ -231,6 +174,8 @@ def main():
stride=args.stride,
stride_offset=args.stride_offset,
slice=args.slice,
batch_size=args.batch_size,
low_memory=args.low_memory,
device=args.device,
dtype=args.dtype,

View File

@ -259,6 +259,9 @@ class ModelExperimentalSettings:
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
# RetNet's chunked inferencing might be a better place for this
parallel_decoding: bool = False # enables some settings to decode ALL RVQ levels in one pass
# this is a bit of a pain to get working in the test trainer
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on

View File

@ -75,6 +75,11 @@ class AR_NAR(Base):
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# RVQ levels to apply masking training on
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
if self.version >= 7:
masking_train_rvq_levels = [0,self.n_resp_levels]
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
@ -127,7 +132,10 @@ class AR_NAR(Base):
timesteps[i] = (timesteps[i] * 0.6) + 0.2
# trim resps to only contain all levels below the target level
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
if self.version < 7:
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
elif not self.parallel_decoding:
resps_list = [r if t in text_task else r[..., l] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
@ -229,6 +237,7 @@ class AR_NAR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
quant_levels: list[int] | None = None,
disable_tqdm=False,
use_lora=None,
@ -237,7 +246,11 @@ class AR_NAR(Base):
device = text_list[0].device
batch_size = len(text_list)
level = 0
if quant_levels is None:
level = 0
else:
level = quant_levels[0] # ugh
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
@ -306,11 +319,12 @@ class AR_NAR(Base):
# fill scores
scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
quant_levels = [ level for _ in range(batch_size) ]
if quant_levels is None:
quant_levels = [ level for _ in range(batch_size) ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc=f"NAR Masked Level {level}", disable=disable_tqdm)
for timestep in iterator:
# update previous list of tokens
prev_list = resps_list
@ -430,6 +444,183 @@ class AR_NAR(Base):
return resps_list
# handles doing demasking inferencing in parallel to inference all tokens
# it works if the underlying model is trained properly (which is a pain)
def forward_nar_masked_parallel(
self,
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
device = text_list[0].device
batch_size = len(text_list)
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
# convert (N)AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
min_length = sampling_kwargs.pop("min_duration", 1)
max_length = sampling_kwargs.pop("max_duration", 500)
max_steps = sampling_kwargs.get("max_steps", 25)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
annealed_sampling = sampling_kwargs.get("annealed_sampling", True)
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs.pop("temperature", 0.0)
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
# this really helps keep audio coherent so far
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise))
# to specify the initial mask used
vc_list = sampling_kwargs.pop("vc_list", None)
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
# force set CFG because too low / no CFG causes issues
original_cfg_strength = cfg_strength
cfg_strength = max( cfg_strength, minimum_cfg_strength )
prefix_context = sampling_kwargs.get("prefix_context", None)
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
# fill scores
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
quant_levels = [ level for _ in range(batch_size) ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
for timestep in iterator:
# update previous list of tokens
prev_list = resps_list
# ramp down over time
annealing = 1.0 - timestep
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# proportion of tokens to remask
remask_p = 1.0 / (max_steps * 2) if remasking else 0
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
# mask off inputs
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.stop_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ]
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
sampling_temperature = temperature * annealing if annealed_sampling else temperature
sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength
input_resps_list = resps_list
# setup inputs
inputs = super().inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
l_scores = []
l_resps_list = []
# cringe hack because we're able to sample multiple levels at once
for l in range(self.n_resp_levels):
# sample with sampler settings
filtered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
temperature=sampling_temperature,
**sampling_kwargs,
)
# retrieves unfiltered logits
unfiltered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
temperature=0.0,
**sampling_kwargs,
)
# get sampled tokens
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ])
# get probability scores
l_scores.append([
# conjugate to have worse scoring tokens picked for topk
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) )
# use unmodified logit scores for this, as it offers better stability
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
])
resps_list = []
scores = []
for batch_index in range(batch_size):
score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels
resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1)
scores.append( score )
resps_list.append( resp )
return resps_list
def forward_nar(
self,
task_list: list[Tensor] | None = None,
@ -911,6 +1102,56 @@ class AR_NAR(Base):
# is NAR
if (len_list is not None or resps_list is not None) and text_list is not None:
if self.version >= 7:
if self.parallel_decoding:
return self.forward_nar_masked_parallel(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
else:
resps_lists = [ None for _ in range(batch_size) ]
for level in range(self.n_resp_levels):
resp_list = self.forward_nar_masked(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
quant_levels=[ level for _ in range(batch_size) ],
**sampling_kwargs,
)
for batch_index, resp in enumerate(resp_list):
if resps_lists[batch_index] is None:
resps_lists[batch_index] = []
resps_lists[batch_index].append( resp )
for batch_index, resps in enumerate(resps_lists):
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
return resps_lists
return self.forward_nar(
task_list=task_list,
@ -988,8 +1229,8 @@ def example_usage():
resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size
kwargs = {
'n_text_tokens': 256,
'n_audio_tokens': 1024,
'n_text_tokens': cfg.model.text_tokens,
'n_audio_tokens': cfg.model.audio_tokens,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
@ -1004,7 +1245,9 @@ def example_usage():
}
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
#available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
available_tasks = ["tts-nar"]
model = AR_NAR(**kwargs).to(cfg.device)
steps = 500 // batch_size
@ -1156,13 +1399,14 @@ def example_usage():
if task == "tts-nar":
len_list = engine( text_list=text_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
len_list = [ resp_list[0].shape[0] for l in len_list ]
len_list = [ r.shape[0] for r in resp_list ]
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
else:
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
for i, o in enumerate(resps_list):
print( o.shape, o )
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device)
unload_model()
@ -1185,7 +1429,9 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" )
"""
#sample("init", 5)
task = available_tasks[0]
#sample("init", task=task)
train()
"""

View File

@ -435,7 +435,7 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
class LlamaModel_Adapted(LlamaModel):
def __init__(self, config, *args, **kwargs):
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0)
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
self.early_exit_r = kwargs.pop("early_exit_r", 2)
@ -459,7 +459,7 @@ class LlamaModel_Adapted(LlamaModel):
self.post_init()
def dropoff_layer( self, l ):
if not self.training:
if not self.training or self.layer_dropout_p <= 0:
return False
# this could probably a LUT but I'm not fiending for aggressive mal-optimizations

View File

@ -37,7 +37,7 @@ from ..samplers import *
from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
@ -245,6 +245,21 @@ class AudioEmbedding(nn.Module):
return x
class AudioEmbedding_Sums(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
):
super().__init__()
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
def forward(self, xi: Tensor ) -> Tensor:
x = sum( [ emb( xi[:, l] ) for l, emb in enumerate(self.embeddings) ] )
return x
# time-step embedding
# for the NAR-len, since it probably most likely requires encoding the timestep
class TimeEmbedding(nn.Module):
@ -272,12 +287,12 @@ class Classifiers(nn.Module):
def __init__(
self,
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
l_embedding_names: list[str] | None = None, # list of names to map to each classifier,
l_embedding_names: list[str], # list of names to map to each classifier,
d_model: int, # dimensionality of the embedding
bias: bool = True,
):
super().__init__()
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for n_tokens in l_embedding_tokens])
self.proj = nn.ModuleList([nn.Linear(d_model, n_tokens, bias=bias) for n_tokens in l_embedding_tokens])
self.names = l_embedding_names
def indices(
@ -288,19 +303,29 @@ class Classifiers(nn.Module):
return names
return [ self.names.index(name) for name in names ]
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None, stack = False ) -> Tensor:
dtype = xi.dtype
device = xi.device
def forward(
self,
xi: Tensor,
levels: list[int] | None = None,
names: list[str] | None = None,
stack = False,
) -> Tensor:
dtype = xi[0].dtype
device = xi[0].device
if levels and isinstance( levels[-1], str ):
names = levels
levels = []
# map names to levels
"""
if names and not levels:
levels = [ self.names.index(name) for name in names ]
levels = [ None if name =="NAR" else self.names.index(name) for name in names ]
"""
if names and not levels:
levels = [ None if name not in self.names else self.names.index(name) for name in names ]
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
xi = [ x if l == None else self.proj[l]( x ) for x, l in zip(xi, levels) ]
if not stack:
return xi
@ -316,6 +341,109 @@ class Classifiers(nn.Module):
]
return torch.stack( xi )
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
class ParallelDecoder(nn.Module):
def __init__(
self,
levels,
d_model,
config_kwargs,
):
super().__init__()
training = config_kwargs.pop("training", False)
attention_backend = config_kwargs.pop("attention_backend", "default")
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
hidden_size = config_kwargs.get("hidden_size")
vocab_size = config_kwargs.get("vocab_size")
#self.d_model = d_model
self.vocab_size = vocab_size
downs = []
modules = []
ups = []
for level in range(levels):
module = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
module = ml.replace_attention( module, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
if hasattr( module, "embeddings" ):
del module.embeddings
if gradient_checkpointing and not module.gradient_checkpointing:
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
modules.append(module)
"""
downs.append(nn.Linear(d_model, hidden_size, bias=False))
ups.append(nn.Linear(hidden_size, vocab_size, bias=False))
"""
self.levels = levels
self.decoders = nn.ModuleList(modules)
"""
self.downs = nn.ModuleList(downs)
self.ups = nn.ModuleList(ups)
"""
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
# split into levels
if level == None:
x = [ self.forward( x, l, **kwargs ) for l in range(self.levels) ]
x = torch.stack( x )
x = x.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
return x
# do one level
# attention + feedforward
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
# this really hates an output head, so just treat the final output as one
x = x[..., :self.vocab_size]
"""
# downscale to head's dimensionality
x = self.downs[level]( x )
# attention + feed forward
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
# upscale to vocab logits
x = self.ups[level]( x )
"""
return x
"""
"""
# naively tries to extract multiple codebooks in parallel based on the last hidden state from the model
# this doesn't work well
"""
class ClassifiersParallel(nn.Module):
def __init__(
self,
n_levels: int, # codebook levels
n_tokens: int, # output token count
token_dim: int, # dimensionality of the embedding
bias: bool = False,
):
super().__init__()
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for _ in range(n_levels)])
def forward(self, xi: Tensor, stack: bool = True ) -> Tensor:
dtype = xi.dtype
device = xi.device
xi = [ proj( xi ) for l, proj in enumerate(self.proj) ]
xi = torch.stack( xi )
xi = xi.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
return xi
"""
class Metrics(nn.Module):
def __init__(
self,
@ -448,6 +576,7 @@ class Base(nn.Module):
self.causal = "ar" in self.capabilities or "len" in self.capabilities
self.version = self.config.version if self.config is not None else 5
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
self.parallel_decoding = self.config.experimental.parallel_decoding if self.config is not None else False
self.arch_type = self.config.arch_type if self.config is not None else "llama"
@ -469,7 +598,7 @@ class Base(nn.Module):
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False
#interleave = self.config.experimental.interleave if self.config is not None else False
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
classifiers_bias = self.config.experimental.classifiers_bias if self.config is not None else False
max_position_embeddings = self.config.experimental.max_position_embeddings if self.config is not None else (75 * 60 * 5)
@ -477,40 +606,56 @@ class Base(nn.Module):
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
layerskip = self.config.experimental.layerskip if self.config is not None else False
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
n_tones = self.config.tones if self.config is not None else 1
# pure AR
if "nar" not in self.capabilities:
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
if self.version < 7:
if "nar" not in self.capabilities:
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
if self.parallel_decoding:
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
else:
"""
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens * self.n_resp_levels]
l_embedding_names = ["NAR"]
l_classifier_tokens = [n_audio_tokens * self.n_resp_levels]
"""
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_classifier_tokens = [n_audio_tokens] * self.n_resp_levels
l_embedding_names = [ f'NAR:{i}:{i}' for i in range( self.n_resp_levels ) ]
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
l_classifier_names = l_embedding_names
@ -528,23 +673,13 @@ class Base(nn.Module):
l_classifier_tokens += [ n_raw_text_tokens ]
l_classifier_names = l_embedding_names + [ "raw_text" ]
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
self.n_vocab = n_vocab
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
self.inject_timestep_embedding = False # results in bad output
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
"""
if noncausal_masks:
attention_backend = "default"
"""
self.text_emb = Embedding(n_text_tokens, d_model)
self.raw_text_emb = None
self.langs_emb = None
@ -572,7 +707,7 @@ class Base(nn.Module):
l_embedding_tokens, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
else:
elif not self.parallel_decoding:
self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
@ -582,6 +717,17 @@ class Base(nn.Module):
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
l_embedding_names=l_embedding_names,
)
else:
self.proms_emb = AudioEmbedding_Sums(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.resps_emb = AudioEmbedding_Sums(
n_tokens=n_audio_tokens + 1,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
if self.version >= 3:
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
@ -597,12 +743,11 @@ class Base(nn.Module):
# this ***might*** let me also unify the proms_emb and resps_embedding
if self.version >= 5:
# "len" RVQ level-0 gets an additional token
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
if self.version < 7 or not self.parallel_decoding:
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model)
self.time_emb = None # TimeEmbedding(d_model) # if not masking_ratio else None
if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
@ -640,10 +785,8 @@ class Base(nn.Module):
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type in ["llama", "mistral", "mixtral"]:
LlamaClass = LlamaModel_Adapted # if (self.layerskip or "len" in self.capabilities) else LlamaModel
if n_experts <= 1:
self.model = LlamaClass(LlamaConfig(
self.model = LlamaModel_Adapted(LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
@ -692,11 +835,6 @@ class Base(nn.Module):
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
"""
if self.layerskip:
self.model.layer_dropout_p = layerskip_p_max
self.model.early_exit_scale = layerskip_e_scale
self.model.early_exit_r = layerskip_r
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
@ -766,19 +904,38 @@ class Base(nn.Module):
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
self.classifiers = None
self.metrics = None
else:
self.classifier = None
self.classifiers = Classifiers( l_classifier_tokens, d_model, l_embedding_names=l_classifier_names, bias=classifiers_bias )
self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias )
self.metrics = Metrics( l_classifier_tokens )
"""
if tie_classifier_to_embedding:
for i, proj in enumerate( self.classifiers.proj ):
self.classifiers.proj[i].weight = self.resps_emb.embeddings[i].weight
"""
self.parallel_decoder = None
if self.parallel_decoding:
pd_model = d_model # // 2
pd_ffn = pd_model * 2
pd_heads = n_heads // 2
pd_layers = 1
config = dict(
vocab_size=n_audio_tokens,
hidden_size=pd_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=pd_ffn,
num_hidden_layers=pd_layers,
num_attention_heads=pd_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=pd_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation="eager",
training=self.training,
attention_backend=attention_backend,
gradient_checkpointing=self.gradient_checkpointing,
)
self.parallel_decoder = ParallelDecoder( self.n_resp_levels, d_model, config )
def _forward(
self,
@ -789,8 +946,6 @@ class Base(nn.Module):
state = None,
layer_skip_lambda = None,
output_attentions = False,
output_hidden_states = False,
):
@ -818,9 +973,6 @@ class Base(nn.Module):
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
if self.layerskip and layer_skip_lambda is not None:
kwargs["layer_skip_lambda"] = layer_skip_lambda
output = self.model(**kwargs)
x = output["last_hidden_state"]
@ -885,7 +1037,7 @@ class Base(nn.Module):
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
return Logits(x, state, inputs, aux_loss, attentions, hidden_states, None)
return Logits(x, state, inputs, aux_loss, attentions, hidden_states)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
@ -943,7 +1095,7 @@ class Base(nn.Module):
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None and not self.interleave:
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
classifier_level = "AR:0:0" if quant_level == 0 else f'NAR:{quant_level-1}:{quant_level}'
@ -976,7 +1128,10 @@ class Base(nn.Module):
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
if self.version >= 7:
classifier_level = f"NAR:{quant_level}:{quant_level}" if not self.parallel_decoding else "NAR"
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
@ -1022,7 +1177,7 @@ class Base(nn.Module):
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None and not self.interleave:
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the output text prompt
if text_list is not None and text_list[i] is not None:
@ -1038,7 +1193,7 @@ class Base(nn.Module):
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None and not self.interleave:
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if text_list is not None and text_list[i] is not None:
@ -1054,7 +1209,7 @@ class Base(nn.Module):
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None and not self.interleave:
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if raw_text_list is not None and raw_text_list[i] is not None:
@ -1117,12 +1272,23 @@ class Base(nn.Module):
return self.proms_emb(
input if quant_level == 0 else input[:, :quant_level]
)
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
if self.version < 7 or not self.parallel_decoding:
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
"""
if not self.parallel_decoding:
return self.proms_emb(
input if input.dim() == 1 else input[:, :quant_level+1],
quant_level = quant_level,
offset = 0,
)
"""
return self.proms_emb( input )
# yuck
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
@ -1188,28 +1354,23 @@ class Base(nn.Module):
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
if self.interleave:
embeddings = [ self.resps_emb(
input[:, :l+1],
#offset = 0,
#quant_level = l,
name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
) for l in range( input.shape[-1] ) ]
embedding = _interleave_sequence_reshape( embeddings )
if self.parallel_decoding:
if dropout_mask is not None:
embedding = self.resps_emb( torch.where( dropout_mask, self.stop_token, input.t() ).t() )
else:
embedding = self.resps_emb( input )
# if training NAR-len RVQ level 0
elif dropout_mask is not None:
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ),
#quant_level = 0,
name = classifier_level,
)
# NAR-len
elif classifier_level == "NAR:0:0":
elif classifier_level == f"NAR:{quant_level}:{quant_level}":
embedding = self.resps_emb(
input if input.dim() == 1 else input[:, 0],
input if input.dim() == 1 else input[:, quant_level],
#quant_level = 0,
name = classifier_level,
)
@ -1323,10 +1484,6 @@ class Base(nn.Module):
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
# interleaved model
if self.interleave and name == "resp":
return input.shape[0] * input.shape[1]
# ending input will not have a separator later
return input.shape[0]
@ -1352,6 +1509,166 @@ class Base(nn.Module):
return ids.to(device=device, dtype=torch.int32)
def calc_loss_parallel(
self,
inputs: list,
logits,
compute_hard_loss = True,
compute_acc = True,
):
loss = {}
stats = {}
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input ):
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
return input
for batch_index, batch in enumerate(inputs):
target = []
causal = False
task_type = "tts"
dropout_mask = None
classifier_level = None
output_len = 0
for name, input in batch:
if name == "task":
task_type = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
it = 0
for name, input in batch:
token = None
ignored = False
# non-tokened tasks
if name in non_tokened_names:
continue
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input ) for input in proms if input is not None ] )
elif name == "resp":
# mask found, apply it
if dropout_mask is not None:
token = torch.where( dropout_mask, input.t(), self.ignore_index ).t()
else:
token = input
# not a special input, inject as-is
else:
token = input
if not isinstance(token, torch.Tensor):
continue
if token.is_floating_point():
ignored = True
# grab range of our logits for later
seq_len = token.shape[0]
start, end = it, it+seq_len
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
else:
output_len = seq_len
if ignored:
# pruned
if self.config.loss_factors:
continue
# fill with ignored out tensor
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
# perform loss calculation on the individual piece
target.append( token )
if classifier_level != "NAR":
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
logit = logits[batch_index]
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
seq = seq[..., l:] # ...predicts token n + 1
if compute_hard_loss:
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_acc and False:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, seq )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
else:
for level, logit in enumerate( logits[batch_index] ):
seq = _join( [ t if t.dim() <= 1 else t[:, level] for t in target ], torch.tensor(self.ignore_index, device=target[-1].device) )
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
seq = seq[..., l:] # ...predicts token n + 1
if compute_hard_loss:
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_acc and False:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, seq )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)
def calc_loss(
self,
inputs: list,
@ -1377,7 +1694,13 @@ class Base(nn.Module):
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
return torch.full_like(input[..., 0], self.ignore_index)
return input if input.dim() == 1 else input[:, quant_level]
if self.version < 7:
return input if input.dim() == 1 else input[:, quant_level]
if not self.parallel_decoding:
return input if input.dim() == 1 else input[:, quant_level]
return input
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
@ -1402,6 +1725,8 @@ class Base(nn.Module):
# nonautoregressive, parallel
elif classifier_level.startswith("NAR:"):
causal = False
elif classifier_level == "NAR":
causal = False
it = 0
for name, input in batch:
@ -1422,9 +1747,6 @@ class Base(nn.Module):
if dropout_mask is not None:
# if mask use original token, else ignore
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
# flatten
elif self.interleave:
token = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
# use resps as-is
else:
token = input if input.dim() == 1 else input[:, quant_level]
@ -1435,28 +1757,6 @@ class Base(nn.Module):
if not isinstance(token, torch.Tensor):
continue
# offset to flattened vocab ranges
"""
if self.classifier is not None:
offsets = _get_offsets()
k = name
if name == "stt":
k = "text"
if name == "prom":
k = f'prom|{quant_level}'
elif name == "resp":
k = f'resps|{classifier_level}'
if k in offsets:
start, end = offsets[k]
for i, t in enumerate( token ):
if t == self.ignore_index:
continue
token[i] += start
"""
if token.is_floating_point():
ignored = True
@ -1566,51 +1866,9 @@ class Base(nn.Module):
quant_levels: list[int] | None = None,
state: dict | list | None = None,
layer_skip_variables: dict | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
# return early if it's "good" enough"
# lambda because we need to capture the classifier_levels and mask
exited_layer = self.n_layers
def layer_skip_lambda( layer, logits ):
nonlocal exited_layer
kwargs = {
"entropy_threshold": 0.05,
"varentropy_threshold": 0.05,
"min_layer": self.n_layers // 2,
"max_layer": self.n_layers,
}
kwargs.update( layer_skip_variables )
# don't bother on early layers
if layer < kwargs["min_layer"]:
return False
# bail if we want to force early layers
if kwargs["max_layer"] < layer:
return True
# hidden states aren't normalized
x = self.model.norm( logits )
# output projection layer with masking
if self.classifier is not None:
x = self.classifier(x) # * m
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels) # * m
# calculate metrics
metrics = calculate_entropix_metrics( logits )
# exit early if "good enough""
early = metrics["logits_entropy"] <= kwargs["entropy_threshold"] and metrics["logits_varentropy"] <= kwargs["varentropy_threshold"]
if early:
exited_layer = layer
return early
):
# derive quant levels from inputs if not provided
if quant_levels is None:
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
@ -1628,10 +1886,6 @@ class Base(nn.Module):
device = x.device
batch_size = len(x_list)
# we only need hidden states if we're training with layerskip
if self.layerskip and training:
output_hidden_states = True
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
@ -1663,55 +1917,55 @@ class Base(nn.Module):
is_causal=is_causal,
position_ids=position_ids,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
)
logits = output.logits
hidden_states = output.hidden_states
logits = [ logit for logit in logits ]
if self.version >= 7 and self.parallel_decoding:
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
if p_indices:
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
p_logits = self.parallel_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
for i, logit in enumerate(p_logits):
logits[p_indices[i]] = logit
"""
logits = [ self.parallel_decoder( logit.unsqueeze(0), attention_mask=mask,
position_ids=position_ids,
use_cache=False,
return_dict=True,
is_causal=is_causal )[0] if level == "NAR" else logit for logit, level in zip(logits, classifier_levels) ]
"""
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
if output.hidden_states:
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifier(hidden_states[i]) # * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels) # * m
logits = self.classifiers(logits, levels = classifier_levels )
if hidden_states is not None:
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_levels) # * m
# Reshape
"""
if self.version >= 7 and not self.parallel_decoding:
for batch_index, logit in enumerate( logits ):
if classifier_levels[batch_index] != "NAR":
continue
logits[batch_index] = logit.reshape( logit.shape[0], 8, 1000 ).permute( 1, 0, 2 )
"""
# Remove padding
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
if hidden_states is not None:
for i, state in enumerate( hidden_states ):
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
# de-offset if needed
if self.classifier is not None:
offsets = _get_offsets()
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level == "stt":
k = "text"
elif classifier_level == "len":
k = "len"
else:
k = f'resps|{classifier_level}'
if k not in offsets:
continue
start, end = offsets[k]
logits[batch_index] = logits[batch_index][:, start:end]
if not training:
loss = None
stats = None
@ -1719,30 +1973,18 @@ class Base(nn.Module):
self.loss = None
self.stats = None
# compute loss if the target is given
elif self.version >= 7 and self.parallel_decoding:
loss, stats = self.calc_loss_parallel( inputs=inputs, logits=logits )
# include any additional losses (for example: MoE router)
if output.loss is not None:
loss["aux_loss"] = output.loss
self.loss = loss
self.stats = stats
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# compute it as an aux-loss
if self.layerskip:
early_exit_loss = {}
if not hasattr( self, "training_steps" ):
self.training_steps = 0
for i, state in enumerate( hidden_states ):
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
for k, v in loss.items():
K = f'early_exit.{k}'
if K not in early_exit_loss:
early_exit_loss[K] = []
early_exit_loss[K].append( v )
for k, v in early_exit_loss.items():
loss[k] = self.model.early_exit_loss( losses=v, t=self.training_steps )
# to-do: instead make the cirriculum rely on samples processed instead of steps
self.training_steps += 1 # batch_size
# include any additional losses (for example: MoE router)
if output.loss is not None:
loss["aux_loss"] = output.loss
@ -1751,7 +1993,7 @@ class Base(nn.Module):
self.stats = stats
# rewrap, because we're modifying the logits here
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states, exited_layer)
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states)
def sample(
self,

View File

@ -176,17 +176,17 @@ def top_no_logits_processing( logits, n = 1.0 ):
# (and because the null logits have a shorter input sequence compared to the positive logits)
def cfg_logits( logits, null, strength, lens, rescale=0.0 ):
for i, seq_len in enumerate( lens ):
pos = logits[i][-seq_len:]
neg = null[i][-seq_len:]
pos = logits[i][..., -seq_len:, :]
neg = null[i][..., -seq_len:, :]
summed = neg + (pos - neg) * strength
if rescale <= 0:
logits[i][-seq_len:] = summed
logits[i][..., -seq_len:, :] = summed
else:
dims = tuple(range(1, summed.ndim - 1))
factor = rescale * (pos.std(dims, keepdim=True) / summed.std(dims, keepdim=True)) + (1 - rescale)
logits[i][-seq_len:] = summed * factor
logits[i][..., -seq_len:, :] = summed * factor
return logits