Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU option to T-Res #275

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
implement on gpu
rwood-97 committed Jun 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 78bbcd305ea511d4e4591048a11cc2f7a1b41b84
6 changes: 6 additions & 0 deletions t_res/geoparser/linking.py
Original file line number Diff line number Diff line change
@@ -112,6 +112,7 @@ def __init__(
linking_resources: Optional[dict] = dict(),
overwrite_training: Optional[bool] = False,
rel_params: Optional[dict] = None,
rel_device: Optional[str] = None,
):
"""
Initialises a Linker object.
@@ -136,6 +137,7 @@ def __init__(
}

self.rel_params = rel_params
self.rel_device = rel_device

def __str__(self) -> str:
"""
@@ -455,6 +457,8 @@ def train_load_model(
"mode": "train",
"model_path": os.path.join(linker_name, "model"),
}
if self.rel_device is not None:
config_rel["device"] = self.rel_device

# Instantiate the entity disambiguation model:
model = entity_disambiguation.EntityDisambiguation(
@@ -476,6 +480,8 @@ def train_load_model(
"mode": "eval",
"model_path": os.path.join(linker_name, "model"),
}
if self.rel_device is not None:
config_rel["device"] = self.rel_device

model = entity_disambiguation.EntityDisambiguation(
self.rel_params["db_embeddings"],
2 changes: 1 addition & 1 deletion t_res/geoparser/recogniser.py
Original file line number Diff line number Diff line change
@@ -326,7 +326,7 @@ def create_pipeline(self) -> Pipeline:
model_name = os.path.join(self.model_path, f"{self.model}.model")

# Load a NER pipeline:
self.pipe = pipeline("ner", model=model_name, ignore_labels=[])
self.pipe = pipeline("ner", model=model_name, ignore_labels=[], device_map="auto")
return self.pipe

# -------------------------------------------------------------
2 changes: 1 addition & 1 deletion t_res/utils/REL/entity_disambiguation.py
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@ def __init__(self, db_embs, user_config, reset_embeddings=False):
self.config = self.__get_config(user_config)

# Use CPU if cuda is not available:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = self.config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
self.prerank_model = None
self.model = None
self.reset_embeddings = reset_embeddings