refractor cleanup, had a revelation on how I can handle a batch of varying tasks
This commit is contained in:
parent
467fa1c5ee
commit
b0bd88833c
|
@ -9,8 +9,7 @@ import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass, field
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
@ -161,13 +161,17 @@ class AR_NAR(Base):
|
||||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
|
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
|
||||||
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||||
|
|
||||||
return super().forward(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
targ_list=targ_list,
|
targ_list=targ_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().forward(
|
||||||
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
# is NAR
|
# is NAR
|
||||||
|
@ -183,12 +187,16 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
quant_levels = torch.full((len(text_list),), level)
|
quant_levels = torch.full((len(text_list),), level)
|
||||||
|
|
||||||
logits = super().forward(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=prev_list,
|
resps_list=prev_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = super().forward(
|
||||||
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -235,23 +243,23 @@ class AR_NAR(Base):
|
||||||
else:
|
else:
|
||||||
resps_list = self._unsqueeze_list(sequence_list)
|
resps_list = self._unsqueeze_list(sequence_list)
|
||||||
|
|
||||||
|
inputs = self.inputs(
|
||||||
|
text_list=text_list,
|
||||||
|
proms_list=proms_list,
|
||||||
|
resps_list=resps_list,
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
)
|
||||||
|
|
||||||
if recurrent_state is not None:
|
if recurrent_state is not None:
|
||||||
logits, recurrent_state = super().forward(
|
logits, recurrent_state = super().forward(
|
||||||
text_list=text_list,
|
inputs=inputs,
|
||||||
proms_list=proms_list,
|
state=recurrent_state,
|
||||||
resps_list=resps_list,
|
|
||||||
lang_list=lang_list,
|
|
||||||
tone_list=tone_list,
|
|
||||||
state=recurrent_state
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logits = super().forward(
|
logits = super().forward(
|
||||||
text_list=text_list,
|
inputs=inputs,
|
||||||
proms_list=proms_list,
|
state=recurrent_state,
|
||||||
resps_list=resps_list,
|
|
||||||
lang_list=lang_list,
|
|
||||||
tone_list=tone_list,
|
|
||||||
state=recurrent_state
|
|
||||||
)
|
)
|
||||||
|
|
||||||
r = super().sample(
|
r = super().sample(
|
||||||
|
|
|
@ -10,6 +10,7 @@ from functools import partial
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from torch import Tensor, einsum, nn
|
from torch import Tensor, einsum, nn
|
||||||
|
from torch.nn import Embedding
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
@ -165,11 +166,13 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
||||||
return x, m
|
return x, m
|
||||||
|
|
||||||
# automagically parses a batch-list and returns it as a list
|
# automagically parses a batch-list and returns it as a list
|
||||||
|
"""
|
||||||
class Embedding(nn.Embedding):
|
class Embedding(nn.Embedding):
|
||||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||||
if len(x_list) == 0:
|
if len(x_list) == 0:
|
||||||
return []
|
return []
|
||||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||||
|
"""
|
||||||
|
|
||||||
class MultiEmbedding(nn.Module):
|
class MultiEmbedding(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -218,22 +221,18 @@ class AudioEmbedding(nn.Module):
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
|
||||||
res_list = []
|
# prom
|
||||||
|
if quant_levels is None and xi.shape[-1] > 1:
|
||||||
for i, xi in enumerate(x_list):
|
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||||
# prom
|
# AR resp
|
||||||
if quant_levels is None and xi.shape[-1] > 1:
|
elif quant_levels is None or quant_levels == 0:
|
||||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
x = self.embeddings[0]( xi[:, 0] )
|
||||||
# AR resp
|
# NAR resp
|
||||||
elif quant_levels is None or quant_levels[i] == 0:
|
else:
|
||||||
x = self.embeddings[0]( xi[:, 0] )
|
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||||
# NAR resp
|
|
||||||
else:
|
return x
|
||||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
|
||||||
res_list.append(x)
|
|
||||||
|
|
||||||
return res_list
|
|
||||||
|
|
||||||
class Base(nn.Module):
|
class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
|
@ -302,17 +301,6 @@ class Base(nn.Module):
|
||||||
def ignore_index(self):
|
def ignore_index(self):
|
||||||
return -100
|
return -100
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _samplewise_merge_tensors(*l, sep: Tensor | None):
|
|
||||||
if sep is None:
|
|
||||||
cat = torch.cat
|
|
||||||
else:
|
|
||||||
cat = partial(_join, sep=sep)
|
|
||||||
|
|
||||||
l = [ x for x in l if x is not None ]
|
|
||||||
|
|
||||||
return [*map(cat, zip(*l))]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_tokens: int = 1024,
|
n_tokens: int = 1024,
|
||||||
|
@ -638,51 +626,104 @@ class Base(nn.Module):
|
||||||
|
|
||||||
return x, state, aux_loss
|
return x, state, aux_loss
|
||||||
|
|
||||||
def forward(
|
def inputs(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
proms_list: list[Tensor],
|
proms_list: list[Tensor],
|
||||||
resps_list: list[Tensor],
|
resps_list: list[Tensor],
|
||||||
targ_list: list[Tensor] | None = None,
|
targ_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
quant_levels: Tensor | None = None,
|
|
||||||
state: dict | list | None = None,
|
|
||||||
):
|
):
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
|
||||||
# silently ignore languages if model does not have it
|
inputs = [ [] for _ in range(batch_size) ]
|
||||||
if self.langs_emb is None:
|
for i in range(batch_size):
|
||||||
lang_list = None
|
if text_list is not None:
|
||||||
# inject default language
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
elif lang_list is None:
|
if proms_list is not None:
|
||||||
lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
|
if resps_list is not None:
|
||||||
# silently ignore tones if model does not have it
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
if self.tones_emb is None:
|
if targ_list is not None:
|
||||||
tone_list = None
|
inputs[i].append( ( "targ", targ_list[i] ) )
|
||||||
# inject default tone
|
|
||||||
elif tone_list is None:
|
|
||||||
tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
|
|
||||||
|
|
||||||
"""
|
return inputs
|
||||||
# Typical sequence format
|
|
||||||
# To-do: integrate tasks again
|
|
||||||
<s><text></s><sep><lang><sep><prom><sep><tone><sep><resp><stop>
|
|
||||||
"""
|
|
||||||
x_list = self._samplewise_merge_tensors(
|
|
||||||
self.text_emb(text_list),
|
|
||||||
self.langs_emb(lang_list) if lang_list is not None else None,
|
|
||||||
self.proms_emb(proms_list),
|
|
||||||
self.tones_emb(tone_list) if tone_list is not None else None,
|
|
||||||
self.resps_emb(resps_list, quant_levels),
|
|
||||||
sep=self.sep,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def inputs_to_embeddings(
|
||||||
|
self,
|
||||||
|
inputs: list,
|
||||||
|
quant_levels: Tensor | None = None
|
||||||
|
):
|
||||||
|
x_list = []
|
||||||
|
for b_i in range(len(inputs)):
|
||||||
|
batch = []
|
||||||
|
for i in range(len(inputs[b_i])):
|
||||||
|
name, input = inputs[b_i][i]
|
||||||
|
embedding = None
|
||||||
|
if name == "text":
|
||||||
|
embedding = self.text_emb( input )
|
||||||
|
elif name == "lang":
|
||||||
|
embedding = self.langs_emb( input )
|
||||||
|
elif name == "prom":
|
||||||
|
embedding = self.proms_emb( input )
|
||||||
|
elif name == "tone":
|
||||||
|
embedding = self.tones_emb( input )
|
||||||
|
elif name == "resp":
|
||||||
|
embedding = self.resps_emb( input, quant_levels[b_i] if quant_levels is not None else None )
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
batch.append(embedding)
|
||||||
|
|
||||||
|
x_list.append( _join( batch, self.sep ) )
|
||||||
|
|
||||||
|
return x_list
|
||||||
|
|
||||||
|
def training_targets(
|
||||||
|
self,
|
||||||
|
inputs: list,
|
||||||
|
):
|
||||||
|
x_list = []
|
||||||
|
for bi in range(len(inputs)):
|
||||||
|
batch = []
|
||||||
|
for i in range(len(inputs[bi])):
|
||||||
|
name, input = inputs[bi][i]
|
||||||
|
device = input.device
|
||||||
|
|
||||||
|
if name == "prom":
|
||||||
|
batch.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
|
elif name in ["text", "lang", "tone", "targ"]:
|
||||||
|
batch.append( input )
|
||||||
|
|
||||||
|
x_list.append( _join( batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||||
|
|
||||||
|
return x_list
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: list,
|
||||||
|
|
||||||
|
quant_levels: Tensor | None = None,
|
||||||
|
state: dict | list | None = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
|
||||||
|
# yes, there's a better way.
|
||||||
|
training = False
|
||||||
|
for b_i in range(len(inputs)):
|
||||||
|
for i in range(len(inputs[b_i])):
|
||||||
|
name, input = inputs[b_i][i]
|
||||||
|
if name == "targ":
|
||||||
|
training = True
|
||||||
|
|
||||||
|
|
||||||
|
device = x.device
|
||||||
|
batch_size = len(x_list)
|
||||||
|
|
||||||
# pad our input and mask, but retain the original length by doing it after
|
# pad our input and mask, but retain the original length by doing it after
|
||||||
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
||||||
|
@ -709,15 +750,9 @@ class Base(nn.Module):
|
||||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if targ_list is not None:
|
if training:
|
||||||
target_list = self._samplewise_merge_tensors(
|
target_list = self.training_targets( inputs )
|
||||||
text_list,
|
|
||||||
lang_list,
|
|
||||||
[ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ], # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
|
||||||
targ_list,
|
|
||||||
sep=torch.tensor(self.ignore_index, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
for i in range(len(target_list)):
|
for i in range(len(target_list)):
|
||||||
if quant_levels is not None and quant_levels[i] > 0:
|
if quant_levels is not None and quant_levels[i] > 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user