Skip to content

Commit

Permalink
Merge pull request #255 from torchkge-team/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
armand33 authored Apr 5, 2023
2 parents 2a303d5 + 2e84476 commit b63b9fa
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 29 deletions.
4 changes: 2 additions & 2 deletions docs/history.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
=======
History
=======
0.17.5 (2022-09-18)
0.17.7 (2023-04-05)
-------------------
* Fix bug in TransH implementation
* Adding additional pretrained models

0.17.6 (2023-03-31)
-------------------
Expand Down
18 changes: 18 additions & 0 deletions docs/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ TransE model

.. autofunction:: torchkge.utils.pretrained_models.load_pretrained_transe

RESCAL Model
=============
.. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm}

+-----------+-----------+-----------+----------+--------------------+
| Model | Dataset | Dimension | Test MRR | Filtered Test MRR |
+===========+===========+===========+==========+====================+
| RESCAL | FB15k237 | 200 | 0.180 | 0.305 |
+-----------+-----------+-----------+----------+--------------------+
| RESCAL | WN18RR | 150 | 0.273 | 0.424 |
+-----------+-----------+-----------+----------+--------------------+
| RESCAL | Yago3-10 | 200 | 0.124 | 0.308 |
+-----------+-----------+-----------+----------+--------------------+

.. autofunction:: torchkge.utils.pretrained_models.load_pretrained_rescal

ComplEx Model
=============
.. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm}
Expand All @@ -55,6 +71,8 @@ ComplEx Model
+-----------+-----------+-----------+----------+--------------------+
| ComplEx | WDV5 | 200 | 0.283 | 0.371 |
+-----------+-----------+-----------+----------+--------------------+
| ComplEx | Yago3-10 | 200 | 0.164 | 0.421 |
+-----------+-----------+-----------+----------+--------------------+

.. autofunction:: torchkge.utils.pretrained_models.load_pretrained_complex

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.17.6
current_version = 0.17.7
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@
setup_requires=setup_requirements,
tests_require=test_requirements,
test_suite='tests',
version='0.17.6',
version='0.17.7',
zip_safe=False,
)
6 changes: 3 additions & 3 deletions torchkge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__author__ = """Armand Boschin"""
__email__ = '[email protected]'
__version__ = '0.17.6'
__version__ = '0.17.7'

from torchkge.exceptions import NotYetEvaluatedError
from torchkge.utils import MarginLoss, LogisticLoss
Expand All @@ -13,5 +13,5 @@
from .evaluation import LinkPredictionEvaluator
from .evaluation import TripletClassificationEvaluator
from .models import ConvKBModel
from .models import RESCALModel, DistMultModel
from .models import TransEModel, TransHModel, TransRModel, TransDModel
from .models import RESCALModel, DistMultModel, HolEModel, ComplExModel, AnalogyModel
from .models import TransEModel, TransHModel, TransRModel, TransDModel, TorusEModel
1 change: 1 addition & 0 deletions torchkge/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
from .losses import MarginLoss, LogisticLoss, BinaryCrossEntropyLoss
from .modeling import init_embedding, get_true_targets, load_embeddings, filter_scores
from .operations import get_rank, get_mask, get_bernoulli_probs
from .pretrained_models import load_pretrained_transe, load_pretrained_rescal, load_pretrained_complex
from .training import Trainer, TrainDataLoader
88 changes: 66 additions & 22 deletions torchkge/utils/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
"""

from ..exceptions import NoPreTrainedVersionError
from ..models import TransEModel, ComplExModel
from ..models import TransEModel, ComplExModel, RESCALModel
from ..utils import load_embeddings


def load_pretrained_transe(dataset, emb_dim, data_home=None):
def load_pretrained_transe(dataset, emb_dim=None, data_home=None):
"""Load a pretrained version of TransE model.
Parameters
----------
dataset: str
emb_dim: int
emb_dim: int (opt, default None)
Embedding dimension
data_home: str (opt, default None)
Path to the `torchkge_data` directory (containing data folders). Useful
Expand All @@ -26,16 +26,18 @@ def load_pretrained_transe(dataset, emb_dim, data_home=None):
model: `TorchKGE.model.translation.TransEModel`
Pretrained version of TransE model.
"""
dims = {'fb15k': 100, 'wn18rr': 100, 'fb15k237': 150, 'wdv5': 150, 'yago310': 200}
try:
assert (dataset in {'fb15k', 'wn18rr'} and emb_dim == 100) \
or (dataset == 'fb15k237' and emb_dim == 150) \
or (dataset == 'wdv5' and emb_dim == 150) \
or (dataset == 'yago310' and emb_dim == 200)

except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of TransE for '
'{} in dimension {}'.format(dataset,
emb_dim))
if emb_dim is None:
emb_dim = dims[dataset]
else:
try:
assert dims[dataset] == emb_dim
except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of TransE for '
'{} in dimension {}'.format(dataset, emb_dim))
except KeyError:
raise NoPreTrainedVersionError('No pre-trained version of TransE for {}'.format(dataset))

state_dict = load_embeddings('transe', emb_dim, dataset, data_home)
model = TransEModel(emb_dim,
Expand All @@ -47,13 +49,13 @@ def load_pretrained_transe(dataset, emb_dim, data_home=None):
return model


def load_pretrained_complex(dataset, emb_dim, data_home=None):
def load_pretrained_complex(dataset, emb_dim=None, data_home=None):
"""Load a pretrained version of ComplEx model.
Parameters
----------
dataset: str
emb_dim: int
emb_dim: int (opt, default None)
Embedding dimension
data_home: str (opt, default None)
Path to the `torchkge_data` directory (containing data folders). Useful
Expand All @@ -64,15 +66,18 @@ def load_pretrained_complex(dataset, emb_dim, data_home=None):
model: `TorchKGE.model.translation.ComplExModel`
Pretrained version of ComplEx model.
"""
dims = {'wn18rr': 200, 'fb15k237': 200, 'wdv5': 200, 'yago310': 200}
try:
assert (dataset == 'wn18rr' and emb_dim == 200) \
or (dataset == 'fb15k237' and emb_dim == 200) \
or (dataset == 'wdv5' and emb_dim == 200)

except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of ComplEx for '
'{} in dimension {}'.format(dataset,
emb_dim))
if emb_dim is None:
emb_dim = dims[dataset]
else:
try:
assert dims[dataset] == emb_dim
except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of ComplEx for '
'{} in dimension {}'.format(dataset, emb_dim))
except KeyError:
raise NoPreTrainedVersionError('No pre-trained version of ComplEx for {}'.format(dataset))

state_dict = load_embeddings('complex', emb_dim, dataset, data_home)
model = ComplExModel(emb_dim,
Expand All @@ -81,3 +86,42 @@ def load_pretrained_complex(dataset, emb_dim, data_home=None):
model.load_state_dict(state_dict)

return model


def load_pretrained_rescal(dataset, emb_dim=None, data_home=None):
"""Load a pretrained version of RESCAL model.
Parameters
----------
dataset: str
emb_dim: int (opt, default None)
Embedding dimension
data_home: str (opt, default None)
Path to the `torchkge_data` directory (containing data folders). Useful
for pre-trained model loading.
Returns
-------
model: `TorchKGE.model.translation.RESCALModel`
Pretrained version of RESCAL model.
"""
dims = {'wn18rr': 200, 'fb15k237': 200, 'yago310': 200}
try:
if emb_dim is None:
emb_dim = dims[dataset]
else:
try:
assert dims[dataset] == emb_dim
except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of RESCAL for '
'{} in dimension {}'.format(dataset, emb_dim))
except KeyError:
raise NoPreTrainedVersionError('No pre-trained version of RESCAL for {}'.format(dataset))

state_dict = load_embeddings('rescal', emb_dim, dataset, data_home)
model = RESCALModel(emb_dim,
n_entities=state_dict['ent_emb.weight'].shape[0],
n_relations=state_dict['rel_mat.weight'].shape[0])
model.load_state_dict(state_dict)

return model

0 comments on commit b63b9fa

Please sign in to comment.