From bccad00c5c08bc467f317117ed9605358d2b2e36 Mon Sep 17 00:00:00 2001 From: Mandimby RAVELOARINJAKA Date: Fri, 11 Oct 2024 17:09:04 -0400 Subject: [PATCH] [FIX] (neanderthal): implement PECount for NativeBlock (#117) --- neanderthal/tech/v3/libs/neanderthal_test.clj | 18 ++++++++---------- src/tech/v3/libs/neanderthal_post_48.clj | 4 +++- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/neanderthal/tech/v3/libs/neanderthal_test.clj b/neanderthal/tech/v3/libs/neanderthal_test.clj index 153d7b2b..555acacf 100644 --- a/neanderthal/tech/v3/libs/neanderthal_test.clj +++ b/neanderthal/tech/v3/libs/neanderthal_test.clj @@ -1,13 +1,11 @@ (ns tech.v3.libs.neanderthal-test - (:require [uncomplicate.neanderthal.core :as n-core] - [uncomplicate.neanderthal.native :as n-native] - [tech.v3.tensor :as dtt] + (:require [clojure.test :refer [deftest is]] + [tech.v3.datatype.argops :as dtype-ops] [tech.v3.datatype.functional :as dfn] - [tech.v3.datatype :as dtype] - [clojure.test :refer [deftest is]] [tech.v3.libs.neanderthal] - [tech.v3.datatype])) - + [tech.v3.tensor :as dtt] + [uncomplicate.neanderthal.core :as n-core] + [uncomplicate.neanderthal.native :as n-native])) (deftest basic-neanderthal-test (let [a (n-native/dge 3 3 (range 9))] @@ -18,7 +16,6 @@ (is (dfn/equals (dtt/ensure-tensor second-row) [1 4 7]))))) - (deftest basic-neanderthal-test-row-major (let [b (n-native/dge 3 3 (range 9) {:layout :row})] (is (dfn/equals (dtt/ensure-tensor b) @@ -27,7 +24,6 @@ (is (dfn/equals (dtt/ensure-tensor second-row) [3 4 5]))))) - (deftest single-col-row-matrix (is (dfn/equals (dtt/ensure-tensor (n-native/dge 1 3 (range 3) {:layout :row})) (dtt/->tensor (range 3)))) @@ -35,9 +31,11 @@ (is (dfn/equals (dtt/ensure-tensor (n-native/dge 1 3 (range 3) {:layout :column})) (dtt/->tensor (range 3)))) - (is (dfn/equals (dtt/ensure-tensor (n-native/dge 3 1 (range 3) {:layout :row})) (dtt/->tensor (range 3)))) (is (dfn/equals (dtt/ensure-tensor (n-native/dge 3 1 (range 3) {:layout :column})) (dtt/->tensor (range 3))))) + +(deftest argsort-supports-native + (is (= [2 1 0] (dtype-ops/argsort (n-native/dv [3 2 1]))))) diff --git a/src/tech/v3/libs/neanderthal_post_48.clj b/src/tech/v3/libs/neanderthal_post_48.clj index 6a7a8d2f..0648f7dd 100644 --- a/src/tech/v3/libs/neanderthal_post_48.clj +++ b/src/tech/v3/libs/neanderthal_post_48.clj @@ -73,4 +73,6 @@ (as-tensor [item] (when (dtype-proto/convertible-to-nd-buffer-desc? item) (-> (dtype-proto/->nd-buffer-descriptor item) - (dtt/nd-buffer-descriptor->tensor))))) + (dtt/nd-buffer-descriptor->tensor)))) + dtype-proto/PECount + (ecount [item] (:dim (n-core/info item))))