diff --git a/src/owl/dense/owl_dense_matrix_c.ml b/src/owl/dense/owl_dense_matrix_c.ml index cb6f9d098..b5d983f8e 100644 --- a/src/owl/dense/owl_dense_matrix_c.ml +++ b/src/owl/dense/owl_dense_matrix_c.ml @@ -47,6 +47,8 @@ let meshgrid xa xb ya yb xn yn = M.meshgrid Complex32 xa xb ya yb xn yn let bernoulli ?p d = M.bernoulli Complex32 ?p d +let unit_basis n i = M.unit_basis Complex32 n i + let hadamard n = M.hadamard Complex32 n let magic n = M.magic Complex32 n diff --git a/src/owl/dense/owl_dense_matrix_c.mli b/src/owl/dense/owl_dense_matrix_c.mli index 3d83407ac..98dc25fe4 100644 --- a/src/owl/dense/owl_dense_matrix_c.mli +++ b/src/owl/dense/owl_dense_matrix_c.mli @@ -58,6 +58,8 @@ val meshup : mat -> mat -> mat * mat val bernoulli : ?p:float -> int -> int -> mat +val unit_basis : int -> int -> mat + val diagm : ?k:int -> mat -> mat val triu : ?k:int -> mat -> mat diff --git a/src/owl/dense/owl_dense_matrix_d.ml b/src/owl/dense/owl_dense_matrix_d.ml index 55f135211..935d54dd7 100644 --- a/src/owl/dense/owl_dense_matrix_d.ml +++ b/src/owl/dense/owl_dense_matrix_d.ml @@ -49,6 +49,8 @@ let meshgrid xa xb ya yb xn yn = M.meshgrid Float64 xa xb ya yb xn yn let bernoulli ?p d = M.bernoulli Float64 ?p d +let unit_basis n i = M.unit_basis Float64 n i + let hadamard n = M.hadamard Float64 n let magic n = M.magic Float64 n diff --git a/src/owl/dense/owl_dense_matrix_d.mli b/src/owl/dense/owl_dense_matrix_d.mli index 4e23a944f..9fdd968d7 100644 --- a/src/owl/dense/owl_dense_matrix_d.mli +++ b/src/owl/dense/owl_dense_matrix_d.mli @@ -37,6 +37,8 @@ val semidef : int -> mat val bernoulli : ?p:float -> int -> int -> mat +val unit_basis : int -> int -> mat + val diagm : ?k:int -> mat -> mat val triu : ?k:int -> mat -> mat diff --git a/src/owl/dense/owl_dense_matrix_generic.ml b/src/owl/dense/owl_dense_matrix_generic.ml index 54c9af195..ed4cb60a3 100644 --- a/src/owl/dense/owl_dense_matrix_generic.ml +++ b/src/owl/dense/owl_dense_matrix_generic.ml @@ -198,6 +198,11 @@ let gaussian k ?mu ?sigma m n = Owl_dense_ndarray_generic.gaussian k ?mu ?sigma let bernoulli k ?p m n = Owl_dense_ndarray_generic.bernoulli k ?p [|m;n|] +let unit_basis k n i = + let x = Owl_dense_ndarray_generic.unit_basis k n i in + reshape x [|1; n|] + + let toeplitz ?c r = let c = match c with | Some c -> c diff --git a/src/owl/dense/owl_dense_matrix_generic.mli b/src/owl/dense/owl_dense_matrix_generic.mli index e7cb63cea..8d3f812c1 100644 --- a/src/owl/dense/owl_dense_matrix_generic.mli +++ b/src/owl/dense/owl_dense_matrix_generic.mli @@ -97,6 +97,11 @@ contains phase angles. Note that the behaviour is undefined if ``rho`` has negative elelments or ``theta`` has infinity elelments. *) +val unit_basis : ('a, 'b) kind -> int -> int -> ('a, 'b) t +(** +``unit_basis k n i`` returns a unit basis vector with ``i``th element set to 1. + *) + val sequential : ('a, 'b) kind -> ?a:'a -> ?step:'a -> int -> int -> ('a, 'b) t (** ``sequential ~a ~step m n`` creates an ``m`` by ``n`` matrix. The elements in ``x`` diff --git a/src/owl/dense/owl_dense_matrix_s.ml b/src/owl/dense/owl_dense_matrix_s.ml index a91934218..c7cf2569d 100644 --- a/src/owl/dense/owl_dense_matrix_s.ml +++ b/src/owl/dense/owl_dense_matrix_s.ml @@ -46,6 +46,8 @@ let meshgrid xa xb ya yb xn yn = M.meshgrid Float32 xa xb ya yb xn yn let bernoulli ?p d = M.bernoulli Float32 ?p d +let unit_basis n i = M.unit_basis Float32 n i + let hadamard n = M.hadamard Float32 n let magic n = M.magic Float32 n diff --git a/src/owl/dense/owl_dense_matrix_s.mli b/src/owl/dense/owl_dense_matrix_s.mli index 3f2a0e3fe..495a3b09d 100644 --- a/src/owl/dense/owl_dense_matrix_s.mli +++ b/src/owl/dense/owl_dense_matrix_s.mli @@ -37,6 +37,8 @@ val semidef : int -> mat val bernoulli : ?p:float -> int -> int -> mat +val unit_basis : int -> int -> mat + val diagm : ?k:int -> mat -> mat val triu : ?k:int -> mat -> mat diff --git a/src/owl/dense/owl_dense_matrix_z.ml b/src/owl/dense/owl_dense_matrix_z.ml index 8e979cba2..f2c4dd26c 100644 --- a/src/owl/dense/owl_dense_matrix_z.ml +++ b/src/owl/dense/owl_dense_matrix_z.ml @@ -47,6 +47,8 @@ let meshgrid xa xb ya yb xn yn = M.meshgrid Complex64 xa xb ya yb xn yn let bernoulli ?p d = M.bernoulli Complex64 ?p d +let unit_basis n i = M.unit_basis Complex64 n i + let hadamard n = M.hadamard Complex64 n let magic n = M.magic Complex64 n diff --git a/src/owl/dense/owl_dense_matrix_z.mli b/src/owl/dense/owl_dense_matrix_z.mli index 86bacd45e..b3b0cd402 100644 --- a/src/owl/dense/owl_dense_matrix_z.mli +++ b/src/owl/dense/owl_dense_matrix_z.mli @@ -58,6 +58,8 @@ val meshup : mat -> mat -> mat * mat val bernoulli : ?p:float -> int -> int -> mat +val unit_basis : int -> int -> mat + val diagm : ?k:int -> mat -> mat val triu : ?k:int -> mat -> mat diff --git a/src/owl/dense/owl_dense_ndarray_c.ml b/src/owl/dense/owl_dense_ndarray_c.ml index 821361717..5b14b0c38 100644 --- a/src/owl/dense/owl_dense_ndarray_c.ml +++ b/src/owl/dense/owl_dense_ndarray_c.ml @@ -41,6 +41,8 @@ let logspace ?base a b n = M.logspace Complex32 ?base a b n let bernoulli ?p d = M.bernoulli Complex32 ?p d +let unit_basis n i = M.unit_basis Complex32 n i + let load f = M.load Complex32 f let of_array x d = M.of_array Complex32 x d diff --git a/src/owl/dense/owl_dense_ndarray_c.mli b/src/owl/dense/owl_dense_ndarray_c.mli index b28f7585c..6e8e2153f 100644 --- a/src/owl/dense/owl_dense_ndarray_c.mli +++ b/src/owl/dense/owl_dense_ndarray_c.mli @@ -49,6 +49,8 @@ val complex : cast_arr -> cast_arr -> arr val polar : cast_arr -> cast_arr -> arr +val unit_basis : int -> int -> arr + (** {6 Obtain basic properties} *) diff --git a/src/owl/dense/owl_dense_ndarray_d.ml b/src/owl/dense/owl_dense_ndarray_d.ml index 1ed8f9016..329b09ad7 100644 --- a/src/owl/dense/owl_dense_ndarray_d.ml +++ b/src/owl/dense/owl_dense_ndarray_d.ml @@ -43,6 +43,8 @@ let logspace ?base a b n = M.logspace Float64 ?base a b n let bernoulli ?p d = M.bernoulli Float64 ?p d +let unit_basis n i = M.unit_basis Float64 n i + let load f = M.load Float64 f let of_array x d = M.of_array Float64 x d diff --git a/src/owl/dense/owl_dense_ndarray_d.mli b/src/owl/dense/owl_dense_ndarray_d.mli index bdfd2f009..7bf9ac96f 100644 --- a/src/owl/dense/owl_dense_ndarray_d.mli +++ b/src/owl/dense/owl_dense_ndarray_d.mli @@ -43,6 +43,8 @@ val logspace : ?base:float -> elt -> elt -> int -> arr val bernoulli : ?p:float -> int array -> arr +val unit_basis : int -> int -> arr + (** {6 Obtain basic properties} *) diff --git a/src/owl/dense/owl_dense_ndarray_generic.ml b/src/owl/dense/owl_dense_ndarray_generic.ml index 5d00380b1..d77948c24 100644 --- a/src/owl/dense/owl_dense_ndarray_generic.ml +++ b/src/owl/dense/owl_dense_ndarray_generic.ml @@ -1267,6 +1267,13 @@ let argsort x = y +let unit_basis k n i = + let x = zeros k [|n|] in + let a1 = Owl_const.one k in + Genarray.set x [|i|] a1; + x + + (* advanced operations *) let iteri f x = diff --git a/src/owl/dense/owl_dense_ndarray_generic.mli b/src/owl/dense/owl_dense_ndarray_generic.mli index 98c033e65..2969fd7e2 100644 --- a/src/owl/dense/owl_dense_ndarray_generic.mli +++ b/src/owl/dense/owl_dense_ndarray_generic.mli @@ -148,6 +148,11 @@ contains phase angles. Note that the behaviour is undefined if ``rho`` has negative elelments or ``theta`` has infinity elelments. *) +val unit_basis : ('a, 'b) kind -> int -> int -> ('a, 'b) t +(** +``unit_basis k n i`` returns a unit basis vector with ``i``th element set to 1. + *) + (** {6 Obtain basic properties} *) diff --git a/src/owl/dense/owl_dense_ndarray_s.ml b/src/owl/dense/owl_dense_ndarray_s.ml index 673f94b53..8d5e52248 100644 --- a/src/owl/dense/owl_dense_ndarray_s.ml +++ b/src/owl/dense/owl_dense_ndarray_s.ml @@ -40,6 +40,8 @@ let logspace ?base a b n = M.logspace Float32 ?base a b n let bernoulli ?p d = M.bernoulli Float32 ?p d +let unit_basis n i = M.unit_basis Float32 n i + let load f = M.load Float32 f let of_array x d = M.of_array Float32 x d diff --git a/src/owl/dense/owl_dense_ndarray_s.mli b/src/owl/dense/owl_dense_ndarray_s.mli index 06e3e2a3c..460d86215 100644 --- a/src/owl/dense/owl_dense_ndarray_s.mli +++ b/src/owl/dense/owl_dense_ndarray_s.mli @@ -43,6 +43,8 @@ val logspace : ?base:float -> elt -> elt -> int -> arr val bernoulli : ?p:float -> int array -> arr +val unit_basis : int -> int -> arr + (** {6 Obtain basic properties} *) diff --git a/src/owl/dense/owl_dense_ndarray_z.ml b/src/owl/dense/owl_dense_ndarray_z.ml index 0aadb8d4c..8bf24642b 100644 --- a/src/owl/dense/owl_dense_ndarray_z.ml +++ b/src/owl/dense/owl_dense_ndarray_z.ml @@ -41,6 +41,8 @@ let logspace ?base a b n = M.logspace Complex64 ?base a b n let bernoulli ?p d = M.bernoulli Complex64 ?p d +let unit_basis n i = M.unit_basis Complex64 n i + let load f = M.load Complex64 f let of_array x d = M.of_array Complex64 x d diff --git a/src/owl/dense/owl_dense_ndarray_z.mli b/src/owl/dense/owl_dense_ndarray_z.mli index 2e4de62e4..96b1b956a 100644 --- a/src/owl/dense/owl_dense_ndarray_z.mli +++ b/src/owl/dense/owl_dense_ndarray_z.mli @@ -49,6 +49,8 @@ val complex : cast_arr -> cast_arr -> arr val polar : cast_arr -> cast_arr -> arr +val unit_basis : int -> int -> arr + (** {6 Obtain basic properties} *) diff --git a/test/test_runner.ml b/test/test_runner.ml index adecfe91b..5b667891c 100644 --- a/test/test_runner.ml +++ b/test/test_runner.ml @@ -13,8 +13,9 @@ let () = "dense ndarray", Unit_dense_ndarray.test_set; "sparse matrix", Unit_sparse_matrix.test_set; "sparse ndarray", Unit_sparse_ndarray.test_set; - "ndarray primitive", Unit_ndarray_primitive.test_set; "ndarray core", Unit_ndarray_core.test_set; + "ndarray primitive", Unit_ndarray_primitive.test_set; + "ndarray operation", Unit_ndarray_operation.test_set; "lazy evaluation", Unit_lazy.test_set; "linear algebra", Unit_linalg.test_set; "slicing basic", Unit_slicing_basic.test_set; diff --git a/test/unit_ndarray_operation.ml b/test/unit_ndarray_operation.ml new file mode 100644 index 000000000..aa5e7a800 --- /dev/null +++ b/test/unit_ndarray_operation.ml @@ -0,0 +1,42 @@ +(** Unit test for Owl_maths module and special functions *) + +module M = Owl_dense_ndarray_s + + +(* a module with functions to test *) +module To_test = struct + + let fun00 () = + let x = M.unit_basis 10 0 in + M.get x [|0|] = 1. + + + let fun01 () = + let x = M.unit_basis 10 3 in + M.get x [|3|] = 1. + + + let fun02 () = + let x = M.unit_basis 10 9 in + M.get x [|9|] = 1. + +end + + +(* the tests *) + +let test_fun00 () = + Alcotest.(check bool) "basic operation 00" true (To_test.fun00 ()) + +let test_fun01 () = + Alcotest.(check bool) "basic operation 01" true (To_test.fun01 ()) + +let test_fun02 () = + Alcotest.(check bool) "basic operation 02" true (To_test.fun02 ()) + + +let test_set = [ + "test 00", `Slow, test_fun00; + "test 01", `Slow, test_fun01; + "test 02", `Slow, test_fun02; +]