Skip to content

Commit

Permalink
Merge pull request #11 from dleemiller/fix-cluster-dtype-mismatch
Browse files Browse the repository at this point in the history
fixing float<>double changes in cython, adding function test
  • Loading branch information
dleemiller authored Sep 15, 2024
2 parents afa0e3f + 6ebde82 commit d3b9a69
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
9 changes: 9 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import unittest
from wordllama import WordLlama


class TestFunctional(unittest.TestCase):

def test_function_clustering(self):
wl = WordLlama.load()
wl.cluster(["a", "b"], k=2)
3 changes: 2 additions & 1 deletion tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def setUp(self):
[0.9, 0.8, 0.7],
[0.4, 0.5, 0.6],
[0.5, 0.4, 0.7],
]
],
dtype=np.float32,
)

def test_kmeans_plusplus_initialization(self):
Expand Down
12 changes: 6 additions & 6 deletions wordllama/algorithms/kmeans_helpers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ from libc.math cimport sqrt

ctypedef np.npy_intp DTYPE_t

cdef inline double squared_euclidean_distance(const double[:] vec1, const double[:] vec2, Py_ssize_t dim) nogil:
cdef inline float squared_euclidean_distance(const float[:] vec1, const float[:] vec2, Py_ssize_t dim) nogil:
cdef Py_ssize_t i
cdef double dist = 0.0
cdef float dist = 0.0
for i in range(dim):
dist += (vec1[i] - vec2[i]) ** 2
return dist

def compute_distances(const double[:, :] embeddings, const double[:, :] centroids):
def compute_distances(const float[:, :] embeddings, const float[:, :] centroids):
cdef Py_ssize_t num_points = embeddings.shape[0]
cdef Py_ssize_t num_centroids = centroids.shape[0]
cdef Py_ssize_t dim = embeddings.shape[1]
cdef double[:, :] distances = np.empty((num_points, num_centroids), dtype=np.float64)
cdef float[:, :] distances = np.empty((num_points, num_centroids), dtype=np.float32)
cdef Py_ssize_t i, j

for i in range(num_points):
Expand All @@ -26,8 +26,8 @@ def compute_distances(const double[:, :] embeddings, const double[:, :] centroid

return np.asarray(distances)

def update_centroids(const double[:, :] embeddings, const DTYPE_t[:] labels, Py_ssize_t num_clusters, Py_ssize_t dim):
cdef double[:, :] new_centroids = np.zeros((num_clusters, dim), dtype=np.float64)
def update_centroids(const float[:, :] embeddings, const DTYPE_t[:] labels, Py_ssize_t num_clusters, Py_ssize_t dim):
cdef float[:, :] new_centroids = np.zeros((num_clusters, dim), dtype=np.float32)
cdef DTYPE_t[:] count = np.zeros(num_clusters, dtype=np.intp)
cdef Py_ssize_t i, j
cdef DTYPE_t label
Expand Down

0 comments on commit d3b9a69

Please sign in to comment.