refractor cleanup, had a revelation on how I can handle a batch of varying tasks

This commit is contained in:
mrq 2024-04-16 21:04:48 -05:00
parent 467fa1c5ee
commit b0bd88833c
3 changed files with 126 additions and 84 deletions

View File

@ -9,8 +9,7 @@ import time
import torch import torch
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path

View File

@ -161,13 +161,17 @@ class AR_NAR(Base):
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ]) resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
return super().forward( inputs = self.inputs(
text_list=text_list, text_list=text_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
targ_list=targ_list, targ_list=targ_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list
)
return super().forward(
inputs=inputs,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
# is NAR # is NAR
@ -183,12 +187,16 @@ class AR_NAR(Base):
quant_levels = torch.full((len(text_list),), level) quant_levels = torch.full((len(text_list),), level)
logits = super().forward( inputs = self.inputs(
text_list=text_list, text_list=text_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=prev_list, resps_list=prev_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
)
logits = super().forward(
inputs=inputs,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
@ -235,23 +243,23 @@ class AR_NAR(Base):
else: else:
resps_list = self._unsqueeze_list(sequence_list) resps_list = self._unsqueeze_list(sequence_list)
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
)
if recurrent_state is not None: if recurrent_state is not None:
logits, recurrent_state = super().forward( logits, recurrent_state = super().forward(
text_list=text_list, inputs=inputs,
proms_list=proms_list, state=recurrent_state,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state
) )
else: else:
logits = super().forward( logits = super().forward(
text_list=text_list, inputs=inputs,
proms_list=proms_list, state=recurrent_state,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state
) )
r = super().sample( r = super().sample(

View File

@ -10,6 +10,7 @@ from functools import partial
from einops import rearrange from einops import rearrange
from torch import Tensor, einsum, nn from torch import Tensor, einsum, nn
from torch.nn import Embedding
from torch.distributions import Categorical from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
@ -165,11 +166,13 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
return x, m return x, m
# automagically parses a batch-list and returns it as a list # automagically parses a batch-list and returns it as a list
"""
class Embedding(nn.Embedding): class Embedding(nn.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]: def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0: if len(x_list) == 0:
return [] return []
return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
"""
class MultiEmbedding(nn.Module): class MultiEmbedding(nn.Module):
""" """
@ -218,22 +221,18 @@ class AudioEmbedding(nn.Module):
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]: def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
res_list = [] # prom
if quant_levels is None and xi.shape[-1] > 1:
for i, xi in enumerate(x_list): x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# prom # AR resp
if quant_levels is None and xi.shape[-1] > 1: elif quant_levels is None or quant_levels == 0:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) x = self.embeddings[0]( xi[:, 0] )
# AR resp # NAR resp
elif quant_levels is None or quant_levels[i] == 0: else:
x = self.embeddings[0]( xi[:, 0] ) x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# NAR resp
else: return x
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
res_list.append(x)
return res_list
class Base(nn.Module): class Base(nn.Module):
@property @property
@ -302,17 +301,6 @@ class Base(nn.Module):
def ignore_index(self): def ignore_index(self):
return -100 return -100
@staticmethod
def _samplewise_merge_tensors(*l, sep: Tensor | None):
if sep is None:
cat = torch.cat
else:
cat = partial(_join, sep=sep)
l = [ x for x in l if x is not None ]
return [*map(cat, zip(*l))]
def __init__( def __init__(
self, self,
n_tokens: int = 1024, n_tokens: int = 1024,
@ -638,51 +626,104 @@ class Base(nn.Module):
return x, state, aux_loss return x, state, aux_loss
def forward( def inputs(
self, self,
text_list: list[Tensor], text_list: list[Tensor],
proms_list: list[Tensor], proms_list: list[Tensor],
resps_list: list[Tensor], resps_list: list[Tensor],
targ_list: list[Tensor] | None = None, targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
): ):
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
# silently ignore languages if model does not have it inputs = [ [] for _ in range(batch_size) ]
if self.langs_emb is None: for i in range(batch_size):
lang_list = None if text_list is not None:
# inject default language inputs[i].append( ( "text", text_list[i] ) )
elif lang_list is None: if proms_list is not None:
lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ] inputs[i].append( ( "prom", proms_list[i] ) )
if resps_list is not None:
# silently ignore tones if model does not have it inputs[i].append( ( "resp", resps_list[i] ) )
if self.tones_emb is None: if targ_list is not None:
tone_list = None inputs[i].append( ( "targ", targ_list[i] ) )
# inject default tone
elif tone_list is None:
tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
""" return inputs
# Typical sequence format
# To-do: integrate tasks again
<s><text></s><sep><lang><sep><prom><sep><tone><sep><resp><stop>
"""
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.langs_emb(lang_list) if lang_list is not None else None,
self.proms_emb(proms_list),
self.tones_emb(tone_list) if tone_list is not None else None,
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
def inputs_to_embeddings(
self,
inputs: list,
quant_levels: Tensor | None = None
):
x_list = []
for b_i in range(len(inputs)):
batch = []
for i in range(len(inputs[b_i])):
name, input = inputs[b_i][i]
embedding = None
if name == "text":
embedding = self.text_emb( input )
elif name == "lang":
embedding = self.langs_emb( input )
elif name == "prom":
embedding = self.proms_emb( input )
elif name == "tone":
embedding = self.tones_emb( input )
elif name == "resp":
embedding = self.resps_emb( input, quant_levels[b_i] if quant_levels is not None else None )
else:
continue
batch.append(embedding)
x_list.append( _join( batch, self.sep ) )
return x_list
def training_targets(
self,
inputs: list,
):
x_list = []
for bi in range(len(inputs)):
batch = []
for i in range(len(inputs[bi])):
name, input = inputs[bi][i]
device = input.device
if name == "prom":
batch.append( torch.full_like(input[..., 0], self.ignore_index) )
elif name in ["text", "lang", "tone", "targ"]:
batch.append( input )
x_list.append( _join( batch, torch.tensor(self.ignore_index, device=device) ) )
return x_list
def forward(
self,
inputs: list,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
):
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
# yes, there's a better way.
training = False
for b_i in range(len(inputs)):
for i in range(len(inputs[b_i])):
name, input = inputs[b_i][i]
if name == "targ":
training = True
device = x.device
batch_size = len(x_list)
# pad our input and mask, but retain the original length by doing it after # pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0: if self.l_padding and x.shape[1] % self.l_padding != 0:
@ -709,15 +750,9 @@ class Base(nn.Module):
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
# compute loss if the target is given # compute loss if the target is given
if targ_list is not None: if training:
target_list = self._samplewise_merge_tensors( target_list = self.training_targets( inputs )
text_list,
lang_list,
[ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ], # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
targ_list,
sep=torch.tensor(self.ignore_index, device=device)
)
# modify only for the AR so it can properly behave like a transformer # modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)): for i in range(len(target_list)):
if quant_levels is not None and quant_levels[i] > 0: if quant_levels is not None and quant_levels[i] > 0: