added option to split between text loss and audio loss (to-do: document this better), because it may or may not be a problem with LLaMA-backed models because my loss hovers around 3.9 / 56% accuracy despite sounding decent at the moment

This commit is contained in:
mrq 2024-05-19 11:23:56 -05:00
parent 74e531d391
commit 458b95d196
4 changed files with 157 additions and 71 deletions

View File

@ -6,41 +6,38 @@ models:
tasks: 8 tasks: 8
langs: 2 langs: 2
tones: 1 tones: 1
arch_type: "retnet" arch_type: llama
training: True training: True
version: 3 version: 4
attention: flash_attention_2
dropout: 0.1
loss_factors:
text: 0.1
resp: 1.0
hyperparameters: hyperparameters:
batch_size: 4 autotune: False
autotune_params:
start_profile_step: 1
end_profile_step: 50
num_tuning_micro_batch_sizes: 8
batch_size: 16
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
gradient_clipping: 10 gradient_clipping: 1.0
warmup_steps: 100
optimizer: Adagrad optimizer: Prodigy
learning_rate: 1.0
torch_optimizer: True torch_optimizer: True
learning_rate: 1.0e-2
scheduler_type: "" scheduler: "" # ScheduleFree
#scheduler_type: OneCycle torch_scheduler: True
#scheduler_params:
# cycle_first_step_size: 10_000
# cycle_first_stair_count: 10_000
# cycle_second_step_size: 15_000
# cycle_second_stair_count: 15_000
# decay_step_size: 5_000
# cycle_min_lr: 2.5e-4 # 1.0e-5
# cycle_max_lr: 2.5e-4 # 1.0e-4
# decay_lr_rate: 0.0
# cycle_min_mom: 0.90
# cycle_max_mom: 0.99
# decay_mom_rate: 0.0
evaluation: evaluation:
batch_size: 8 batch_size: 8
frequency: 10000 frequency: 5000
size: 8 size: 8
steps: 500 steps: 500
@ -49,8 +46,9 @@ evaluation:
load_disabled_engines: True load_disabled_engines: True
trainer: trainer:
no_logger: True #no_logger: True
ddp: False
#check_for_oom: False
iterations: 1_000_000 iterations: 1_000_000
save_tag: step save_tag: step
@ -72,7 +70,7 @@ trainer:
gc_mode: None # "global_step" gc_mode: None # "global_step"
weight_dtype: float32 weight_dtype: float32 # float16 or bfloat16
amp: False amp: False
backend: deepspeed backend: deepspeed
@ -81,34 +79,34 @@ trainer:
zero_optimization_level: 0 zero_optimization_level: 0
use_compression_training: False use_compression_training: False
amp: False
activation_checkpointing: True activation_checkpointing: True
load_webui: True load_webui: False
inference: inference:
backend: deepspeed backend: deepspeed
audio_backend: "dac" audio_backend: "dac"
normalize: False normalize: False
weight_dtype: float32 weight_dtype: float32 # float16 or bfloat16
amp: False amp: False
bitsandbytes: optimizations:
enabled: False
injects: False injects: False
replace: False replace: True
linear: False linear: False
embedding: False embedding: False
optimizers: True
bitsandbytes: False
dadaptation: False
bitnet: False bitnet: False
fp8: False
fp8: experimental: True # practically required now it seems
enabled: False
backend: "te"
experimental: True
dataset: dataset:
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
@ -121,23 +119,19 @@ dataset:
hdf5_flag: r hdf5_flag: r
validate: True validate: True
workers: 8 workers: 2
cache: True cache: True
#phones_range: [4, 512] duration_range: [3.0, 5.0]
#duration_range: [1.0, 32.0]
phones_range: [0, 512]
duration_range: [0.0, 64.0]
random_utterance: 1.0 random_utterance: 1.0
max_prompts: 3 max_prompts: 1
prompt_duration: 6.0 prompt_duration: 3.0
max_resps: 1 max_resps: 1
p_resp_append: 0.25 p_resp_append: 0.25
sample_type: speaker sample_type: path # speaker
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]

View File

@ -213,10 +213,14 @@ class Model:
attention: str = "auto" attention: str = "auto"
audio_embedding_sums: bool = True audio_embedding_sums: bool = True
dropout: float = 0.1 # adjustable dropout value dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: {})
def get(self, name=None): def get(self, name=None):
return [ self ] if not name or self.name == name else [] return [ self ] if not name or self.name == name else []
def loss_factor(self, k):
return self.loss_factors[k] if k in self.loss_factors else 1.0
@property @property
def max_levels(self): def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels return self._max_levels if self._max_levels > 0 else self.prom_levels
@ -491,6 +495,11 @@ class DeepSpeed:
return ds_cfg return ds_cfg
@dataclass()
class LossFactor:
text: float = 1.0
resp: float = 1.0
@dataclass() @dataclass()
class Trainer: class Trainer:
iterations: int = 100_000 iterations: int = 100_000

View File

@ -348,12 +348,15 @@ def example_usage():
text_list = [ text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
] ]
proms_list = [ proms_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device), qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
] ]
resps_list = [ resps_list = [
qnt.to(device), qnt[:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
] ]
text_list = text_list[:1] text_list = text_list[:1]

View File

@ -426,6 +426,11 @@ class Base(nn.Module):
def ignore_index(self): def ignore_index(self):
return -100 return -100
def loss_factor(self, k):
if self.config is None:
return 1.0
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
def __init__( def __init__(
self, self,
n_tokens: int = 1024, n_tokens: int = 1024,
@ -880,6 +885,31 @@ class Base(nn.Module):
return x_list return x_list
def training_targets_split(
self,
inputs: list,
):
text_lists = []
resp_lists = []
for bi in range(len(inputs)):
text_batch = []
resp_batch = []
for i in range(len(inputs[bi])):
name, input = inputs[bi][i]
device = input.device
if name in ["text", "lang" ]:
text_batch.append( input )
elif name == "targ":
resp_batch.append( input )
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
return text_lists, resp_lists
def forward( def forward(
self, self,
inputs: list, inputs: list,
@ -929,6 +959,7 @@ class Base(nn.Module):
# compute loss if the target is given # compute loss if the target is given
if training: if training:
if not self.config.loss_factors:
target_list = self.training_targets( inputs ) target_list = self.training_targets( inputs )
# modify only for the AR so it can properly behave like a transformer # modify only for the AR so it can properly behave like a transformer
@ -950,9 +981,58 @@ class Base(nn.Module):
acc = self.accuracy_metric( inputs, target ), acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ), # precision = self.precision_metric( inputs, target ),
) )
# split our loss
else:
target_text_list, target_resp_list = self.training_targets_split( inputs )
# grab respective slice of logits
logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ]
logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ]
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_text_list)):
if quant_levels is not None and quant_levels[i] > 0:
continue
# shift the target so that token n...
logits_text[i] = logits_text[i][..., :-1, :]
logits_resp[i] = logits_resp[i][..., :-1, :]
# predicts token n + 1
target_text_list[i] = target_text_list[i][..., 1:]
target_resp_list[i] = target_resp_list[i][..., 1:]
target_text = torch.cat( target_text_list ).long()
target_resp = torch.cat( target_resp_list ).long()
inputs_text = torch.cat( logits_text )
inputs_resp = torch.cat( logits_resp )
self.loss = dict(
text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ),
resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ),
)
for k in self.loss:
self.loss[k] *= self.loss_factor(k)
# to-do: compute loss per individual batch to scale per RVQ level
"""
rvq_loss_factor = self.loss_factor("quant")
if isinstance( rvq_loss_factor, list ):
...
"""
self.stats = dict(
acc = dict(
text = self.accuracy_metric( inputs_text, target_text ),
resp = self.accuracy_metric( inputs_resp, target_resp ),
),
)
# include any additional losses (for example: MoE router)
if aux_loss is not None: if aux_loss is not None:
self.loss["nll"] += aux_loss self.loss["aux_loss"] = aux_loss
return (logits, state) if state is not None else logits return (logits, state) if state is not None else logits