added option to split between text loss and audio loss (to-do: document this better), because it may or may not be a problem with LLaMA-backed models because my loss hovers around 3.9 / 56% accuracy despite sounding decent at the moment
This commit is contained in:
parent
74e531d391
commit
458b95d196
|
@ -6,41 +6,38 @@ models:
|
||||||
tasks: 8
|
tasks: 8
|
||||||
langs: 2
|
langs: 2
|
||||||
tones: 1
|
tones: 1
|
||||||
arch_type: "retnet"
|
arch_type: llama
|
||||||
training: True
|
training: True
|
||||||
version: 3
|
version: 4
|
||||||
|
attention: flash_attention_2
|
||||||
|
dropout: 0.1
|
||||||
|
|
||||||
|
loss_factors:
|
||||||
|
text: 0.1
|
||||||
|
resp: 1.0
|
||||||
|
|
||||||
hyperparameters:
|
hyperparameters:
|
||||||
batch_size: 4
|
autotune: False
|
||||||
|
autotune_params:
|
||||||
|
start_profile_step: 1
|
||||||
|
end_profile_step: 50
|
||||||
|
num_tuning_micro_batch_sizes: 8
|
||||||
|
|
||||||
|
batch_size: 16
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
gradient_clipping: 10
|
gradient_clipping: 1.0
|
||||||
|
warmup_steps: 100
|
||||||
|
|
||||||
optimizer: Adagrad
|
optimizer: Prodigy
|
||||||
|
learning_rate: 1.0
|
||||||
torch_optimizer: True
|
torch_optimizer: True
|
||||||
learning_rate: 1.0e-2
|
|
||||||
|
|
||||||
scheduler_type: ""
|
scheduler: "" # ScheduleFree
|
||||||
#scheduler_type: OneCycle
|
torch_scheduler: True
|
||||||
#scheduler_params:
|
|
||||||
# cycle_first_step_size: 10_000
|
|
||||||
# cycle_first_stair_count: 10_000
|
|
||||||
|
|
||||||
# cycle_second_step_size: 15_000
|
|
||||||
# cycle_second_stair_count: 15_000
|
|
||||||
|
|
||||||
# decay_step_size: 5_000
|
|
||||||
|
|
||||||
# cycle_min_lr: 2.5e-4 # 1.0e-5
|
|
||||||
# cycle_max_lr: 2.5e-4 # 1.0e-4
|
|
||||||
# decay_lr_rate: 0.0
|
|
||||||
|
|
||||||
# cycle_min_mom: 0.90
|
|
||||||
# cycle_max_mom: 0.99
|
|
||||||
# decay_mom_rate: 0.0
|
|
||||||
|
|
||||||
evaluation:
|
evaluation:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
frequency: 10000
|
frequency: 5000
|
||||||
size: 8
|
size: 8
|
||||||
|
|
||||||
steps: 500
|
steps: 500
|
||||||
|
@ -49,8 +46,9 @@ evaluation:
|
||||||
load_disabled_engines: True
|
load_disabled_engines: True
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
no_logger: True
|
#no_logger: True
|
||||||
|
ddp: False
|
||||||
|
#check_for_oom: False
|
||||||
iterations: 1_000_000
|
iterations: 1_000_000
|
||||||
|
|
||||||
save_tag: step
|
save_tag: step
|
||||||
|
@ -72,7 +70,7 @@ trainer:
|
||||||
|
|
||||||
gc_mode: None # "global_step"
|
gc_mode: None # "global_step"
|
||||||
|
|
||||||
weight_dtype: float32
|
weight_dtype: float32 # float16 or bfloat16
|
||||||
amp: False
|
amp: False
|
||||||
|
|
||||||
backend: deepspeed
|
backend: deepspeed
|
||||||
|
@ -81,34 +79,34 @@ trainer:
|
||||||
zero_optimization_level: 0
|
zero_optimization_level: 0
|
||||||
use_compression_training: False
|
use_compression_training: False
|
||||||
|
|
||||||
|
amp: False
|
||||||
|
|
||||||
activation_checkpointing: True
|
activation_checkpointing: True
|
||||||
|
|
||||||
load_webui: True
|
load_webui: False
|
||||||
|
|
||||||
inference:
|
inference:
|
||||||
backend: deepspeed
|
backend: deepspeed
|
||||||
audio_backend: "dac"
|
audio_backend: "dac"
|
||||||
normalize: False
|
normalize: False
|
||||||
|
|
||||||
weight_dtype: float32
|
weight_dtype: float32 # float16 or bfloat16
|
||||||
amp: False
|
amp: False
|
||||||
|
|
||||||
bitsandbytes:
|
optimizations:
|
||||||
enabled: False
|
|
||||||
|
|
||||||
injects: False
|
injects: False
|
||||||
replace: False
|
replace: True
|
||||||
|
|
||||||
linear: False
|
linear: False
|
||||||
embedding: False
|
embedding: False
|
||||||
|
optimizers: True
|
||||||
|
|
||||||
|
bitsandbytes: False
|
||||||
|
dadaptation: False
|
||||||
bitnet: False
|
bitnet: False
|
||||||
|
fp8: False
|
||||||
|
|
||||||
fp8:
|
experimental: True # practically required now it seems
|
||||||
enabled: False
|
|
||||||
backend: "te"
|
|
||||||
|
|
||||||
experimental: True
|
|
||||||
|
|
||||||
dataset:
|
dataset:
|
||||||
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||||
|
@ -121,23 +119,19 @@ dataset:
|
||||||
hdf5_flag: r
|
hdf5_flag: r
|
||||||
validate: True
|
validate: True
|
||||||
|
|
||||||
workers: 8
|
workers: 2
|
||||||
cache: True
|
cache: True
|
||||||
|
|
||||||
#phones_range: [4, 512]
|
duration_range: [3.0, 5.0]
|
||||||
#duration_range: [1.0, 32.0]
|
|
||||||
|
|
||||||
phones_range: [0, 512]
|
|
||||||
duration_range: [0.0, 64.0]
|
|
||||||
|
|
||||||
random_utterance: 1.0
|
random_utterance: 1.0
|
||||||
max_prompts: 3
|
max_prompts: 1
|
||||||
prompt_duration: 6.0
|
prompt_duration: 3.0
|
||||||
|
|
||||||
max_resps: 1
|
max_resps: 1
|
||||||
p_resp_append: 0.25
|
p_resp_append: 0.25
|
||||||
|
|
||||||
sample_type: speaker
|
sample_type: path # speaker
|
||||||
|
|
||||||
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
|
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
|
||||||
|
|
||||||
|
|
|
@ -213,10 +213,14 @@ class Model:
|
||||||
attention: str = "auto"
|
attention: str = "auto"
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
|
loss_factors: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
return [ self ] if not name or self.name == name else []
|
return [ self ] if not name or self.name == name else []
|
||||||
|
|
||||||
|
def loss_factor(self, k):
|
||||||
|
return self.loss_factors[k] if k in self.loss_factors else 1.0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_levels(self):
|
def max_levels(self):
|
||||||
return self._max_levels if self._max_levels > 0 else self.prom_levels
|
return self._max_levels if self._max_levels > 0 else self.prom_levels
|
||||||
|
@ -491,6 +495,11 @@ class DeepSpeed:
|
||||||
|
|
||||||
return ds_cfg
|
return ds_cfg
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
|
class LossFactor:
|
||||||
|
text: float = 1.0
|
||||||
|
resp: float = 1.0
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Trainer:
|
class Trainer:
|
||||||
iterations: int = 100_000
|
iterations: int = 100_000
|
||||||
|
|
|
@ -348,12 +348,15 @@ def example_usage():
|
||||||
|
|
||||||
text_list = [
|
text_list = [
|
||||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||||
|
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||||
]
|
]
|
||||||
proms_list = [
|
proms_list = [
|
||||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||||
|
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||||
]
|
]
|
||||||
resps_list = [
|
resps_list = [
|
||||||
qnt.to(device),
|
qnt[:, :].to(device),
|
||||||
|
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||||
]
|
]
|
||||||
|
|
||||||
text_list = text_list[:1]
|
text_list = text_list[:1]
|
||||||
|
|
|
@ -426,6 +426,11 @@ class Base(nn.Module):
|
||||||
def ignore_index(self):
|
def ignore_index(self):
|
||||||
return -100
|
return -100
|
||||||
|
|
||||||
|
def loss_factor(self, k):
|
||||||
|
if self.config is None:
|
||||||
|
return 1.0
|
||||||
|
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_tokens: int = 1024,
|
n_tokens: int = 1024,
|
||||||
|
@ -880,6 +885,31 @@ class Base(nn.Module):
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
|
|
||||||
|
def training_targets_split(
|
||||||
|
self,
|
||||||
|
inputs: list,
|
||||||
|
):
|
||||||
|
text_lists = []
|
||||||
|
resp_lists = []
|
||||||
|
|
||||||
|
for bi in range(len(inputs)):
|
||||||
|
text_batch = []
|
||||||
|
resp_batch = []
|
||||||
|
|
||||||
|
for i in range(len(inputs[bi])):
|
||||||
|
name, input = inputs[bi][i]
|
||||||
|
device = input.device
|
||||||
|
|
||||||
|
if name in ["text", "lang" ]:
|
||||||
|
text_batch.append( input )
|
||||||
|
elif name == "targ":
|
||||||
|
resp_batch.append( input )
|
||||||
|
|
||||||
|
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||||
|
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||||
|
|
||||||
|
return text_lists, resp_lists
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
|
@ -929,6 +959,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if training:
|
if training:
|
||||||
|
if not self.config.loss_factors:
|
||||||
target_list = self.training_targets( inputs )
|
target_list = self.training_targets( inputs )
|
||||||
|
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
|
@ -950,9 +981,58 @@ class Base(nn.Module):
|
||||||
acc = self.accuracy_metric( inputs, target ),
|
acc = self.accuracy_metric( inputs, target ),
|
||||||
# precision = self.precision_metric( inputs, target ),
|
# precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
|
# split our loss
|
||||||
|
else:
|
||||||
|
target_text_list, target_resp_list = self.training_targets_split( inputs )
|
||||||
|
|
||||||
|
# grab respective slice of logits
|
||||||
|
logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ]
|
||||||
|
logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ]
|
||||||
|
|
||||||
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
|
for i in range(len(target_text_list)):
|
||||||
|
if quant_levels is not None and quant_levels[i] > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# shift the target so that token n...
|
||||||
|
logits_text[i] = logits_text[i][..., :-1, :]
|
||||||
|
logits_resp[i] = logits_resp[i][..., :-1, :]
|
||||||
|
|
||||||
|
# predicts token n + 1
|
||||||
|
target_text_list[i] = target_text_list[i][..., 1:]
|
||||||
|
target_resp_list[i] = target_resp_list[i][..., 1:]
|
||||||
|
|
||||||
|
target_text = torch.cat( target_text_list ).long()
|
||||||
|
target_resp = torch.cat( target_resp_list ).long()
|
||||||
|
|
||||||
|
inputs_text = torch.cat( logits_text )
|
||||||
|
inputs_resp = torch.cat( logits_resp )
|
||||||
|
|
||||||
|
self.loss = dict(
|
||||||
|
text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ),
|
||||||
|
resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ),
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in self.loss:
|
||||||
|
self.loss[k] *= self.loss_factor(k)
|
||||||
|
|
||||||
|
# to-do: compute loss per individual batch to scale per RVQ level
|
||||||
|
"""
|
||||||
|
rvq_loss_factor = self.loss_factor("quant")
|
||||||
|
if isinstance( rvq_loss_factor, list ):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.stats = dict(
|
||||||
|
acc = dict(
|
||||||
|
text = self.accuracy_metric( inputs_text, target_text ),
|
||||||
|
resp = self.accuracy_metric( inputs_resp, target_resp ),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# include any additional losses (for example: MoE router)
|
||||||
if aux_loss is not None:
|
if aux_loss is not None:
|
||||||
self.loss["nll"] += aux_loss
|
self.loss["aux_loss"] = aux_loss
|
||||||
|
|
||||||
return (logits, state) if state is not None else logits
|
return (logits, state) if state is not None else logits
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user