Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swap out NaNs for nones #272

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions nomic/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import defaultdict
from contextlib import contextmanager
from datetime import date, datetime
from math import isnan
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -1339,6 +1340,12 @@ def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: np.a
embeddings: A numpy array of embeddings: each row corresponds to a row in the table. Use if you already have embeddings for your datapoints.
pbar: (Optional). A tqdm progress bar to update.
"""

# More often than not, NaNs in this list indicate that someone is uploading from pandas, which uses NaN for null.
# Since we can't plot a NaN anyway, we replace it with nulls.
if (isinstance(data, list)):
data = swap_nans_for_nones(data)

if embeddings is not None or (isinstance(data, pa.Table) and "_embeddings" in data.column_names):
self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar)
else:
Expand Down Expand Up @@ -1601,3 +1608,7 @@ def update_indices(self, rebuild_topic_models: bool = False):
)

logger.info(f"Updating maps in dataset `{self.identifier}`")


def swap_nans_for_nones(items : List[Dict]):
return [{k : (v if not isnan(v) else None) for k, v in item.items() if not isnan} for item in items]