From b390afe7126587a2b759eb820b08322cba94e841 Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Mon, 13 Jan 2025 00:25:02 +0800 Subject: [PATCH] fix(server): database interaction --- server/app.py | 6 ++- src/lm_saes/database.py | 84 +++++++++++++++++++++++++++++++---------- ui/src/types/feature.ts | 6 +-- 3 files changed, 71 insertions(+), 25 deletions(-) diff --git a/server/app.py b/server/app.py index cf18fe09..a8d6a1b9 100644 --- a/server/app.py +++ b/server/app.py @@ -116,7 +116,6 @@ def get_image(dataset_name: str, context_idx: int, image_idx: int): @app.get("/dictionaries/{name}/features/{feature_index}") def get_feature(name: str, feature_index: str | int): - model = get_model(name) if isinstance(feature_index, str) and feature_index != "random": try: feature_index = int(feature_index) @@ -143,6 +142,8 @@ def get_feature(name: str, feature_index: str | int): feature_acts = sampling.feature_acts[i] context_idx = sampling.context_idx[i] dataset_name = sampling.dataset_name[i] + model_name = sampling.model_name[i] + model = get_model(model_name) data = get_dataset(dataset_name)[context_idx] _, token_origins = model.to_tokens_with_origins(data) @@ -173,8 +174,9 @@ def get_feature(name: str, feature_index: str | int): make_serializable( { "feature_index": feature.index, + "dictionary_name": feature.sae_name, "act_times": feature.analyses[0].act_times, - "max_feature_act": feature.analyses[0].max_feature_act, + "max_feature_act": feature.analyses[0].max_feature_acts, "sample_groups": sample_groups, } ) diff --git a/src/lm_saes/database.py b/src/lm_saes/database.py index 773e80c4..4f674285 100644 --- a/src/lm_saes/database.py +++ b/src/lm_saes/database.py @@ -26,17 +26,18 @@ class DatasetRecord(BaseModel): class FeatureAnalysisSampling(BaseModel): name: str - feature_acts: list[float] + feature_acts: list[list[float]] dataset_name: list[str] shard_idx: Optional[list[int]] = None n_shards: Optional[list[int]] = None context_idx: list[int] + model_name: list[str] class FeatureAnalysis(BaseModel): name: str act_times: int - max_feature_act: int + max_feature_acts: float samplings: list[FeatureAnalysisSampling] @@ -189,33 +190,76 @@ def get_sae(self, sae_name: str, sae_series: str) -> Optional[SAERecord]: return None return SAERecord.model_validate(sae) - def get_random_alive_feature(self, sae_name: str, sae_series: str): - feature = self.feature_collection.aggregate( - [ - {"$match": {"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}}}, - {"$sample": {"size": 1}}, - ] - ).next() + def get_random_alive_feature( + self, sae_name: str, sae_series: str, name: str = "default" + ) -> Optional[FeatureRecord]: + """Get a random feature that has non-zero activation. + + Args: + sae_name: Name of the SAE model + sae_series: Series of the SAE model + name: Name of the analysis + + Returns: + A random feature record with non-zero activation, or None if no such feature exists + """ + pipeline = [ + { + "$match": { + "sae_name": sae_name, + "sae_series": sae_series, + "analyses": {"$elemMatch": {"name": name, "max_feature_acts": {"$gt": 0}}}, + } + }, + {"$sample": {"size": 1}}, + ] + feature = next(self.feature_collection.aggregate(pipeline), None) if feature is None: return None - return FeatureRecord.model_validate(**feature) + return FeatureRecord.model_validate(feature) - def get_alive_feature_count(self, sae_name: str, sae_series: str): - return self.feature_collection.count_documents( - {"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}} - ) + def get_alive_feature_count(self, sae_name: str, sae_series: str, name: str = "default"): + pipeline = [ + {"$unwind": "$analyses"}, + { + "$match": { + "sae_name": sae_name, + "sae_series": sae_series, + "analyses.name": name, + "analyses.max_feature_acts": {"$gt": 0}, + } + }, + {"$count": "count"}, + ] + return self.feature_collection.aggregate(pipeline).next()["count"] - def get_max_feature_acts(self, sae_name: str, sae_series: str) -> dict[int, int] | None: + def get_max_feature_acts(self, sae_name: str, sae_series: str, name: str = "default") -> dict[int, int] | None: pipeline = [ - {"$match": {"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}}}, - {"$project": {"_id": 0, "index": 1, "max_feature_acts": 1}}, + {"$unwind": "$analyses"}, + { + "$match": { + "sae_name": sae_name, + "sae_series": sae_series, + "analyses.name": name, + "analyses.max_feature_acts": {"$gt": 0}, + } + }, + {"$project": {"_id": 0, "index": 1, "max_feature_acts": "$analyses.max_feature_acts"}}, ] return {f["index"]: f["max_feature_acts"] for f in self.feature_collection.aggregate(pipeline)} - def get_feature_act_times(self, sae_name: str, sae_series: str): + def get_feature_act_times(self, sae_name: str, sae_series: str, name: str = "default"): pipeline = [ - {"$match": {"sae_name": sae_name, "sae_series": sae_series, "max_feature_acts": {"$gt": 0}}}, - {"$project": {"_id": 0, "index": 1, "act_times": 1}}, + {"$unwind": "$analyses"}, + { + "$match": { + "sae_name": sae_name, + "sae_series": sae_series, + "analyses.name": name, + "analyses.act_times": {"$gt": 0}, + } + }, + {"$project": {"_id": 0, "index": 1, "act_times": "$analyses.act_times"}}, ] return {f["index"]: f["act_times"] for f in self.feature_collection.aggregate(pipeline)} diff --git a/ui/src/types/feature.ts b/ui/src/types/feature.ts index a3bd7203..9fcbf5cd 100644 --- a/ui/src/types/feature.ts +++ b/ui/src/types/feature.ts @@ -55,7 +55,7 @@ export type Interpretation = z.infer; export const FeatureSchema = z.object({ featureIndex: z.number(), dictionaryName: z.string(), - featureActivationHistogram: z.any().nullable(), + featureActivationHistogram: z.any().nullish(), actTimes: z.number(), maxFeatureAct: z.number(), sampleGroups: z.array( @@ -80,8 +80,8 @@ export const FeatureSchema = z.object({ ), histogram: z.any(), }) - .nullable(), - interpretation: InterpretationSchema.nullable(), + .nullish(), + interpretation: InterpretationSchema.nullish(), }); export type Feature = z.infer;