Skip to content

Commit

Permalink
feat: better refresh (#641)
Browse files Browse the repository at this point in the history
* feat: better refresh

* fix: include unlabeled data
  • Loading branch information
frederik-encord authored Sep 15, 2023
1 parent dab773d commit c55d061
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 20 deletions.
57 changes: 54 additions & 3 deletions src/encord_active/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional, Set

Expand All @@ -12,6 +13,7 @@

from encord_active.cli.project import project_cli
from encord_active.cli.utils.server import ensure_safe_project
from encord_active.lib.embeddings.embedding_index import EmbeddingIndex

load_dotenv()

Expand Down Expand Up @@ -385,7 +387,13 @@ def import_local_project(
@cli.command(name="refresh")
@ensure_project()
def refresh(
target: Path = typer.Option(Path.cwd(), "--target", "-t", help="Path to the target project.", file_okay=False)
target: Path = typer.Option(Path.cwd(), "--target", "-t", help="Path to the target project.", file_okay=False),
include_unlabeled: bool = typer.Option(
False,
"--include-unlabeled",
"-i",
help="Include unlabeled data. [blue]Note:[/blue] this will affect the results of 'encord.Project.list_label_rows()' as every label row will now have a label_hash.",
),
):
"""
[green bold]Sync[/green bold] data and labels from a remote Encord project :arrows_counterclockwise:
Expand All @@ -396,12 +404,55 @@ def refresh(
2. The hash of the remote Encord project (project_hash: remote-encord-project-hash).
3. The path to the private Encord user SSH key (ssh_key_path: private/encord/user/ssh/key/path).
"""
from encord_active.lib.db.connection import DBConnection
from encord_active.lib.metrics.types import EmbeddingType
from encord_active.lib.project import Project

try:
Project(target).refresh()
project = Project(target)
state = project.refresh(initialize_label_rows=include_unlabeled)

from encord_active.lib.metrics.execute import (
execute_metrics,
get_metrics_by_embedding_type,
)

embedding_types_to_update = list(
filter(
None,
[
EmbeddingType.IMAGE if state.data_changed else None,
EmbeddingType.OBJECT if state.labels_changed and state.has_objects else None,
EmbeddingType.CLASSIFICATION if state.labels_changed and state.has_classifications else None,
],
)
)
for et in embedding_types_to_update:
emb_file = project.file_structure.get_embeddings_file(et)
if emb_file.is_file():
emb_file.unlink()

emb_file_2d = project.file_structure.get_embeddings_file(et, reduced=True)
if emb_file_2d.is_file():
emb_file_2d.unlink()

EmbeddingIndex.remove_index(project.file_structure, et)

metrics_to_execute = list(chain(*[get_metrics_by_embedding_type(et) for et in embedding_types_to_update]))
execute_metrics(metrics_to_execute, data_dir=target, use_cache_only=True)
with DBConnection(project.file_structure) as conn:
from encord_active.lib.db.merged_metrics import (
MergedMetrics,
build_merged_metrics,
)

MergedMetrics(conn).replace_all(build_merged_metrics(project.file_structure.metrics))
ensure_safe_project(target)

except AttributeError as e:
rich.print(f"[orange1]{e}[/orange1]")
except Exception as e:
rich.print(f"[red] ERROR: The data sync failed. Log: {e}.")
rich.print(f"[red]ERROR: The data sync failed. Log: {e}.")
else:
rich.print("[green]Data and labels successfully synced from the remote project[/green]")

Expand Down
19 changes: 10 additions & 9 deletions src/encord_active/lib/embeddings/embedding_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import pickle
from pathlib import Path
from typing import NamedTuple, Optional

import numpy as np
Expand Down Expand Up @@ -37,17 +36,20 @@ def _query_in_batches(index: NNDescent, embeddings: np.ndarray, k: int, batch_si
return EmbeddingSearchResult(np.concatenate(dists, axis=0), np.concatenate(idxs, axis=0))


def _get_embedding_index_file(embedding_file: Path, metric: str) -> Path:
return embedding_file.parent / f"{embedding_file.stem}_{metric}_index.pkl"


class EmbeddingIndex:
@classmethod
def index_available(
cls, project_file_structure: ProjectFileStructure, embedding_type: EmbeddingType, metric: str = "cosine"
):
embedding_file = project_file_structure.get_embeddings_file(embedding_type)
return _get_embedding_index_file(embedding_file, metric).is_file()
return project_file_structure.get_embedding_index_file(embedding_type, metric).is_file()

@classmethod
def remove_index(
cls, project_file_structure: ProjectFileStructure, embedding_type: EmbeddingType, metric: str = "cosine"
):
embedding_file = project_file_structure.get_embedding_index_file(embedding_type, metric)
if embedding_file.is_file():
embedding_file.unlink()

@classmethod
def from_project(
Expand All @@ -70,8 +72,7 @@ def from_project(
An index ready for querying
"""
embedding_file = project_file_structure.get_embeddings_file(embedding_type)
index_file = _get_embedding_index_file(embedding_file, metric)
index_file = project_file_structure.get_embedding_index_file(embedding_type, metric)

if iterator is None:
iterator = DatasetIterator(cache_dir=project_file_structure.project_dir)
Expand Down
67 changes: 59 additions & 8 deletions src/encord_active/lib/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import json
import logging
import tempfile
from datetime import datetime, timedelta
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Union

if TYPE_CHECKING:
import prisma
Expand Down Expand Up @@ -46,6 +47,13 @@
encord_logger.setLevel(logging.ERROR)


class ProjectRefreshChangeResponse(NamedTuple):
data_changed: bool
labels_changed: bool
has_objects: bool
has_classifications: bool


class Project:
def __init__(self, project_dir: Union[str, Path]):
self.file_structure = ProjectFileStructure(project_dir)
Expand Down Expand Up @@ -122,14 +130,18 @@ def from_encord_project(self, encord_project: EncordProject) -> Project:

return self.load()

def refresh(self):
def refresh(self, initialize_label_rows: bool = False) -> ProjectRefreshChangeResponse:
"""
Refresh project data and labels using its remote project in Encord Annotate.
Args:
initialize_label_rows: if data has never been initialized, it won't have a label hash.
Setting this flag will give them a label hash.
:return: The updated project instance.
"""

self.project_meta = fetch_project_meta(self.file_structure.project_dir)
self.project_meta = self.file_structure.load_project_meta()
if not self.project_meta.get("has_remote", False):
raise AttributeError("The project does not have a remote project associated to it.")

Expand All @@ -146,10 +158,40 @@ def refresh(self):

encord_client = get_client(Path(ssh_key_path))
encord_project = encord_client.get_project(project_hash)
self.__download_and_save_label_rows(encord_project)

self.__save_project_meta(encord_project)
new_ontology = OntologyStructure.from_dict(encord_project.ontology)
self.save_ontology(new_ontology)

has_uninitialized_rows = not all(row["label_hash"] is not None for row in encord_project.label_rows)
if has_uninitialized_rows and initialize_label_rows:
untoched_data = list(filter(lambda x: x.label_hash is None, encord_project.list_label_rows_v2()))
collect_async(lambda x: x.initialise_labels(), untoched_data, desc="Preparing uninitialized label rows")
encord_project.refetch_data()

data_changed = self.__download_and_save_label_rows(encord_project)

stored_edit_times: list[datetime] = list(
filter(None, [meta.last_edited_at for meta in self.__load_label_row_meta().values()])
)
latest_stored_edit = max(stored_edit_times) if stored_edit_times else None
latest_edit_times: list[datetime] = list(
filter(None, [meta.last_edited_at for meta in encord_project.list_label_rows_v2()])
)
latest_edit = max(latest_edit_times) if latest_edit_times else None

self.__save_label_row_meta(encord_project) # Update cached metadata of the label rows (after new data sync)

return self.load()
return ProjectRefreshChangeResponse(
data_changed=data_changed,
labels_changed=data_changed
or (
latest_edit is not None
and (latest_stored_edit is None or (latest_edit - latest_stored_edit) > timedelta(seconds=1))
),
has_objects=len(new_ontology.objects) > 0,
has_classifications=len(new_ontology.classifications) > 0,
)

@property
def is_loaded(self) -> bool:
Expand Down Expand Up @@ -187,7 +229,7 @@ def __load_ontology(self):
raise FileNotFoundError(f"Expected file `ontology.json` at {ontology_file_path.parent}")
self.ontology = OntologyStructure.from_dict(json.loads(ontology_file_path.read_text(encoding="utf-8")))

def __save_label_row_meta(self, encord_project: EncordProject):
def __save_label_row_meta(self, encord_project: EncordProject) -> dict[str, Any]:
label_row_meta = {
lrm.label_hash: handle_enum_and_datetime(lrm)
for lrm in encord_project.list_label_rows()
Expand All @@ -198,6 +240,7 @@ def __save_label_row_meta(self, encord_project: EncordProject):
meta["last_edited_at"] = meta["last_edited_at"].rsplit(".", maxsplit=1)[0]

self.file_structure.label_row_meta.write_text(json.dumps(label_row_meta, indent=2), encoding="utf-8")
return label_row_meta

def __load_label_row_meta(self, subset_size: Optional[int] = None) -> dict[str, LabelRowMetadata]:
label_row_meta_file_path = self.file_structure.label_row_meta
Expand All @@ -220,11 +263,19 @@ def save_label_row(self, label_row: LabelRow):
where={"label_hash": label_row["label_hash"]},
)

def __download_and_save_label_rows(self, encord_project: EncordProject):
def __download_and_save_label_rows(self, encord_project: EncordProject) -> bool:
"""
Args:
encord_project: project to download data from
Returns:
boolean whether new data was downloaded.
"""
label_rows = self.__download_label_rows_and_data(encord_project, self.file_structure)
split_lr_videos(label_rows, self.file_structure)
logger.info("Data and labels successfully synced from the remote project")
return
return len(label_rows) > 0

def __download_label_rows_and_data(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/encord_active/lib/project/project_file_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def metrics_meta(self) -> Path:
def embeddings(self) -> Path:
return self.project_dir / "embeddings"

def get_embedding_index_file(self, embedding_type: EmbeddingType, metric: str) -> Path:
embedding_file = self.get_embeddings_file(embedding_type)
return embedding_file.parent / f"{embedding_file.stem}_{metric}_index.pkl"

def get_embeddings_file(self, type_: EmbeddingType, reduced: bool = False) -> Path:
lookup = EMBEDDING_REDUCED_TO_FILENAME if reduced else EMBEDDING_TYPE_TO_FILENAME
return self.embeddings / lookup[type_]
Expand Down

0 comments on commit c55d061

Please sign in to comment.