Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update node_representation_learning.md #105

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions docs/use_cases/node_representation_learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

## Introduction: how to represent relationships

Of the various types of information - words, pictures, and connections between things - **relationships** are especially interesting. Relationships show how things interact and create networks. But not all ways of representing relationships are the same. In machine learning, **how we do vector represention of relationships affects performance** on a wide range of tasks.
Of the various types of information - words, pictures, and connections between things - **relationships** are especially interesting. Relationships show how things interact and create networks. But not all ways of representing relationships are the same. In machine learning, **how we do vector representation of relationships affects performance** on a wide range of tasks.

Below, we evaluate several approaches to vector representation on a real-life use case: how well they classify academic articles and replicate a citation graph using Cora citation network.
Below, we evaluate several approaches to vector representation on a real-life use case: how well each approach classifies academic articles in a subset of the Cora citation network.

We look first at Bag-of-Words. Because BoW can't represent the network structure, we turn to solutions that can help BoW's performance: Node2Vec and GraphSAGE. We also look for a solution to BoW's other shortcoming - its inability to capture semantic meaning. We evaluate LLM embeddings, first on their own, then combined with Node2Vec, and, finally, LLM-trained GraphSAGE.
We look first at Bag-of-Words (BoW), a standard approach to vectorizing text data in ML. Because BoW can't represent the network structure, we turn to solutions that can help BoW's performance: Node2Vec and GraphSAGE. We also look for a solution to BoW's other shortcoming - its inability to capture semantic meaning. We evaluate LLM embeddings, first on their own, then combined with Node2Vec, and, finally, LLM-trained GraphSAGE.

## Leading our dataset, evaluating BoW

Expand Down Expand Up @@ -49,10 +49,17 @@ evaluate(ds.x, ds.y)
>>> F1 macro 0.701
```

BoW's accuracy and F1 macro scores leave a lot of room for improvement. It fails to correctly classify papers more than 25% of the time. And on average across classes BoW is inaccurate nearly 30% of the time.
BoW's accuracy and F1 macro scores are pretty good, but leave significant room for improvement. BoW falls short of correctly classify papers more than 25% of the time. And on average across classes BoW is inaccurate nearly 30% of the time.

**BoW representation of citation pair similarity**
But we also want to see whether BoW representations accurately capture the relationships between articles. Any given article will tend to cite other articles that belong to the same topic that it belongs to. Therefore, representations that embed not just textual data but also citation data of articles contained in our network will classify articles more accurately.
## Improving on BoW: taking advantage of citation graph data

Can we improve on this? Our citation network contains not only text data but also relationship data - a citation graph. Any given article will tend to cite other articles that belong to the same topic that it belongs to. Therefore, representations that embed not just textual data but also citation data of articles contained in our network will probably classify articles more accurately.

BoW features represent text data. But how well does BoW capture the relationships between articles?

**That is, do BoW features represent citation graph data?**

### Comparing citation pair similarity in BoW

To examine how well citation pairs show up in BoW features, we can make a plot comparing connected and not connected pairs of papers based on how similar their respective BoW features are.

Expand All @@ -62,9 +69,9 @@ In this plot, we define groups (shown on the y-axis) so that each group has abou

The plot demonstrates how connected nodes usually have higher cosine similarities. Papers that cite each other often use similar words. But if we ignore paper pairs with zero similarities (the 0.00-0.00 group), papers that have _not_ cited each other also seem to have a wide range of common words.

Though BoW representations embody _some_ information about article connectivity, BoW features don't contain enough citation pair information to accurately reconstruct the actual citation graph. Because BoW looks exclusively at word co-occurrence between article pairs, it misses word context data contained in the network structure - data that can be used to classify articles better.
Though BoW representations embody _some_ information about article connectivity, BoW features don't contain enough citation pair information to accurately reconstruct the actual citation graph. BoW looks exclusively at word co-occurrence between article pairs, and therefore misses word context data contained in the network structure.

Can we make up for BoW's inability to represent the citation network's structure? Are there methods that capture node data and node connectivity data better?
**Can we make up for BoW's inability to represent the citation network's structure?** Are there methods that capture node connectivity data better?

Node2Vec is built to do precisely this, for static networks. So is GraphSAGE, for dynamic ones.
Let's look at Node2Vec first.
Expand All @@ -77,7 +84,7 @@ $P(\text{context}|\text{source}) = \frac{1}{Z}\exp(w_{c}^Tw_s) $

Here, $w_c$ and $w_s$ are the embeddings of the context node $c$ and source node $s$ respectively. The variable $Z$ serves as a normalization constant, which, for computational efficiency, is never explicitly computed.

The embeddings are learned by maximizing the co-occurance probability for (source,context) pairs drawn from the true data distribution (positive pairs), and at the same time minimizing for pairs drawn from a synthetic noise distribution. This process ensures that the embedding vectors of similar nodes are close in the embedding space, while dissimilar nodes are further apart (with respect to the dot product).
The embeddings are learned by maximizing the co-occurence probability for (source,context) pairs drawn from the true data distribution (positive pairs), and at the same time minimizing for pairs drawn from a synthetic noise distribution. This process ensures that the embedding vectors of similar nodes are close in the embedding space, while dissimilar nodes are further apart (with respect to the dot product).

The random walks are sampled according to a policy, which is guided by 2 parameters: return $p$, and in-out $q$.

Expand Down Expand Up @@ -148,7 +155,7 @@ Let's also see if Node2Vec does a better job of **representing citation data** t

![N2V cosine similarity edge counts](../assets/use_cases/node_representation_learning/bins_n2v.png)

This time, using Node2Vec we can see a well defined separation; these embeddings capture the connectivity of the graph much better than BoW did.
Using Node2Vec, we can see a well defined separation; these embeddings capture the connectivity of the graph much better than BoW did.

**But can we _further_ improve classification performance?**
What if we _combine_ the two information sources - relationship (Node2Vec) embeddings and textual (BoW) features?
Expand Down Expand Up @@ -287,9 +294,9 @@ evaluate(embeddings, ds.y)
>>> F1 macro 0.820
```

The results are slightly worse than the results we got by combining Node2Vec with BoW features. But, remember, we're evaluating GraphSAGE because it can handle dynamic network data, whereas Node2Vec is better suited to static networks. GraphSAGE embeddings perform well on our classification task _and_ are able to embed completely new nodes as well. When your use case involves new nodes or nodes that evolve, an induction model like GraphSAGE may be a better choice than Node2Vec.
The results are slightly worse than the results we got by combining Node2Vec with BoW features. But, remember, we're evaluating GraphSAGE because it can handle dynamic network data, whereas Node2Vec cannot. GraphSAGE embeddings perform well on our classification task _and_ are able to embed completely new nodes as well. When your use case involves new nodes or nodes that evolve, an induction model like GraphSAGE may be a better choice than Node2Vec.

## Using better node representations than BoW: LLM
## Embedding semantics: LLM

In addition to not being able to represent network structure, BoW vectors, because they treat words as contextless occurrences, can't capture semantic meaning, and therefore don't perform as well on classification tasks as approaches that can do semantic embedding. Let's summarize the classification performance results we obtained above using BoW features.

Expand All @@ -302,6 +309,8 @@ These article classification results **can be improved further using LLM embeddi

To do this, we use the `all-mpnet-base-v2` model available on [Hugging Face](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) for embedding the title and abstract of each paper. Data loading and optimization is done in exactly the same way we did in the code snippets above, when we used BoW features. This time, we just substitute LLM features where the BoW features were.

### LLM results

The results obtained with LLM only, Node2Vec combined with LLM, and GraphSAGE trained on LLM appear in the following table, along with the _relative_ percent improvement compared to using the BoW features:

| Metric | LLM | Node2Vec+LLM | GraphSAGE(LLM-trained) |
Expand Down
Loading