Skip to content

Commit

Permalink
make to_datetime more robust to bad inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent-Maladiere committed Nov 13, 2023
1 parent 05f961e commit 71a0ec3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
25 changes: 14 additions & 11 deletions skrub/_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,21 +292,21 @@ def _get_datetime_column_indices(X_split, dayfirst=True):
index_to_format = {}

for col_idx, X_col in enumerate(X_split):
X_col = X_col[pd.notnull(X_col)]

# convert pd.TimeStamp to np.datetime64
if all(isinstance(val, pd.Timestamp) for val in X_col):
X_col = X_col.astype("datetime64")
X_col = X_col[pd.notnull(X_col)] # X_col is a numpy array

if _is_column_datetime_parsable(X_col):
indices.append(col_idx)

if np.issubdtype(X_col.dtype, np.datetime64):
# _guess_datetime_format only accept string columns.
# We need to filter out columns of object dtype that
# contains e.g., datetime.datetime or pd.Timestamp.
X_col_str = X_col.astype(str)
if np.array_equal(X_col, X_col_str):
datetime_format = _guess_datetime_format(X_col)
else:
# We don't need to specify a parsing format
# for columns that are already of type datetime64.
datetime_format = None
else:
datetime_format = _guess_datetime_format(X_col)

index_to_format[col_idx] = datetime_format

Expand All @@ -332,7 +332,7 @@ def _is_column_datetime_parsable(X_col):
try:
if np.array_equal(X_col, X_col.astype(np.float64)):
return False
except ValueError:
except (ValueError, TypeError):
pass

np_dtypes_candidates = [np.object_, np.str_, np.datetime64]
Expand All @@ -349,7 +349,7 @@ def _is_column_datetime_parsable(X_col):
# At this stage, the format itself doesn't matter.
_ = pd.to_datetime(X_col, format=MIXED_FORMAT)
return True
except (pd.errors.ParserError, ValueError):
except (pd.errors.ParserError, ValueError, TypeError):
pass
return False

Expand Down Expand Up @@ -377,7 +377,10 @@ def _guess_datetime_format(X_col):
-------
datetime_format : str or None
"""
X_col = X_col.astype(np.object_)
# Passing numpy.str_ (i.e. dtype '<U10') to 'guess_datetime_format'
# raises a TypeError.
# We have to convert these to the object dtype first.
X_col = X_col.astype("object")
vfunc = np.vectorize(guess_datetime_format)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
Expand Down
2 changes: 2 additions & 0 deletions skrub/tests/test_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def test_datetime_encoder_invalid_params():
[1, 2],
np.array([1, 2]),
pd.Timestamp(2020, 1, 1),
np.array([pd.Timestamp(2020, 1, 1), "hello"]),
np.array(["2020-01-01", {"hello"}]),
np.array(["2020-01-01", "hello", "2020-01-02"]),
],
)
Expand Down

0 comments on commit 71a0ec3

Please sign in to comment.