2024-12-23 02:11:31 +00:00
# this is a VERY rudimentary script to test if a HF-ified model works (it sort of does)
2024-12-22 04:52:10 +00:00
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
2024-12-23 02:11:31 +00:00
from torch.distributions import Categorical
2024-12-22 04:52:10 +00:00
# tokenizer = LlamaTokenizer.from_pretrained("./training/llama-encodec-ar+nar-len/hf/")
model = LlamaForCausalLM.from_pretrained("./training/llama-encodec-ar+nar-len/hf/")
2024-12-23 02:11:31 +00:00
model.to(device="cuda", dtype=torch.bfloat16)
mode = "nar"
phn = [1,22,111,100,4,37,115,169,11,2]
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
prom = [
resp = [
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
sep = [291]
rvq_lvl = [256]
lang = [264]
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
for l, codes in enumerate( prom ):
for i, t in enumerate( codes ):
prom[l][i] += 292 + (1024 * l)
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
for l, codes in enumerate( resp ):
for i, t in enumerate( codes ):
resp[l][i] += 9509 + (1024 * l)
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
ids = torch.tensor([])
pos_ids = torch.tensor([])
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
ids = torch.concat([ ids, torch.tensor(phn), torch.tensor(sep) ])
seq = torch.tensor([ _ for _ in range( len(phn) + 1 ) ])
pos_ids = torch.concat([ pos_ids, seq ])
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
ids = torch.concat([ ids, torch.tensor(lang), torch.tensor(sep) ])
seq = torch.tensor([ _ for _ in range( len(lang) + 1 ) ])
pos_ids = torch.concat([ pos_ids, seq ])
ids = torch.concat([ ids, torch.tensor(rvq_lvl), torch.tensor(sep) ])
seq = torch.tensor([ _ for _ in range( len(rvq_lvl) + 1 ) ])
pos_ids = torch.concat([ pos_ids, seq ])
ids = torch.concat([ ids, torch.tensor(prom[0]), torch.tensor(sep) ])
seq = torch.tensor([ _ for _ in range( len(prom[0]) + 1 ) ])
pos_ids = torch.concat([ pos_ids, seq ])
start, end, stop = (None, None, None)
if mode == "len":
len_seq = [279]
ids = torch.concat([ ids, torch.tensor(len_seq) ])
seq = torch.tensor([ _ for _ in range( len(len_seq) ) ])
pos_ids = torch.concat([ pos_ids, seq ])
start, end, stop = (279, 279+11, 10)
max_n = 10
outputs = 1
elif mode =="ar":
start, end, stop = (8484, 8484+1025, 1024)
max_n = 350
outputs = 1
elif mode =="nar":
ids = torch.concat([ ids, torch.tensor(resp[0]) ])
seq = torch.tensor([ _ for _ in range( len(resp[0]) ) ])
pos_ids = torch.concat([ pos_ids, seq ])
start, end, stop = (9509, 9509+1024, None)
max_n = 1
outputs = len(resp[0])
ids = ids.to(device="cuda", dtype=torch.int32)
pos_ids = pos_ids.to(device="cuda", dtype=torch.int32)
attention_mask = torch.tensor([ True for _ in range( ids.shape[0] ) ], dtype=torch.bool)
n = 0
with torch.no_grad():
while n < max_n:
if n == 0:
embs = model.model.embed_tokens( ids )
for i, emb in enumerate( embs ):
print( i, ids[i].item(), sum(emb).item(), pos_ids[i].item() )
out = model(input_ids=ids.unsqueeze(0), position_ids=pos_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0))
logits = out.logits[0, -outputs:, start:end]
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
if mode == "ar":
tokens = Categorical(logits=logits).sample()
tokens = logits.argmax(dim=-1)
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
n += 1
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
print( n, tokens )
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
if outputs == 1:
if stop in tokens:
2024-12-22 04:52:10 +00:00
2024-12-23 02:11:31 +00:00
ids = torch.concat( [ ids, tokens + start ] )
pos_ids = torch.concat( [ pos_ids, torch.tensor([n]).to(pos_ids) ] )
attention_mask = torch.concat([ attention_mask, torch.tensor([True]).to(attention_mask) ])
2024-12-22 04:52:10 +00:00
print( out )
2024-12-23 02:11:31 +00:00
print( ids )
print( pos_ids )