Skip to content

Commit

Permalink
feat: remove cloud and login on import (#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elliott authored Jan 5, 2024
1 parent a30ca91 commit 85109ae
Show file tree
Hide file tree
Showing 21 changed files with 57 additions and 186 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ Install the package.
pip install dataquality
```

Create an account at [Galileo](https://console.cloud.rungalileo.io/sign-up)
Create an account at [Galileo](https://{console-url}.rungalileo.io/sign-up)

Grab your [token](https://console.cloud.rungalileo.io/get-token)
Grab your [token](https://console-url.rungalileo.io/get-token)

Get your dataset and analyze it with `dq.auto`
(You will be prompted for your token here)
Expand Down
30 changes: 1 addition & 29 deletions dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,12 @@
"""


__version__ = "1.4.2"
__version__ = "1.5.0"

import sys
from typing import Any, List, Optional

import dataquality.core._config
import dataquality.integrations

# We try/catch this in case the user installed dq inside of jupyter. You need to
# restart the kernel after the install and we want to make that clear. This is because
try:
import dataquality.metrics
from dataquality.analytics import Analytics
from dataquality.clients.api import ApiClient
except (FileNotFoundError, AttributeError):
raise Exception(
"It looks like you've installed dataquality from a notebook. "
"Please restart the kernel before continuing"
) from None
from dataquality.core import configure, set_console_url
from dataquality.core._config import config
from dataquality.core.auth import login, logout
Expand Down Expand Up @@ -144,21 +131,6 @@
pass


# Logging is optional. If enabled, imports, method calls
# and exceptions can be logged by calling the logger.
# This is useful for debugging and detecting issues.
# Logging is disabled by default for enterprise users.
# To enable logging, set the environment variable
# DQ_TELEMETRICS=1
# To log initiate the Analytics class and pass in the gallileo ApiClient + dq.config
# a = Analytics(ApiClient, config)
# Once initialized you can start logging
# a.log_import("dataquality")
# a.log_method_call("dataquality.log_data_samples")
a = Analytics(ApiClient, config)
a.log_import("dataquality")


class _DataQuality:
"""This class is used to create a singleton instance of the DataQuality class.
Expand Down
5 changes: 3 additions & 2 deletions dataquality/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@check_noop
def configure(do_login: bool = True, _internal: bool = False) -> None:
"""[Not for cloud users] Update your active config with new information
"""Update your active config with new information
You can use environment variables to set the config, or wait for prompts
Available environment variables to update:
Expand All @@ -31,7 +31,8 @@ def configure(do_login: bool = True, _internal: bool = False) -> None:

if "GALILEO_API_URL" in os.environ:
del os.environ["GALILEO_API_URL"]
updated_config = dataquality.core._config.reset_config(cloud=False)

updated_config = dataquality.core._config.reset_config()
for k, v in updated_config.dict().items():
config.__setattr__(k, v)
config.token = None
Expand Down
33 changes: 17 additions & 16 deletions dataquality/core/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataquality.schemas.task_type import TaskType
from dataquality.utils.helpers import galileo_disabled

CLOUD_URL = "https://console.cloud.rungalileo.io"
MINIMUM_API_VERSION = "0.4.0"
GALILEO_DEFAULT_IMG_BUCKET_NAME = "galileo-images"
GALILEO_DEFAULT_RUN_BUCKET_NAME = "galileo-project-runs"
Expand Down Expand Up @@ -99,7 +98,7 @@ def update_file_config(self) -> None:

@validator("api_url", pre=True, always=True, allow_reuse=True)
def add_scheme(cls, v: str) -> str:
if not v.startswith("http"):
if v and not v.startswith("http"):
# api url needs the scheme
v = f"http://{v}"
return v
Expand Down Expand Up @@ -199,12 +198,14 @@ def _check_console_url() -> None:
set_platform_urls(console_url_str=console_url)


def set_config(cloud: bool = True) -> Config:
def set_config(initial_startup: bool = False) -> Config:
if galileo_disabled():
return Config(api_url="")
_check_console_url()

if not os.path.isdir(config_data.DEFAULT_GALILEO_CONFIG_DIR):
os.makedirs(config_data.DEFAULT_GALILEO_CONFIG_DIR, exist_ok=True)

if os.path.exists(config_data.DEFAULT_GALILEO_CONFIG_FILE):
with open(config_data.DEFAULT_GALILEO_CONFIG_FILE) as f:
try:
Expand All @@ -227,28 +228,28 @@ def set_config(cloud: bool = True) -> Config:
config = Config(**galileo_vars)

else:
name = "Galileo Cloud" if cloud else "Galileo"
print(f"Welcome to {name} {dq_version}!")
if cloud:
console_url = CLOUD_URL
else:
print(
"To skip this prompt in the future, set the following environment "
"variable: GALILEO_CONSOLE_URL"
)
console_url = input("🔭 Enter the url of your Galileo console\n")
config = Config(api_url="")

if not initial_startup and not config.api_url:
print(f"Welcome to Galileo {dq_version}!")
print(
"To skip this prompt in the future, set the following environment "
"variable: GALILEO_CONSOLE_URL"
)
console_url = input("🔭 Enter the url of your Galileo console\n")
set_platform_urls(console_url_str=console_url)
galileo_vars = GalileoConfigVars.get_config_mapping()
config = Config(**galileo_vars)

config.update_file_config()
return config


def reset_config(cloud: bool = True) -> Config:
def reset_config() -> Config:
"""Wipe the config file and reconfigure"""
if os.path.isfile(config_data.DEFAULT_GALILEO_CONFIG_FILE):
os.remove(config_data.DEFAULT_GALILEO_CONFIG_FILE)
return set_config(cloud)
return set_config()


config = set_config()
config = set_config(initial_startup=True)
9 changes: 8 additions & 1 deletion dataquality/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import webbrowser

from dataquality.clients.api import ApiClient
from dataquality.core._config import config, url_is_localhost
from dataquality.core._config import config, reset_config, url_is_localhost
from dataquality.schemas.route import Route
from dataquality.utils.helpers import check_noop

Expand Down Expand Up @@ -42,6 +42,13 @@ def login() -> None:
To skip the prompt for automated workflows, you can set `GALILEO_USERNAME`
(your email) and GALILEO_PASSWORD if you signed up with an email and password
"""
if not config.api_url:
updated_config = reset_config()
for k, v in updated_config.dict().items():
config.__setattr__(k, v)
config.token = None
config.update_file_config()

if api_client.valid_current_user():
print(f"✅ Already logged in as {config.current_user}!")
print("Use logout() if you want to change users")
Expand Down
4 changes: 2 additions & 2 deletions dataquality/core/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def init(
run_name = create_run_name(project_name)
run_name = validate_name(run_name, assign_random=False)
run, run_created = _init.get_or_create_run(project_name, run_name, task_type)
dataquality.config.current_project_name = project_name
dataquality.config.current_run_name = run_name
config.current_project_name = project_name
config.current_run_name = run_name

if not run_created:
warnings.warn(
Expand Down
6 changes: 4 additions & 2 deletions dataquality/dq_auto/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from datasets import Dataset, DatasetDict

import dataquality as dq
from dataquality import Analytics, ApiClient
from dataquality import config
from dataquality.analytics import Analytics
from dataquality.clients.api import ApiClient
from dataquality.dq_auto.base_data_manager import BaseDatasetManager
from dataquality.dq_auto.ner_trainer import get_trainer
from dataquality.schemas.task_type import TaskType
from dataquality.utils.auto import add_val_data_if_missing, run_name_from_hf_dataset
from dataquality.utils.auto_trainer import do_train

a = Analytics(ApiClient, dq.config)
a = Analytics(ApiClient, config)
a.log_import("auto_ner")


Expand Down
3 changes: 2 additions & 1 deletion dataquality/dq_auto/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from transformers import Trainer

import dataquality as dq
from dataquality import Analytics, ApiClient
from dataquality.analytics import Analytics
from dataquality.clients.api import ApiClient
from dataquality.dq_auto.base_data_manager import BaseDatasetManager
from dataquality.dq_auto.tc_trainer import get_trainer
from dataquality.schemas.split import Split
Expand Down
6 changes: 4 additions & 2 deletions dataquality/integrations/transformers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,13 @@ def setup_model(self, model: Module) -> None:
# Setup the model only once
if self._model_setup:
return
assert dq.config.task_type, GalileoException(

task_type = dq.config.task_type
assert task_type, GalileoException(
"dq client must be initialized. "
"For example: dq.init('text_classification')"
)
self.task = dq.config.task_type
self.task = task_type
# Attach hooks to the model
self._attach_hooks_to_model(
model, self.classifier_layer, self.last_hidden_state_layer
Expand Down
26 changes: 0 additions & 26 deletions dataquality/loggers/data_logger/base_data_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import glob
import os
import sys
import warnings
Expand All @@ -22,7 +21,6 @@
from dataquality.schemas.ner import TaggingSchema
from dataquality.schemas.split import Split
from dataquality.utils import tqdm
from dataquality.utils.cloud import is_galileo_cloud
from dataquality.utils.cuda import cuml_available
from dataquality.utils.emb import (
DATA_EMB_PATH,
Expand Down Expand Up @@ -66,7 +64,6 @@ class BaseGalileoDataLogger(BaseGalileoLogger):
MAX_DOC_LEN = 10_000 # Max characters in document metadata attribute
LIMIT_NUM_DOCS = 3 # Limit the number of documents logged per split
INPUT_DATA_BASE = "input_data"
MAX_DATA_SIZE_CLOUD = 300_000
# 2GB max size for arrow strings. We use 1.5GB for some buffer
# https://issues.apache.org/jira/browse/ARROW-17828
STRING_MAX_SIZE_B = 1.5e9
Expand Down Expand Up @@ -190,8 +187,6 @@ def log(self) -> None:
os.makedirs(f"{self.input_data_path}/{self.split}", exist_ok=True)

df = self._get_input_df()
# Validates cloud size limit
self.validate_data_size(df)

ids = df["id"].tolist()
self.validate_ids_for_split(ids)
Expand Down Expand Up @@ -710,24 +705,3 @@ def _get_input_df(self) -> DataFrame:
def set_tagging_schema(cls, tagging_schema: TaggingSchema) -> None:
"""Sets the tagging schema, if applicable. Must be implemented by child"""
raise GalileoException(f"Cannot set tagging schema for {cls.__logger_name__}")

def validate_data_size(self, df: DataFrame) -> None:
"""Validates that the data size is within the limits of Galileo Cloud
If the data size is too large, a warning is raised.
"""
if not is_galileo_cloud():
return
samples_logged = len(df)
path_to_logged_data = f"{self.input_data_path}/*/*arrow"
if glob.glob(path_to_logged_data):
samples_logged += len(vaex.open(f"{self.input_data_path}/*/*arrow"))
nrows = BaseGalileoDataLogger.MAX_DATA_SIZE_CLOUD
if samples_logged > nrows:
warnings.warn(
f"⚠️ Hey there! You've logged over {nrows} rows in your input data. "
f"Galileo Cloud only supports up to {nrows} rows. "
"If you are using larger datasets, you may see degraded performance. "
"Please email us at [email protected] if you have any questions.",
GalileoWarning,
)
2 changes: 0 additions & 2 deletions dataquality/loggers/data_logger/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ def log(self) -> None:
We write the dfs to disk in the following locations:
/Users/username/.galileo/logs/proj-id/run-id/training/data/data.hdf5
/Users/username/.galileo/logs/proj-id/run-id/training/prob/prob.hdf5
NOTE: We don't restrict row or feature counts here for cloud users.
"""
self.validate_and_prepare_logger()

Expand Down
5 changes: 0 additions & 5 deletions dataquality/utils/cloud.py

This file was deleted.

11 changes: 0 additions & 11 deletions docs/notebooks/Inference-Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,6 @@
"We can log multiple inference runs with different inference names. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5af3a375-5fde-45ec-addf-657e504919c1",
"metadata": {},
"outputs": [],
"source": [
"#dq.init(task_type=\"text_classification\", project_name=\"gonzaga\", run_name=\"duke\")\n",
"dq.config"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
18 changes: 7 additions & 11 deletions docs/notebooks/NER Inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@
"metadata": {},
"outputs": [],
"source": [
"from dataquality import config\n",
"import pandas as pd\n",
"from dataquality import config\n",
"from dataquality.clients.api import ApiClient\n",
"from time import sleep\n",
"\n",
"\n",
"api_client = ApiClient()\n",
Expand All @@ -76,9 +75,11 @@
" print(\"Waiting for data to be processed\")\n",
" api_client.wait_for_run()\n",
"\n",
" task_type = dq.config.task_type\n",
" proj = api_client.get_project(config.current_project_id)[\"name\"]\n",
" run = api_client.get_project_run(config.current_project_id, config.current_run_id)[\"name\"]\n",
" task_type = \"text_ner\"\n",
" project_id = config.current_project_id\n",
" run_id = config.current_run_id\n",
" proj = api_client.get_project(project_id)[\"name\"]\n",
" run = api_client.get_project_run(project_id, run_id)[\"name\"]\n",
" api_client.export_run(proj, run, \"training\", f\"{task_type}_training.csv\")\n",
" api_client.export_run(proj, run, \"test\", f\"{task_type}_test.csv\")\n",
" api_client.export_run(proj, run, \"validation\", f\"{task_type}_validation.csv\")\n",
Expand Down Expand Up @@ -110,12 +111,7 @@
"metadata": {},
"outputs": [],
"source": [
"from dataquality.schemas.task_type import TaskType\n",
"from dataquality import config \n",
"from uuid import uuid4\n",
"import numpy as np\n",
"from time import sleep\n",
"from tqdm.notebook import tqdm\n",
"\n",
"\n",
"dq.init(\"text_ner\", \"test-ner-proj\", \"test-ner-run\")\n",
Expand All @@ -137,7 +133,7 @@
"def log_outputs():\n",
" num_classes = 28\n",
" embs = [np.random.rand(119, 768) for _ in range(5)]\n",
" logits= [np.random.rand(119, 28) for _ in range(5)] \n",
" logits= [np.random.rand(119, num_classes) for _ in range(5)] \n",
" ids= list(range(5))\n",
" for split in [\"inference\"]:\n",
" dq.log_model_outputs(\n",
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ dependencies = [
"ipywidgets>=8.1.0",
"imagededup>=0.3.1",
"pyjwt>=2.8.0",
"peft"
"peft",
# Pin opencv for linting incompatibility
"opencv-python<=4.8.1.78",
]
[[project.authors]]
name = "Galileo Technologies, Inc."
Expand Down
2 changes: 1 addition & 1 deletion tests/clients/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytest

import dataquality.clients.api
from dataquality import config
from dataquality.clients.api import ApiClient
from dataquality.core._config import config
from dataquality.exceptions import GalileoException
from dataquality.schemas import RequestType
from dataquality.schemas.task_type import TaskType
Expand Down
Loading

0 comments on commit 85109ae

Please sign in to comment.