Skip to content

GML-1432: Refactor dataloaders #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
55e7049
refactor(BaseLoader): split read_data and parse_data
billshitg Oct 7, 2023
324fea0
refactor(BaseLoader): consolidate vertex and edge parse
billshitg Oct 9, 2023
441a0b3
refactor(BaseLoader): consolidate attr_to_tensor and reindex
billshitg Oct 9, 2023
65cca46
feat(BaseLoader): handle one data point per msg
billshitg Oct 10, 2023
9eaef0e
fix(BaseLoader): upate _start_request for the refactor
billshitg Oct 10, 2023
e4e407e
refactor(VertexLoader): update gsql template
billshitg Oct 16, 2023
4c42d4a
refactor(VertexLoader): update reader thread
billshitg Oct 16, 2023
41c45f9
fix(VertexLoader): fix wrong comma in gsql
billshitg Oct 17, 2023
ce95a21
fix(BaseLoader): fix missing function decorator
billshitg Oct 17, 2023
a1ff95b
fix(BaseLoader,VertexLoader): change how we shuffle
billshitg Oct 18, 2023
e7f3e94
feat(BaseLoader,VertexLoader): fix when to stop kafka
billshitg Oct 18, 2023
1dfa8bd
feat(VertexLoader): update shuffle
billshitg Oct 18, 2023
d5a97df
test(VertexLoader): update tests
billshitg Oct 18, 2023
0298bc0
feat(VertexLoader): move kafka commands into template
billshitg Oct 19, 2023
7315071
feat: move shuffle to gsql for speed and add EdgeLoader
billshitg Oct 20, 2023
d8b816b
feat(GraphLoader): update gsql and add subquery
billshitg Oct 24, 2023
7b2970c
feat(dataloaders): update for GraphLoader
billshitg Oct 24, 2023
7f5f584
feat(utilities): add function to install multiple files
billshitg Oct 24, 2023
7320237
test(GraphLoaders): update unit tests
billshitg Oct 24, 2023
ffe5b0d
feat(dataloaders): update NeighborLoader
billshitg Oct 26, 2023
a78f33d
feat: update all about EdgeNeighborLoader
billshitg Oct 31, 2023
c3434b5
feat: add new hgt loader
billshitg Nov 6, 2023
001e1ea
feat(gds): update nodepieceloader
billshitg Nov 9, 2023
2ac345b
test(BaseLoader): rm unneeded tests
billshitg Nov 9, 2023
189bc29
fix(dataloaders): fix seeds type issue
billshitg Nov 9, 2023
9a9e78c
test(GDS): fix num_batches issue
billshitg Nov 10, 2023
17390dc
fix(Trainer): rm num_batches
billshitg Nov 10, 2023
0d03fe5
test(featurizer): fix patch_ver issue
billshitg Nov 10, 2023
22a02b6
fix: bool error when num_batches is None
billshitg Nov 17, 2023
ed6a53c
fix(test_HGTLoader): fix error when edge not present
billshitg Dec 13, 2023
3f257a2
fix(dataloaders): check edge number returned from DB
billshitg Dec 14, 2023
98cdc1f
feat: update baseloader and unit test for new queries
billshitg Dec 21, 2023
7e0de75
feat(EdgeNeighborLoader): update loader and gsql
billshitg Dec 22, 2023
4d334a1
fix(parse_edge_data): error when edge type has _
billshitg Jan 2, 2024
c704e02
feat: allow kafka in distributed query
billshitg Jan 4, 2024
0ae8be8
tests(EdgeNeighborLoader): test distributed query
billshitg Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,882 changes: 1,866 additions & 1,016 deletions pyTigerGraph/gds/dataloaders.py

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions pyTigerGraph/gds/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,9 @@ def edgeNeighborLoader(
timeout: int = 300000,
callback_fn: Callable = None,
reinstall_query: bool = False,
distributed_query: bool = False
distributed_query: bool = False,
num_machines: int = 1,
num_segments: int = 20
) -> EdgeNeighborLoader:
"""Returns an `EdgeNeighborLoader` instance.
An `EdgeNeighborLoader` instance performs neighbor sampling from all edges in the graph in batches in the following manner:
Expand Down Expand Up @@ -1098,7 +1100,9 @@ def edgeNeighborLoader(
"delimiter": delimiter,
"timeout": timeout,
"callback_fn": callback_fn,
"distributed_query": distributed_query
"distributed_query": distributed_query,
"num_machines": num_machines,
"num_segments": num_segments
}
if self.kafkaConfig:
params.update(self.kafkaConfig)
Expand Down Expand Up @@ -1130,7 +1134,9 @@ def edgeNeighborLoader(
"delimiter": delimiter,
"timeout": timeout,
"callback_fn": callback_fn,
"distributed_query": distributed_query
"distributed_query": distributed_query,
"num_machines": num_machines,
"num_segments": num_segments
}
if self.kafkaConfig:
params.update(self.kafkaConfig)
Expand Down
239 changes: 56 additions & 183 deletions pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
CREATE QUERY edge_loader_{QUERYSUFFIX}(
INT batch_size,
INT num_batches=1,
BOOL shuffle=FALSE,
STRING filter_by,
SET<STRING> e_types,
STRING delimiter,
BOOL shuffle=FALSE,
INT num_chunks=2,
STRING kafka_address="",
STRING kafka_topic,
STRING kafka_topic="",
INT kafka_topic_partitions=1,
STRING kafka_max_size="104857600",
INT kafka_timeout=300000,
Expand Down Expand Up @@ -41,198 +40,72 @@ CREATE QUERY edge_loader_{QUERYSUFFIX}(
sasl_password : SASL password for Kafka.
ssl_ca_location: Path to CA certificate for verifying the Kafka broker key
*/
TYPEDEF TUPLE<INT tmp_id, VERTEX src, VERTEX tgt> ID_Tuple;
INT num_vertices;
INT kafka_errcode;
SumAccum<INT> @tmp_id;
SumAccum<STRING> @@kafka_error;
UINT producer;
MapAccum<INT, BOOL> @@edges_sampled;
SetAccum<VERTEX> @valid_v_out;
SetAccum<VERTEX> @valid_v_in;

# Initialize Kafka producer
IF kafka_address != "" THEN
producer = init_kafka_producer(
kafka_address, kafka_max_size, security_protocol,
sasl_mechanism, sasl_username, sasl_password, ssl_ca_location,
ssl_certificate_location, ssl_key_location, ssl_key_password,
ssl_endpoint_identification_algorithm, sasl_kerberos_service_name,
sasl_kerberos_keytab, sasl_kerberos_principal);
END;

# Shuffle vertex ID if needed
start = {ANY};
# Filter seeds if needed
seeds = SELECT s
FROM start:s -(e_types:e)- :t
WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL")
POST-ACCUM s.@tmp_id = getvid(s)
POST-ACCUM t.@tmp_id = getvid(t);
# Shuffle vertex ID if needed
IF shuffle THEN
num_vertices = start.size();
INT num_vertices = seeds.size();
res = SELECT s
FROM start:s
POST-ACCUM s.@tmp_id = floor(rand()*num_vertices);
ELSE
res = SELECT s
FROM start:s
POST-ACCUM s.@tmp_id = getvid(s);
END;

SumAccum<FLOAT> @@num_edges;
IF filter_by IS NOT NULL THEN
res = SELECT s
FROM start:s -(e_types:e)- :t WHERE e.getAttr(filter_by, "BOOL")
ACCUM
IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count
@@num_edges += 2
ELSE
@@num_edges += 1
END;
ELSE
res = SELECT s
FROM start:s -(e_types:e)- :t
ACCUM
IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count
@@num_edges += 2
ELSE
@@num_edges += 1
END;
END;
INT batch_s;
IF batch_size IS NULL THEN
batch_s = ceil((@@num_edges/2)/num_batches);
ELSE
batch_s = batch_size;
FROM seeds:s
POST-ACCUM s.@tmp_id = floor(rand()*num_vertices)
LIMIT 1;
END;

# Generate batches
FOREACH batch_id IN RANGE[0, num_batches-1] DO
SumAccum<STRING> @@e_batch;
SetAccum<VERTEX> @@seeds;
SetAccum<VERTEX> @@targets;
HeapAccum<ID_Tuple> (1, tmp_id ASC) @@batch_heap;
@@batch_heap.resize(batch_s);
start = {ANY};
IF filter_by IS NOT NULL THEN
res =
SELECT s
FROM start:s -(e_types:e)- :t
WHERE e.getAttr(filter_by, "BOOL")
AND
((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR
(t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))))
OR
(NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR
(t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))
AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR
(t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id)))))
ACCUM
IF t.@tmp_id >= s.@tmp_id THEN
@@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t)
ELSE
@@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t)
END;

FOREACH elem IN @@batch_heap DO
SetAccum<VERTEX> @@src;
@@seeds += elem.src;
@@targets += elem.tgt;
@@src += elem.src;
src = {@@src};
res = SELECT s FROM src:s -(e_types:e)- :t
WHERE t == elem.tgt
ACCUM
s.@valid_v_out += elem.tgt,
t.@valid_v_in += elem.src;
END;
start = {@@seeds};
res =
SELECT s
FROM start:s -(e_types:e)- :t
WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out
ACCUM
{EDGEATTRS},
IF t.@tmp_id >= s.@tmp_id THEN
@@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE),
IF NOT e.isDirected() THEN
@@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE)
END
ELSE
@@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE),
IF NOT e.isDirected() THEN
@@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE)
END
END
POST-ACCUM
s.@valid_v_in.clear(), s.@valid_v_out.clear()
POST-ACCUM
t.@valid_v_in.clear(), t.@valid_v_out.clear();
ELSE
res =
SELECT s
FROM start:s -(e_types:e)- :t
WHERE ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR
(t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))))
OR
(NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR
(t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))
AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR
(t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id)))))
ACCUM
IF t.@tmp_id >= s.@tmp_id THEN
@@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t)
ELSE
@@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t)
END;

FOREACH elem IN @@batch_heap DO
SetAccum<VERTEX> @@src;
@@seeds += elem.src;
@@targets += elem.tgt;
@@src += elem.src;
src = {@@src};
res = SELECT s FROM src:s -(e_types:e)- :t
WHERE t == elem.tgt
ACCUM
s.@valid_v_out += elem.tgt,
t.@valid_v_in += elem.src;
END;
start = {@@seeds};
res =
SELECT s
FROM start:s -(e_types:e)- :t
WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out
ACCUM
{EDGEATTRS},
IF t.@tmp_id >= s.@tmp_id THEN
@@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE),
IF NOT e.isDirected() THEN
@@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE)
END
ELSE
@@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE),
IF NOT e.isDirected() THEN
@@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE)
END
END
POST-ACCUM
s.@valid_v_in.clear(), s.@valid_v_out.clear()
POST-ACCUM
t.@valid_v_in.clear(), t.@valid_v_out.clear();
# If using kafka to export
IF kafka_address != "" THEN
SumAccum<STRING> @@kafka_error;

# Initialize Kafka producer
UINT producer = init_kafka_producer(
kafka_address, kafka_max_size, security_protocol,
sasl_mechanism, sasl_username, sasl_password, ssl_ca_location,
ssl_certificate_location, ssl_key_location, ssl_key_password,
ssl_endpoint_identification_algorithm, sasl_kerberos_service_name,
sasl_kerberos_keytab, sasl_kerberos_principal);

FOREACH chunk IN RANGE[0, num_chunks-1] DO
res = SELECT s
FROM seeds:s -(e_types:e)- :t
WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk)
ACCUM
{EDGEATTRSKAFKA}
LIMIT 1;
END;
# Export batch
IF kafka_address != "" THEN
# Write to kafka
kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch);
IF kafka_errcode != 0 THEN
@@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n");

FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO
INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", "");
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n");
END;
ELSE
# Add to response
PRINT @@e_batch AS edge_batch;
END;
END;
IF kafka_address != "" THEN
kafka_errcode = close_kafka_producer(producer, kafka_timeout);
IF kafka_errcode != 0 THEN

INT kafka_errcode = close_kafka_producer(producer, kafka_timeout);
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n");
END;
PRINT @@kafka_error as kafkaError;
# Else return as http response
ELSE
FOREACH chunk IN RANGE[0, num_chunks-1] DO
ListAccum<STRING> @@e_batch;
res = SELECT s
FROM seeds:s -(e_types:e)- :t
WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk)
ACCUM
{EDGEATTRSHTTP}
LIMIT 1;

FOREACH i IN @@e_batch DO
PRINT i as data_batch;
END;
END;
END;
}
Loading