re-implemented config.model.interleave for the HF-compat experimental method
This commit is contained in:
parent
c93d5863fd
commit
406ff7bbe1
|
@ -216,7 +216,7 @@ class Model:
|
|||
dropout: float = 0.1 # adjustable dropout value
|
||||
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
||||
kv_heads: int = 0
|
||||
experimental: bool = False
|
||||
experimental: bool = False # for now it sets things to be HF compatible
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
|
|
@ -44,7 +44,8 @@ def fold_inputs(
|
|||
|
||||
text_tokens = 256,
|
||||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.max_levels
|
||||
audio_rvq_levels = cfg.model.max_levels,
|
||||
quant_levels = None,
|
||||
):
|
||||
def _create_mask(l, device):
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -75,23 +76,43 @@ def fold_inputs(
|
|||
|
||||
offset = text_tokens
|
||||
for i, prom in enumerate(prom_list):
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
||||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64)
|
||||
else:
|
||||
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * quant_level )
|
||||
else:
|
||||
seq = prom.flatten().to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
||||
else:
|
||||
seq = prom.flatten().to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens + (audio_tokens * audio_rvq_levels)
|
||||
for i, resp in enumerate(resp_list):
|
||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * quant_level )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
if quant_level == 0:
|
||||
input_ids[i].append( stop )
|
||||
else:
|
||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
for i, batch in enumerate(input_ids):
|
||||
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
|
||||
|
@ -99,6 +120,7 @@ def fold_inputs(
|
|||
return list_to_tensor(input_ids)
|
||||
|
||||
# unfold from one unified token ID space to separate token spaces
|
||||
# to-do: unfold at a specific RVQ level instead if requested
|
||||
def unfold_outputs(
|
||||
output_ids,
|
||||
|
||||
|
@ -107,7 +129,8 @@ def unfold_outputs(
|
|||
|
||||
text_tokens = 256,
|
||||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.max_levels
|
||||
audio_rvq_levels = cfg.model.max_levels,
|
||||
quant_levels = None,
|
||||
):
|
||||
device = output_ids.device
|
||||
batch_size = output_ids.shape[0]
|
||||
|
@ -129,30 +152,33 @@ def unfold_outputs(
|
|||
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
|
||||
resp_list[i].append( (id - text_tokens) % audio_tokens )
|
||||
|
||||
prom_len = len(prom_list[i])
|
||||
if prom_len % audio_rvq_levels == 0 and False:
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
|
||||
if quant_levels is not None:
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
|
||||
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64)
|
||||
else:
|
||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||
for pos in range( prom_len ):
|
||||
rvq = pos % audio_rvq_levels
|
||||
bins[rvq].append( prom_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
prom_len = len(prom_list[i])
|
||||
if prom_len % audio_rvq_levels == 0 and False:
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
|
||||
else:
|
||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||
for pos in range( prom_len ):
|
||||
rvq = pos % audio_rvq_levels
|
||||
bins[rvq].append( prom_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
|
||||
|
||||
resp_len = len(resp_list[i])
|
||||
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
|
||||
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
|
||||
else:
|
||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||
for pos in range( resp_len ):
|
||||
rvq = pos % audio_rvq_levels
|
||||
bins[rvq].append( resp_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
resp_len = len(resp_list[i])
|
||||
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
|
||||
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
|
||||
else:
|
||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||
for pos in range( resp_len ):
|
||||
rvq = pos % audio_rvq_levels
|
||||
bins[rvq].append( resp_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
|
||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
||||
|
||||
|
|
|
@ -158,7 +158,22 @@ class Model(LlmArchClass):
|
|||
|
||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if SELECTED_ARCH == "mamba":
|
||||
kwargs["cg"] = True
|
||||
|
||||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
|
||||
if "do_sample" in kwargs:
|
||||
kwargs.pop("do_sample")
|
||||
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
|
@ -239,13 +254,9 @@ def example_usage():
|
|||
proms_list = proms_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list)
|
||||
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100)
|
||||
prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list)
|
||||
|
||||
kwargs = {}
|
||||
model = Model(**kwargs).to(device)
|
||||
steps = 50
|
||||
steps = 50 if cfg.model.interleave else 250
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
|
@ -312,15 +323,46 @@ def example_usage():
|
|||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ):
|
||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||
engine.eval()
|
||||
if SELECTED_ARCH == "mamba":
|
||||
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3)
|
||||
target_length = 0
|
||||
resp_list = None
|
||||
if cfg.model.interleave:
|
||||
input_ids, attention_mask = fold_inputs(text_list, proms_list)
|
||||
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
||||
|
||||
unfolded = unfold_outputs( output )
|
||||
resp_list = unfolded["resp_list"]
|
||||
else:
|
||||
output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
||||
resp_list = [ [] for _ in range(len(text_list)) ]
|
||||
for l in range(cfg.model.max_levels):
|
||||
quant_levels = [ l ]
|
||||
input_ids, attention_mask = fold_inputs(text_list, proms_list, quant_levels=quant_levels)
|
||||
min_length = len(input_ids[0])
|
||||
|
||||
unfolded = unfold_outputs( output )
|
||||
for i, batch in enumerate(unfolded["resp_list"]):
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length+(steps if l > 0 else 0),
|
||||
max_length=min_length+steps,
|
||||
eos_token_id=3 if l == 0 else None ,
|
||||
do_sample=False
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
if l == 0:
|
||||
steps = 0
|
||||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
if l == 0:
|
||||
steps = max( steps, resp.shape[0] )
|
||||
resp_list[batch].append( resp )
|
||||
|
||||
for i, resp in enumerate( resp_list ):
|
||||
resp_list[i] = torch.stack( resp ).t()
|
||||
|
||||
for i, batch in enumerate(resp_list):
|
||||
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
@ -330,6 +372,13 @@ def example_usage():
|
|||
t = trange(steps)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
|
||||
batch_size = len(text_list)
|
||||
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list, quant_levels=quant_levels)
|
||||
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100, quant_levels=quant_levels)
|
||||
|
||||
if SELECTED_ARCH == "mamba":
|
||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user