๋ณธ ํ์ด์ง๋ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ hub๋ฅผ ํตํด ๋ฐฐํฌํ๋ ๋ฐฉ๋ฒ๊ณผ ๋ฐฐํฌ๋ ๋ชจ๋ธ์ Pytorch Korea์ ๋ชจ๋ธ ์๋ด ํ์ด์ง์ ์ถ๊ฐํ๋ ๋ฐฉ๋ฒ์ ์๋ดํฉ๋๋ค.
torch.hub
๋ ์ฐ๊ตฌ ์ฌํ ๋ฐ ์ฌ์ฉ์ ์ฉ์ดํ๊ฒ ํ๊ธฐ ์ํด ์ค๊ณ๋ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ์ ์ฅ์์
๋๋ค.
Pytorch Hub๋ hubconf.py
๋ฅผ ์ถ๊ฐํ์ฌ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ(๋ชจ๋ธ ์ ์ ๋ฐ ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น)์ ๊นํ๋ธ ์ ์ฅ์์ ๊ฒ์ํ ์ ์๋๋ก ์ง์ํฉ๋๋ค.
hubconf.py
๋ ์ฌ๋ ค ๊ฐ์ ์ํธ๋ฆฌ ํฌ์ธํธ(entry point)๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค. ๊ฐ ์ํธ๋ฆฌ ํฌ์ธํธ๋ค์ ํ์ด์ฌ ํจ์๋ก ์ ์๋ฉ๋๋ค. (์๋ฅผ ๋ค์ด, ์ฌ์ฉ์๊ฐ ๋ฑ๋กํ๊ณ ์ ํ๋ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ)
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
๋ค์์ pytorch/vision/hubconf.py
์ ์ฐธ๊ณ ํ์ฌ ์์ฑํ resnet18
๋ชจ๋ธ์ ์ํธ๋ฆฌ ํฌ์ธํธ๋ฅผ ์ง์ ํ๋ ์ฝ๋์
๋๋ค.
๋๋ถ๋ถ์ ๊ฒฝ์ฐ hubconf.py
์ ๊ตฌํ๋ ๊ธฐ๋ฅ์ ๊ฐ์ ธ์ค๋ฉด ์ถฉ๋ถํฉ๋๋ค.
์ฌ๊ธฐ์์ ํ์ฅ ๋ฒ์ ์ ์์๋ก ์๋ ๋ฐฉ์์ ๋ณด์ฌ๋๋ฆฌ๊ฒ ์ต๋๋ค.
pytorch/vision repo ์ฌ๊ธฐ์ ๋ชจ๋ ์คํฌ๋ฆฝํธ๋ฅผ ๋ณผ ์ ์์ต๋๋ค.
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18์ ์ํธ๋ฆฌ ํฌ์ธํธ์ ์ด๋ฆ์
๋๋ค.
def resnet18(pretrained=False, **kwargs):
""" # ์ด ๋ฌธ์ ๋ฌธ์์ด์ hub.help() Resnet18 model ์์์ ๋ณด์ฌ์ค๋๋ค
์ฌ์ ๊ต์ก (bool) : kwwargs, load pretrained weights๋ฅผ ๋ชจ๋ธ์ ์ ์ฉํฉ๋๋ค
"""
# ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๊ณ ์ฌ์ ํ์ต๋ ๊ฐ์ค์น๋ฅผ ๋ก๋
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
๋ณ์๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๋ ๋ฐ ํ์ํ ํจํค์ง ์ด๋ฆ์ ๋ด์ ๋ชฉ๋ก์
๋๋ค. ์ด๋ ๋ชจ๋ธ ํ์ต์ ํ์ํ ์ข
์์ฑ๊ณผ๋ ์ฝ๊ฐ ๋ค๋ฅผ ์ ์์์ ์ฃผ์ํ์ธ์.
args
์ kwargs
๋ ์ค์ ํธ์ถ ๊ฐ๋ฅํ ํจ์๋ก ์ ๋ฌ๋ฉ๋๋ค.
ํจ์์ Docstring์ ๋์ ๋ฉ์์ง์ ์ญํ ์ ํฉ๋๋ค. ๋ชจ๋ธ์ ๊ธฐ๋ฅ๊ณผ ํ์ฉ๋ ์์น/ํค์๋ ์ธ์์ ๋ํด ์ค๋ช
ํฉ๋๋ค. ์ฌ๊ธฐ์ ๋ช ๊ฐ์ง ์์๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์ํธ๋ฆฌ ํฌ์ธํธ ํจ์๋ ๋ชจ๋ธ(nn.module)์ ๋ฆฌํดํ๊ฑฐ๋ ์์
ํ๋ฆ์ ๋ณด๋ค ๋ถ๋๋ฝ๊ฒ ๋ง๋ค๊ธฐ ์ํ ๋ณด์กฐ ๋๊ตฌ(์: tokenizer)๋ฅผ ๋ฐํํ ์ ์์ต๋๋ค.
๋ฐ์ค์ด ์์ ๋ถ์ callables๋ torch.hub.list()์ ํ์๋์ง ์๋ ๋์ฐ๋ฏธ ํจ์๋ก ๊ฐ์ฃผ๋ฉ๋๋ค.
์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ ๊นํ๋ธ ์ ์ฅ์์ ๋ก์ปฌ๋ก ์ ์ฅ๋๊ฑฐ๋ torch.hub.load_state_dict_from_url()
์ ์ฌ์ฉํด ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. ํฌ๊ธฐ๊ฐ 2GB ๋ฏธ๋ง์ผ ๊ฒฝ์ฐ project release์ ์ฒจ๋ถํ๊ณ ๋ฆด๋ฆฌ์ค์ URL์ ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ์์ ์์์์๋ torchvision.models.resnet.resnet18
์ด pretrained
๋ฅผ ๋ค๋ฃจ์ง๋ง, ํด๋น ๋ก์ง์ ์ํธ๋ฆฌ ํฌ์ธํธ ์ ์์ ๋ฃ๋ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ์๋ ์์ต๋๋ค.
if pretrained:
#์ฒดํฌ ํฌ์ธํธ๋ก ๋ก์ปฌ ๊นํ๋ธ ์ ์ฅ์, ์๋ฅผ ๋ค๋ฉด <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth์ ์ ์ฅํฉ๋๋ค
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# ์ฒดํฌ ํฌ์ธํธ๊ฐ ๋ค๋ฅธ ๊ณณ์ ์ ์ฅ๋ ๊ฒฝ์ฐ
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
์ฃผ์ ์ฌํญ ๋ฐฐํฌ๋ ๋ชจ๋ธ๋ค์ ์ ์ด๋ branch/tag์ ์ํด์ผ ํฉ๋๋ค. ๋๋ค ์ปค๋ฐ์ด ๋๋ฉด ์๋ฉ๋๋ค.
2. Pytorch Korea ์ ๋ชจ๋ธ ์๋ด ํ์ด์ง ์ถ๊ฐํ๊ธฐ
- `torch.hub` ๋ฅผ ๋ฐฐํฌํ ์ดํ ํด๋น ๋ชจ๋ธ์ ๋ํ ์๋ด ํ์ด์ง๋ฅผ pytorch.kr ์ ๋ชจ๋ธ ์๋ด ํ์ด์ง ์ถ๊ฐํ๊ธฐ
touch ${project_root}/sample.md
2.1.2. template.md ๋ฅผ ์ฐธ์กฐํ์ฌ ํ์ผ ์์ฑํ๊ธฐ
2.2.1. ๋น๋ํ๊ธฐ ๋ฅผ ์ฐธ์กฐํ์ฌ ํํ์ด์ง๋ฅผ ๋น๋
2.2.2. ์ดํ์ https://127.0.0.1:4000/hub/ ์์ ์ถ๊ฐ๋ ํ์ด์ง๋ฅผ ํ์ธํ๊ธฐ
๋ค์๊ณผ ๊ฐ์ ํ์ด์ง๋ฅผ ํ์ธํ์ค ์ ์์ต๋๋ค.
์ถ๊ฐํ ํ์ด์ง๋ ๊ธฐ์ฌํ๊ธฐ ๋ฅผ ์ฐธ๊ณ ํ์ฌ ๋ณธ ๋ ํฌ์งํฐ๋ฆฌ์ ๊ธฐ์ฌํด์ฃผ์ธ์!