Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: map from url #322

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
16 changes: 16 additions & 0 deletions examples/image/map_images_from_urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from datasets import load_dataset
from nomic import AtlasDataset
from tqdm import tqdm

dataset = load_dataset('ChihHsuan-Yang/Arboretum', split='train[:100000]')
ids = list(range(len(dataset)))
dataset = dataset.add_column("id", ids)

atlas_dataset = AtlasDataset("andriy/arboretum-100k-image-url-upload", unique_id_field="id")
records = dataset.remove_columns(["photo_id"]).to_list()

records = [record for record in tqdm(records) if record["photo_url"] is not None]
image_urls = [record.pop("photo_url") for record in records]

atlas_dataset.add_data(data=records, blobs=image_urls)
atlas_dataset.create_index(embedding_model="nomic-embed-vision-v1.5", topic_model=False)
60 changes: 48 additions & 12 deletions nomic/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,9 +1095,9 @@ def create_index(
modality = self.meta["modality"]

if modality == "image":
indexed_field = "_blob_hash"
if indexed_field is not None:
logger.warning("Ignoring indexed_field for image datasets. Only _blob_hash is supported.")
indexed_field = "_blob_hash"

colorable_fields = []

Expand Down Expand Up @@ -1170,11 +1170,14 @@ def create_index(

if modality == "image":
if topic_model.topic_label_field is None:
print(
"You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics."
)
if topic_model.build_topic_model:
logger.warning(
"You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics."
)
topic_model.build_topic_model = False

topic_field = None
topic_model.build_topic_model = False

else:
topic_field = (
topic_model.topic_label_field if topic_model.topic_label_field != indexed_field else None
Expand Down Expand Up @@ -1361,7 +1364,7 @@ def add_data(
Args:
data: A pandas DataFrame, list of dictionaries, or pyarrow Table matching the dataset schema.
embeddings: A numpy array of embeddings: each row corresponds to a row in the table. Use if you already have embeddings for your datapoints.
blobs: A list of image paths, bytes, or PIL Images. Use if you want to create an AtlasDataset using image embeddings over your images. Note: Blobs are stored locally only.
blobs: A list of image paths, bytes, PIL Images, or URLs. Use if you want to create an AtlasDataset using image embeddings over your images.
pbar: (Optional). A tqdm progress bar to update.
"""
if embeddings is not None:
Expand Down Expand Up @@ -1408,6 +1411,7 @@ def _add_blobs(

# TODO: add support for other modalities
images = []
urls = []
for uuid, blob in tqdm(zip(ids, blobs), total=len(ids), desc="Loading images"):
if isinstance(blob, str) and os.path.exists(blob):
# Auto resize to max 512x512
Expand All @@ -1417,6 +1421,8 @@ def _add_blobs(
buffered = BytesIO()
image.save(buffered, format="JPEG")
images.append((uuid, buffered.getvalue()))
elif isinstance(blob, str) and (blob.startswith("http://") or blob.startswith("https://")):
urls.append((uuid, blob))
elif isinstance(blob, bytes):
images.append((uuid, blob))
elif isinstance(blob, Image.Image):
Expand All @@ -1428,22 +1434,40 @@ def _add_blobs(
else:
raise ValueError(f"Invalid blob type for {uuid}. Must be a path to an image, bytes, or PIL Image.")

batch_size = 40
num_workers = 10
if len(images) == 0 and len(urls) == 0:
raise ValueError("No valid images found in the blobs list.")
if len(images) > 0 and len(urls) > 0:
raise ValueError("Cannot mix local and remote blobs in the same batch.")

if urls:
batch_size = 10
num_workers = 10
else:
batch_size = 40
num_workers = 10

def send_request(i):
image_batch = images[i : i + batch_size]
ids = [uuid for uuid, _ in image_batch]
blobs = [("blobs", blob) for _, blob in image_batch]
urls_batch = urls[i : i + batch_size]

if image_batch:
blobs = [("blobs", blob) for _, blob in image_batch]
ids = [uuid for uuid, _ in image_batch]
else:
blobs = []
ids = [uuid for uuid, _ in urls_batch]
urls_batch = [url for _, url in urls_batch]

response = requests.post(
self.atlas_api_path + blob_upload_endpoint,
headers=self.header,
data={"dataset_id": self.id},
data={"dataset_id": self.id, "urls": urls_batch},
files=blobs,
)
if response.status_code != 200:
raise Exception(response.text)
return {uuid: blob_hash for uuid, blob_hash in zip(ids, response.json()["hashes"])}
id2hash = {uuid: blob_hash for uuid, blob_hash in zip(ids, response.json()["hashes"])}
return id2hash

# if this method is being called internally, we pass a global progress bar
if pbar is None:
Expand All @@ -1452,6 +1476,7 @@ def send_request(i):
hash_schema = pa.schema([(self.id_field, pa.string()), ("_blob_hash", pa.string())])
returned_ids = []
returned_hashes = []
failed_ids = []

succeeded = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
Expand All @@ -1461,13 +1486,24 @@ def send_request(i):
response = future.result()
# add hash to data as _blob_hash
for uuid, blob_hash in response.items():
if blob_hash is None:
failed_ids.append(uuid)
continue

returned_ids.append(uuid)
returned_hashes.append(blob_hash)

# A successful upload.
succeeded += len(response)
pbar.update(len(response))

# remove all rows that failed to upload
if len(failed_ids) > 0:
failed_ids_array = pa.array(failed_ids, type=pa.string())
logger.info(f"Failed to upload {len(failed_ids)} blobs.")
logger.info(f"Filtering out {failed_ids} from the dataset.")
data = pc.filter(data, pc.invert(pc.is_in(data[self.id_field], failed_ids_array))) # type: ignore

hash_tb = pa.Table.from_pydict({self.id_field: returned_ids, "_blob_hash": returned_hashes}, schema=hash_schema)
merged_data = data.join(right_table=hash_tb, keys=self.id_field) # type: ignore

Expand Down