forked from duckdb/duckdb
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from zeroRains/v1.1
IMBridge的DuckDB实现 V1.1
- Loading branch information
Showing
50 changed files
with
1,149 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
|
||
#include "duckdb.hpp" | ||
|
||
#include <iostream> | ||
#include <string> | ||
|
||
using namespace duckdb; | ||
|
||
template <typename TYPE, int NUM_INPUT> | ||
static void udf_tmp(DataChunk &input, ExpressionState &state, Vector &result) { | ||
result.SetVectorType(VectorType::FLAT_VECTOR); | ||
auto result_data = ConstantVector::GetData<TYPE>(result); | ||
input.Flatten(); | ||
auto tmp_data1 = ConstantVector::GetData<TYPE>(input.data[0]); | ||
auto tmp_data2 = ConstantVector::GetData<TYPE>(input.data[1]); | ||
memset(result_data, std::numeric_limits<TYPE>::min(), input.size() * sizeof(TYPE)); | ||
for (idx_t i = 0; i < input.size(); i++) { | ||
result_data[i] = 1 * tmp_data1[i] + 0 * tmp_data2[i]; | ||
} | ||
} | ||
|
||
void create_data(Connection con, int n = 10000) { | ||
std::stringstream ss; | ||
ss << "INSERT INTO data VALUES (1, 10)"; | ||
for (int i = 2; i <= n; i++) { | ||
ss << ", ("; | ||
ss << i; | ||
ss << ", "; | ||
ss << i * 10; | ||
ss << ")"; | ||
} | ||
con.Query(ss.str()); | ||
printf("Finish create!\n"); | ||
} | ||
|
||
int main() { | ||
DuckDB db(nullptr); | ||
Connection con(db); | ||
con.Query("SET threads = 1"); | ||
con.Query("CREATE TABLE data (i INTEGER, age INTEGER)"); | ||
create_data(con); | ||
// con.Query("SELECT * FROM data LIMIT 10")->Print(); | ||
con.CreateVectorizedFunction<int, int, int>("udf_vectorized_int", &udf_tmp<int, 2>); | ||
clock_t start_time=clock(); | ||
con.Query("SELECT i, udf_vectorized_int(i, age) FROM data WHERE udf_vectorized_int(i, age)%2==0")->Print(); | ||
clock_t end_time=clock(); | ||
printf("finished execute %lf s!\n",(double)(end_time - start_time) / CLOCKS_PER_SEC); | ||
// con.Query("SELECT i FROM data WHERE i%2==0")->Print(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#include "duckdb.hpp" | ||
|
||
#include <iostream> | ||
#include <random> | ||
#include <string> | ||
|
||
using namespace duckdb; | ||
|
||
template <typename TYPE, int NUM_INPUT> | ||
static void udf_tmp(DataChunk &input, ExpressionState &state, Vector &result) { | ||
result.SetVectorType(VectorType::FLAT_VECTOR); | ||
auto result_data = ConstantVector::GetData<int>(result); | ||
input.Flatten(); | ||
auto tmp_data1 = ConstantVector::GetData<TYPE>(input.data[0]); | ||
auto tmp_data2 = ConstantVector::GetData<TYPE>(input.data[1]); | ||
auto tmp_data3 = ConstantVector::GetData<TYPE>(input.data[2]); | ||
memset(result_data, std::numeric_limits<TYPE>::min(), input.size() * sizeof(TYPE)); | ||
for (idx_t i = 0; i < input.size(); i++) { | ||
result_data[i] = tmp_data1[i]; | ||
} | ||
} | ||
|
||
void create_feature_table(Connection con, int n = 100) { | ||
con.Query("CREATE TABLE feature (i INTEGER, f1 FLOAT, f2 FLOAT, f3 FLOAT, label INT)"); | ||
|
||
std::random_device rd; // 使用随机设备作为种子 | ||
std::mt19937 gen(rd()); // 使用 Mersenne Twister 引擎 | ||
std::gamma_distribution<float> dis(2, 10); | ||
|
||
std::stringstream ss; | ||
ss << "INSERT INTO feature VALUES (1, 3.4, 5.3, 9.3, 3)"; | ||
for (int i = 2; i <= n; i++) { | ||
float x1 = dis(gen); | ||
float x2 = dis(gen); | ||
float x3 = dis(gen); | ||
int label = static_cast<int>(3 * x1 + 3.5 * x2 + 4 * x3) % 4; | ||
ss << ", ("; | ||
ss << i; | ||
ss << ", "; | ||
ss << x1; | ||
ss << ", "; | ||
ss << x2; | ||
ss << ", "; | ||
ss << x3; | ||
ss << ", "; | ||
ss << label; | ||
ss << ")"; | ||
} | ||
con.Query(ss.str()); | ||
printf("create feature finished!\n"); | ||
} | ||
|
||
void create_label_table(Connection con) { | ||
con.Query("CREATE TABLE label (id INTEGER, name VARCHAR(20))"); | ||
std::stringstream ss; | ||
ss << "INSERT INTO label VALUES (0, 'yellow'), (1, 'red'), (2, 'black'), (3, 'white')"; | ||
con.Query(ss.str()); | ||
printf("create label finished!\n"); | ||
} | ||
|
||
void create_color_table(Connection con) { | ||
con.Query("CREATE TABLE color (id INTEGER, r INTEGER, g INTEGER, b INTEGER)"); | ||
std::stringstream ss; | ||
ss << "INSERT INTO color VALUES (0, 1, 2, 3), (1, 4, 5, 6), (2, 7, 8, 9), (3, 10, 11, 12)"; | ||
con.Query(ss.str()); | ||
printf("create label finished!\n"); | ||
} | ||
|
||
int main() { | ||
DuckDB db("/root/db/duckdb_test/feature_label_color.db"); | ||
Connection con(db); | ||
// string sql = "SELECT i, udf_vectorized_int(f1, f2, f3) as predict, feature.label as label, label.name as class " | ||
// "FROM feature JOIN label ON " | ||
// "udf_vectorized_int(f1, f2, f3) == label.id WHERE udf_vectorized_int(f1, f2, f3)%2==1"; | ||
string sql = "SELECT udf_vectorized_int(sum(f1), count(f2), mean(f3)) as predict FROM feature GROUP BY i"; | ||
con.Query("SET threads = 1;"); | ||
// create_label_table(con); | ||
// create_color_table(con); | ||
// create_feature_table(con, 13000); | ||
con.CreateVectorizedFunction<double, double, int64_t, double>("udf_vectorized_int", &udf_tmp<double, 3>); | ||
con.Query(sql)->Print(); | ||
// con.Query("SELECT i FROM data WHERE i%2==0")->Print(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include "duckdb.hpp" | ||
|
||
#include <iostream> | ||
|
||
using namespace duckdb; | ||
|
||
template <typename TYPE, int NUM_INPUT> | ||
static void udf_tmp(DataChunk &input, ExpressionState &state, Vector &result) { | ||
std::cout << input.size() << std::endl; | ||
result.SetVectorType(VectorType::FLAT_VECTOR); | ||
auto result_data = ConstantVector::GetData<int>(result); | ||
input.Flatten(); | ||
auto tmp_data1 = ConstantVector::GetData<TYPE>(input.data[0]); | ||
auto tmp_data2 = ConstantVector::GetData<TYPE>(input.data[1]); | ||
// auto tmp_data3 = ConstantVector::GetData<TYPE>(input.data[2]); | ||
// auto tmp_data4 = ConstantVector::GetData<TYPE>(input.data[3]); | ||
// auto tmp_data5 = ConstantVector::GetData<TYPE>(input.data[4]); | ||
// auto tmp_data6 = ConstantVector::GetData<TYPE>(input.data[5]); | ||
// auto tmp_data7 = ConstantVector::GetData<TYPE>(input.data[6]); | ||
memset(result_data, std::numeric_limits<TYPE>::min(), input.size() * sizeof(TYPE)); | ||
for (idx_t i = 0; i < input.size(); i++) { | ||
result_data[i] = tmp_data1[i] + tmp_data2[i]; | ||
} | ||
} | ||
|
||
int main() { | ||
DuckDB db("/root/duckdb_test/imbridge.db"); | ||
|
||
Connection con(db); | ||
con.Query("SET threads TO 1;"); | ||
con.CreateVectorizedFunction<double, int64_t, int64_t>("udf", &udf_tmp<int64_t, 2>); | ||
auto result = con.Query( | ||
"explain analyze select userID, productID, r, score from (select userID, productID, score, rank() OVER " | ||
"(PARTITION BY userID ORDER BY score) as r from (select userID, productID, udf(userID, productID) score from " | ||
"(select userID, productID from Product_Rating group by userID, productID))) where r <=10;"); | ||
// auto result = con.Query("select count(serial_number) from Failures;"); | ||
result->Print(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include "duckdb.hpp" | ||
|
||
#include <iostream> | ||
|
||
using namespace duckdb; | ||
|
||
template <typename TYPE, int NUM_INPUT> | ||
static void udf_tmp(DataChunk &input, ExpressionState &state, Vector &result) { | ||
// std::cout << input.size() << std::endl; | ||
result.SetVectorType(VectorType::FLAT_VECTOR); | ||
auto result_data = ConstantVector::GetData<int>(result); | ||
input.Flatten(); | ||
auto tmp_data1 = ConstantVector::GetData<TYPE>(input.data[0]); | ||
auto tmp_data2 = ConstantVector::GetData<TYPE>(input.data[1]); | ||
memset(result_data, std::numeric_limits<TYPE>::min(), input.size() * sizeof(TYPE)); | ||
for (idx_t i = 0; i < input.size(); i++) { | ||
result_data[i] = tmp_data1[i] + tmp_data2[i]; | ||
} | ||
std::cout << input.size() << std::endl; | ||
} | ||
|
||
int main() { | ||
DuckDB db("/root/db/duckdb_test/imbridge2.db"); | ||
|
||
Connection con(db); | ||
con.Query("SET threads TO 1;"); | ||
con.CreateVectorizedFunction<double, double, double, double, double, double, double, double, double, double, double, | ||
double, double, double, double, double, double, double, double, double, double, double, | ||
double, double, double, double, double, double, double, double>("udf", | ||
&udf_tmp<double, 29>); | ||
auto result = con.Query( | ||
"Explain analyze SELECT Time, Amount, udf(V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15," | ||
"V16, V17, V18, V19, V20, V21, V22, V23, V24, V25, V26, V27, V28, Amount) as class FROM Credit_Card_extension " | ||
"WHERE V1 > 1 AND V2 < 0.27 AND V3 > 0.3;"); | ||
// auto result = con.Query("select count(serial_number) from Failures;"); | ||
result->Print(); | ||
} |
Oops, something went wrong.