Skip to content

Commit

Permalink
Merge pull request #135 from dice-group/robust
Browse files Browse the repository at this point in the history
Last commit before the new release
  • Loading branch information
Demirrr authored Aug 10, 2023
2 parents 0ea3ebb + 85ca41d commit f2084df
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 279 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ pip install dicee
```
or
```bash
pip3 install "pandas>=1.5.1"
pip3 install "torch>=2.0.0"
pip3 install "pandas>=1.5.1"
pip3 install "polars>=0.16.14"
pip3 install "scikit-learn>=1.2.2"
pip3 install "pyarrow>=11.0.0"
pip3 install "pytest>=7.2.2"
pip3 install "gradio>=3.23.0"
pip3 install "psutil>=5.9.4"
pip3 install "pytorch-lightning==1.6.4"
pip3 install "pykeen==1.10.1"
pip3 install "zstandard>=0.21.0"
pip3 install "pytest>=7.2.2"
pip3 install "psutil>=5.9.4"
pip3 install "ruff>=0.0.284"
pip3 install "gradio>=3.23.0"
pip3 install "rdflib>=7.0.0"
pip3 install "ruff>=0.0.283"
```

To test the Installation
Expand Down
2 changes: 1 addition & 1 deletion dicee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .knowledge_graph_embeddings import KGE # noqa
from .executer import Execute # noqa
from .dataset_classes import * # noqa
__version__ = '0.0.4'
__version__ = '0.0.5'
27 changes: 6 additions & 21 deletions dicee/dataset_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def __init__(self, train_set: np.ndarray, num_entities: int, num_relations: int,
self.neg_sample_ratio = torch.tensor(
neg_sample_ratio) # 0 Implies that we do not add negative samples. This is needed during testing and validation
self.train_set = torch.from_numpy(train_set).unsqueeze(1)
#assert num_entities >= max(self.train_set[:, 0]) and num_entities >= max(self.train_set[:, 2])
# assert num_entities >= max(self.train_set[:, 0]) and num_entities >= max(self.train_set[:, 2])
self.length = len(self.train_set)
self.num_entities = torch.tensor(num_entities)
self.num_relations = torch.tensor(num_relations)
Expand All @@ -335,27 +335,12 @@ def __getitem__(self, idx):

triple = self.train_set[idx]

y = torch.ones(1)
corr_entities = torch.randint(0, high=self.num_entities, size=(1,))
negative_triple = torch.cat((triple[:, 0], triple[:, 1], corr_entities), dim=0).unsqueeze(0)





negative_triple=torch.cat((triple[:,0],triple[:,1],triple[:,2]),dim=0)

print(triple.shape)
print(negative_triple.shape)
exit(1)

y = torch.ones(0)

print(triple.shape)
x=torch.cat((triple,negative_triple),dim=1)

print(x)
exit(1)
# Workaround to create negative triples
return triple, y
x = torch.cat((triple, negative_triple), dim=0)
y=torch.tensor([1.0, 0.0])
return x,y


class TriplePredictionDataset(torch.utils.data.Dataset):
Expand Down
4 changes: 1 addition & 3 deletions dicee/executer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ def start(self) -> dict:
self.trainer = DICE_Trainer(args=self.args,
is_continual_training=self.is_continual_training,
storage_path=self.storage_path,
evaluator=self.evaluator,
dataset=self.dataset # only used for Pykeen's models
)
evaluator=self.evaluator)
# (4) Start the training
self.trained_model, form_of_labelling = self.trainer.start(dataset=self.dataset)
return self.end(form_of_labelling)
Expand Down
3 changes: 2 additions & 1 deletion dicee/models/real.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def forward_k_vs_sample(self, x: torch.LongTensor, target_entity_idx: torch.Long
t = self.entity_embeddings(target_entity_idx).transpose(1, 2)
return torch.bmm(hr, t).squeeze(1)


def score(self,h,r,t):
return (self.hidden_dropout(self.hidden_normalizer(h * r)) * t).sum(dim=1)
class TransE(BaseKGE):
"""
Translating Embeddings for Modeling
Expand Down
22 changes: 5 additions & 17 deletions dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_pickle(file_path=str):


# @TODO: Could these funcs can be merged?
def select_model(args: dict, is_continual_training: bool = None, storage_path: str = None, dataset=None):
def select_model(args: dict, is_continual_training: bool = None, storage_path: str = None):
isinstance(args, dict)
assert len(args) > 0
assert isinstance(is_continual_training, bool)
Expand All @@ -58,7 +58,7 @@ def select_model(args: dict, is_continual_training: bool = None, storage_path: s
print(f"{storage_path}/model.pt is not found. The model will be trained with random weights")
return model, _
else:
return intialize_model(args, dataset)
return intialize_model(args)


def load_model(path_of_experiment_folder, model_name='model.pt') -> Tuple[object, dict, dict]:
Expand Down Expand Up @@ -194,13 +194,11 @@ def save_checkpoint_model(model, path: str) -> None:
print(model.name)
print('Could not save the model correctly')
else:
# Pykeen
torch.save(model.model.state_dict(), path)


def store(trainer,
trained_model, model_name: str = 'model', full_storage_path: str = None,
dataset=None, save_embeddings_as_csv=False) -> None:
trained_model, model_name: str = 'model', full_storage_path: str = None, save_embeddings_as_csv=False) -> None:
"""
Store trained_model model and save embeddings into csv file.
:param trainer: an instance of trainer class
Expand Down Expand Up @@ -287,22 +285,12 @@ def read_or_load_kg(args, cls):
return kg


def get_pykeen_model(model_name: str, args, dataset):
if dataset is None:
# (1) Load a pretrained Pykeen Model
return PykeenKGE(args=args)
elif args['scoring_technique'] in ['KvsAll', "NegSample"]:
return PykeenKGE(args=args)
else:
raise NotImplementedError("Incorrect scoring technique")


def intialize_model(args: dict, dataset=None) -> Tuple[object, str]:
def intialize_model(args: dict) -> Tuple[object, str]:
# @TODO: Apply construct_krone as callback? or use KronE_QMult as a prefix.
# @TODO: Remove form_of_labelling
model_name = args['model']
if "pykeen" in model_name.lower():
model = get_pykeen_model(model_name, args, dataset)
model = PykeenKGE(args=args)
form_of_labelling = "EntityPrediction"
elif model_name == 'Shallom':
model = Shallom(args=args)
Expand Down
6 changes: 2 additions & 4 deletions dicee/trainer/dice_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class DICE_Trainer:
report:dict
"""

def __init__(self, args, is_continual_training, storage_path, evaluator=None, dataset=None):
def __init__(self, args, is_continual_training, storage_path, evaluator=None):
self.report = dict()
self.args = args
self.trainer = None
Expand All @@ -101,7 +101,6 @@ def __init__(self, args, is_continual_training, storage_path, evaluator=None, da
# Required for CV.
self.evaluator = evaluator
self.form_of_labelling = None
self.dataset = dataset
print(
f'# of CPUs:{os.cpu_count()} | # of GPUs:{torch.cuda.device_count()} | # of CPUs for dataloader:{self.args.num_core}')

Expand Down Expand Up @@ -144,8 +143,7 @@ def initialize_trainer(self, callbacks: List, plugins: List) -> pl.Trainer:
@timeit
def initialize_or_load_model(self):
print('Initializing Model...', end='\t')
model, form_of_labelling = select_model(vars(self.args), self.is_continual_training, self.storage_path,
self.dataset)
model, form_of_labelling = select_model(vars(self.args), self.is_continual_training, self.storage_path)
self.report['form_of_labelling'] = form_of_labelling
assert form_of_labelling in ['EntityPrediction', 'RelationPrediction']
return model, form_of_labelling
Expand Down
135 changes: 0 additions & 135 deletions environment.yml

This file was deleted.

Loading

0 comments on commit f2084df

Please sign in to comment.