diff --git a/umap/umap_.py b/umap/umap_.py index b6890734..0a4a4fd1 100644 --- a/umap/umap_.py +++ b/umap/umap_.py @@ -371,6 +371,18 @@ def make_tree(data, indices, rng_state, leaf_size=30, angular=False): def get_leaves(tree): + """Return the set of leaf nodes of a random projection tree. + + Parameters + ---------- + tree: RandomProjectionTreeNode + The root node of the tree to get leaves of. + + Returns + ------- + leaves: list + A list of arrays of indices of points in each leaf node. + """ if tree.is_leaf: return [tree.indices] else: @@ -379,6 +391,27 @@ def get_leaves(tree): @numba.njit('f8[:, :, :](i8,i8)') def make_heap(n_points, size): + """Constructor for the numba enabled heap objects. The heaps are used + for approximate nearest neighbor search, maintaining a list of potential + neighbors sorted by their distance. We also flag if potential neighbors + are newly added to the list or not. Internally this is stored as + a single ndarray; the first axis determines whether we are looking at the + array of candidate indices, the array of distances, or the flag array for + whether elements are new or not. Each of these arrays are of shape + (``n_points``, ``size``) + + Parameters + ---------- + n_points: int + The number of data points to track in the heap. + + size: int + The number of items to keep on the heap for each data point. + + Returns + ------- + heap: An ndarray suitable for passing to other numba enabled heap functions. + """ result = np.zeros((3, n_points, size)) result[0] = -1 result[1] = np.infty @@ -389,6 +422,33 @@ def make_heap(n_points, size): @numba.jit('i8(f8[:,:,:],i8,f8,i8,i8)') def heap_push(heap, row, weight, index, flag): + """Push a new element onto the heap. The heap stores potential neighbors + for each data point. The ``row`` parameter determines which data point we + are addressing, the ``weight`` determines the distance (for heap sorting), + the ``index`` is the element to add, and the flag determines whether this + is to be considered a new addition. + + Parameters + ---------- + heap: ndarray generated by ``make_heap`` + The heap object to push into + + row: int + Which actual heap within the heap object to push to + + weight: float + The priority value of the element to push onto the heap + + index: int + The actual value to be pushed + + flag: int + Whether to flag the newly added element or not. + + Returns + ------- + success: The number of new elements successfully pushed into the heap. + """ indices = heap[0, row] weights = heap[1, row] is_new = heap[2, row] @@ -396,6 +456,7 @@ def heap_push(heap, row, weight, index, flag): if weight > weights[0]: return 0 + # break if we already have this element. for i in range(indices.shape[0]): if index == indices[i]: return 0