forked from mrq/DL-Art-School
script for uploading models to the HF hub
This commit is contained in:
parent
dbc74e96b2
commit
ba155e4e2f
22
codes/scripts/hugging_face_hub_upload.py
Normal file
22
codes/scripts/hugging_face_hub_upload.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from models.asr.w2v_wrapper import Wav2VecWrapper
|
||||||
|
from models.tacotron2.text import tacotron_symbol_mapping
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
"""
|
||||||
|
Utility script for uploading model weights to the HF hub
|
||||||
|
"""
|
||||||
|
|
||||||
|
'''
|
||||||
|
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
|
||||||
|
weights = torch.load('D:\\dlas\\experiments\\train_wav2vec_mass_large2\\models\\22500_wav2vec.pth')
|
||||||
|
model.load_state_dict(weights)
|
||||||
|
model.w2v.save_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli", push_to_hub=True)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Build tokenizer vocab
|
||||||
|
mapping = tacotron_symbol_mapping()
|
||||||
|
print(json.dumps(mapping))
|
Loading…
Reference in New Issue
Block a user