Skip to content

Commit

Permalink
Fix instructions and workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
mkuchnik committed May 18, 2023
1 parent f727834 commit 7aec941
Show file tree
Hide file tree
Showing 19 changed files with 84 additions and 119 deletions.
49 changes: 21 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ After install, you will want to create the environment. To create it:

```bash
conda create -n py37 python=3.7
```

To activate the environment:
```bash
conda activate py37
```

Expand All @@ -46,23 +50,6 @@ We additionally use Rust as a backend for parts of the ReLM runtime.
Therefore, you will need to install a Rust compiler and build the corresponding
extensions.

###### Initial Setup
Install miniconda (as described above).
```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
```

Add a Python3.7 env, `py37`.
```bash
conda create python=3.7 -n py37
```

Activate the environment
```bash
conda activate py37
```

###### PyTorch
Install PyTorch (more instructions
[here](https://pytorch.org/get-started/locally/).
Expand Down Expand Up @@ -95,7 +82,7 @@ apt install build-essential
Build and install the Rust bindings for ReLM.
```bash
pushd rust_regex_compiler
cargo build
cargo build --release
popd
pushd rust_regex_compiler_bindings
bash install_deps.sh
Expand All @@ -112,25 +99,31 @@ popd

###### Additional Dependencies
You can install additional dependencies from the provided requirements file.
We recommend installing from there to avoid unnecessary constraints on package
versions.
```bash
pip install -r requirements_minimal.txt
pip install -r requirements.txt
```

For reference, a full set of requirements for CPU only are provided in:
`requirements_full_cpu.txt`.
Note that this contain file paths to locally installed ReLM components and does
not contain CUDA variants of PyTorch.
Thus, it is recommended to use this file as a guide (and not an exact formula)
to resolve dependency issues.

## Getting Started
We recommend checking out the Jupyter Notebook
[Introduction_to_ReLM](notebook/Introduction_to_ReLM.ipynb) to get started.

To run it, you will need to install additional dependencies in the conda
environment.
```bash
conda install nb_conda
conda install -c conda-forge ipywidgets
```

Then you can do:
```bash
cd notebook
jupyter-notebook Introduction_to_ReLM.ipynb
```

## Experiments
Experiments in the paper can be found under the [Experiments](experiments)
directory.
Each experiment has its own README with instructions.
Namely, we use the datasets The Pile and LAMBADA with GPT2 and GPT2-XL models.
These experiments can take a while to run and can be made smaller by limiting
the number of samples generated.
2 changes: 2 additions & 0 deletions experiments/bias/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ run_bias_gpt2_inference.sh
run_bias_gpt2xl_inference.sh
```

To make this experiment run faster, consider changing `MAX_SAMPLES` in
the scripts.
Once the experiments have run, you can plot them by using the plotter script.
The plotter script points to the directory containing the experiments and plots
all results in the current directory
Expand Down
3 changes: 3 additions & 0 deletions experiments/language_understanding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ bash run_lambada_gpt2.sh
bash run_lambada_gpt2xl.sh
```

To make this experiment run faster, consider changing `MAX_RESULTS` in
the scripts.

To print the accuracy results, run:

```bash
Expand Down
10 changes: 8 additions & 2 deletions experiments/language_understanding/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
import re
import itertools
import tqdm
import pathlib

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from nltk.corpus import stopwords

import relm

SUPPORTED_MODELS = ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]
Expand Down Expand Up @@ -104,6 +103,8 @@ def sanitize_query_str_rust(x: str) -> str:
def read_data():
"""Read and process lambada."""
f = "../../lambada_test.jsonl"
if not pathlib.Path(f).exists():
raise RuntimeError("Lambada dataset '{}' does not exist.".format(f))
df = pd.read_json(f, lines=True)
df["processed"] = df["text"].map(preprocess)
text = df['processed']
Expand Down Expand Up @@ -137,6 +138,11 @@ def get_parser():
@functools.lru_cache(1)
def get_stop_words():
"""Return a set of stop words."""
# Download stopwords data, just in case it's not downloaded
import nltk
nltk.download("stopwords")

from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
return stop_words

Expand Down
11 changes: 3 additions & 8 deletions experiments/memorization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ bash run_url_gpt2xl.sh

This will generate a results directory, which has a `results.json` file
describing the extraction attempts for each of the experiments.
To make this experiment run faster, consider changing `MAX_SAMPLES` in
`run_url_gpt2xl.sh`.

Now, before moving to the next step, first install the required requirements.

Expand All @@ -23,14 +25,7 @@ Then, we post-process these file to determine if they are valid URLs with:

```bash
pip3 install -r requirements.txt
python3 plot_results.py test_memorization_gpt2-xl/relm
python3 plot_results.py test_memorization_gpt2-xl/baseline_1
python3 plot_results.py test_memorization_gpt2-xl/baseline_2
python3 plot_results.py test_memorization_gpt2-xl/baseline_4
python3 plot_results.py test_memorization_gpt2-xl/baseline_8
python3 plot_results.py test_memorization_gpt2-xl/baseline_16
python3 plot_results.py test_memorization_gpt2-xl/baseline_32
python3 plot_results.py test_memorization_gpt2-xl/baseline_64
for f in test_memorization_gpt2-xl/*; do python3 plot_results.py $f; done
```

NOTE: Please be mindful of how many URLs you are validating and how quickly you
Expand Down
2 changes: 1 addition & 1 deletion experiments/memorization/mega_plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
df1["with_duplicates"] = True
dfs.append(df1)

for n in [1, 2, 4, 8, 16, 32, 64]:
for n in [4, 8, 16, 32, 64]:
f = \
"{}/baseline_{}/results_joined_gpt2-xl.csv".format(top_level_dir, n)
df = pd.read_csv(f)
Expand Down
4 changes: 3 additions & 1 deletion experiments/memorization/plot_results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Check if URLs are valid."""

import argparse
import concurrent
import concurrent.futures
import json
import socket

Expand All @@ -11,6 +11,7 @@
import seaborn as sns
import tqdm
import validators
import re

model = "gpt2-xl"

Expand Down Expand Up @@ -40,6 +41,7 @@
for u in urls:
prefix = "None"
clean_url = u
clean_url = re.sub("<|endoftext|>$", "", clean_url.split("\n")[0])
prefixes.append(prefix)
clean_urls.append(clean_url)
df["prefixes"] = prefixes
Expand Down
4 changes: 0 additions & 4 deletions experiments/memorization/run_url_gpt2xl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ baseline_test() {
}

relm_test
LENGTH=1 # Power of 2
baseline_test
LENGTH=2 # Power of 2
baseline_test
LENGTH=4 # Power of 2
baseline_test
LENGTH=8 # Power of 2
Expand Down
3 changes: 3 additions & 0 deletions experiments/toxicity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ cancel the run after sufficient extractions have been attempted.
bash run_insults_gpt2xl.sh
```

To make this experiment run faster, consider changing `MAX_SAMPLES` in
the scripts as well as passing `--max_results`.

Once results are run, you can plot them with:
```bash
python3 plot_results.py
Expand Down
2 changes: 1 addition & 1 deletion experiments/toxicity/find_insults.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def prompt_sampler_sentence(base_prompts, query_str, prompt_length,
re_matcher = re.compile(query_str)
prompt_sentences = map(lambda prompt: nltk.tokenize.sent_tokenize(prompt),
base_prompts)
prompt_sentences = {xx for x in prompt_sentences for xx in x}
prompt_sentences = dict.fromkeys(xx for x in prompt_sentences for xx in x)
prompt_sentences = list(prompt_sentences)
prompt_sentences = sorted(prompt_sentences, key=lambda x: len(x))

Expand Down
6 changes: 4 additions & 2 deletions experiments/toxicity/run_insults_gpt2xl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ standard_test() {
--num_punctuation_edits=0 \
--num_space_edits=0 \
--batch_size=1 \
--max_samples="${MAX_SAMPLES}"
--max_samples="${MAX_SAMPLES}" \
--max_results=250
popd || exit
}

Expand All @@ -35,7 +36,8 @@ baseline_test() {
--top_k=40 \
--batch_size=1 \
--static_minimize \
--max_samples="${MAX_SAMPLES}"
--max_samples="${MAX_SAMPLES}" \
--max_results=500
popd || exit
}

Expand Down
8 changes: 4 additions & 4 deletions notebook/Introduction_to_ReLM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"metadata": {},
"source": [
"Let's now pretend we are testing the model.\n",
"We want to answer the question: \"Does the model know the birthdate of George Washington?\".\n",
"We want to answer the question: \"Does the model know the birth date of George Washington?\".\n",
"One way to do this test is to sample from the model given the prefix, and see what happens.\n",
"You can see that we can't glean much from this information---some of the responses aren't even dates, and they have all this extra stuff at the end.\n",
"We can see some patterns in the data (e.g., date ranges and sentence structure) which may allow us to be more or less confident that the model knows _something_.\n",
Expand Down Expand Up @@ -156,7 +156,7 @@
"source": [
"To demonstrate ReLM, let's do a simple knowledge verification query of the model, building off the prior example.\n",
"Specifically, let's test for George Washington's birthday.\n",
"One way to do this is to ask \"George Washington was born on <u>date</u>\".\n",
"One way to do this is to ask \"George Washington was born on <b><u>\\<date\\></u></b>\".\n",
"Let's construct the query in ReLM."
]
},
Expand All @@ -176,7 +176,7 @@
"Below, we construct a regular expression query using ReLM's abstractions.\n",
"To start, we have to define what strings we are looking for, using a `relm.QueryString`.\n",
"The first parameter, `query_str`, defines exactly what strings we are looking for.\n",
"In regular expression syntax, we are saying \"George Washington was born on <u>date</u>\", where \"<u>date</u>\" is a month, day, and year.\n",
"In regular expression syntax, we are saying \"George Washington was born on <b><u>\\<date\\></u></b>\", where \"<b><u>\\<date\\></u></b>\" is a month, day, and year.\n",
"Notice we also fill in the `prefix_str` parameter.\n",
"`prefix_str` is a substring of `query_str`, and it indicates that the user wants that prefix to be treated as if it were fixed.\n",
"If you are familiar with [prompting](https://en.wikipedia.org/wiki/Prompt_engineering), the `prefix_str` parameter is trying to capture the prompt.\n",
Expand Down Expand Up @@ -441,7 +441,7 @@
"metadata": {},
"source": [
"Hmm, the top answer is not what we wanted 🤔.\n",
"The expected [birthdate](https://en.wikipedia.org/wiki/George_Washington) is February 22, 1732.\n",
"The expected [birth date](https://en.wikipedia.org/wiki/George_Washington) is February 22, 1732.\n",
"However, we can see that the correct answer is 6th in this list, so the model is reasonably aware of the correct answer.\n",
"Unlike the opening example, we know for sure that we get a date in exactly the format we wanted.\n",
"\n",
Expand Down
3 changes: 2 additions & 1 deletion relm/src/relm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
search)
from .relm_logging import get_logger as get_relm_logger
from .relm_search_query import (QuerySearchStrategy, QueryString,
QueryTokenizationStrategy, SimpleSearchQuery)
QueryTokenizationStrategy, SimpleSearchQuery,
QueryPreprocessors,)
16 changes: 0 additions & 16 deletions relm/src/relm/automata.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,22 +757,6 @@ def _num_strings_counter_naive_fn(max_length, s, M, f) -> int:
return int(s.T.dot(M**max_length).dot(f).toarray()[0][0])


def _num_strings_counter_JNF_fn(max_length, s, M, f) -> int:
"""Count automata strings with combinatorics.
Takes starting state s, transition matrix M, and final states, f, and
returns the number of max_length strings.
This implementation uses Jordan normal form, which reduces exponentiation
over matrices to exponentiation of a list of numbers.
Algorithm from:
https://math.stackexchange.com/questions/2245373/find-the-number-of-accepted-matches-for-a-dfa-defined-by-a-regular-expression
"""
raise NotImplementedError("Implement JNF")


def string_cardinality_generator_for_automata(automata, dtype=None):
"""Count the number of strings of length m in an automata.
Expand Down
57 changes: 25 additions & 32 deletions relm/src/relm/regex_graph_optimizations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Complicated (compiler-like) graph optimizations for regex."""
import collections
import itertools
import pprint
from typing import Optional
Expand Down Expand Up @@ -181,6 +180,11 @@ def _simplify_automata_symbols_DFS(
except KeyError:
d[s] = [a]

# Indexed by ilabel
_indexed_arc_cache = {s: {a.ilabel: a.nextstate for a in arcs}
for s, arcs
in _arcs_cache.items()}

def simplify_automata_iter(arcs_cache, first_symbol_arcs_cache, substr):
"""Yield shortcut edges (from, to) for a word.
Expand All @@ -204,37 +208,26 @@ def simplify_automata_iter(arcs_cache, first_symbol_arcs_cache, substr):
# match on many "a" "aa" "aaa" etc. in a chain.
# In that case, we would match on all edges, and the longest edge would
# traverse all |w| length of those edges.
for s, arcs in d.items():
for a in arcs:
# Start BFS
Q = [((a.nextstate,), list(substr[1:]))]
Q = collections.deque(Q)
visited = set()
while Q:
past_states, to_match = Q.popleft()
curr = past_states[-1]
if not len(to_match):
yield s, curr, past_states
else:
try:
nth_symbol = \
inverted_input_symbols[to_match[0]]
except KeyError:
continue
try:
# lookup arcs from this state on symbol
next_arcs = \
first_symbol_arcs_cache[nth_symbol][
past_states[-1]]
except KeyError:
continue
for next_a in next_arcs:
proposed_states = (past_states +
(next_a.nextstate,))
if proposed_states not in visited:
Q.append((proposed_states,
to_match[1:]))
visited.add(proposed_states)
for s in d.keys():
curr_state = s
past_states = [s]
success = True
for nth_char in substr:
try:
nth_symbol = inverted_input_symbols[nth_char]
except KeyError:
success = False
break
next_state = _indexed_arc_cache[curr_state].get(
nth_symbol, None)
if next_state is not None:
curr_state = next_state
past_states.append(next_state)
else:
success = False
break
if success:
yield s, curr_state, past_states

one = fst.Weight.one(automata.weight_type())
new_automata = automata.copy()
Expand Down
Loading

0 comments on commit 7aec941

Please sign in to comment.