refractor cleanup, had a revelation on how I can handle a batch of varying tasks

This commit is contained in:
mrq 2024-04-16 21:04:48 -05:00
parent 467fa1c5ee
commit b0bd88833c
3 changed files with 126 additions and 84 deletions

View File

@ -9,8 +9,7 @@ import time
import torch
from dataclasses import asdict, dataclass
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from functools import cached_property
from pathlib import Path

View File

@ -161,13 +161,17 @@ class AR_NAR(Base):
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
return super().forward(
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
targ_list=targ_list,
lang_list=lang_list,
tone_list=tone_list,
tone_list=tone_list
)
return super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
# is NAR
@ -183,12 +187,16 @@ class AR_NAR(Base):
quant_levels = torch.full((len(text_list),), level)
logits = super().forward(
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
)
logits = super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
@ -235,23 +243,23 @@ class AR_NAR(Base):
else:
resps_list = self._unsqueeze_list(sequence_list)
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
)
if recurrent_state is not None:
logits, recurrent_state = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state
inputs=inputs,
state=recurrent_state,
)
else:
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state
inputs=inputs,
state=recurrent_state,
)
r = super().sample(

View File

@ -10,6 +10,7 @@ from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.nn import Embedding
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
@ -165,11 +166,13 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
return x, m
# automagically parses a batch-list and returns it as a list
"""
class Embedding(nn.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
"""
class MultiEmbedding(nn.Module):
"""
@ -218,22 +221,18 @@ class AudioEmbedding(nn.Module):
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = []
for i, xi in enumerate(x_list):
# prom
if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels[i] == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
res_list.append(x)
return res_list
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
# prom
if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
return x
class Base(nn.Module):
@property
@ -302,17 +301,6 @@ class Base(nn.Module):
def ignore_index(self):
return -100
@staticmethod
def _samplewise_merge_tensors(*l, sep: Tensor | None):
if sep is None:
cat = torch.cat
else:
cat = partial(_join, sep=sep)
l = [ x for x in l if x is not None ]
return [*map(cat, zip(*l))]
def __init__(
self,
n_tokens: int = 1024,
@ -638,51 +626,104 @@ class Base(nn.Module):
return x, state, aux_loss
def forward(
def inputs(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
):
device = text_list[0].device
batch_size = len(text_list)
# silently ignore languages if model does not have it
if self.langs_emb is None:
lang_list = None
# inject default language
elif lang_list is None:
lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
# silently ignore tones if model does not have it
if self.tones_emb is None:
tone_list = None
# inject default tone
elif tone_list is None:
tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size):
if text_list is not None:
inputs[i].append( ( "text", text_list[i] ) )
if proms_list is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
if resps_list is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
if targ_list is not None:
inputs[i].append( ( "targ", targ_list[i] ) )
"""
# Typical sequence format
# To-do: integrate tasks again
<s><text></s><sep><lang><sep><prom><sep><tone><sep><resp><stop>
"""
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.langs_emb(lang_list) if lang_list is not None else None,
self.proms_emb(proms_list),
self.tones_emb(tone_list) if tone_list is not None else None,
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
return inputs
def inputs_to_embeddings(
self,
inputs: list,
quant_levels: Tensor | None = None
):
x_list = []
for b_i in range(len(inputs)):
batch = []
for i in range(len(inputs[b_i])):
name, input = inputs[b_i][i]
embedding = None
if name == "text":
embedding = self.text_emb( input )
elif name == "lang":
embedding = self.langs_emb( input )
elif name == "prom":
embedding = self.proms_emb( input )
elif name == "tone":
embedding = self.tones_emb( input )
elif name == "resp":
embedding = self.resps_emb( input, quant_levels[b_i] if quant_levels is not None else None )
else:
continue
batch.append(embedding)
x_list.append( _join( batch, self.sep ) )
return x_list
def training_targets(
self,
inputs: list,
):
x_list = []
for bi in range(len(inputs)):
batch = []
for i in range(len(inputs[bi])):
name, input = inputs[bi][i]
device = input.device
if name == "prom":
batch.append( torch.full_like(input[..., 0], self.ignore_index) )
elif name in ["text", "lang", "tone", "targ"]:
batch.append( input )
x_list.append( _join( batch, torch.tensor(self.ignore_index, device=device) ) )
return x_list
def forward(
self,
inputs: list,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
):
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
# yes, there's a better way.
training = False
for b_i in range(len(inputs)):
for i in range(len(inputs[b_i])):
name, input = inputs[b_i][i]
if name == "targ":
training = True
device = x.device
batch_size = len(x_list)
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
@ -709,15 +750,9 @@ class Base(nn.Module):
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
# compute loss if the target is given
if targ_list is not None:
target_list = self._samplewise_merge_tensors(
text_list,
lang_list,
[ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ], # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
targ_list,
sep=torch.tensor(self.ignore_index, device=device)
)
if training:
target_list = self.training_targets( inputs )
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)):
if quant_levels is not None and quant_levels[i] > 0: