Skip to content

Commit

Permalink
fix(server): database interaction
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Jan 12, 2025
1 parent 210951d commit b390afe
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 25 deletions.
6 changes: 4 additions & 2 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def get_image(dataset_name: str, context_idx: int, image_idx: int):

@app.get("/dictionaries/{name}/features/{feature_index}")
def get_feature(name: str, feature_index: str | int):
model = get_model(name)
if isinstance(feature_index, str) and feature_index != "random":
try:
feature_index = int(feature_index)
Expand All @@ -143,6 +142,8 @@ def get_feature(name: str, feature_index: str | int):
feature_acts = sampling.feature_acts[i]
context_idx = sampling.context_idx[i]
dataset_name = sampling.dataset_name[i]
model_name = sampling.model_name[i]
model = get_model(model_name)
data = get_dataset(dataset_name)[context_idx]
_, token_origins = model.to_tokens_with_origins(data)

Expand Down Expand Up @@ -173,8 +174,9 @@ def get_feature(name: str, feature_index: str | int):
make_serializable(
{
"feature_index": feature.index,
"dictionary_name": feature.sae_name,
"act_times": feature.analyses[0].act_times,
"max_feature_act": feature.analyses[0].max_feature_act,
"max_feature_act": feature.analyses[0].max_feature_acts,
"sample_groups": sample_groups,
}
)
Expand Down
84 changes: 64 additions & 20 deletions src/lm_saes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ class DatasetRecord(BaseModel):

class FeatureAnalysisSampling(BaseModel):
name: str
feature_acts: list[float]
feature_acts: list[list[float]]
dataset_name: list[str]
shard_idx: Optional[list[int]] = None
n_shards: Optional[list[int]] = None
context_idx: list[int]
model_name: list[str]


class FeatureAnalysis(BaseModel):
name: str
act_times: int
max_feature_act: int
max_feature_acts: float
samplings: list[FeatureAnalysisSampling]


Expand Down Expand Up @@ -189,33 +190,76 @@ def get_sae(self, sae_name: str, sae_series: str) -> Optional[SAERecord]:
return None
return SAERecord.model_validate(sae)

def get_random_alive_feature(self, sae_name: str, sae_series: str):
feature = self.feature_collection.aggregate(
[
{"$match": {"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}}},
{"$sample": {"size": 1}},
]
).next()
def get_random_alive_feature(
self, sae_name: str, sae_series: str, name: str = "default"
) -> Optional[FeatureRecord]:
"""Get a random feature that has non-zero activation.
Args:
sae_name: Name of the SAE model
sae_series: Series of the SAE model
name: Name of the analysis
Returns:
A random feature record with non-zero activation, or None if no such feature exists
"""
pipeline = [
{
"$match": {
"sae_name": sae_name,
"sae_series": sae_series,
"analyses": {"$elemMatch": {"name": name, "max_feature_acts": {"$gt": 0}}},
}
},
{"$sample": {"size": 1}},
]
feature = next(self.feature_collection.aggregate(pipeline), None)
if feature is None:
return None
return FeatureRecord.model_validate(**feature)
return FeatureRecord.model_validate(feature)

def get_alive_feature_count(self, sae_name: str, sae_series: str):
return self.feature_collection.count_documents(
{"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}}
)
def get_alive_feature_count(self, sae_name: str, sae_series: str, name: str = "default"):
pipeline = [
{"$unwind": "$analyses"},
{
"$match": {
"sae_name": sae_name,
"sae_series": sae_series,
"analyses.name": name,
"analyses.max_feature_acts": {"$gt": 0},
}
},
{"$count": "count"},
]
return self.feature_collection.aggregate(pipeline).next()["count"]

def get_max_feature_acts(self, sae_name: str, sae_series: str) -> dict[int, int] | None:
def get_max_feature_acts(self, sae_name: str, sae_series: str, name: str = "default") -> dict[int, int] | None:
pipeline = [
{"$match": {"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}}},
{"$project": {"_id": 0, "index": 1, "max_feature_acts": 1}},
{"$unwind": "$analyses"},
{
"$match": {
"sae_name": sae_name,
"sae_series": sae_series,
"analyses.name": name,
"analyses.max_feature_acts": {"$gt": 0},
}
},
{"$project": {"_id": 0, "index": 1, "max_feature_acts": "$analyses.max_feature_acts"}},
]
return {f["index"]: f["max_feature_acts"] for f in self.feature_collection.aggregate(pipeline)}

def get_feature_act_times(self, sae_name: str, sae_series: str):
def get_feature_act_times(self, sae_name: str, sae_series: str, name: str = "default"):
pipeline = [
{"$match": {"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}}},
{"$project": {"_id": 0, "index": 1, "act_times": 1}},
{"$unwind": "$analyses"},
{
"$match": {
"sae_name": sae_name,
"sae_series": sae_series,
"analyses.name": name,
"analyses.act_times": {"$gt": 0},
}
},
{"$project": {"_id": 0, "index": 1, "act_times": "$analyses.act_times"}},
]
return {f["index"]: f["act_times"] for f in self.feature_collection.aggregate(pipeline)}

Expand Down
6 changes: 3 additions & 3 deletions ui/src/types/feature.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export type Interpretation = z.infer<typeof InterpretationSchema>;
export const FeatureSchema = z.object({
featureIndex: z.number(),
dictionaryName: z.string(),
featureActivationHistogram: z.any().nullable(),
featureActivationHistogram: z.any().nullish(),
actTimes: z.number(),
maxFeatureAct: z.number(),
sampleGroups: z.array(
Expand All @@ -80,8 +80,8 @@ export const FeatureSchema = z.object({
),
histogram: z.any(),
})
.nullable(),
interpretation: InterpretationSchema.nullable(),
.nullish(),
interpretation: InterpretationSchema.nullish(),
});

export type Feature = z.infer<typeof FeatureSchema>;

0 comments on commit b390afe

Please sign in to comment.