-
Notifications
You must be signed in to change notification settings - Fork 0
/
knnring_sequential.c
138 lines (111 loc) · 3.83 KB
/
knnring_sequential.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/*
* knnring_sequential.c
*
* Created on: Nov 21, 2019
* Author: Lambis
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <cblas.h>
#include "knnring.h"
knnresult kNN(double * X, double * Y, int n, int m, int d, int k) {
knnresult knn;
// Allocating memory for the knnresult.
knn.nidx = (int *)malloc(m*k*sizeof(int));
knn.ndist = (double *)malloc(m*k*sizeof(double));
// Passing m and k values to knnresult.
knn.m = m;
knn.k = k;
// Allocating memory for the distances array.
double *D = (double *)malloc(m*n*sizeof(double));
// Calculation of sum(X.^2,2).
double *a = (double *)malloc(n*sizeof(double));
for (int i=0; i<n; i++) {
a[i] = cblas_dnrm2(d, &X[i*d], 1);
a[i] = a[i]*a[i];
}
// Calculation of sum(Y.^2,2).
double *b = (double *)malloc(m*sizeof(double));
for (int i=0; i<m; i++) {
b[i] = cblas_dnrm2(d, &Y[i*d], 1);
b[i] = b[i]*b[i];
}
// Calculation of -2*X*Y.' multiplication.
double *c = (double *)malloc(n*m*sizeof(double));
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, n, m, d, -2, X, d, Y, d, 0, c, m);
// Adding sum(X.^2,2) sum(Y.^2,2).' to the c array.
for (int i=0; i<n; i++)
for(int j=0; j<m; j++)
c[i*m+j] += a[i] + b[j];
// Cleanup.
free(a);
free(b);
// Applying elementwise square root to the c array.
for (int i=0; i<n; i++)
for(int j=0; j<m; j++)
c[i*m+j] = sqrt(fabs(c[i*m+j]));
// Creates a nxn identity array.
double *iA = (double *)calloc(n*n, sizeof(double));
for (int i=0; i<n; i++)
iA[i*n+i] = 1;
// Transposing c array and storing it in D.
cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, n, 1, c, m, iA, n, 0, D, n);
// More efficient than the multiplication above. This can only be used with openblas.
//cblas_domatcopy(CblasRowMajor, CblasTrans, n, m, 1, c, m, D, n);
// Cleanup.
free(iA);
free(c);
// Setup of idx number for corpus points.
int *idx = (int *)malloc(m*n*sizeof(int));
for (int i=0; i<m; i++)
for (int j=0; j<n; j++)
idx[i*n+j]= j;
// Sorting of all distances.
for (int i=0; i<m; i++)
quickSort(D, i*n, ((i+1)*n)-1, idx);
// Passing the k nearest neighbors' indexes and distances to knnresult.
for (int i=0; i<m; i++) {
for (int j=0; j<k; j++) {
knn.nidx[i*k+j] = idx[i*n+j];
knn.ndist[i*k+j] = D[i*n+j];
}
}
return knn;
}
// A utility function to swap two elements.
void swap(void* a, void* b, size_t s) {
void* tmp = malloc(s);
memcpy(tmp, a, s);
memcpy(a, b, s);
memcpy(b, tmp, s);
free(tmp);
}
/* This function takes last element as pivot, places the pivot element at its correct position in sorted array
* and places all smaller (smaller than pivot) to left of pivot and all greater elements to right of pivot */
int partition(double *arr, int low, int high, int *idx) {
double pivot = arr[high]; // Pivot
int i = (low - 1); // Index of smaller element
for (int j=low; j<=high-1; j++) {
// If current element is smaller than the pivot
if (arr[j] < pivot) {
i++; // Increment index of smaller element
swap(&arr[i], &arr[j], sizeof(double));
swap(&idx[i], &idx[j], sizeof(int));
}
}
swap(&arr[i + 1], &arr[high], sizeof(double));
swap(&idx[i + 1], &idx[high], sizeof(int));
return (i + 1);
}
// The main function that implements QuickSort arr[] --> Array to be sorted, low --> Starting index, high --> Ending index
void quickSort(double *arr, int low, int high, int *idx) {
if (low < high) {
// pi is partitioning index, arr[pi] is now at right place
int pi = partition(arr, low, high, idx);
// Separately sort elements before partition and after partition
quickSort(arr, low, pi - 1, idx);
quickSort(arr, pi + 1, high, idx);
}
}