re-implemented config.model.interleave for the HF-compat experimental method

This commit is contained in:
mrq 2024-06-04 14:19:52 -05:00
parent c93d5863fd
commit 406ff7bbe1
3 changed files with 121 additions and 46 deletions

View File

@ -216,7 +216,7 @@ class Model:
dropout: float = 0.1 # adjustable dropout value dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
kv_heads: int = 0 kv_heads: int = 0
experimental: bool = False experimental: bool = False # for now it sets things to be HF compatible
def get(self, name=None): def get(self, name=None):
return [ self ] if not name or self.name == name else [] return [ self ] if not name or self.name == name else []

View File

@ -44,7 +44,8 @@ def fold_inputs(
text_tokens = 256, text_tokens = 256,
audio_tokens = 1024, audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels audio_rvq_levels = cfg.model.max_levels,
quant_levels = None,
): ):
def _create_mask(l, device): def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -75,23 +76,43 @@ def fold_inputs(
offset = text_tokens offset = text_tokens
for i, prom in enumerate(prom_list): for i, prom in enumerate(prom_list):
if ignore_index is not None: if quant_levels is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) quant_level = quant_levels[i]
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level )
else: else:
seq = prom.flatten().to("cpu", dtype=torch.int64) if ignore_index is not None:
for idx, token in enumerate( seq ): seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels) offset = text_tokens + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list): for i, resp in enumerate(resp_list):
seq = resp.flatten().to("cpu", dtype=torch.int64) if quant_levels is not None:
for idx, token in enumerate( seq ): quant_level = quant_levels[i]
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
input_ids[i].append( seq ) for idx, token in enumerate( seq ):
input_ids[i].append( stop ) token += offset + ( audio_tokens * quant_level )
input_ids[i].append( seq )
if quant_level == 0:
input_ids[i].append( stop )
else:
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
for i, batch in enumerate(input_ids): for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
@ -99,6 +120,7 @@ def fold_inputs(
return list_to_tensor(input_ids) return list_to_tensor(input_ids)
# unfold from one unified token ID space to separate token spaces # unfold from one unified token ID space to separate token spaces
# to-do: unfold at a specific RVQ level instead if requested
def unfold_outputs( def unfold_outputs(
output_ids, output_ids,
@ -107,7 +129,8 @@ def unfold_outputs(
text_tokens = 256, text_tokens = 256,
audio_tokens = 1024, audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels audio_rvq_levels = cfg.model.max_levels,
quant_levels = None,
): ):
device = output_ids.device device = output_ids.device
batch_size = output_ids.shape[0] batch_size = output_ids.shape[0]
@ -129,30 +152,33 @@ def unfold_outputs(
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id: elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens ) resp_list[i].append( (id - text_tokens) % audio_tokens )
prom_len = len(prom_list[i]) if quant_levels is not None:
if prom_len % audio_rvq_levels == 0 and False: prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64)
else: else:
bins = [ [] for _ in range(audio_rvq_levels) ] prom_len = len(prom_list[i])
for pos in range( prom_len ): if prom_len % audio_rvq_levels == 0 and False:
rvq = pos % audio_rvq_levels prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
bins[rvq].append( prom_list[i][pos] ) else:
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels bins = [ [] for _ in range(audio_rvq_levels) ]
bins = bins[:nearest] for pos in range( prom_len ):
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
resp_len = len(resp_list[i])
resp_len = len(resp_list[i]) if len(resp_list[i]) % audio_rvq_levels == 0 and False:
if len(resp_list[i]) % audio_rvq_levels == 0 and False: resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() else:
else: bins = [ [] for _ in range(audio_rvq_levels) ]
bins = [ [] for _ in range(audio_rvq_levels) ] for pos in range( resp_len ):
for pos in range( resp_len ): rvq = pos % audio_rvq_levels
rvq = pos % audio_rvq_levels bins[rvq].append( resp_list[i][pos] )
bins[rvq].append( resp_list[i][pos] ) nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels bins = bins[:nearest]
bins = bins[:nearest] resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64) text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)

View File

@ -158,7 +158,22 @@ class Model(LlmArchClass):
self.backbone.gradient_checkpointing = gradient_checkpointing self.backbone.gradient_checkpointing = gradient_checkpointing
def generate(
self,
*args,
**kwargs
):
if SELECTED_ARCH == "mamba":
kwargs["cg"] = True
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
if "do_sample" in kwargs:
kwargs.pop("do_sample")
return super().forward(*args, **kwargs)
def forward( def forward(
self, self,
*args, *args,
@ -239,13 +254,9 @@ def example_usage():
proms_list = proms_list[:1] proms_list = proms_list[:1]
resps_list = resps_list[:1] resps_list = resps_list[:1]
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list)
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100)
prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list)
kwargs = {} kwargs = {}
model = Model(**kwargs).to(device) model = Model(**kwargs).to(device)
steps = 50 steps = 50 if cfg.model.interleave else 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
@ -312,15 +323,46 @@ def example_usage():
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode() @torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ): def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
engine.eval() engine.eval()
if SELECTED_ARCH == "mamba": target_length = 0
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3) resp_list = None
if cfg.model.interleave:
input_ids, attention_mask = fold_inputs(text_list, proms_list)
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
unfolded = unfold_outputs( output )
resp_list = unfolded["resp_list"]
else: else:
output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False) resp_list = [ [] for _ in range(len(text_list)) ]
for l in range(cfg.model.max_levels):
quant_levels = [ l ]
input_ids, attention_mask = fold_inputs(text_list, proms_list, quant_levels=quant_levels)
min_length = len(input_ids[0])
unfolded = unfold_outputs( output ) output = model.generate(
for i, batch in enumerate(unfolded["resp_list"]): input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length+(steps if l > 0 else 0),
max_length=min_length+steps,
eos_token_id=3 if l == 0 else None ,
do_sample=False
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
if l == 0:
steps = max( steps, resp.shape[0] )
resp_list[batch].append( resp )
for i, resp in enumerate( resp_list ):
resp_list[i] = torch.stack( resp ).t()
for i, batch in enumerate(resp_list):
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device) _ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model() unload_model()
@ -330,6 +372,13 @@ def example_usage():
t = trange(steps) t = trange(steps)
for i in t: for i in t:
stats = {"step": i} stats = {"step": i}
batch_size = len(text_list)
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list, quant_levels=quant_levels)
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100, quant_levels=quant_levels)
if SELECTED_ARCH == "mamba": if SELECTED_ARCH == "mamba":
stats |= engine.traverse(input_ids=input_ids, labels=target_ids) stats |= engine.traverse(input_ids=input_ids, labels=target_ids)
else: else: