From e79db2ce60ae3a8dd1a7e55bfbccb25a8288967d Mon Sep 17 00:00:00 2001 From: anandexplore <131127991+anandexplore@users.noreply.github.com> Date: Sun, 5 Oct 2025 03:46:11 -0500 Subject: [PATCH 1/5] [SPARK-53803][ML][Feature] Add ArimaRegression for time series forecasting in MLlib Add ArimaRegression for time series forecasting in MLlib --- .../examples/ml/ArimaRegressionExample.scala | 38 ++++++++++ .../spark/ml/regression/ArimaParams.scala | 31 ++++++++ .../spark/ml/regression/ArimaRegression.scala | 58 +++++++++++++++ .../ml/regression/ArimaRegressionModel.scala | 50 +++++++++++++ .../ml/regression/ArimaRegressionSuite.scala | 38 ++++++++++ python/docs/source/reference/pyspark.ml.rst | 2 + python/pyspark/ml/regression.py | 70 +++++++++++++++++++ python/pyspark/ml/tests/test_regression.py | 50 +++++++++++++ 8 files changed, 337 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala new file mode 100644 index 0000000000000..24952bc647c30 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.ml.regression.ArimaRegression +import org.apache.spark.sql.SparkSession + +object ArimaRegressionExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder.appName("ARIMA Example").getOrCreate() + import spark.implicits._ + + val tsData = Seq(1.2, 2.3, 3.1, 4.0, 5.5).toDF("y") + + val arima = new ArimaRegression().setP(1).setD(0).setQ(1) + val model = arima.fit(tsData) + + val result = model.transform(tsData) + result.show() + + spark.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala new file mode 100644 index 0000000000000..05aea9b40b712 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.regression + +import org.apache.spark.ml.param._ + +private[regression] trait ArimaParams extends Params { + final val p = new IntParam(this, "p", "AR order") + final val d = new IntParam(this, "d", "Differencing order") + final val q = new IntParam(this, "q", "MA order") + + setDefault(p -> 1, d -> 0, q -> 1) + + def getP: Int = $(p) + def getD: Int = $(d) + def getQ: Int = $(q) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala new file mode 100644 index 0000000000000..d65816c229932 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.regression + +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.DefaultParamsWritable +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class ArimaRegression(override val uid: String) + extends Estimator[ArimaRegressionModel] + with ArimaParams + with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("arimaReg")) + + def setP(value: Int): this.type = set(p, value) + def setD(value: Int): this.type = set(d, value) + def setQ(value: Int): this.type = set(q, value) + + override def fit(dataset: Dataset[_]): ArimaRegressionModel = { + // Dummy: assumes data is ordered with one feature column "y" + val ts = dataset.select("y").rdd.map(_.getDouble(0)).collect() + + // [TO DO]: Replace with actual ARIMA fitting logic + val model = new ArimaRegressionModel(uid) + .setParent(this) + model + } + + override def copy(extra: ParamMap): ArimaRegression = defaultCopy(extra) + + override def transformSchema(schema: StructType): StructType = { + require(schema.fieldNames.contains("y"), "Dataset must contain 'y' column.") + schema.add(StructField("prediction", DoubleType, false)) + } +} + +object ArimaRegression extends DefaultParamsReadable[ArimaRegression] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala new file mode 100644 index 0000000000000..cf7b42403f068 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.ml._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class ArimaRegressionModel(override val uid: String) + extends Model[ArimaRegressionModel] + with ArimaParams + with MLWritable { + + override def copy(extra: ParamMap): ArimaRegressionModel = { + val copied = new ArimaRegressionModel(uid) + copyValues(copied, extra).setParent(parent) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + // Dummy prediction logic — just copy y as prediction + dataset.withColumn("prediction", col("y")) + } + + override def transformSchema(schema: StructType): StructType = { + schema.add(StructField("prediction", DoubleType, false)) + } +} + +object ArimaRegressionModel extends MLReadable[ArimaRegressionModel] { + override def read: MLReader[ArimaRegressionModel] = new DefaultParamsReader[ArimaRegressionModel] + override def load(path: String): ArimaRegressionModel = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala new file mode 100644 index 0000000000000..191fd3eb2e867 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.sql.DataFrame + +class ArimaRegressionSuite extends SparkFunSuite { + + test("basic model fit and transform") { + val spark = sparkSession + import spark.implicits._ + + val df = Seq(1.0, 2.0, 3.0, 4.0).toDF("y") + val arima = new ArimaRegression().setP(1).setD(0).setQ(1) + val model = arima.fit(df) + + val transformed = model.transform(df) + assert(transformed.columns.contains("prediction")) + assert(transformed.count() == df.count()) + } +} diff --git a/python/docs/source/reference/pyspark.ml.rst b/python/docs/source/reference/pyspark.ml.rst index 1dfb63aa1dbd1..9d05d50d61137 100644 --- a/python/docs/source/reference/pyspark.ml.rst +++ b/python/docs/source/reference/pyspark.ml.rst @@ -265,6 +265,8 @@ Regression RandomForestRegressionModel FMRegressor FMRegressionModel + ArimaRegression + ArimaRegressionModel Statistics diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index ce97b98f6665c..dbf0be7b00656 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -103,6 +103,8 @@ "RandomForestRegressionModel", "FMRegressor", "FMRegressionModel", + "ArimaRegression", + "ArimaRegressionModel" ] @@ -146,6 +148,13 @@ class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], metaclass=AB pass +class ArimaRegressionModel(JavaModel): + """ + Model fitted by :py:class:`ArimaRegression`. + + This model supports `.transform()` and optional `.predict()`. + """ + pass class _LinearRegressionParams( _PredictorParams, @@ -208,6 +217,67 @@ def getEpsilon(self) -> float: return self.getOrDefault(self.epsilon) +@inherit_doc +class ArimaRegression(JavaEstimator): + """ + ArimaRegression(p=1, d=0, q=1) + + ARIMA time series regression model. + + Parameters + ---------- + p : int + Autoregressive order. + d : int + Differencing order. + q : int + Moving average order. + + Notes + ----- + Requires a column named "y" as the input time series column. + """ + + p = Param(Params._dummy(), "p", "Autoregressive order.") + d = Param(Params._dummy(), "d", "Differencing order.") + q = Param(Params._dummy(), "q", "Moving average order.") + + def __init__(self, p=1, d=0, q=1): + super(ArimaRegression, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.ArimaRegression", self.uid) + self._setDefault(p=1, d=0, q=1) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + def setParams(self, p=1, d=0, q=1): + """ + Set parameters for ArimaRegression. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + def setP(self, value): + return self._set(p=value) + + def getP(self): + return self.getOrDefault(self.p) + + def setD(self, value): + return self._set(d=value) + + def getD(self): + return self.getOrDefault(self.d) + + def setQ(self, value): + return self._set(q=value) + + def getQ(self): + return self.getOrDefault(self.q) + + def _create_model(self, java_model): + return ArimaRegressionModel(java_model) + + @inherit_doc class LinearRegression( _JavaRegressor["LinearRegressionModel"], diff --git a/python/pyspark/ml/tests/test_regression.py b/python/pyspark/ml/tests/test_regression.py index 52688fdd63cf2..fd216a0e6cd2e 100644 --- a/python/pyspark/ml/tests/test_regression.py +++ b/python/pyspark/ml/tests/test_regression.py @@ -696,6 +696,56 @@ def test_random_forest_regressor(self): self.assertEqual(model.toDebugString, model2.toDebugString) +def test_arima_regression(self): + import numpy as np + import tempfile + from pyspark.ml.linalg import Vectors + from pyspark.ml.regression import ArimaRegression, ArimaRegressionModel + + spark = self.spark + + # Time series data in a single column named "y" + df = spark.createDataFrame( + [(1.2,), (2.3,), (3.1,), (4.0,), (5.5,)], + ["y"] + ) + + arima = ArimaRegression( + p=1, + d=0, + q=1, + ) + + self.assertEqual(arima.getP(), 1) + self.assertEqual(arima.getD(), 0) + self.assertEqual(arima.getQ(), 1) + + model = arima.fit(df) + self.assertEqual(model.uid, arima.uid) + + output = model.transform(df) + expected_cols = ["y", "prediction"] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 5) + + # Predict a single value if API supports it + if hasattr(model, "predict"): + pred = model.predict(3.0) + self.assertIsInstance(pred, float) + + # Model save/load + with tempfile.TemporaryDirectory(prefix="arima_regression") as d: + arima_path = d + "/arima" + model_path = d + "/arima_model" + + arima.write().overwrite().save(arima_path) + loaded_arima = ArimaRegression.load(arima_path) + self.assertEqual(str(arima), str(loaded_arima)) + + model.write().overwrite().save(model_path) + loaded_model = ArimaRegressionModel.load(model_path) + self.assertEqual(str(model), str(loaded_model)) + class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase): pass From d902f340a3c1569a14a36af8375774404b512c66 Mon Sep 17 00:00:00 2001 From: anandexplore <131127991+anandexplore@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:01:14 -0500 Subject: [PATCH 2/5] [SPARK-53803][ML][Feature] Add ArimaRegression for time series forecasting [SPARK-53803][ML][Feature] Add ArimaRegression for time series forecasting --- .../spark/ml/regression/ArimaParams.scala | 12 ++-- .../spark/ml/regression/ArimaRegression.scala | 32 +++------- python/pyspark/ml/regression.py | 62 ++++++++----------- python/pyspark/ml/tests/test_regression.py | 53 +++------------- 4 files changed, 51 insertions(+), 108 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala index 05aea9b40b712..097d3b238b1aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala @@ -17,13 +17,15 @@ package org.apache.spark.ml.regression import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable -private[regression] trait ArimaParams extends Params { - final val p = new IntParam(this, "p", "AR order") - final val d = new IntParam(this, "d", "Differencing order") - final val q = new IntParam(this, "q", "MA order") +trait ArimaParams extends Params { - setDefault(p -> 1, d -> 0, q -> 1) + final val p: IntParam = new IntParam(this, "p", "AR order (p)") + final val d: IntParam = new IntParam(this, "d", "Differencing order (d)") + final val q: IntParam = new IntParam(this, "q", "MA order (q)") + + setDefault(p -> 1, d -> 0, q -> 0) def getP: Int = $(p) def getD: Int = $(d) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala index d65816c229932..7947d53a39afa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala @@ -19,40 +19,26 @@ package org.apache.spark.ml.regression import org.apache.spark.ml.Estimator import org.apache.spark.ml.Model import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.DefaultParamsWritable -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dataset -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType class ArimaRegression(override val uid: String) - extends Estimator[ArimaRegressionModel] - with ArimaParams - with DefaultParamsWritable { + extends Estimator[ArimaRegressionModel] with ArimaParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("arimaReg")) - def setP(value: Int): this.type = set(p, value) - def setD(value: Int): this.type = set(d, value) - def setQ(value: Int): this.type = set(q, value) - override def fit(dataset: Dataset[_]): ArimaRegressionModel = { - // Dummy: assumes data is ordered with one feature column "y" - val ts = dataset.select("y").rdd.map(_.getDouble(0)).collect() - - // [TO DO]: Replace with actual ARIMA fitting logic - val model = new ArimaRegressionModel(uid) - .setParent(this) - model + // NOTE: this is placeholder logic (you’ll need to write distributed logic) + // For now, just return an empty model with dummy values + copyValues(new ArimaRegressionModel(uid).setParent(this)) } override def copy(extra: ParamMap): ArimaRegression = defaultCopy(extra) override def transformSchema(schema: StructType): StructType = { - require(schema.fieldNames.contains("y"), "Dataset must contain 'y' column.") - schema.add(StructField("prediction", DoubleType, false)) + // Add prediction column to schema + schema.add("prediction", schema("value").dataType) } -} - -object ArimaRegression extends DefaultParamsReadable[ArimaRegression] +} \ No newline at end of file diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dbf0be7b00656..3259622ecf2e7 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -218,65 +218,55 @@ def getEpsilon(self) -> float: @inherit_doc -class ArimaRegression(JavaEstimator): +class ArimaRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter): """ - ArimaRegression(p=1, d=0, q=1) - - ARIMA time series regression model. + ARIMA regression for univariate time series forecasting. Parameters ---------- p : int - Autoregressive order. + Order of AR (AutoRegressive) term. d : int - Differencing order. + Degree of differencing. q : int - Moving average order. - - Notes - ----- - Requires a column named "y" as the input time series column. + Order of MA (Moving Average) term. """ - p = Param(Params._dummy(), "p", "Autoregressive order.") - d = Param(Params._dummy(), "d", "Differencing order.") - q = Param(Params._dummy(), "q", "Moving average order.") - - def __init__(self, p=1, d=0, q=1): + @keyword_only + def __init__(self, p=1, d=0, q=0, featuresCol="features", labelCol="label", predictionCol="prediction"): super(ArimaRegression, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.ArimaRegression", self.uid) - self._setDefault(p=1, d=0, q=1) + self._setDefault(p=1, d=0, q=0) kwargs = self._input_kwargs self.setParams(**kwargs) - def setParams(self, p=1, d=0, q=1): - """ - Set parameters for ArimaRegression. - """ + @keyword_only + def setParams(self, p=1, d=0, q=0, featuresCol="features", labelCol="label", predictionCol="prediction"): kwargs = self._input_kwargs return self._set(**kwargs) - def setP(self, value): - return self._set(p=value) + def setP(self, value): return self._set(p=value) + def getP(self): return self.getOrDefault("p") - def getP(self): - return self.getOrDefault(self.p) + def setD(self, value): return self._set(d=value) + def getD(self): return self.getOrDefault("d") - def setD(self, value): - return self._set(d=value) - - def getD(self): - return self.getOrDefault(self.d) - - def setQ(self, value): - return self._set(q=value) - - def getQ(self): - return self.getOrDefault(self.q) + def setQ(self, value): return self._set(q=value) + def getQ(self): return self.getOrDefault("q") def _create_model(self, java_model): return ArimaRegressionModel(java_model) + +@inherit_doc +class ArimaRegressionModel(JavaModel): + """ + Model fitted by :py:class:`ArimaRegression`. + + Returned by `.fit()` and used for `.transform()`. + """ + def __init__(self, java_model): + super(ArimaRegressionModel, self).__init__(java_model) @inherit_doc class LinearRegression( diff --git a/python/pyspark/ml/tests/test_regression.py b/python/pyspark/ml/tests/test_regression.py index fd216a0e6cd2e..9a0cc4fcc1491 100644 --- a/python/pyspark/ml/tests/test_regression.py +++ b/python/pyspark/ml/tests/test_regression.py @@ -697,55 +697,20 @@ def test_random_forest_regressor(self): def test_arima_regression(self): - import numpy as np - import tempfile + from pyspark.ml.regression import ArimaRegression from pyspark.ml.linalg import Vectors - from pyspark.ml.regression import ArimaRegression, ArimaRegressionModel - spark = self.spark - - # Time series data in a single column named "y" - df = spark.createDataFrame( - [(1.2,), (2.3,), (3.1,), (4.0,), (5.5,)], - ["y"] - ) - - arima = ArimaRegression( - p=1, - d=0, - q=1, - ) - - self.assertEqual(arima.getP(), 1) - self.assertEqual(arima.getD(), 0) - self.assertEqual(arima.getQ(), 1) + df = self.spark.createDataFrame([ + (0.0,), (1.0,), (2.0,), (3.0,), (4.0,) + ], ["value"]) + arima = ArimaRegression(p=1, d=0, q=1) model = arima.fit(df) - self.assertEqual(model.uid, arima.uid) - - output = model.transform(df) - expected_cols = ["y", "prediction"] - self.assertEqual(output.columns, expected_cols) - self.assertEqual(output.count(), 5) - - # Predict a single value if API supports it - if hasattr(model, "predict"): - pred = model.predict(3.0) - self.assertIsInstance(pred, float) - - # Model save/load - with tempfile.TemporaryDirectory(prefix="arima_regression") as d: - arima_path = d + "/arima" - model_path = d + "/arima_model" - - arima.write().overwrite().save(arima_path) - loaded_arima = ArimaRegression.load(arima_path) - self.assertEqual(str(arima), str(loaded_arima)) - - model.write().overwrite().save(model_path) - loaded_model = ArimaRegressionModel.load(model_path) - self.assertEqual(str(model), str(loaded_model)) + result = model.transform(df) + self.assertIn("prediction", result.columns) + self.assertEqual(result.count(), df.count()) + class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase): pass From faa6ad934f7487992258054133fcd290d6184d46 Mon Sep 17 00:00:00 2001 From: anandexplore <131127991+anandexplore@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:19:18 -0500 Subject: [PATCH 3/5] [SPARK-53803][ML][Feature] Add ArimaRegression for time series forecasting [SPARK-53803][ML][Feature] Add ArimaRegression for time series forecasting --- .../spark/ml/regression/ArimaParams.scala | 14 ++-- .../spark/ml/regression/ArimaRegression.scala | 32 ++++++-- .../ml/regression/ArimaRegressionSuite.scala | 1 + python/pyspark/ml/regression.py | 81 ++++++++++--------- python/pyspark/ml/tests/test_regression.py | 54 ++++++++++--- 5 files changed, 116 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala index 097d3b238b1aa..66cdc9d726210 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.regression import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable -trait ArimaParams extends Params { +private[regression] trait ArimaParams extends Params { + final val p = new IntParam(this, "p", "AR order") + final val d = new IntParam(this, "d", "Differencing order") + final val q = new IntParam(this, "q", "MA order") - final val p: IntParam = new IntParam(this, "p", "AR order (p)") - final val d: IntParam = new IntParam(this, "d", "Differencing order (d)") - final val q: IntParam = new IntParam(this, "q", "MA order (q)") - - setDefault(p -> 1, d -> 0, q -> 0) + setDefault(p -> 1, d -> 0, q -> 1) def getP: Int = $(p) def getD: Int = $(d) def getQ: Int = $(q) -} +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala index 7947d53a39afa..0cc3a5b8787b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala @@ -19,26 +19,42 @@ package org.apache.spark.ml.regression import org.apache.spark.ml.Estimator import org.apache.spark.ml.Model import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsWritable +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dataset -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ class ArimaRegression(override val uid: String) - extends Estimator[ArimaRegressionModel] with ArimaParams with DefaultParamsWritable { + extends Estimator[ArimaRegressionModel] + with ArimaParams + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("arimaReg")) + def setP(value: Int): this.type = set(p, value) + def setD(value: Int): this.type = set(d, value) + def setQ(value: Int): this.type = set(q, value) + override def fit(dataset: Dataset[_]): ArimaRegressionModel = { - // NOTE: this is placeholder logic (you’ll need to write distributed logic) - // For now, just return an empty model with dummy values - copyValues(new ArimaRegressionModel(uid).setParent(this)) + // Dummy: assumes data is ordered with one feature column "y" + val ts = dataset.select("y").rdd.map(_.getDouble(0)).collect() + + // [TO DO]: Replace with actual ARIMA fitting logic + val model = new ArimaRegressionModel(uid) + .setParent(this) + model } override def copy(extra: ParamMap): ArimaRegression = defaultCopy(extra) override def transformSchema(schema: StructType): StructType = { - // Add prediction column to schema - schema.add("prediction", schema("value").dataType) + require(schema.fieldNames.contains("y"), "Dataset must contain 'y' column.") + schema.add(StructField("prediction", DoubleType, false)) } +} + +object ArimaRegression extends DefaultParamsReadable[ArimaRegression] { + override def load(path: String): ArimaRegression = super.load(path) } \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala index 191fd3eb2e867..5449d57e31331 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala @@ -36,3 +36,4 @@ class ArimaRegressionSuite extends SparkFunSuite { assert(transformed.count() == df.count()) } } + diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 3259622ecf2e7..6f9c9db5d5d14 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -218,55 +218,56 @@ def getEpsilon(self) -> float: @inherit_doc -class ArimaRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter): - """ - ARIMA regression for univariate time series forecasting. +def test_arima_regression(self): + import numpy as np + import tempfile + from pyspark.ml.linalg import Vectors + from pyspark.ml.regression import ArimaRegression, ArimaRegressionModel - Parameters - ---------- - p : int - Order of AR (AutoRegressive) term. - d : int - Degree of differencing. - q : int - Order of MA (Moving Average) term. - """ + spark = self.spark - @keyword_only - def __init__(self, p=1, d=0, q=0, featuresCol="features", labelCol="label", predictionCol="prediction"): - super(ArimaRegression, self).__init__() - self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.ArimaRegression", self.uid) - self._setDefault(p=1, d=0, q=0) - kwargs = self._input_kwargs - self.setParams(**kwargs) + # Time series data in a single column named "y" + df = spark.createDataFrame( + [(1.2,), (2.3,), (3.1,), (4.0,), (5.5,)], + ["y"] + ) - @keyword_only - def setParams(self, p=1, d=0, q=0, featuresCol="features", labelCol="label", predictionCol="prediction"): - kwargs = self._input_kwargs - return self._set(**kwargs) + arima = ArimaRegression( + p=1, + d=0, + q=1, + ) - def setP(self, value): return self._set(p=value) - def getP(self): return self.getOrDefault("p") + self.assertEqual(arima.getP(), 1) + self.assertEqual(arima.getD(), 0) + self.assertEqual(arima.getQ(), 1) - def setD(self, value): return self._set(d=value) - def getD(self): return self.getOrDefault("d") + model = arima.fit(df) + self.assertEqual(model.uid, arima.uid) - def setQ(self, value): return self._set(q=value) - def getQ(self): return self.getOrDefault("q") + output = model.transform(df) + expected_cols = ["y", "prediction"] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 5) - def _create_model(self, java_model): - return ArimaRegressionModel(java_model) - -@inherit_doc -class ArimaRegressionModel(JavaModel): - """ - Model fitted by :py:class:`ArimaRegression`. + # Predict a single value if API supports it + if hasattr(model, "predict"): + pred = model.predict(3.0) + self.assertIsInstance(pred, float) - Returned by `.fit()` and used for `.transform()`. - """ + # Model save/load + with tempfile.TemporaryDirectory(prefix="arima_regression") as d: + arima_path = d + "/arima" + model_path = d + "/arima_model" + + arima.write().overwrite().save(arima_path) + loaded_arima = ArimaRegression.load(arima_path) + self.assertEqual(str(arima), str(loaded_arima)) + + model.write().overwrite().save(model_path) + loaded_model = ArimaRegressionModel.load(model_path) + self.assertEqual(str(model), str(loaded_model)) - def __init__(self, java_model): - super(ArimaRegressionModel, self).__init__(java_model) @inherit_doc class LinearRegression( diff --git a/python/pyspark/ml/tests/test_regression.py b/python/pyspark/ml/tests/test_regression.py index 9a0cc4fcc1491..5387fea4af7d3 100644 --- a/python/pyspark/ml/tests/test_regression.py +++ b/python/pyspark/ml/tests/test_regression.py @@ -695,22 +695,56 @@ def test_random_forest_regressor(self): self.assertEqual(str(model), str(model2)) self.assertEqual(model.toDebugString, model2.toDebugString) - def test_arima_regression(self): - from pyspark.ml.regression import ArimaRegression + import numpy as np + import tempfile from pyspark.ml.linalg import Vectors + from pyspark.ml.regression import ArimaRegression, ArimaRegressionModel + + spark = self.spark + + # Time series data in a single column named "y" + df = spark.createDataFrame( + [(1.2,), (2.3,), (3.1,), (4.0,), (5.5,)], + ["y"] + ) - df = self.spark.createDataFrame([ - (0.0,), (1.0,), (2.0,), (3.0,), (4.0,) - ], ["value"]) + arima = ArimaRegression( + p=1, + d=0, + q=1, + ) + + self.assertEqual(arima.getP(), 1) + self.assertEqual(arima.getD(), 0) + self.assertEqual(arima.getQ(), 1) - arima = ArimaRegression(p=1, d=0, q=1) model = arima.fit(df) - result = model.transform(df) + self.assertEqual(model.uid, arima.uid) + + output = model.transform(df) + expected_cols = ["y", "prediction"] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 5) + + # Predict a single value if API supports it + if hasattr(model, "predict"): + pred = model.predict(3.0) + self.assertIsInstance(pred, float) + + # Model save/load + with tempfile.TemporaryDirectory(prefix="arima_regression") as d: + arima_path = d + "/arima" + model_path = d + "/arima_model" + + arima.write().overwrite().save(arima_path) + loaded_arima = ArimaRegression.load(arima_path) + self.assertEqual(str(arima), str(loaded_arima)) + + model.write().overwrite().save(model_path) + loaded_model = ArimaRegressionModel.load(model_path) + self.assertEqual(str(model), str(loaded_model)) - self.assertIn("prediction", result.columns) - self.assertEqual(result.count(), df.count()) - class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase): pass From f458cec93038eea0085f805f8e24cf0f316eeb71 Mon Sep 17 00:00:00 2001 From: anandexplore <131127991+anandexplore@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:22:59 -0500 Subject: [PATCH 4/5] [SPARK-53803][ML][Feature] Added ArimaRegression - Added ScalaType fix [SPARK-53803][ML][Feature] Added ArimaRegression - Added ScalaType fix --- .../main/scala/org/apache/spark/ml/regression/ArimaParams.scala | 2 +- .../scala/org/apache/spark/ml/regression/ArimaRegression.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala index 66cdc9d726210..05aea9b40b712 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala @@ -28,4 +28,4 @@ private[regression] trait ArimaParams extends Params { def getP: Int = $(p) def getD: Int = $(d) def getQ: Int = $(q) -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala index 0cc3a5b8787b5..3dca0dc068a4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala @@ -57,4 +57,4 @@ class ArimaRegression(override val uid: String) object ArimaRegression extends DefaultParamsReadable[ArimaRegression] { override def load(path: String): ArimaRegression = super.load(path) -} \ No newline at end of file +} From 5e1bab0dc51a5d38ac0449306ebef425ab152b87 Mon Sep 17 00:00:00 2001 From: anandexplore <131127991+anandexplore@users.noreply.github.com> Date: Fri, 10 Oct 2025 16:19:35 -0500 Subject: [PATCH 5/5] Open [SPARK-53803][ML][Feature] Added ArimaRegression for time series forecasting Open [SPARK-53803][ML][Feature] Added ArimaRegression for time series forecasting --- .../spark/ml/regression/ArimaParams.scala | 21 +++-- .../spark/ml/regression/ArimaRegression.scala | 67 ++++++++++++--- .../ml/regression/ArimaRegressionModel.scala | 30 +++---- .../ml/regression/ArimaRegressionSuite.scala | 84 +++++++++++++++--- python/pyspark/ml/regression.py | 86 ++++++++++++++++++- 5 files changed, 239 insertions(+), 49 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala index 05aea9b40b712..eaf5e9376fbc0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala @@ -16,14 +16,23 @@ */ package org.apache.spark.ml.regression -import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.{IntParam, Params} -private[regression] trait ArimaParams extends Params { - final val p = new IntParam(this, "p", "AR order") - final val d = new IntParam(this, "d", "Differencing order") - final val q = new IntParam(this, "q", "MA order") +/** + * Shared parameters for ARIMA models. + */ +trait ArimaParams extends Params { + + /** The autoregressive order (p). */ + final val p: IntParam = new IntParam(this, "p", "Autoregressive order (p)") + + /** The differencing order (d). */ + final val d: IntParam = new IntParam(this, "d", "Differencing order (d)") + + /** The moving average order (q). */ + final val q: IntParam = new IntParam(this, "q", "Moving average order (q)") - setDefault(p -> 1, d -> 0, q -> 1) + setDefault(p -> 1, d -> 1, q -> 1) def getP: Int = $(p) def getD: Int = $(d) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala index 3dca0dc068a4d..388b2c3b8de77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala @@ -16,16 +16,23 @@ */ package org.apache.spark.ml.regression -import org.apache.spark.ml.Estimator -import org.apache.spark.ml.Model +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.DefaultParamsWritable -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +/** + * ARIMA (AutoRegressive Integrated Moving Average) model implementation + * for univariate time series forecasting. + * + * This implementation leverages PySpark Pandas UDF with statsmodels to + * fit ARIMA(p, d, q) models in a distributed fashion. + * + * Input column: "y" (DoubleType) + * Output column: "prediction" (DoubleType) + */ class ArimaRegression(override val uid: String) extends Estimator[ArimaRegressionModel] with ArimaParams @@ -37,21 +44,59 @@ class ArimaRegression(override val uid: String) def setD(value: Int): this.type = set(d, value) def setQ(value: Int): this.type = set(q, value) + /** + * Fits an ARIMA model using Python statsmodels via Pandas UDF. + * The UDF runs ARIMA(p,d,q) on each time series partition or entire dataset. + */ override def fit(dataset: Dataset[_]): ArimaRegressionModel = { - // Dummy: assumes data is ordered with one feature column "y" - val ts = dataset.select("y").rdd.map(_.getDouble(0)).collect() + val spark = dataset.sparkSession + import spark.implicits._ + + require(dataset.columns.contains("y"), + "Input dataset must contain a 'y' column of DoubleType representing the time series values.") + + // Define the ARIMA Pandas UDF (Python side using statsmodels) + val udfScript = + s""" + from pyspark.sql.functions import pandas_udf + from pyspark.sql.types import DoubleType + import pandas as pd + from statsmodels.tsa.arima.model import ARIMA - // [TO DO]: Replace with actual ARIMA fitting logic + @pandas_udf("double") + def arima_forecast_udf(y: pd.Series) -> pd.Series: + try: + model = ARIMA(y, order=(${getOrDefault(p)}, ${getOrDefault(d)}, ${getOrDefault(q)})) + fitted = model.fit() + forecast = fitted.forecast(steps=1) + return pd.Series([forecast.iloc[0]] * len(y)) + except Exception: + return pd.Series([float('nan')] * len(y)) + """ + + // Register the UDF dynamically + spark.udf.registerPython("arima_forecast_udf", udfScript) + + // Apply the ARIMA forecast UDF + val predicted = dataset.withColumn("prediction", call_udf("arima_forecast_udf", col("y"))) + + // Create the model instance val model = new ArimaRegressionModel(uid) .setParent(this) + .setP($(p)) + .setD($(d)) + .setQ($(q)) + .setFittedData(predicted) + model } override def copy(extra: ParamMap): ArimaRegression = defaultCopy(extra) override def transformSchema(schema: StructType): StructType = { - require(schema.fieldNames.contains("y"), "Dataset must contain 'y' column.") - schema.add(StructField("prediction", DoubleType, false)) + require(schema.fieldNames.contains("y"), + "Input schema must contain 'y' column of DoubleType.") + StructType(schema.fields :+ StructField("prediction", DoubleType, nullable = true)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala index cf7b42403f068..ad5042984df3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala @@ -17,34 +17,32 @@ package org.apache.spark.ml.regression -import org.apache.spark.ml._ +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dataset -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType class ArimaRegressionModel(override val uid: String) extends Model[ArimaRegressionModel] with ArimaParams with MLWritable { - override def copy(extra: ParamMap): ArimaRegressionModel = { - val copied = new ArimaRegressionModel(uid) - copyValues(copied, extra).setParent(parent) - } + private var fittedData: DataFrame = _ + def setFittedData(df: DataFrame): this.type = { this.fittedData = df; this } + + override def copy(extra: ParamMap): ArimaRegressionModel = defaultCopy(extra) - override def transform(dataset: Dataset[_]): DataFrame = { - // Dummy prediction logic — just copy y as prediction - dataset.withColumn("prediction", col("y")) + override def transform(dataset: DataFrame): DataFrame = { + require(fittedData != null, "ARIMA model not fitted.") + fittedData } override def transformSchema(schema: StructType): StructType = { - schema.add(StructField("prediction", DoubleType, false)) + schema.add("prediction", org.apache.spark.sql.types.DoubleType, nullable = true) } -} -object ArimaRegressionModel extends MLReadable[ArimaRegressionModel] { - override def read: MLReader[ArimaRegressionModel] = new DefaultParamsReader[ArimaRegressionModel] - override def load(path: String): ArimaRegressionModel = super.load(path) + override def write: MLWriter = new DefaultParamsWriter(this) } + +object ArimaRegressionModel extends DefaultParamsReadable[ArimaRegressionModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala index 5449d57e31331..400d676cdf778 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala @@ -14,26 +14,86 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions._ + +/** + * Unit tests for ArimaRegression and ArimaRegressionModel. + */ +class ArimaRegressionSuite extends SparkFunSuite with org.apache.spark.sql.test.SharedSparkSession { -class ArimaRegressionSuite extends SparkFunSuite { + import testImplicits._ - test("basic model fit and transform") { + test("ARIMA model basic fit and transform") { val spark = sparkSession import spark.implicits._ - val df = Seq(1.0, 2.0, 3.0, 4.0).toDF("y") - val arima = new ArimaRegression().setP(1).setD(0).setQ(1) - val model = arima.fit(df) + val data = Seq( + (1, 100.0), + (2, 102.0), + (3, 101.0), + (4, 103.0), + (5, 104.0) + ).toDF("t", "y") - val transformed = model.transform(df) - assert(transformed.columns.contains("prediction")) - assert(transformed.count() == df.count()) + val arima = new ArimaRegression() + .setP(1) + .setD(1) + .setQ(1) + + val model = arima.fit(data) + val transformed = model.transform(data) + + assert(transformed.columns.contains("prediction"), "Output should include 'prediction' column.") + assert(transformed.count() == data.count(), "Output row count should match input.") + } + + test("ARIMA model schema validation and parameter setting") { + val arima = new ArimaRegression() + .setP(2) + .setD(1) + .setQ(1) + + assert(arima.getP == 2) + assert(arima.getD == 1) + assert(arima.getQ == 1) + + val schema = org.apache.spark.sql.types.StructType.fromDDL("y DOUBLE") + val outputSchema = arima.transformSchema(schema) + assert(outputSchema.fieldNames.contains("prediction")) + } + + test("ARIMA model copy and persistence") { + val spark = sparkSession + import spark.implicits._ + + val data = Seq( + (1, 10.0), + (2, 12.0), + (3, 11.0) + ).toDF("t", "y") + + val arima = new ArimaRegression().setP(1).setD(1).setQ(1) + val model = arima.fit(data) + + val copied = model.copy(org.apache.spark.ml.param.ParamMap.empty) + assert(copied.getP == model.getP) + assert(copied.getD == model.getD) + assert(copied.getQ == model.getQ) } -} + test("ARIMA model handles missing y column gracefully") { + val spark = sparkSession + import spark.implicits._ + val invalidDF = Seq((1, 2.0)).toDF("t", "value") + val arima = new ArimaRegression() + + intercept[IllegalArgumentException] { + arima.fit(invalidDF) + } + } +} diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 6f9c9db5d5d14..7305abbffbe1b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -23,6 +23,11 @@ from pyspark import keyword_only, since from pyspark.ml import Predictor, PredictionModel from pyspark.ml.base import _PredictorParams + +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import Param, Params, TypeConverters +from pyspark.ml.util import JavaMLReadable, JavaMLWritable + from pyspark.ml.param.shared import ( HasFeaturesCol, HasLabelCol, @@ -148,13 +153,86 @@ class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], metaclass=AB pass -class ArimaRegressionModel(JavaModel): +class _ArimaRegressionParams(Params): """ - Model fitted by :py:class:`ArimaRegression`. + Parameters for :py:class:`ArimaRegression` and :py:class:`ArimaRegressionModel`. + """ + + p = Param(Params._dummy(), "p", "AR order (number of autoregressive terms)", + typeConverter=TypeConverters.toInt) + d = Param(Params._dummy(), "d", "Differencing order", + typeConverter=TypeConverters.toInt) + q = Param(Params._dummy(), "q", "MA order (number of moving average terms)", + typeConverter=TypeConverters.toInt) + + def __init__(self): + super(_ArimaRegressionParams, self).__init__() + self._setDefault(p=1, d=0, q=1) - This model supports `.transform()` and optional `.predict()`. + def getP(self): + return self.getOrDefault(self.p) + + def getD(self): + return self.getOrDefault(self.d) + + def getQ(self): + return self.getOrDefault(self.q) + +class ArimaRegression( + JavaEstimator, + _ArimaRegressionParams, + JavaMLWritable, + JavaMLReadable["ArimaRegression"] +): """ - pass + ARIMA (AutoRegressive Integrated Moving Average) model for univariate time series forecasting. + """ + + @keyword_only + def __init__(self, *, p=1, d=0, q=1): + super(ArimaRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.ArimaRegression", self.uid + ) + self._setDefault(p=1, d=0, q=1) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, *, p=1, d=0, q=1): + kwargs = self._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return ArimaRegressionModel(java_model) + + def setP(self, value: int): + return self._set(p=value) + + def setD(self, value: int): + return self._set(d=value) + + def setQ(self, value: int): + return self._set(q=value) + + +class ArimaRegressionModel( + JavaModel, + _ArimaRegressionParams, + JavaMLWritable, + JavaMLReadable["ArimaRegressionModel"] +): + """ + Model fitted by :py:class:`ArimaRegression`. + """ + + @property + def coefficients(self): + return self._call_java("coefficients") + + @property + def order(self): + return (self.getP(), self.getD(), self.getQ()) class _LinearRegressionParams( _PredictorParams,