Skip to content

Commit

Permalink
Merge branch 'release/v0.2'
Browse files Browse the repository at this point in the history
  • Loading branch information
krikit committed Dec 27, 2018
2 parents 9694950 + 1ea0367 commit 4b46d5e
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 75 deletions.
2 changes: 0 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ RUN pip install cython
RUN pip install --upgrade pip
RUN pip install -r requirements.txt

RUN pip install cmake
RUN mkdir build

WORKDIR /workspace/khaiii/build

RUN cmake ..
Expand Down
2 changes: 1 addition & 1 deletion include/khaiii/khaiii_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// constants //
///////////////
#define KHAIII_VERSION_MAJOR 0
#define KHAIII_VERSION_MINOR 1
#define KHAIII_VERSION_MINOR 2
#define _MAC2STR(m) #m
#define _JOIN_VER(x,y) _MAC2STR(x) "." _MAC2STR(y) // NOLINT
#define KHAIII_VERSION _JOIN_VER(KHAIII_VERSION_MAJOR,KHAIII_VERSION_MINOR) // NOLINT
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
cmake>=3.10
numpy
tqdm
74 changes: 5 additions & 69 deletions rsc/lib/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,16 @@
###########
# imports #
###########
import re
import codecs
from collections import defaultdict
import copy
import logging
import os
import torch
from torch import nn
import numpy as np
from tqdm import tqdm


#########
# types #
#########
class Vocabulary(object):
class Vocabulary:
"""
vocabulary class
"""
Expand Down Expand Up @@ -69,6 +63,9 @@ def __getitem__(self, key):
def __len__(self):
return len(self.dic)

'''
# 리소스 빌드 시 pytorch 의존성 제거를 위해 임시로 메서드를 제거합니다.
# 추후 학습 코드를 추가할 때 이 부분을 리팩토링 합니다.
def get_embedding(self, dim, padding_idx=None):
"""
embedding을 리턴합니다.
Expand All @@ -79,6 +76,7 @@ def get_embedding(self, dim, padding_idx=None):
if padding_idx:
return nn.Embedding(len(self), dim, padding_idx=padding_idx)
return nn.Embedding(len(self), dim)
''' # pylint: disable=pointless-string-statement

def padding_idx(self):
"""
Expand Down Expand Up @@ -117,65 +115,3 @@ def _load(self, path, cutoff=1):
self.rev.append(entry)
append_num += 1
logging.info('%s: %d entries, %d cutoff', os.path.basename(path), append_num, cutoff_num)


class PreTrainedVocabulary(Vocabulary):
"""
pre-train된 word2vec를 사용하는 경우, vector에 있는 어휘로
사전을 구성하도록 합니다.
"""
def __init__(self, path): #pylint: disable=super-init-not-called
"""
Args:
path: file path
"""
# simple : 사과/N , none : 사과
# 읽어들인 glove의 키 타입을 보고 판단해놓는다.
self.glove_key_type = None
self.dic, self.vectors = self._load_glove(path)
self.rev = {val:key for key, val in self.dic.items()}
assert len(self.dic) == len(self.rev)
logging.info('%s: %d entries, %d dim - not trainable',
os.path.basename(path), len(self.dic), self.vectors.size(1))

def get_embedding(self, dim, padding_idx=None):
"""
pre-training된 벡터가 세팅된 embedding을 리턴합니다.
"""
assert dim == self.vectors.size(1)
embed = super().get_embedding(dim, padding_idx)
embed.weight = nn.Parameter(self.vectors, requires_grad=False)
return embed

def _load_glove(self, path):
"""
pre-trained GloVe (텍스트 포맷) 워드 벡터를 읽어들인다.
Args:
path: 워드 벡터 경로
"""
unk = None
vecs = []
for line in tqdm(codecs.open(path, 'r', encoding='UTF-8')):
cols = line.split(' ')
word = cols[0]
vec = np.array([float(_) for _ in cols[1:]])
if vec.size == 0: # format error
continue
if word == '<unk>':
unk = vec
continue
vecs.append((word, vec))
if self.glove_key_type is None:
if re.search('/[A-Z]$', word) is None:
self.glove_key_type = 'none'
else:
self.glove_key_type = 'simple'
if unk is None:
unk = [0] * len(vecs[0][1])
padding = [0] * len(vecs[0][1])
vecs.sort(key=lambda x: x[0])
vecs.insert(0, ('<unk>', unk))
vecs.insert(1, ('<p>', padding))
vocab = defaultdict(int)
vocab.update({word: idx for idx, (word, _) in enumerate(vecs)})
return vocab, torch.Tensor([vec for _, vec in vecs])
2 changes: 1 addition & 1 deletion src/main/python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ setup(
packages=['khaiii', ],
include_package_data=True,
install_requires=[],
setup_requires=['numpy', 'pytest-runner', 'tqdm'],
setup_requires=['cmake>=3.10', 'pytest-runner'],
tests_require=['pytest', ],
zip_safe=False,
cmdclass={'build': CustomBuild}
Expand Down

0 comments on commit 4b46d5e

Please sign in to comment.