Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: completion spark job #210

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ services:
start_period: 5s
retries: 2

spark-test:
profiles: ["spark-test"]
build: ./spark-test
environment:
JOB_NAME: "completion"
KUBECONFIG_FILE_PATH: "/opt/kubeconfig/kubeconfig.yaml"
SPARK_IMAGE: "radicalbit-spark-py:develop"
volumes:
- ./docker/k3s_data/kubeconfig/kubeconfig.yaml:/opt/kubeconfig/kubeconfig.yaml

dind:
image: docker:dind
privileged: true
Expand Down
11 changes: 11 additions & 0 deletions spark-test/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.11.8-slim

WORKDIR /spark-test

COPY requirements.txt requirements.txt

RUN pip install --no-cache-dir -r requirements.txt

COPY . .

CMD ["python3", "main.py"]
14 changes: 14 additions & 0 deletions spark-test/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def create_secrets():
return {
"AWS_ACCESS_KEY_ID": "minio",
"AWS_SECRET_ACCESS_KEY": "minio123",
"AWS_REGION": "us-east-1",
"S3_ENDPOINT_URL": "http://minio:9000",
"POSTGRES_URL": "jdbc:postgresql://postgres:5432/radicalbit",
"POSTGRES_DB": "radicalbit",
"POSTGRES_HOST": "postgres",
"POSTGRES_PORT": "5432",
"POSTGRES_USER": "postgres",
"POSTGRES_PASSWORD": "postgres",
"POSTGRES_SCHEMA": "public",
}
37 changes: 37 additions & 0 deletions spark-test/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
from conf import create_secrets
from uuid import uuid4
from spark_on_k8s.k8s.sync_client import KubernetesClientManager
from spark_on_k8s.client import SparkOnK8S

envs = ["KUBECONFIG_FILE_PATH", "JOB_NAME", "SPARK_IMAGE"]

for var in envs:
if var not in os.environ:
raise EnvironmentError("Failed because {} is not set.".format(var))

kube_conf = os.environ["KUBECONFIG_FILE_PATH"]
job_name = os.environ["JOB_NAME"]
spark_image = os.environ["SPARK_IMAGE"]

k8s_client_manager = KubernetesClientManager(kube_conf)
spark_k8s_client = SparkOnK8S(k8s_client_manager=k8s_client_manager)

path = "s3a://test-bucket/metrics_one.json"

spark_k8s_client.submit_app(
image=spark_image,
app_path=f"local:///opt/spark/custom_jobs/{job_name}_job.py",
app_arguments=[
path,
str(uuid4()),
"completion_dataset_metrics",
"completion_dataset"
],
app_name=f"{spark_image}-completion-job",
namespace="spark",
service_account="spark",
app_waiter="no_wait",
image_pull_policy="IfNotPresent",
secret_values=create_secrets(),
)
1 change: 1 addition & 0 deletions spark-test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
spark-on-k8s==0.10.1
10 changes: 8 additions & 2 deletions spark/jobs/completion_job.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sys
import os
import uuid

import orjson
from pyspark.sql.types import StructField, StructType, StringType

from metrics.completion_metrics import CompletionMetrics
from utils.models import JobStatus
from utils.db import update_job_status, write_to_db

Expand All @@ -13,7 +15,11 @@

def compute_metrics(df: DataFrame) -> dict:
complete_record = {}
# TODO: compute model quality metrics
completion_service = CompletionMetrics()
model_quality = completion_service.extract_metrics(df)
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump(serialize_as_any=True)).decode(
"utf-8"
)
return complete_record


Expand Down
85 changes: 85 additions & 0 deletions spark/jobs/metrics/completion_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pyspark.sql.functions as F
import numpy as np
from pyspark.sql import DataFrame
from pyspark.sql.types import FloatType

from models.completion_dataset import CompletionMetricsModel


class CompletionMetrics:
def __init__(self):
pass

@staticmethod
@F.udf(FloatType())
def compute_probability_udf(log_prob: float) -> float:
return float(np.exp(log_prob))

@staticmethod
@F.udf(FloatType())
def compute_perplexity(log_probs: list[float]) -> float:
return float(np.exp(-np.mean(log_probs)))

@staticmethod
@F.udf(FloatType())
def compute_prob_mean_per_phrase(probs: list[float]) -> float:
return float(np.mean(probs))

@staticmethod
def remove_columns(df: DataFrame) -> DataFrame:
df = df.drop(
"model",
"object",
"created",
"system_fingerprint",
"usage",
"service_tier",
)
return df

def compute_prob(self, df: DataFrame):
df = df.select(F.explode("choices").alias("element"), F.col("id"))
df = df.select(
F.col("id"), F.explode("element.logprobs.content").alias("content")
)
df = df.select("id", "content.logprob", "content.token").withColumn(
"prob", self.compute_probability_udf("logprob")
)
return df

def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel:
df = self.remove_columns(df)
df = self.compute_prob(df)
df_prob = df.drop("logprob")
df_prob = df_prob.groupBy("id").agg(
F.collect_list(F.struct("token", "prob")).alias("probs")
)
df_mean_values = df.groupBy("id").agg(
self.compute_prob_mean_per_phrase(F.collect_list("prob")).alias(
"prob_per_phrase"
),
self.compute_perplexity(F.collect_list("logprob")).alias(
"perplex_per_phrase"
),
)
df = df_mean_values.agg(
F.mean("prob_per_phrase").alias("prob_tot_mean"),
F.mean("perplex_per_phrase").alias("perplex_tot_mean"),
)
tokens = [
{
"id": row["id"],
"probs": [
{"token": prob["token"], "prob": prob["prob"]}
for prob in row["probs"]
],
}
for row in df_prob.toLocalIterator()
]

res = {
"tokens": tokens,
"mean_per_phrase": df_mean_values.toPandas().to_dict(orient="records"),
"mean_per_file": df.toPandas().to_dict(orient="records"),
}
return CompletionMetricsModel(**res)
35 changes: 35 additions & 0 deletions spark/jobs/models/completion_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pydantic import BaseModel, confloat, ConfigDict
from typing import List


class Prob(BaseModel):
token: str
prob: confloat(ge=0, le=1)


class Probs(BaseModel):
id: str
probs: List[Prob]

model_config = ConfigDict(ser_json_inf_nan="null")


class MeanPerPhrase(BaseModel):
id: str
prob_per_phrase: confloat(ge=0, le=1)
perplex_per_phrase: confloat(ge=1)

model_config = ConfigDict(ser_json_inf_nan="null")


class MeanPerFile(BaseModel):
prob_tot_mean: confloat(ge=0, le=1)
perplex_tot_mean: confloat(ge=1)

model_config = ConfigDict(ser_json_inf_nan="null")


class CompletionMetricsModel(BaseModel):
tokens: List[Probs]
mean_per_phrase: List[MeanPerPhrase]
mean_per_file: List[MeanPerFile]
45 changes: 45 additions & 0 deletions spark/tests/completion_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import orjson
from jobs.completion_job import compute_metrics
from jobs.metrics.completion_metrics import CompletionMetrics
from jobs.models.completion_dataset import CompletionMetricsModel
from tests.results.completion_metrics_results import completion_metric_results


@pytest.fixture
def input_file(spark_fixture, test_data_dir):
yield spark_fixture.read.option("multiline", "true").json(
f"{test_data_dir}/completion/metrics.json"
)


def test_remove_columns(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
assert "id" in df.columns
assert "choices" in df.columns
assert len(df.columns) == 2


def test_compute_prob(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
df = completion_metrics_service.compute_prob(df)
assert {"id", "logprob", "token", "prob"} == set(df.columns)
assert not df.rdd.isEmpty()


def test_extract_metrics(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
completion_metrics_model: CompletionMetricsModel = completion_metrics_service.extract_metrics(input_file)
assert len(completion_metrics_model.tokens) > 0
assert len(completion_metrics_model.mean_per_phrase) > 0
assert len(completion_metrics_model.mean_per_file) > 0


def test_compute_metrics(spark_fixture, input_file):
complete_record = compute_metrics(input_file)
model_quality = complete_record.get("MODEL_QUALITY")
assert model_quality == orjson.dumps(completion_metric_results).decode(
"utf-8"
)
4 changes: 4 additions & 0 deletions spark/tests/resources/completion/metrics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
{"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Sure", "bytes": [83, 117, 114, 101], "logprob": -0.61251247, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.102561064, "top_logprobs": []}, {"token": " go", "bytes": [32, 103, 111], "logprob": -2.5411978, "top_logprobs": []}, {"token": " ahead", "bytes": [32, 97, 104, 101, 97, 100], "logprob": -0.0014073749, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -2.0402107, "top_logprobs": []}, {"token": " What's", "bytes": [32, 87, 104, 97, 116, 39, 115], "logprob": -0.8943377, "top_logprobs": []}, {"token": " up", "bytes": [32, 117, 112], "logprob": -0.08216706, "top_logprobs": []}, {"token": "?", "bytes": [63], "logprob": -4.978234e-05, "top_logprobs": []}], "refusal": null}, "message": {"content": "Sure, go ahead. What's up?", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733743961, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "afnfwuinawufwa", "usage": {"completion_tokens": 8, "prompt_tokens": 45, "total_tokens": 53, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Certainly", "bytes": [67, 101, 114, 116, 97, 105, 110, 108, 121], "logprob": -3.8160203, "top_logprobs": []}, {"token": "!", "bytes": [33], "logprob": -0.11697425, "top_logprobs": []}, {"token": " Just", "bytes": [32, 74, 117, 115, 116], "logprob": -5.9011784, "top_logprobs": []}, {"token": " let", "bytes": [32, 108, 101, 116], "logprob": -0.666558, "top_logprobs": []}, {"token": " me", "bytes": [32, 109, 101], "logprob": -5.574252e-05, "top_logprobs": []}, {"token": " know", "bytes": [32, 107, 110, 111, 119], "logprob": -0.0008052219, "top_logprobs": []}, {"token": " how", "bytes": [32, 104, 111, 119], "logprob": -0.7132411, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.034184996, "top_logprobs": []}], "refusal": null}, "message": {"content": "Certainly! Just let me know how.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733751906, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "afnfwuinawufwa", "usage": {"completion_tokens": 8, "prompt_tokens": 45, "total_tokens": 53, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}}
]
5 changes: 5 additions & 0 deletions spark/tests/resources/completion/metrics_one.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
{"id": "chatcmpl-AcYPBK1t4QUvRG1oiS2JjoK0xdrGH", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Sure", "bytes": [83, 117, 114, 101], "logprob": -0.46845868, "top_logprobs": []}, {"token": "!", "bytes": [33], "logprob": -0.82938546, "top_logprobs": []}, {"token": " Please", "bytes": [32, 80, 108, 101, 97, 115, 101], "logprob": -2.1082883, "top_logprobs": []}, {"token": " ask", "bytes": [32, 97, 115, 107], "logprob": -0.67591065, "top_logprobs": []}, {"token": " your", "bytes": [32, 121, 111, 117, 114], "logprob": -0.022492433, "top_logprobs": []}, {"token": " question", "bytes": [32, 113, 117, 101, 115, 116, 105, 111, 110], "logprob": -0.13909559, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -3.3818815, "top_logprobs": []}, {"token": " I'll", "bytes": [32, 73, 39, 108, 108], "logprob": -0.4059926, "top_logprobs": []}, {"token": " be", "bytes": [32, 98, 101], "logprob": -1.799196, "top_logprobs": []}, {"token": " concise", "bytes": [32, 99, 111, 110, 99, 105, 115, 101], "logprob": -0.23712571, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.03838387, "top_logprobs": []}], "refusal": null}, "message": {"content": "Sure! Please ask your question. I'll be concise.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752081, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "dwadawdwd", "usage": {"completion_tokens": 11, "prompt_tokens": 45, "total_tokens": 56, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYQfMsRAIA01ynfZAQb8Zmyc6WKp", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Yes", "bytes": [89, 101, 115], "logprob": -5.5577775e-06, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -5.2285613e-05, "top_logprobs": []}, {"token": " Python", "bytes": [32, 80, 121, 116, 104, 111, 110], "logprob": -0.00089550123, "top_logprobs": []}, {"token": " is", "bytes": [32, 105, 115], "logprob": -4.246537e-06, "top_logprobs": []}, {"token": " versatile", "bytes": [32, 118, 101, 114, 115, 97, 116, 105, 108, 101], "logprob": -5.1374755, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.22549273, "top_logprobs": []}, {"token": " widely", "bytes": [32, 119, 105, 100, 101, 108, 121], "logprob": -0.06408858, "top_logprobs": []}, {"token": " used", "bytes": [32, 117, 115, 101, 100], "logprob": -0.0007200573, "top_logprobs": []}, {"token": " for", "bytes": [32, 102, 111, 114], "logprob": -0.24617708, "top_logprobs": []}, {"token": " data", "bytes": [32, 100, 97, 116, 97], "logprob": -2.132315, "top_logprobs": []}, {"token": " analysis", "bytes": [32, 97, 110, 97, 108, 121, 115, 105, 115], "logprob": -0.13818723, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.10035052, "top_logprobs": []}, {"token": " machine", "bytes": [32, 109, 97, 99, 104, 105, 110, 101], "logprob": -0.33921197, "top_logprobs": []}, {"token": " learning", "bytes": [32, 108, 101, 97, 114, 110, 105, 110, 103], "logprob": 0.0, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -9.138441e-05, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.005160108, "top_logprobs": []}, {"token": " monitoring", "bytes": [32, 109, 111, 110, 105, 116, 111, 114, 105, 110, 103], "logprob": -1.1899015, "top_logprobs": []}, {"token": " tasks", "bytes": [32, 116, 97, 115, 107, 115], "logprob": -0.50267833, "top_logprobs": []}, {"token": " due", "bytes": [32, 100, 117, 101], "logprob": -1.4371668, "top_logprobs": []}, {"token": " to", "bytes": [32, 116, 111], "logprob": -8.418666e-06, "top_logprobs": []}, {"token": " its", "bytes": [32, 105, 116, 115], "logprob": -6.9882217e-06, "top_logprobs": []}, {"token": " ease", "bytes": [32, 101, 97, 115, 101], "logprob": -3.047081, "top_logprobs": []}, {"token": " of", "bytes": [32, 111, 102], "logprob": -5.5265704e-05, "top_logprobs": []}, {"token": " use", "bytes": [32, 117, 115, 101], "logprob": -0.000113794704, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.0067897392, "top_logprobs": []}, {"token": " robust", "bytes": [32, 114, 111, 98, 117, 115, 116], "logprob": -2.8805246, "top_logprobs": []}, {"token": " libraries", "bytes": [32, 108, 105, 98, 114, 97, 114, 105, 101, 115], "logprob": -0.05405761, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.0008492942, "top_logprobs": []}], "refusal": null}, "message": {"content": "Yes, Python is versatile and widely used for data analysis, machine learning, and monitoring tasks due to its ease of use and robust libraries.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752173, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "fawfaw", "usage": {"completion_tokens": 28, "prompt_tokens": 45, "total_tokens": 73, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYR545NY6eo8xAdKuEbE60YgrF1a", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Yes", "bytes": [89, 101, 115], "logprob": -0.00012582695, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -2.8206474e-05, "top_logprobs": []}, {"token": " Rust", "bytes": [32, 82, 117, 115, 116], "logprob": -0.0001658757, "top_logprobs": []}, {"token": " is", "bytes": [32, 105, 115], "logprob": -0.0629955, "top_logprobs": []}, {"token": " known", "bytes": [32, 107, 110, 111, 119, 110], "logprob": -0.12907438, "top_logprobs": []}, {"token": " for", "bytes": [32, 102, 111, 114], "logprob": -6.704273e-07, "top_logprobs": []}, {"token": " its", "bytes": [32, 105, 116, 115], "logprob": -0.053365633, "top_logprobs": []}, {"token": " performance", "bytes": [32, 112, 101, 114, 102, 111, 114, 109, 97, 110, 99, 101], "logprob": -0.23435384, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.0018123905, "top_logprobs": []}, {"token": " memory", "bytes": [32, 109, 101, 109, 111, 114, 121], "logprob": -0.97972125, "top_logprobs": []}, {"token": " safety", "bytes": [32, 115, 97, 102, 101, 116, 121], "logprob": -1.8908588e-05, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.0049610855, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -5.3954464e-05, "top_logprobs": []}, {"token": " concurrency", "bytes": [32, 99, 111, 110, 99, 117, 114, 114, 101, 110, 99, 121], "logprob": -0.08939207, "top_logprobs": []}, {"token": " capabilities", "bytes": [32, 99, 97, 112, 97, 98, 105, 108, 105, 116, 105, 101, 115], "logprob": -2.2125401, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.38892052, "top_logprobs": []}], "refusal": null}, "message": {"content": "Yes, Rust is known for its performance, memory safety, and concurrency capabilities.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752199, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "wafwafwf", "usage": {"completion_tokens": 16, "prompt_tokens": 46, "total_tokens": 62, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}}
]
Loading
Loading