From ddfae0ef5eda315346be700b20390025662a3782 Mon Sep 17 00:00:00 2001 From: Ewan Date: Tue, 10 Dec 2024 11:41:52 +0800 Subject: [PATCH 1/3] [xgb] Fix xgb intern to close replaced array --- .../java/ai/djl/ml/xgboost/XgbNDArray.java | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java index 69eb6914b0e..62f036469e9 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java @@ -88,14 +88,35 @@ public ByteBuffer toByteBuffer(boolean tryDirect) { /** {@inheritDoc} */ @Override public void intern(NDArray replaced) { - if (handle != null && handle.get() != 0L) { - long pointer = handle.getAndSet(0L); - JniUtils.deleteDMatrix(pointer); + if (replaced == null) { + throw new IllegalArgumentException("The replaced NDArray cannot be null."); + } + if (!(replaced instanceof XgbNDArray)) { + throw new IllegalArgumentException("The replaced NDArray must be an instance of XgbNDArray."); } XgbNDArray array = (XgbNDArray) replaced; - data = array.data; - handle = array.handle; - format = array.format; + + synchronized (this) { + if (handle != null && handle.get() != 0L) { + long pointer = handle.getAndSet(0L); + JniUtils.deleteDMatrix(pointer); + } + + data = array.data; + format = array.format; + + if (array.handle != null && array.handle.get() != 0L) { + if (handle == null) { + handle = new AtomicLong(); + } + handle.set(array.handle.getAndSet(0L)); + } + } + + array.data = null; + array.handle = null; + array.format = null; + array.close(); } /** {@inheritDoc} */ From 4b9d4aeeec1f840d1dc277d1d8c331bd6b9e91db Mon Sep 17 00:00:00 2001 From: Ewan Date: Tue, 10 Dec 2024 14:50:53 +0800 Subject: [PATCH 2/3] [xgb] Fix xgb intern to close replaced array --- .../ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java index 62f036469e9..8282f2c5c60 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java @@ -92,7 +92,8 @@ public void intern(NDArray replaced) { throw new IllegalArgumentException("The replaced NDArray cannot be null."); } if (!(replaced instanceof XgbNDArray)) { - throw new IllegalArgumentException("The replaced NDArray must be an instance of XgbNDArray."); + throw new IllegalArgumentException( + "The replaced NDArray must be an instance of XgbNDArray."); } XgbNDArray array = (XgbNDArray) replaced; From 5ce5c1db1abbce3e5a3825bd85a0d83e2f1d47b1 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 3 Jan 2025 10:17:36 -0800 Subject: [PATCH 3/3] Review changes --- .../java/ai/djl/ndarray/NDArrayAdapter.java | 2 +- .../java/ai/djl/ml/xgboost/XgbNDArray.java | 35 ++++++++----------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index cceb5efd494..b27cd0ac1d6 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -40,7 +40,7 @@ public abstract class NDArrayAdapter implements NDArray { protected NDManager manager; protected NDManager alternativeManager; - private NDArray alternativeArray; + protected NDArray alternativeArray; protected Shape shape; protected DataType dataType; diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java index 8282f2c5c60..2d03161fcc4 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java @@ -88,35 +88,30 @@ public ByteBuffer toByteBuffer(boolean tryDirect) { /** {@inheritDoc} */ @Override public void intern(NDArray replaced) { - if (replaced == null) { - throw new IllegalArgumentException("The replaced NDArray cannot be null."); - } if (!(replaced instanceof XgbNDArray)) { throw new IllegalArgumentException( "The replaced NDArray must be an instance of XgbNDArray."); } XgbNDArray array = (XgbNDArray) replaced; + if (isReleased()) { + throw new IllegalArgumentException("This array is already closed"); + } + if (replaced.isReleased()) { + throw new IllegalArgumentException("This target array is already closed"); + } - synchronized (this) { - if (handle != null && handle.get() != 0L) { - long pointer = handle.getAndSet(0L); - JniUtils.deleteDMatrix(pointer); - } - - data = array.data; - format = array.format; - - if (array.handle != null && array.handle.get() != 0L) { - if (handle == null) { - handle = new AtomicLong(); - } - handle.set(array.handle.getAndSet(0L)); - } + long pointer = handle.getAndSet(0L); + JniUtils.deleteDMatrix(pointer); + if (alternativeArray != null) { + alternativeArray.close(); } - array.data = null; + data = array.data; + handle = array.handle; + format = array.format; + alternativeArray = array.alternativeArray; array.handle = null; - array.format = null; + array.alternativeArray = null; array.close(); }