agony
This commit is contained in:
parent
1c0ed6abac
commit
04fef5dad5
|
@ -90,9 +90,11 @@ However, because this codec relies on FSQ (Finite Scalar Quantization) rather th
|
|||
Proposed architectures may include:
|
||||
* independent NAR-demasking for *all* levels, rather than FSQ level 0.
|
||||
* little additional code is required, as existing NAR-demasking training/inference code can be repurposed for additional levels.
|
||||
* this also has the best backwards compat with vall_e.cpp, as no extra model code is required.
|
||||
* parallel decoding for *all* levels in one pass, rather than separate passes for each level.
|
||||
* some extra code would be required for orchestrating the additional decoding heads in parallel.
|
||||
* the decoding heads may simply be a single `nn.Linear` classifier, or additional transformer blocks.
|
||||
* the former yields bad results when overfitting, the latter without an output projection head allows for overfitting.
|
||||
|
||||
## `transcribe.py`
|
||||
|
||||
|
|
|
@ -18,55 +18,32 @@ from vall_e.config import cfg
|
|||
from vall_e.emb.g2p import encode as phonemize
|
||||
from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
def process_items( items, stride=0, stride_offset=0 ):
|
||||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
def load_audio( path, device="cuda" ):
|
||||
waveform, sample_rate = torchaudio.load(path)
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
waveform = convert_audio(waveform, sample_rate, cfg.sample_rate, 1)
|
||||
return waveform.to(device=device), cfg.sample_rate
|
||||
from vall_e.emb.process import pad, load_audio, process_items, process_jobs
|
||||
|
||||
def process(
|
||||
audio_backend="encodec",
|
||||
input_audio="Emilia",
|
||||
output_dataset="training",
|
||||
raise_exceptions=False,
|
||||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
):
|
||||
# encodec / vocos
|
||||
|
||||
if audio_backend in ["encodec", "vocos"]:
|
||||
audio_extension = ".enc"
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.model.resp_levels = 8
|
||||
elif audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
cfg.sample_rate = 44_100
|
||||
cfg.model.resp_levels = 9
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
sample_rate = 48_000
|
||||
audio_extension = ".dec"
|
||||
cfg.model.resp_levels = 8 # ?
|
||||
else:
|
||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||
audio_backend="encodec",
|
||||
input_audio="Emilia",
|
||||
output_dataset="training",
|
||||
raise_exceptions=False,
|
||||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
batch_size=1,
|
||||
low_memory=False,
|
||||
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
):
|
||||
# prepare from args
|
||||
cfg.audio_backend = audio_backend # "encodec"
|
||||
cfg.device = device
|
||||
cfg.set_audio_backend(audio_backend)
|
||||
audio_extension = cfg.audio_backend_extension
|
||||
|
||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||
cfg.inference.amp = amp # False
|
||||
|
||||
dtype = cfg.inference.dtype if not amp else None
|
||||
|
||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||
|
||||
|
@ -145,58 +122,11 @@ def process(
|
|||
if waveform is None:
|
||||
waveform, sample_rate = load_audio(inpath)
|
||||
|
||||
wavs.append((
|
||||
outpath,
|
||||
text,
|
||||
language,
|
||||
waveform,
|
||||
sample_rate
|
||||
))
|
||||
jobs.append(( outpath, waveform, sample_rate, text, language.lower() ))
|
||||
|
||||
if len(wavs) > 0:
|
||||
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
|
||||
try:
|
||||
outpath, text, language, waveform, sample_rate = job
|
||||
|
||||
phones = phonemize(text, language=f'{language}'.lower())
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
if raise_exceptions:
|
||||
raise e
|
||||
continue
|
||||
# processes audio files one at a time
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
jobs = []
|
||||
|
||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
|
||||
|
@ -214,6 +144,8 @@ def main():
|
|||
parser.add_argument("--stride", type=int, default=0)
|
||||
parser.add_argument("--stride-offset", type=int, default=0)
|
||||
parser.add_argument("--slice", type=str, default="auto")
|
||||
parser.add_argument("--low-memory", action="store_true")
|
||||
parser.add_argument("--batch-size", type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -232,6 +164,8 @@ def main():
|
|||
stride=args.stride,
|
||||
stride_offset=args.stride_offset,
|
||||
slice=args.slice,
|
||||
batch_size=args.batch_size,
|
||||
low_memory=args.low_memory,
|
||||
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
|
|
|
@ -15,52 +15,36 @@ from pathlib import Path
|
|||
|
||||
from vall_e.config import cfg
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
|
||||
|
||||
from vall_e.emb.process import pad, load_audio, process_items, process_jobs
|
||||
|
||||
def process_items( items, stride=0, stride_offset=0 ):
|
||||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
def process(
|
||||
audio_backend="encodec",
|
||||
input_audio="LibriTTS_R",
|
||||
output_dataset="training",
|
||||
raise_exceptions=False,
|
||||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
):
|
||||
# encodec / vocos
|
||||
|
||||
if audio_backend in ["encodec", "vocos"]:
|
||||
audio_extension = ".enc"
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.model.resp_levels = 8
|
||||
elif audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
cfg.sample_rate = 44_100
|
||||
cfg.model.resp_levels = 9
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
sample_rate = 48_000
|
||||
audio_extension = ".dec"
|
||||
cfg.model.resp_levels = 8 # ?
|
||||
else:
|
||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||
audio_backend="encodec",
|
||||
input_audio="LibriTTS_R",
|
||||
output_dataset="training",
|
||||
raise_exceptions=False,
|
||||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
batch_size=1,
|
||||
low_memory=False,
|
||||
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
):
|
||||
# prepare from args
|
||||
cfg.audio_backend = audio_backend # "encodec"
|
||||
cfg.device = device
|
||||
cfg.set_audio_backend(audio_backend)
|
||||
audio_extension = cfg.audio_backend_extension
|
||||
|
||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||
cfg.inference.amp = amp # False
|
||||
|
||||
# import after because we've overriden the config above
|
||||
# need to validate if this is even necessary anymore
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
from vall_e.emb.qnt import encode as quantize, _replace_file_extension
|
||||
dtype = cfg.inference.dtype if not amp else None
|
||||
|
||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||
|
||||
|
@ -140,62 +124,19 @@ def process(
|
|||
continue
|
||||
|
||||
if waveform is None:
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
waveform, sample_rate = load_audio(inpath)
|
||||
|
||||
wavs.append((
|
||||
outpath,
|
||||
text,
|
||||
language,
|
||||
waveform,
|
||||
sample_rate
|
||||
))
|
||||
jobs.append(( outpath, waveform, sample_rate, text, language ))
|
||||
|
||||
if len(wavs) > 0:
|
||||
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
|
||||
try:
|
||||
outpath, text, language, waveform, sample_rate = job
|
||||
|
||||
phones = phonemize(text, language=language)
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
if raise_exceptions:
|
||||
raise e
|
||||
continue
|
||||
# processes audio files one at a time
|
||||
if low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
jobs = []
|
||||
|
||||
# processes all audio files for a given speaker
|
||||
if not low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
jobs = []
|
||||
|
||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
|
||||
|
@ -213,6 +154,8 @@ def main():
|
|||
parser.add_argument("--stride", type=int, default=0)
|
||||
parser.add_argument("--stride-offset", type=int, default=0)
|
||||
parser.add_argument("--slice", type=str, default="auto")
|
||||
parser.add_argument("--low-memory", action="store_true")
|
||||
parser.add_argument("--batch-size", type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -231,6 +174,8 @@ def main():
|
|||
stride=args.stride,
|
||||
stride_offset=args.stride_offset,
|
||||
slice=args.slice,
|
||||
batch_size=args.batch_size,
|
||||
low_memory=args.low_memory,
|
||||
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
|
|
|
@ -259,6 +259,9 @@ class ModelExperimentalSettings:
|
|||
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
|
||||
# RetNet's chunked inferencing might be a better place for this
|
||||
|
||||
parallel_decoding: bool = False # enables some settings to decode ALL RVQ levels in one pass
|
||||
# this is a bit of a pain to get working in the test trainer
|
||||
|
||||
masking_train_p: float = 0.0 # odds of training with masking
|
||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||
|
||||
|
|
|
@ -75,6 +75,11 @@ class AR_NAR(Base):
|
|||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# RVQ levels to apply masking training on
|
||||
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
|
||||
|
||||
if self.version >= 7:
|
||||
masking_train_rvq_levels = [0,self.n_resp_levels]
|
||||
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
|
||||
|
||||
# CFG
|
||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||
|
@ -127,7 +132,10 @@ class AR_NAR(Base):
|
|||
timesteps[i] = (timesteps[i] * 0.6) + 0.2
|
||||
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
if self.version < 7:
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
elif not self.parallel_decoding:
|
||||
resps_list = [r if t in text_task else r[..., l] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
# tensor to cat for RVQ level 0
|
||||
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
||||
|
@ -229,6 +237,7 @@ class AR_NAR(Base):
|
|||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
quant_levels: list[int] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -237,7 +246,11 @@ class AR_NAR(Base):
|
|||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
level = 0
|
||||
if quant_levels is None:
|
||||
level = 0
|
||||
else:
|
||||
level = quant_levels[0] # ugh
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
|
@ -306,11 +319,12 @@ class AR_NAR(Base):
|
|||
# fill scores
|
||||
scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
if quant_levels is None:
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
|
||||
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc=f"NAR Masked Level {level}", disable=disable_tqdm)
|
||||
for timestep in iterator:
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
|
@ -430,6 +444,183 @@ class AR_NAR(Base):
|
|||
|
||||
return resps_list
|
||||
|
||||
# handles doing demasking inferencing in parallel to inference all tokens
|
||||
# it works if the underlying model is trained properly (which is a pain)
|
||||
def forward_nar_masked_parallel(
|
||||
self,
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
level = 0
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
# convert (N)AR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
|
||||
min_length = sampling_kwargs.pop("min_duration", 1)
|
||||
max_length = sampling_kwargs.pop("max_duration", 500)
|
||||
max_steps = sampling_kwargs.get("max_steps", 25)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
annealed_sampling = sampling_kwargs.get("annealed_sampling", True)
|
||||
|
||||
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||
temperature = sampling_kwargs.pop("temperature", 0.0)
|
||||
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
|
||||
# this really helps keep audio coherent so far
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
remasking = sampling_kwargs.get("remasking", True)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
# to specify the initial mask used
|
||||
vc_list = sampling_kwargs.pop("vc_list", None)
|
||||
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
|
||||
vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
|
||||
|
||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||
|
||||
# force set CFG because too low / no CFG causes issues
|
||||
original_cfg_strength = cfg_strength
|
||||
cfg_strength = max( cfg_strength, minimum_cfg_strength )
|
||||
|
||||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
# fill scores
|
||||
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
|
||||
for timestep in iterator:
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
# ramp down over time
|
||||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# proportion of tokens to remask
|
||||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
# normal masking
|
||||
# mask off inputs
|
||||
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.stop_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
||||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
sampling_temperature = temperature * annealing if annealed_sampling else temperature
|
||||
sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength
|
||||
|
||||
input_resps_list = resps_list
|
||||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=time_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=time_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
|
||||
|
||||
l_scores = []
|
||||
l_resps_list = []
|
||||
# cringe hack because we're able to sample multiple levels at once
|
||||
for l in range(self.n_resp_levels):
|
||||
# sample with sampler settings
|
||||
filtered_sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in prev_list ],
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# retrieves unfiltered logits
|
||||
unfiltered_sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in prev_list ],
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=0.0,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# get sampled tokens
|
||||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ])
|
||||
# get probability scores
|
||||
l_scores.append([
|
||||
# conjugate to have worse scoring tokens picked for topk
|
||||
1.0 -
|
||||
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
|
||||
torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) )
|
||||
# use unmodified logit scores for this, as it offers better stability
|
||||
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
|
||||
])
|
||||
|
||||
resps_list = []
|
||||
scores = []
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels
|
||||
resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1)
|
||||
|
||||
scores.append( score )
|
||||
resps_list.append( resp )
|
||||
|
||||
return resps_list
|
||||
|
||||
def forward_nar(
|
||||
self,
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
@ -911,6 +1102,56 @@ class AR_NAR(Base):
|
|||
|
||||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and text_list is not None:
|
||||
if self.version >= 7:
|
||||
if self.parallel_decoding:
|
||||
return self.forward_nar_masked_parallel(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
else:
|
||||
resps_lists = [ None for _ in range(batch_size) ]
|
||||
for level in range(self.n_resp_levels):
|
||||
resp_list = self.forward_nar_masked(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
quant_levels=[ level for _ in range(batch_size) ],
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
for batch_index, resp in enumerate(resp_list):
|
||||
if resps_lists[batch_index] is None:
|
||||
resps_lists[batch_index] = []
|
||||
|
||||
resps_lists[batch_index].append( resp )
|
||||
|
||||
for batch_index, resps in enumerate(resps_lists):
|
||||
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
|
||||
|
||||
return resps_lists
|
||||
|
||||
return self.forward_nar(
|
||||
task_list=task_list,
|
||||
|
||||
|
@ -988,8 +1229,8 @@ def example_usage():
|
|||
resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size
|
||||
|
||||
kwargs = {
|
||||
'n_text_tokens': 256,
|
||||
'n_audio_tokens': 1024,
|
||||
'n_text_tokens': cfg.model.text_tokens,
|
||||
'n_audio_tokens': cfg.model.audio_tokens,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
|
@ -1004,7 +1245,9 @@ def example_usage():
|
|||
}
|
||||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
|
||||
|
||||
#available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
|
||||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 500 // batch_size
|
||||
|
@ -1156,13 +1399,14 @@ def example_usage():
|
|||
|
||||
if task == "tts-nar":
|
||||
len_list = engine( text_list=text_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||
len_list = [ resp_list[0].shape[0] for l in len_list ]
|
||||
len_list = [ r.shape[0] for r in resp_list ]
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
|
||||
else:
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
print( o.shape, o )
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device)
|
||||
|
||||
unload_model()
|
||||
|
@ -1185,7 +1429,9 @@ def example_usage():
|
|||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
#sample("init", 5)
|
||||
task = available_tasks[0]
|
||||
#sample("init", task=task)
|
||||
|
||||
train()
|
||||
|
||||
"""
|
||||
|
|
|
@ -435,7 +435,7 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
|||
|
||||
class LlamaModel_Adapted(LlamaModel):
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
|
||||
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0)
|
||||
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
|
||||
self.early_exit_r = kwargs.pop("early_exit_r", 2)
|
||||
|
||||
|
@ -459,7 +459,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
self.post_init()
|
||||
|
||||
def dropoff_layer( self, l ):
|
||||
if not self.training:
|
||||
if not self.training or self.layer_dropout_p <= 0:
|
||||
return False
|
||||
|
||||
# this could probably a LUT but I'm not fiending for aggressive mal-optimizations
|
||||
|
|
|
@ -37,7 +37,7 @@ from ..samplers import *
|
|||
from ..data import get_task_symmap
|
||||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states'])
|
||||
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
|
@ -245,6 +245,21 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
class AudioEmbedding_Sums(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
|
||||
def forward(self, xi: Tensor ) -> Tensor:
|
||||
x = sum( [ emb( xi[:, l] ) for l, emb in enumerate(self.embeddings) ] )
|
||||
|
||||
return x
|
||||
|
||||
# time-step embedding
|
||||
# for the NAR-len, since it probably most likely requires encoding the timestep
|
||||
class TimeEmbedding(nn.Module):
|
||||
|
@ -272,12 +287,12 @@ class Classifiers(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
l_embedding_names: list[str] | None = None, # list of names to map to each classifier,
|
||||
l_embedding_names: list[str], # list of names to map to each classifier,
|
||||
d_model: int, # dimensionality of the embedding
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for n_tokens in l_embedding_tokens])
|
||||
self.proj = nn.ModuleList([nn.Linear(d_model, n_tokens, bias=bias) for n_tokens in l_embedding_tokens])
|
||||
self.names = l_embedding_names
|
||||
|
||||
def indices(
|
||||
|
@ -288,19 +303,29 @@ class Classifiers(nn.Module):
|
|||
return names
|
||||
return [ self.names.index(name) for name in names ]
|
||||
|
||||
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None, stack = False ) -> Tensor:
|
||||
dtype = xi.dtype
|
||||
device = xi.device
|
||||
def forward(
|
||||
self,
|
||||
xi: Tensor,
|
||||
levels: list[int] | None = None,
|
||||
names: list[str] | None = None,
|
||||
stack = False,
|
||||
) -> Tensor:
|
||||
dtype = xi[0].dtype
|
||||
device = xi[0].device
|
||||
|
||||
if levels and isinstance( levels[-1], str ):
|
||||
names = levels
|
||||
levels = []
|
||||
|
||||
# map names to levels
|
||||
"""
|
||||
if names and not levels:
|
||||
levels = [ self.names.index(name) for name in names ]
|
||||
levels = [ None if name =="NAR" else self.names.index(name) for name in names ]
|
||||
"""
|
||||
if names and not levels:
|
||||
levels = [ None if name not in self.names else self.names.index(name) for name in names ]
|
||||
|
||||
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||
xi = [ x if l == None else self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||
if not stack:
|
||||
return xi
|
||||
|
||||
|
@ -316,6 +341,109 @@ class Classifiers(nn.Module):
|
|||
]
|
||||
return torch.stack( xi )
|
||||
|
||||
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
|
||||
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
|
||||
class ParallelDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
levels,
|
||||
d_model,
|
||||
config_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
training = config_kwargs.pop("training", False)
|
||||
attention_backend = config_kwargs.pop("attention_backend", "default")
|
||||
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
|
||||
|
||||
hidden_size = config_kwargs.get("hidden_size")
|
||||
vocab_size = config_kwargs.get("vocab_size")
|
||||
|
||||
#self.d_model = d_model
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
downs = []
|
||||
modules = []
|
||||
ups = []
|
||||
for level in range(levels):
|
||||
module = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
|
||||
|
||||
module = ml.replace_attention( module, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
|
||||
if hasattr( module, "embeddings" ):
|
||||
del module.embeddings
|
||||
|
||||
if gradient_checkpointing and not module.gradient_checkpointing:
|
||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
modules.append(module)
|
||||
"""
|
||||
downs.append(nn.Linear(d_model, hidden_size, bias=False))
|
||||
ups.append(nn.Linear(hidden_size, vocab_size, bias=False))
|
||||
"""
|
||||
|
||||
self.levels = levels
|
||||
self.decoders = nn.ModuleList(modules)
|
||||
"""
|
||||
self.downs = nn.ModuleList(downs)
|
||||
self.ups = nn.ModuleList(ups)
|
||||
"""
|
||||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
# split into levels
|
||||
if level == None:
|
||||
x = [ self.forward( x, l, **kwargs ) for l in range(self.levels) ]
|
||||
x = torch.stack( x )
|
||||
x = x.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
|
||||
return x
|
||||
|
||||
# do one level
|
||||
|
||||
# attention + feedforward
|
||||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
||||
# this really hates an output head, so just treat the final output as one
|
||||
x = x[..., :self.vocab_size]
|
||||
|
||||
"""
|
||||
# downscale to head's dimensionality
|
||||
x = self.downs[level]( x )
|
||||
# attention + feed forward
|
||||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
||||
# upscale to vocab logits
|
||||
x = self.ups[level]( x )
|
||||
"""
|
||||
|
||||
return x
|
||||
"""
|
||||
|
||||
"""
|
||||
# naively tries to extract multiple codebooks in parallel based on the last hidden state from the model
|
||||
# this doesn't work well
|
||||
"""
|
||||
class ClassifiersParallel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_levels: int, # codebook levels
|
||||
n_tokens: int, # output token count
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for _ in range(n_levels)])
|
||||
|
||||
def forward(self, xi: Tensor, stack: bool = True ) -> Tensor:
|
||||
dtype = xi.dtype
|
||||
device = xi.device
|
||||
|
||||
xi = [ proj( xi ) for l, proj in enumerate(self.proj) ]
|
||||
xi = torch.stack( xi )
|
||||
xi = xi.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
|
||||
|
||||
return xi
|
||||
"""
|
||||
|
||||
class Metrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -448,6 +576,7 @@ class Base(nn.Module):
|
|||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||
self.version = self.config.version if self.config is not None else 5
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||
self.parallel_decoding = self.config.experimental.parallel_decoding if self.config is not None else False
|
||||
|
||||
self.arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||
|
||||
|
@ -469,7 +598,7 @@ class Base(nn.Module):
|
|||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
#interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||
classifiers_bias = self.config.experimental.classifiers_bias if self.config is not None else False
|
||||
max_position_embeddings = self.config.experimental.max_position_embeddings if self.config is not None else (75 * 60 * 5)
|
||||
|
@ -477,40 +606,56 @@ class Base(nn.Module):
|
|||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||
|
||||
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
||||
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
|
||||
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
||||
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
||||
|
||||
n_tasks = self.config.tasks if self.config is not None else 8
|
||||
n_langs = self.config.langs if self.config is not None else 2
|
||||
n_tones = self.config.tones if self.config is not None else 1
|
||||
|
||||
# pure AR
|
||||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
if "ar" in self.capabilities:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
if self.version < 7:
|
||||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
if "ar" in self.capabilities:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
else:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
if self.parallel_decoding:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
||||
else:
|
||||
"""
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens * self.n_resp_levels]
|
||||
l_embedding_names = ["NAR"]
|
||||
l_classifier_tokens = [n_audio_tokens * self.n_resp_levels]
|
||||
"""
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_classifier_tokens = [n_audio_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [ f'NAR:{i}:{i}' for i in range( self.n_resp_levels ) ]
|
||||
|
||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
l_classifier_names = l_embedding_names
|
||||
|
||||
|
@ -528,23 +673,13 @@ class Base(nn.Module):
|
|||
l_classifier_tokens += [ n_raw_text_tokens ]
|
||||
l_classifier_names = l_embedding_names + [ "raw_text" ]
|
||||
|
||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
self.inject_timestep_embedding = False # results in bad output
|
||||
self.masking_ratio = masking_ratio
|
||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
|
||||
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||
"""
|
||||
if noncausal_masks:
|
||||
attention_backend = "default"
|
||||
"""
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.raw_text_emb = None
|
||||
self.langs_emb = None
|
||||
|
@ -572,7 +707,7 @@ class Base(nn.Module):
|
|||
l_embedding_tokens, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
else:
|
||||
elif not self.parallel_decoding:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
||||
|
@ -582,6 +717,17 @@ class Base(nn.Module):
|
|||
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
|
||||
l_embedding_names=l_embedding_names,
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding_Sums(
|
||||
n_tokens=n_audio_tokens,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding_Sums(
|
||||
n_tokens=n_audio_tokens + 1,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
|
@ -597,12 +743,11 @@ class Base(nn.Module):
|
|||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||
if self.version >= 5:
|
||||
# "len" RVQ level-0 gets an additional token
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
||||
if self.version < 7 or not self.parallel_decoding:
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
||||
|
||||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model)
|
||||
self.time_emb = None # TimeEmbedding(d_model) # if not masking_ratio else None
|
||||
|
||||
if self.version >= 6:
|
||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||
|
||||
|
@ -640,10 +785,8 @@ class Base(nn.Module):
|
|||
n_levels=self.n_resp_levels,
|
||||
) for _ in range(n_layers) ])
|
||||
elif self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||
LlamaClass = LlamaModel_Adapted # if (self.layerskip or "len" in self.capabilities) else LlamaModel
|
||||
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaClass(LlamaConfig(
|
||||
self.model = LlamaModel_Adapted(LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
|
@ -692,11 +835,6 @@ class Base(nn.Module):
|
|||
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
|
||||
"""
|
||||
|
||||
if self.layerskip:
|
||||
self.model.layer_dropout_p = layerskip_p_max
|
||||
self.model.early_exit_scale = layerskip_e_scale
|
||||
self.model.early_exit_r = layerskip_r
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
|
@ -766,19 +904,38 @@ class Base(nn.Module):
|
|||
if not split_classifiers:
|
||||
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
|
||||
self.classifiers = None
|
||||
|
||||
self.metrics = None
|
||||
else:
|
||||
self.classifier = None
|
||||
self.classifiers = Classifiers( l_classifier_tokens, d_model, l_embedding_names=l_classifier_names, bias=classifiers_bias )
|
||||
self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias )
|
||||
self.metrics = Metrics( l_classifier_tokens )
|
||||
|
||||
"""
|
||||
if tie_classifier_to_embedding:
|
||||
for i, proj in enumerate( self.classifiers.proj ):
|
||||
self.classifiers.proj[i].weight = self.resps_emb.embeddings[i].weight
|
||||
"""
|
||||
self.parallel_decoder = None
|
||||
if self.parallel_decoding:
|
||||
pd_model = d_model # // 2
|
||||
pd_ffn = pd_model * 2
|
||||
pd_heads = n_heads // 2
|
||||
pd_layers = 1
|
||||
|
||||
config = dict(
|
||||
vocab_size=n_audio_tokens,
|
||||
hidden_size=pd_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=pd_ffn,
|
||||
num_hidden_layers=pd_layers,
|
||||
num_attention_heads=pd_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=pd_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation="eager",
|
||||
|
||||
training=self.training,
|
||||
attention_backend=attention_backend,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.parallel_decoder = ParallelDecoder( self.n_resp_levels, d_model, config )
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
|
@ -789,8 +946,6 @@ class Base(nn.Module):
|
|||
|
||||
state = None,
|
||||
|
||||
layer_skip_lambda = None,
|
||||
|
||||
output_attentions = False,
|
||||
output_hidden_states = False,
|
||||
):
|
||||
|
@ -818,9 +973,6 @@ class Base(nn.Module):
|
|||
if self.n_experts > 1 and self.training:
|
||||
kwargs["output_router_logits"] = True
|
||||
|
||||
if self.layerskip and layer_skip_lambda is not None:
|
||||
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
||||
|
||||
output = self.model(**kwargs)
|
||||
x = output["last_hidden_state"]
|
||||
|
||||
|
@ -885,7 +1037,7 @@ class Base(nn.Module):
|
|||
# but skip the last state, as it already is normalized
|
||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||
|
||||
return Logits(x, state, inputs, aux_loss, attentions, hidden_states, None)
|
||||
return Logits(x, state, inputs, aux_loss, attentions, hidden_states)
|
||||
|
||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||
def inputs(
|
||||
|
@ -943,7 +1095,7 @@ class Base(nn.Module):
|
|||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert RVQ level guidance token if the model is versioned for it
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
|
||||
classifier_level = "AR:0:0" if quant_level == 0 else f'NAR:{quant_level-1}:{quant_level}'
|
||||
|
@ -976,7 +1128,10 @@ class Base(nn.Module):
|
|||
# insert the current output response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
|
||||
if self.version >= 7:
|
||||
classifier_level = f"NAR:{quant_level}:{quant_level}" if not self.parallel_decoding else "NAR"
|
||||
|
||||
inputs[i].append( ("classifier_level", classifier_level) )
|
||||
# Audio length prediction task
|
||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||
|
@ -1022,7 +1177,7 @@ class Base(nn.Module):
|
|||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert RVQ level guidance token if the model is versioned for it
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the output text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
|
@ -1038,7 +1193,7 @@ class Base(nn.Module):
|
|||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
|
@ -1054,7 +1209,7 @@ class Base(nn.Module):
|
|||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the text prompt
|
||||
if raw_text_list is not None and raw_text_list[i] is not None:
|
||||
|
@ -1117,12 +1272,23 @@ class Base(nn.Module):
|
|||
return self.proms_emb(
|
||||
input if quant_level == 0 else input[:, :quant_level]
|
||||
)
|
||||
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
offset = 0,
|
||||
)
|
||||
|
||||
if self.version < 7 or not self.parallel_decoding:
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
offset = 0,
|
||||
)
|
||||
"""
|
||||
if not self.parallel_decoding:
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, :quant_level+1],
|
||||
quant_level = quant_level,
|
||||
offset = 0,
|
||||
)
|
||||
"""
|
||||
|
||||
return self.proms_emb( input )
|
||||
|
||||
# yuck
|
||||
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
|
||||
|
@ -1188,28 +1354,23 @@ class Base(nn.Module):
|
|||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
if self.interleave:
|
||||
embeddings = [ self.resps_emb(
|
||||
input[:, :l+1],
|
||||
#offset = 0,
|
||||
#quant_level = l,
|
||||
name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
|
||||
) for l in range( input.shape[-1] ) ]
|
||||
|
||||
embedding = _interleave_sequence_reshape( embeddings )
|
||||
|
||||
if self.parallel_decoding:
|
||||
if dropout_mask is not None:
|
||||
embedding = self.resps_emb( torch.where( dropout_mask, self.stop_token, input.t() ).t() )
|
||||
else:
|
||||
embedding = self.resps_emb( input )
|
||||
# if training NAR-len RVQ level 0
|
||||
elif dropout_mask is not None:
|
||||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ),
|
||||
#quant_level = 0,
|
||||
name = classifier_level,
|
||||
)
|
||||
# NAR-len
|
||||
elif classifier_level == "NAR:0:0":
|
||||
elif classifier_level == f"NAR:{quant_level}:{quant_level}":
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 else input[:, 0],
|
||||
input if input.dim() == 1 else input[:, quant_level],
|
||||
#quant_level = 0,
|
||||
name = classifier_level,
|
||||
)
|
||||
|
@ -1323,10 +1484,6 @@ class Base(nn.Module):
|
|||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
|
||||
|
||||
# interleaved model
|
||||
if self.interleave and name == "resp":
|
||||
return input.shape[0] * input.shape[1]
|
||||
|
||||
# ending input will not have a separator later
|
||||
return input.shape[0]
|
||||
|
||||
|
@ -1352,6 +1509,166 @@ class Base(nn.Module):
|
|||
|
||||
return ids.to(device=device, dtype=torch.int32)
|
||||
|
||||
def calc_loss_parallel(
|
||||
self,
|
||||
inputs: list,
|
||||
logits,
|
||||
|
||||
compute_hard_loss = True,
|
||||
compute_acc = True,
|
||||
):
|
||||
loss = {}
|
||||
stats = {}
|
||||
|
||||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input ):
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
return input
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
target = []
|
||||
causal = False
|
||||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
classifier_level = None
|
||||
output_len = 0
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
elif name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "classifier_level":
|
||||
classifier_level = input
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
token = None
|
||||
ignored = False
|
||||
|
||||
# non-tokened tasks
|
||||
if name in non_tokened_names:
|
||||
continue
|
||||
# prom can either be a tensor itself or a list of tensors and strings
|
||||
if name == "prom":
|
||||
# expand to list if not a list
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
# iterate over the list to inject their tokens
|
||||
token = torch.cat( [ prompt_input_to_token( input ) for input in proms if input is not None ] )
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = torch.where( dropout_mask, input.t(), self.ignore_index ).t()
|
||||
else:
|
||||
token = input
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
token = input
|
||||
|
||||
if not isinstance(token, torch.Tensor):
|
||||
continue
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
||||
# grab range of our logits for later
|
||||
seq_len = token.shape[0]
|
||||
start, end = it, it+seq_len
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# deduce if a name for a task is an input or output
|
||||
if name != task_outputs.get(task_type, name):
|
||||
if self.ignore_inputs_for_loss:
|
||||
ignored = True
|
||||
else:
|
||||
output_len = seq_len
|
||||
|
||||
if ignored:
|
||||
# pruned
|
||||
if self.config.loss_factors:
|
||||
continue
|
||||
# fill with ignored out tensor
|
||||
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
|
||||
|
||||
# perform loss calculation on the individual piece
|
||||
target.append( token )
|
||||
|
||||
if classifier_level != "NAR":
|
||||
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
logit = logits[batch_index]
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
seq = seq[..., l:] # ...predicts token n + 1
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if compute_acc and False:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, seq )
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
else:
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
seq = _join( [ t if t.dim() <= 1 else t[:, level] for t in target ], torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
seq = seq[..., l:] # ...predicts token n + 1
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if compute_acc and False:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, seq )
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
def calc_loss(
|
||||
self,
|
||||
inputs: list,
|
||||
|
@ -1377,7 +1694,13 @@ class Base(nn.Module):
|
|||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
return torch.full_like(input[..., 0], self.ignore_index)
|
||||
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
if self.version < 7:
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
if not self.parallel_decoding:
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
return input
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
|
@ -1402,6 +1725,8 @@ class Base(nn.Module):
|
|||
# nonautoregressive, parallel
|
||||
elif classifier_level.startswith("NAR:"):
|
||||
causal = False
|
||||
elif classifier_level == "NAR":
|
||||
causal = False
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
|
@ -1422,9 +1747,6 @@ class Base(nn.Module):
|
|||
if dropout_mask is not None:
|
||||
# if mask use original token, else ignore
|
||||
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
|
||||
# flatten
|
||||
elif self.interleave:
|
||||
token = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||
# use resps as-is
|
||||
else:
|
||||
token = input if input.dim() == 1 else input[:, quant_level]
|
||||
|
@ -1435,28 +1757,6 @@ class Base(nn.Module):
|
|||
if not isinstance(token, torch.Tensor):
|
||||
continue
|
||||
|
||||
# offset to flattened vocab ranges
|
||||
"""
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
|
||||
k = name
|
||||
if name == "stt":
|
||||
k = "text"
|
||||
if name == "prom":
|
||||
k = f'prom|{quant_level}'
|
||||
elif name == "resp":
|
||||
k = f'resps|{classifier_level}'
|
||||
|
||||
if k in offsets:
|
||||
start, end = offsets[k]
|
||||
|
||||
for i, t in enumerate( token ):
|
||||
if t == self.ignore_index:
|
||||
continue
|
||||
token[i] += start
|
||||
"""
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
||||
|
@ -1566,51 +1866,9 @@ class Base(nn.Module):
|
|||
quant_levels: list[int] | None = None,
|
||||
state: dict | list | None = None,
|
||||
|
||||
layer_skip_variables: dict | None = None,
|
||||
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
):
|
||||
# return early if it's "good" enough"
|
||||
# lambda because we need to capture the classifier_levels and mask
|
||||
exited_layer = self.n_layers
|
||||
def layer_skip_lambda( layer, logits ):
|
||||
nonlocal exited_layer
|
||||
kwargs = {
|
||||
"entropy_threshold": 0.05,
|
||||
"varentropy_threshold": 0.05,
|
||||
"min_layer": self.n_layers // 2,
|
||||
"max_layer": self.n_layers,
|
||||
}
|
||||
|
||||
kwargs.update( layer_skip_variables )
|
||||
|
||||
# don't bother on early layers
|
||||
if layer < kwargs["min_layer"]:
|
||||
return False
|
||||
# bail if we want to force early layers
|
||||
if kwargs["max_layer"] < layer:
|
||||
return True
|
||||
|
||||
# hidden states aren't normalized
|
||||
x = self.model.norm( logits )
|
||||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x) # * m
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_levels) # * m
|
||||
|
||||
# calculate metrics
|
||||
metrics = calculate_entropix_metrics( logits )
|
||||
# exit early if "good enough""
|
||||
early = metrics["logits_entropy"] <= kwargs["entropy_threshold"] and metrics["logits_varentropy"] <= kwargs["varentropy_threshold"]
|
||||
|
||||
if early:
|
||||
exited_layer = layer
|
||||
|
||||
return early
|
||||
|
||||
):
|
||||
# derive quant levels from inputs if not provided
|
||||
if quant_levels is None:
|
||||
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
|
||||
|
@ -1628,10 +1886,6 @@ class Base(nn.Module):
|
|||
device = x.device
|
||||
batch_size = len(x_list)
|
||||
|
||||
# we only need hidden states if we're training with layerskip
|
||||
if self.layerskip and training:
|
||||
output_hidden_states = True
|
||||
|
||||
# pad our input and mask, but retain the original length by doing it after
|
||||
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
||||
# pad input
|
||||
|
@ -1663,55 +1917,55 @@ class Base(nn.Module):
|
|||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
output_attentions = output_attentions,
|
||||
output_hidden_states = output_hidden_states,
|
||||
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
logits = [ logit for logit in logits ]
|
||||
|
||||
if self.version >= 7 and self.parallel_decoding:
|
||||
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
|
||||
if p_indices:
|
||||
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
||||
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
||||
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
||||
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
|
||||
|
||||
p_logits = self.parallel_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
|
||||
|
||||
for i, logit in enumerate(p_logits):
|
||||
logits[p_indices[i]] = logit
|
||||
|
||||
"""
|
||||
logits = [ self.parallel_decoder( logit.unsqueeze(0), attention_mask=mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
is_causal=is_causal )[0] if level == "NAR" else logit for logit, level in zip(logits, classifier_levels) ]
|
||||
"""
|
||||
|
||||
# output projection layer
|
||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) # * m
|
||||
|
||||
if output.hidden_states:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifier(hidden_states[i]) # * m
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_levels) # * m
|
||||
logits = self.classifiers(logits, levels = classifier_levels )
|
||||
|
||||
if hidden_states is not None:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_levels) # * m
|
||||
# Reshape
|
||||
"""
|
||||
if self.version >= 7 and not self.parallel_decoding:
|
||||
for batch_index, logit in enumerate( logits ):
|
||||
if classifier_levels[batch_index] != "NAR":
|
||||
continue
|
||||
logits[batch_index] = logit.reshape( logit.shape[0], 8, 1000 ).permute( 1, 0, 2 )
|
||||
"""
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
||||
if hidden_states is not None:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
||||
# de-offset if needed
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
if classifier_level == "stt":
|
||||
k = "text"
|
||||
elif classifier_level == "len":
|
||||
k = "len"
|
||||
else:
|
||||
k = f'resps|{classifier_level}'
|
||||
|
||||
if k not in offsets:
|
||||
continue
|
||||
|
||||
start, end = offsets[k]
|
||||
|
||||
logits[batch_index] = logits[batch_index][:, start:end]
|
||||
|
||||
if not training:
|
||||
loss = None
|
||||
stats = None
|
||||
|
@ -1719,30 +1973,18 @@ class Base(nn.Module):
|
|||
self.loss = None
|
||||
self.stats = None
|
||||
# compute loss if the target is given
|
||||
elif self.version >= 7 and self.parallel_decoding:
|
||||
loss, stats = self.calc_loss_parallel( inputs=inputs, logits=logits )
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
loss["aux_loss"] = output.loss
|
||||
|
||||
self.loss = loss
|
||||
self.stats = stats
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
||||
# compute it as an aux-loss
|
||||
if self.layerskip:
|
||||
early_exit_loss = {}
|
||||
if not hasattr( self, "training_steps" ):
|
||||
self.training_steps = 0
|
||||
|
||||
for i, state in enumerate( hidden_states ):
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
|
||||
|
||||
for k, v in loss.items():
|
||||
K = f'early_exit.{k}'
|
||||
if K not in early_exit_loss:
|
||||
early_exit_loss[K] = []
|
||||
early_exit_loss[K].append( v )
|
||||
|
||||
for k, v in early_exit_loss.items():
|
||||
loss[k] = self.model.early_exit_loss( losses=v, t=self.training_steps )
|
||||
|
||||
# to-do: instead make the cirriculum rely on samples processed instead of steps
|
||||
self.training_steps += 1 # batch_size
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
loss["aux_loss"] = output.loss
|
||||
|
@ -1751,7 +1993,7 @@ class Base(nn.Module):
|
|||
self.stats = stats
|
||||
|
||||
# rewrap, because we're modifying the logits here
|
||||
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states, exited_layer)
|
||||
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
|
|
@ -176,17 +176,17 @@ def top_no_logits_processing( logits, n = 1.0 ):
|
|||
# (and because the null logits have a shorter input sequence compared to the positive logits)
|
||||
def cfg_logits( logits, null, strength, lens, rescale=0.0 ):
|
||||
for i, seq_len in enumerate( lens ):
|
||||
pos = logits[i][-seq_len:]
|
||||
neg = null[i][-seq_len:]
|
||||
pos = logits[i][..., -seq_len:, :]
|
||||
neg = null[i][..., -seq_len:, :]
|
||||
|
||||
summed = neg + (pos - neg) * strength
|
||||
|
||||
if rescale <= 0:
|
||||
logits[i][-seq_len:] = summed
|
||||
logits[i][..., -seq_len:, :] = summed
|
||||
else:
|
||||
dims = tuple(range(1, summed.ndim - 1))
|
||||
factor = rescale * (pos.std(dims, keepdim=True) / summed.std(dims, keepdim=True)) + (1 - rescale)
|
||||
logits[i][-seq_len:] = summed * factor
|
||||
logits[i][..., -seq_len:, :] = summed * factor
|
||||
|
||||
return logits
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user