Skip to content

Commit

Permalink
fix doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst committed Oct 23, 2024
1 parent 54660f3 commit bab3632
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4278,7 +4278,7 @@ def to_tf(
If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns.
A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``.
>>> ds = ds.add_column("sample weights", lambda df: 1)
>>> ds = ds.add_column("sample weights", lambda x: [1] * x.num_rows)
>>> ds.to_tf(feature_columns="features", label_columns="target", additional_columns="sample weights")
<_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))>
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def to_tf(
If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns.
A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``.
>>> ds = ds.add_column("sample weights", lambda df: 1)
>>> ds = ds.add_column("sample weights", lambda x: [1] * x.num_rows)
>>> it = ds.iterator()
>>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target", additional_columns="sample weights")
<_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))>
Expand Down

0 comments on commit bab3632

Please sign in to comment.