Skip to content

Commit

Permalink
add embedding layer; fix lm_ log_odds computation; add prefix space
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Nov 13, 2023
1 parent 26cf3b3 commit c17a37e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
4 changes: 4 additions & 0 deletions ccs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from .extraction import Extract, extract_hiddens
from .training import EigenFitter, EigenFitterConfig
from .training.train import Elicit
from .evaluation import Eval
from .truncated_eigh import truncated_eigh

__all__ = [
"EigenFitter",
"EigenFitterConfig",
"extract_hiddens",
"Extract",
"Elicit",
"Eval",
"truncated_eigh",
]
34 changes: 19 additions & 15 deletions ccs/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def __post_init__(self, layer_stride: int):
config = assert_type(
PretrainedConfig, AutoConfig.from_pretrained(self.model)
)
layer_range = range(1, config.num_hidden_layers, layer_stride)
self.layers = tuple(layer_range)
# Note that we always include 0 which is the embedding layer
layer_range = range(1, config.num_hidden_layers + 1, layer_stride)
self.layers = (0,) + tuple(layer_range)

def explode(self) -> list["Extract"]:
"""Explode this config into a list of configs, one for each layer."""
Expand Down Expand Up @@ -198,7 +199,7 @@ def extract_hiddens(
seed=cfg.seed,
)

layer_indices = cfg.layers or tuple(range(1, model.config.num_hidden_layers))
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1))

global_max_examples = cfg.max_examples[0 if split_type == "train" else 1]

Expand Down Expand Up @@ -263,14 +264,17 @@ def extract_hiddens(
if is_enc_dec:
answer = labels = assert_type(Tensor, encoding.labels)
else:
encoding2 = tokenizer(
choice["answer"],
# Don't include [CLS] and [SEP] in the answer
add_special_tokens=False,
return_tensors="pt",
).to(device)

answer = assert_type(Tensor, encoding2.input_ids)
a_id = tokenizer.encode(" " + choice["answer"], add_special_tokens=False)

# the Llama tokenizer splits off leading spaces
if tokenizer.decode(a_id[0]).strip() == "":
a_id_without_space = tokenizer.encode(
choice, add_special_tokens=False
)
assert a_id_without_space == a_id[1:]
a_id = a_id_without_space

answer = torch.tensor([a_id], device=device)
labels = (
# -100 is the mask token
torch.cat([torch.full_like(ids, -100), answer], dim=-1)
Expand All @@ -293,13 +297,13 @@ def extract_hiddens(

# Compute the log probability of the answer tokens if available
if has_lm_preds:
logprob = -assert_type(Tensor, outputs.loss)
logprob = -assert_type(Tensor, outputs.loss).to(torch.float32)
# Convert logprob to logodds to be consistent with reporters
# Because we went through logprobs, logodds corresponding to
# probs near 1 will be somewhat imprecise
# log(p/(1-p)) = log(p) - log(1-p) = logp - log(1 - exp(logp))
lm_log_odds[i, j] = logprob - torch.log1p(-logprob.exp())

hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
Expand Down Expand Up @@ -339,7 +343,7 @@ def extract_hiddens(
**hidden_dict,
)
if has_lm_preds:
out_record["lm_log_odds"] = lm_log_odds.log_softmax(dim=-1)
out_record["lm_log_odds"] = lm_log_odds

assert out_record["variant_ids"] == sorted(out_record["variant_ids"])
num_yielded += 1
Expand Down Expand Up @@ -377,7 +381,7 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]:
if num_dropped:
print(f"Dropping {num_dropped} non-multiple choice templates")

layer_indices = cfg.layers or tuple(range(1, model_cfg.num_hidden_layers))
layer_indices = cfg.layers or tuple(range(model_cfg.num_hidden_layers + 1))
layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
Expand Down

0 comments on commit c17a37e

Please sign in to comment.