added option to split between text loss and audio loss (to-do: document this better), because it may or may not be a problem with LLaMA-backed models because my loss hovers around 3.9 / 56% accuracy despite sounding decent at the moment

This commit is contained in:
mrq 2024-05-19 11:23:56 -05:00
parent 74e531d391
commit 458b95d196
4 changed files with 157 additions and 71 deletions

View File

@ -6,41 +6,38 @@ models:
tasks: 8
langs: 2
tones: 1
arch_type: "retnet"
arch_type: llama
training: True
version: 3
version: 4
attention: flash_attention_2
dropout: 0.1
loss_factors:
text: 0.1
resp: 1.0
hyperparameters:
batch_size: 4
autotune: False
autotune_params:
start_profile_step: 1
end_profile_step: 50
num_tuning_micro_batch_sizes: 8
batch_size: 16
gradient_accumulation_steps: 4
gradient_clipping: 10
optimizer: Adagrad
gradient_clipping: 1.0
warmup_steps: 100
optimizer: Prodigy
learning_rate: 1.0
torch_optimizer: True
learning_rate: 1.0e-2
scheduler_type: ""
#scheduler_type: OneCycle
#scheduler_params:
# cycle_first_step_size: 10_000
# cycle_first_stair_count: 10_000
# cycle_second_step_size: 15_000
# cycle_second_stair_count: 15_000
# decay_step_size: 5_000
# cycle_min_lr: 2.5e-4 # 1.0e-5
# cycle_max_lr: 2.5e-4 # 1.0e-4
# decay_lr_rate: 0.0
# cycle_min_mom: 0.90
# cycle_max_mom: 0.99
# decay_mom_rate: 0.0
scheduler: "" # ScheduleFree
torch_scheduler: True
evaluation:
batch_size: 8
frequency: 10000
frequency: 5000
size: 8
steps: 500
@ -49,8 +46,9 @@ evaluation:
load_disabled_engines: True
trainer:
no_logger: True
#no_logger: True
ddp: False
#check_for_oom: False
iterations: 1_000_000
save_tag: step
@ -72,7 +70,7 @@ trainer:
gc_mode: None # "global_step"
weight_dtype: float32
weight_dtype: float32 # float16 or bfloat16
amp: False
backend: deepspeed
@ -81,34 +79,34 @@ trainer:
zero_optimization_level: 0
use_compression_training: False
amp: False
activation_checkpointing: True
load_webui: True
load_webui: False
inference:
backend: deepspeed
audio_backend: "dac"
normalize: False
weight_dtype: float32
weight_dtype: float32 # float16 or bfloat16
amp: False
bitsandbytes:
enabled: False
optimizations:
injects: False
replace: False
replace: True
linear: False
embedding: False
bitnet: False
optimizers: True
fp8:
enabled: False
backend: "te"
experimental: True
bitsandbytes: False
dadaptation: False
bitnet: False
fp8: False
experimental: True # practically required now it seems
dataset:
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
@ -121,23 +119,19 @@ dataset:
hdf5_flag: r
validate: True
workers: 8
workers: 2
cache: True
#phones_range: [4, 512]
#duration_range: [1.0, 32.0]
phones_range: [0, 512]
duration_range: [0.0, 64.0]
duration_range: [3.0, 5.0]
random_utterance: 1.0
max_prompts: 3
prompt_duration: 6.0
max_prompts: 1
prompt_duration: 3.0
max_resps: 1
p_resp_append: 0.25
sample_type: speaker
sample_type: path # speaker
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]

View File

@ -213,9 +213,13 @@ class Model:
attention: str = "auto"
audio_embedding_sums: bool = True
dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: {})
def get(self, name=None):
return [ self ] if not name or self.name == name else []
def loss_factor(self, k):
return self.loss_factors[k] if k in self.loss_factors else 1.0
@property
def max_levels(self):
@ -491,6 +495,11 @@ class DeepSpeed:
return ds_cfg
@dataclass()
class LossFactor:
text: float = 1.0
resp: float = 1.0
@dataclass()
class Trainer:
iterations: int = 100_000

View File

@ -348,12 +348,15 @@ def example_usage():
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
]
proms_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resps_list = [
qnt.to(device),
qnt[:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text_list = text_list[:1]

View File

@ -426,6 +426,11 @@ class Base(nn.Module):
def ignore_index(self):
return -100
def loss_factor(self, k):
if self.config is None:
return 1.0
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
def __init__(
self,
n_tokens: int = 1024,
@ -880,6 +885,31 @@ class Base(nn.Module):
return x_list
def training_targets_split(
self,
inputs: list,
):
text_lists = []
resp_lists = []
for bi in range(len(inputs)):
text_batch = []
resp_batch = []
for i in range(len(inputs[bi])):
name, input = inputs[bi][i]
device = input.device
if name in ["text", "lang" ]:
text_batch.append( input )
elif name == "targ":
resp_batch.append( input )
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
return text_lists, resp_lists
def forward(
self,
inputs: list,
@ -929,30 +959,80 @@ class Base(nn.Module):
# compute loss if the target is given
if training:
target_list = self.training_targets( inputs )
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)):
if quant_levels is not None and quant_levels[i] > 0:
continue
if not self.config.loss_factors:
target_list = self.training_targets( inputs )
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)):
if quant_levels is not None and quant_levels[i] > 0:
continue
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
target = torch.cat( target_list )
inputs = torch.cat( logits )
target = torch.cat( target_list )
inputs = torch.cat( logits )
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
# split our loss
else:
target_text_list, target_resp_list = self.training_targets_split( inputs )
# grab respective slice of logits
logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ]
logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ]
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_text_list)):
if quant_levels is not None and quant_levels[i] > 0:
continue
# shift the target so that token n...
logits_text[i] = logits_text[i][..., :-1, :]
logits_resp[i] = logits_resp[i][..., :-1, :]
# predicts token n + 1
target_text_list[i] = target_text_list[i][..., 1:]
target_resp_list[i] = target_resp_list[i][..., 1:]
target_text = torch.cat( target_text_list ).long()
target_resp = torch.cat( target_resp_list ).long()
inputs_text = torch.cat( logits_text )
inputs_resp = torch.cat( logits_resp )
self.loss = dict(
text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ),
resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ),
)
for k in self.loss:
self.loss[k] *= self.loss_factor(k)
# to-do: compute loss per individual batch to scale per RVQ level
"""
rvq_loss_factor = self.loss_factor("quant")
if isinstance( rvq_loss_factor, list ):
...
"""
self.stats = dict(
acc = dict(
text = self.accuracy_metric( inputs_text, target_text ),
resp = self.accuracy_metric( inputs_resp, target_resp ),
),
)
# include any additional losses (for example: MoE router)
if aux_loss is not None:
self.loss["nll"] += aux_loss
self.loss["aux_loss"] = aux_loss
return (logits, state) if state is not None else logits