diff --git a/nomic/dataset.py b/nomic/dataset.py index 6575eac2..4c3b21ba 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -10,6 +10,7 @@ from collections import defaultdict from contextlib import contextmanager from datetime import date, datetime +from math import isnan from pathlib import Path from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -1339,6 +1340,12 @@ def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: np.a embeddings: A numpy array of embeddings: each row corresponds to a row in the table. Use if you already have embeddings for your datapoints. pbar: (Optional). A tqdm progress bar to update. """ + + # More often than not, NaNs in this list indicate that someone is uploading from pandas, which uses NaN for null. + # Since we can't plot a NaN anyway, we replace it with nulls. + if (isinstance(data, list)): + data = swap_nans_for_nones(data) + if embeddings is not None or (isinstance(data, pa.Table) and "_embeddings" in data.column_names): self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar) else: @@ -1601,3 +1608,7 @@ def update_indices(self, rebuild_topic_models: bool = False): ) logger.info(f"Updating maps in dataset `{self.identifier}`") + + +def swap_nans_for_nones(items : List[Dict]): + return [{k : (v if not isnan(v) else None) for k, v in item.items() if not isnan} for item in items] \ No newline at end of file