Support fp16

This commit is contained in:
enhuiz 2023-01-14 17:09:08 +08:00
parent 4cb958f7ff
commit be4619cddf
4 changed files with 17 additions and 6 deletions

View File

@ -44,16 +44,27 @@ class Config(ConfigBase):
min_phones: int = 10
max_phones: int = 50
use_fp16: bool = True
@cached_property
def get_spkr(self):
return eval(self.spkr_name_getter)
@property
def fp16_cfg(self):
return {
"enabled": self.use_fp16,
}
@property
def ds_cfg(self):
return {
"train_micro_batch_size_per_gpu": self.batch_size,
"gradient_accumulation_steps": 1,
"optimizer": {"type": "Adam"},
"optimizer": {
"type": "Adam",
"lr": self.warmup_min_lr,
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
@ -65,6 +76,7 @@ class Config(ConfigBase):
},
},
"gradient_clipping": self.gradient_clipping,
"fp16": self.fp16_cfg,
}
@property

@ -1 +1 @@
Subproject commit 0dbc980be3a4cb26ad5e3b16643a70a47623358a
Subproject commit 78d70f0331f844e9fe9ed253175b9e0b28082153

View File

@ -31,6 +31,7 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
m = _create_mask(l, x_list[0].device)
m = m.t().unsqueeze(-1) # (t b 1)
m = rearrange(m, pattern)
m = m.to(x)
return x, m
@ -256,7 +257,7 @@ class AdditiveMultiEmbedding(nn.Embedding):
x = torch.cat(x_list)
assert x.shape[1] == self.n_levels
w = rearrange(self.weight, "(q k) d -> q k d", q=self.n_levels)
x = F.one_hot(x, num_classes=self.n_tokens).float() # n q -> n q k
x = F.one_hot(x, num_classes=self.n_tokens).to(w) # n q -> n q k
x = einsum("q k d, n q k -> n d", w, x)
x_list = x.split([*map(len, x_list)])
return x_list
@ -285,7 +286,7 @@ class SelectiveMultiEmbedding(nn.Embedding):
w = repeat(self.weight[0], "d -> b d", b=len(x))
w = rearrange(w, "b (k d) -> b k d", k=self.n_tokens_per_level)
x = F.one_hot(x, num_classes=self.n_tokens_per_level).float() # b t k
x = F.one_hot(x, num_classes=self.n_tokens_per_level).to(w) # b t k
x = einsum("b k d, b t k -> b t d", w, x)
x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]

View File

@ -1,5 +1,3 @@
import random
import torch
from einops import rearrange
from torch import Tensor