-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathhubconf.py
37 lines (24 loc) · 985 Bytes
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""torch.hub configuration."""
dependencies = ["torch"]
import torch # pylint: disable=wrong-import-position
from xvector_jtubespeech.network.xvector import _XVector # pylint: disable=wrong-import-position
URLS = {
"xvector_jtubespeech": "https://raw.githubusercontent.com/sarulab-speech/xvector_jtubespeech/master/xvector.pth",
}
def xvector(progress: bool = True, pretrained: bool = True) -> _XVector:
"""
`x-vector JTubeSpeech` utterance embedding model.
Args:
progress - Whether to show model checkpoint load progress
"""
# Init
model = _XVector(24, 1233)
# Pretrained weights
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(url=URLS["xvector_jtubespeech"], map_location="cpu", progress=progress)
model.load_state_dict(state_dict)
# Mode
model.eval()
for param in model.parameters():
param.requires_grad = False
return model