Skip to content

Commit

Permalink
counterfact
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Jul 25, 2023
1 parent d5aaa60 commit 481850a
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion elk/extraction/neel.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,46 @@ def invert_example(example):
)


def get_counterfact():
# Load the Counterfact-Tracing dataset
dataset = load_dataset("NeelNanda/counterfact-tracing")

# Get the first row
first_row = dataset["train"][0]

# Get the values for "prompt" and "answer"
first_row["prompt"]
first_row["target_true"]
first_row["target_false"]

# for row in dataset["train"]:
# example = row_to_example(row)
# print(example)

def row_to_example(row):
label = 0
template_names = ["template_null"]
prompts = [
Prompt(
choices=[
Choice(question=row["prompt"], answer=row["target_true"]),
Choice(question=row["prompt"], answer=row["target_false"]),
]
)
]

example = Example(label=label, prompts=prompts, template_names=template_names)
return example

examples = [row_to_example(row) for row in dataset["train"]]
inverted_list = [
invert_example(example) if index % 2 == 0 else example
for index, example in enumerate(examples)
]

return inverted_list


def get_dumb_nots():
# Load the Counterfact-Tracing dataset
dataset = load_dataset("NeelNanda/counterfact-tracing")
Expand Down Expand Up @@ -308,6 +348,6 @@ def get_and_save_neel_inverted_by_lm():

if __name__ == "__main__":
# get_and_save_neel_inverted_by_lm()
xs = get_dumb_nots()
xs = get_counterfact()
print(xs)
# upload_to_huggingface('filterdots.tsv')

0 comments on commit 481850a

Please sign in to comment.