diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e93db6ef9..f27705d36 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,7 @@ jobs: sudo python3 --version sudo python3 -m pip install dataclasses sudo python3 -m pip install setuptools + sudo python3 -m pip install --upgrade protobuf==3.20.0 sudo python3 -m pip install -U git+https://github.com/tensorflow/docs find docs -name '*.ipynb' | xargs python3 -m tensorflow_docs.tools.nbfmt echo "Check for failed fmt: " diff --git a/WORKSPACE b/WORKSPACE index 5eb4e72dc..f567a65cf 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -158,11 +158,12 @@ http_archive( http_archive( name = "arrow", build_file = "//third_party:arrow.BUILD", - sha256 = "57e13c62f27b710e1de54fd30faed612aefa22aa41fa2c0c3bacd204dd18a8f3", - strip_prefix = "arrow-apache-arrow-7.0.0", + patch_cmds = ["""sed -i.bak '24i\\'$'\\n#undef ARROW_WITH_OPENTELEMETRY\\n' cpp/src/arrow/util/tracing_internal.h"""], + sha256 = "19ece12de48e51ce4287d2dee00dc358fbc5ff02f41629d16076f77b8579e272", + strip_prefix = "arrow-apache-arrow-8.0.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz", - "https://github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-8.0.0.tar.gz", + "https://github.com/apache/arrow/archive/apache-arrow-8.0.0.tar.gz", ], ) diff --git a/tensorflow_io/arrow.py b/tensorflow_io/arrow.py index 44de3253c..e6265af2a 100644 --- a/tensorflow_io/arrow.py +++ b/tensorflow_io/arrow.py @@ -17,6 +17,7 @@ @@ArrowDataset @@ArrowFeatherDataset @@ArrowStreamDataset +@@ArrowS3Dataset @@list_feather_columns """ @@ -26,6 +27,7 @@ from tensorflow_io.python.ops.arrow_dataset_ops import ArrowDataset from tensorflow_io.python.ops.arrow_dataset_ops import ArrowFeatherDataset from tensorflow_io.python.ops.arrow_dataset_ops import ArrowStreamDataset +from tensorflow_io.python.ops.arrow_dataset_ops import ArrowS3Dataset from tensorflow_io.python.ops.arrow_dataset_ops import list_feather_columns @@ -33,6 +35,7 @@ "ArrowDataset", "ArrowFeatherDataset", "ArrowStreamDataset", + "ArrowS3Dataset", "list_feather_columns", ] diff --git a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc index 7716391a9..e425567f0 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "arrow/api.h" #include "arrow/io/stdio.h" #include "arrow/ipc/api.h" #include "arrow/result.h" +#include "parquet/arrow/reader.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/public/version.h" #include "tensorflow_io/core/kernels/arrow/arrow_kernels.h" #include "tensorflow_io/core/kernels/arrow/arrow_stream_client.h" @@ -101,7 +105,6 @@ class ArrowDatasetBase : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - // If in initial state, setup and read first batch if (current_batch_ == nullptr && current_row_idx_ == 0) { TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); @@ -163,6 +166,7 @@ class ArrowDatasetBase : public DatasetBase { } // Assign Tensors for each column in the current row + result_tensors->reserve(this->dataset()->columns_.size()); for (size_t i = 0; i < this->dataset()->columns_.size(); ++i) { int32 col = this->dataset()->columns_[i]; DataType output_type = this->dataset()->output_types_[i]; @@ -173,11 +177,17 @@ class ArrowDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(ArrowUtil::AssignShape( arr, current_row_idx_, batch_size, &output_shape)); + if (output_shape.dims() == 1) { + auto&& output_shape_in = this->dataset()->output_shapes_[i]; + if (output_shape_in.dim_size(output_shape_in.dims() - 1) == 1) { + output_shape.AddDim(1); + } + } + // Allocate a new tensor and assign Arrow data to it Tensor tensor(ctx->allocator({}), output_type, output_shape); TF_RETURN_IF_ERROR( ArrowUtil::AssignTensor(arr, current_row_idx_, &tensor)); - result_tensors->emplace_back(std::move(tensor)); } @@ -757,8 +767,6 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase { std::shared_ptr<::arrow::Table> table; CHECK_ARROW(reader->Read(&table)); - int64_t num_columns = table->num_columns(); - // Convert the table to a sequence of batches arrow::TableBatchReader tr(*table.get()); std::shared_ptr batch; @@ -937,6 +945,308 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase { }; }; +class ArrowS3DatasetOp : public ArrowOpKernelBase { + public: + explicit ArrowS3DatasetOp(OpKernelConstruction* ctx) + : ArrowOpKernelBase(ctx) {} + + virtual void MakeArrowDataset( + OpKernelContext* ctx, const std::vector& columns, + const int64 batch_size, const ArrowBatchMode batch_mode, + const DataTypeVector& output_types, + const std::vector& output_shapes, + ArrowDatasetBase** output) override { + tstring aws_access_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "aws_access_key", + &aws_access_key)); + + tstring aws_secret_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "aws_secret_key", + &aws_secret_key)); + + tstring aws_endpoint_override; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "aws_endpoint_override", + &aws_endpoint_override)); + + const Tensor* parquet_files_tensor; + OP_REQUIRES_OK(ctx, ctx->input("parquet_files", &parquet_files_tensor)); + OP_REQUIRES( + ctx, parquet_files_tensor->dims() <= 1, + errors::InvalidArgument("`parquet_files` must be a scalar or vector.")); + std::vector parquet_files; + parquet_files.reserve(parquet_files_tensor->NumElements()); + for (int i = 0; i < parquet_files_tensor->NumElements(); ++i) { + parquet_files.push_back(parquet_files_tensor->flat()(i)); + } + + const Tensor* column_names_tensor; + OP_REQUIRES_OK(ctx, ctx->input("column_names", &column_names_tensor)); + OP_REQUIRES( + ctx, column_names_tensor->dims() <= 1, + errors::InvalidArgument("`column_names` must be a scalar or vector.")); + std::vector column_names; + column_names.reserve(column_names_tensor->NumElements()); + for (int i = 0; i < column_names_tensor->NumElements(); ++i) { + column_names.push_back(column_names_tensor->flat()(i)); + } + + std::vector column_cols(column_names.size()); + std::iota(column_cols.begin(), column_cols.end(), 0); + + tstring filter; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "filter", &filter)); + + *output = + new Dataset(ctx, aws_access_key, aws_secret_key, aws_endpoint_override, + parquet_files, column_names, filter, column_cols, + batch_size, batch_mode, output_types_, output_shapes_); + } + + private: + class Dataset : public ArrowDatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::string& aws_access_key, + const std::string& aws_secret_key, + const std::string& aws_endpoint_override, + const std::vector& parquet_files, + const std::vector& column_names, + const std::string& filter, const std::vector columns, + const int64 batch_size, const ArrowBatchMode batch_mode, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : ArrowDatasetBase(ctx, columns, batch_size, batch_mode, output_types, + output_shapes), + aws_access_key_(aws_access_key), + aws_secret_key_(aws_secret_key), + aws_endpoint_override_(aws_endpoint_override), + parquet_files_(parquet_files), + column_names_(column_names), + filter_(filter) {} + + string DebugString() const override { return "ArrowS3DatasetOp::Dataset"; } + Status InputDatasets(std::vector* inputs) const { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* aws_access_key = nullptr; + tstring access_key = aws_access_key_; + TF_RETURN_IF_ERROR(b->AddScalar(access_key, &aws_access_key)); + Node* aws_secret_key = nullptr; + tstring secret_key = aws_secret_key_; + TF_RETURN_IF_ERROR(b->AddScalar(secret_key, &aws_secret_key)); + Node* aws_endpoint_override = nullptr; + tstring endpoint_override = aws_endpoint_override_; + TF_RETURN_IF_ERROR( + b->AddScalar(endpoint_override, &aws_endpoint_override)); + Node* parquet_files = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(parquet_files_, &parquet_files)); + Node* column_names = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(column_names_, &column_names)); + Node* columns = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(columns_, &columns)); + Node* filter = nullptr; + tstring filter_str = filter_; + TF_RETURN_IF_ERROR(b->AddScalar(filter_str, &filter)); + Node* batch_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + Node* batch_mode = nullptr; + tstring batch_mode_str; + TF_RETURN_IF_ERROR(GetBatchModeStr(batch_mode_, &batch_mode_str)); + TF_RETURN_IF_ERROR(b->AddScalar(batch_mode_str, &batch_mode)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {aws_access_key, aws_secret_key, aws_endpoint_override, parquet_files, + column_names, filter, columns, batch_size, batch_mode}, + output)); + return Status::OK(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::ArrowS3")})); + } + + private: + class Iterator : public ArrowBaseIterator { + public: + explicit Iterator(const Params& params) + : ArrowBaseIterator(params) {} + + private: + Status SetupStreamsLocked(Env* env) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + if (!s3fs_) { + arrow::fs::EnsureS3Initialized(); + auto s3Options = arrow::fs::S3Options::FromAccessKey( + dataset()->aws_access_key_, dataset()->aws_secret_key_); + s3Options.endpoint_override = dataset()->aws_endpoint_override_; + s3fs_ = arrow::fs::S3FileSystem::Make(s3Options).ValueOrDie(); + } + TF_RETURN_IF_ERROR(ReadFile(current_file_idx_)); + if (!background_worker_) { + background_worker_ = + std::make_shared(env, "download_next_worker"); + } + + if (current_batch_idx_ < record_batches_.size()) { + current_batch_ = record_batches_[current_batch_idx_]; + } + + if (current_file_idx_ + 1 < dataset()->parquet_files_.size()) { + background_worker_->Schedule(std::bind(&Iterator::ReadFile, this, + current_file_idx_ + 1, true)); + } + return Status::OK(); + } + + Status NextStreamLocked(Env* env) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + ArrowBaseIterator::NextStreamLocked(env); + if (++current_batch_idx_ < record_batches_.size()) { + current_batch_ = record_batches_[current_batch_idx_]; + } else if (++current_file_idx_ < dataset()->parquet_files_.size()) { + current_batch_idx_ = 0; + + { + mutex_lock lk(cv_mu_); + while (!background_thread_finished_) { + cv_.wait(lk); + } + } + + record_batches_.swap(next_record_batches_); + if (!record_batches_.empty()) { + current_batch_ = record_batches_[current_batch_idx_]; + } else { + current_batch_ = nullptr; + } + background_thread_finished_ = false; + if (current_file_idx_ + 1 < dataset()->parquet_files_.size()) { + background_worker_->Schedule(std::bind( + &Iterator::ReadFile, this, current_file_idx_ + 1, true)); + } + } + return Status::OK(); + } + + void ResetStreamsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + ArrowBaseIterator::ResetStreamsLocked(); + current_file_idx_ = 0; + current_batch_idx_ = 0; + record_batches_.clear(); + next_record_batches_.clear(); + } + + Status ReadFile(int file_index, bool background = false) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto access_file = + s3fs_->OpenInputFile(dataset()->parquet_files_[file_index]) + .ValueOrDie(); + + parquet::ArrowReaderProperties properties; + properties.set_use_threads(true); + properties.set_pre_buffer(true); + parquet::ReaderProperties parquet_properties = + parquet::default_reader_properties(); + + std::shared_ptr builder = + std::make_shared(); + builder->Open(access_file, parquet_properties); + + std::unique_ptr reader; + builder->properties(properties)->Build(&reader); + + if (column_indices_.empty()) { + std::shared_ptr schema; + reader->GetSchema(&schema); + // check column name exist + std::string err_column_names; + for (const auto& name : dataset()->column_names_) { + int fieldIndex = schema->GetFieldIndex(name); + column_indices_.push_back(fieldIndex); + if (-1 == fieldIndex) { + err_column_names = err_column_names + " " + name; + } + } + + if (err_column_names.length() != 0) { + return errors::InvalidArgument("these column names don't exist: ", + err_column_names); + } + } + // Read file columns and build a table + std::shared_ptr<::arrow::Table> table; + CHECK_ARROW(reader->ReadTable(column_indices_, &table)); + // Convert the table to a sequence of batches + std::shared_ptr batch_reader = + std::make_shared(table); + std::shared_ptr batch = nullptr; + + // filter + if (!dataset()->filter_.empty()) { + auto scanner_builder = + arrow::dataset::ScannerBuilder::FromRecordBatchReader( + batch_reader); + arrow::compute::Expression filter_expr; + TF_RETURN_IF_ERROR( + ArrowUtil::ParseExpression(dataset()->filter_, filter_expr)); + scanner_builder->Filter(filter_expr); + auto scanner = scanner_builder->Finish().ValueOrDie(); + batch_reader = scanner->ToRecordBatchReader().ValueOrDie(); + } + + CHECK_ARROW(batch_reader->ReadNext(&batch)); + TF_RETURN_IF_ERROR(CheckBatchColumnTypes(batch)); + next_record_batches_.clear(); + while (batch != nullptr) { + if (!background) { + record_batches_.emplace_back(batch); + } else { + next_record_batches_.emplace_back(batch); + } + CHECK_ARROW(batch_reader->ReadNext(&batch)); + } + + if (background) { + mutex_lock lk(cv_mu_); + background_thread_finished_ = true; + cv_.notify_all(); + } + + return Status::OK(); + } + + size_t current_file_idx_ TF_GUARDED_BY(mu_) = 0; + size_t current_batch_idx_ TF_GUARDED_BY(mu_) = 0; + std::vector> record_batches_ + TF_GUARDED_BY(mu_); + std::vector> next_record_batches_ + TF_GUARDED_BY(mu_); + std::shared_ptr s3fs_ TF_GUARDED_BY(mu_) = + nullptr; + std::vector column_indices_ TF_GUARDED_BY(mu_); + std::shared_ptr background_worker_ = nullptr; + mutex cv_mu_; + condition_variable cv_; + bool background_thread_finished_ = false; + }; + + const std::string aws_access_key_; + const std::string aws_secret_key_; + const std::string aws_endpoint_override_; + const std::vector parquet_files_; + const std::vector column_names_; + const std::string filter_; + }; +}; // class ArrowS3DatasetOp + REGISTER_KERNEL_BUILDER(Name("IO>ArrowZeroCopyDataset").Device(DEVICE_CPU), ArrowZeroCopyDatasetOp); @@ -949,5 +1259,8 @@ REGISTER_KERNEL_BUILDER(Name("IO>ArrowFeatherDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("IO>ArrowStreamDataset").Device(DEVICE_CPU), ArrowStreamDatasetOp); +REGISTER_KERNEL_BUILDER(Name("IO>ArrowS3Dataset").Device(DEVICE_CPU), + ArrowS3DatasetOp); + } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.cc b/tensorflow_io/core/kernels/arrow/arrow_util.cc index b5d500883..5f2abd845 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_util.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow_io/core/kernels/arrow/arrow_util.h" #include "arrow/adapters/tensorflow/convert.h" @@ -470,6 +469,342 @@ Status ParseHost(std::string host, std::string* host_address, return Status::OK(); } +enum calType { + CONSTANT, + VARIABLE, + ADD, + SUBTRACT, + MULTIPLY, + DIVIDE, + EQUAL, + NOT_EQUAL, + LESS, + LESS_EQUAL, + GREATER, + GREATER_EQUAL, + AND, + OR, + LPAREN, + RPAREN, +}; + +enum OpType { + OPERATOR, + OPERAND, +}; + +typedef struct Token { + Token(calType type, int value) : type_(type) { + if (type_ == calType::CONSTANT) { + expression_ = arrow::compute::literal(value); + } + } + Token(calType type, float value) : type_(type) { + if (type_ == calType::CONSTANT) { + expression_ = arrow::compute::literal(value); + } + } + Token(calType type, std::string func) : type_(type), func_(func) { + if (type_ == calType::VARIABLE) { + expression_ = arrow::compute::field_ref(func_); + } + } + calType type_; + std::string func_; + arrow::compute::Expression expression_; +} Token; + +typedef struct ASTNode { + ASTNode(std::shared_ptr token, std::shared_ptr left, + std::shared_ptr right) + : token_(token), left_(left), right_(right) {} + std::shared_ptr token_; + std::shared_ptr left_; + std::shared_ptr right_; +} ASTNode; + +class Lexer { + public: + Lexer(const std::string& text) + : text_(text), position_(0), cur_op_(OPERATOR){}; + void skip_space() { + while (position_ < text_.length() && text_[position_] == ' ') { + position_++; + } + } + + std::string get_constant() { + int start = position_ - 1; + while (position_ < text_.length() && std::isdigit(text_[position_]) || + '.' == text_[position_]) { + position_++; + } + return text_.substr(start, position_ - start).c_str(); + } + + calType get_comparison_type() { + // == != >= <= + char begin_char = text_[position_ - 1]; + if (position_ < text_.length() && text_[position_] == '=') { + position_++; + if (begin_char == '=') { + return calType::EQUAL; + } else if (begin_char == '!') { + return calType::NOT_EQUAL; + } else if (begin_char == '>') { + return calType::GREATER_EQUAL; + } else if (begin_char == '<') { + return calType::LESS_EQUAL; + } + } else { + if (begin_char == '>') { + return calType::GREATER; + } else if (begin_char == '<') { + return calType::LESS; + } + } + } + + std::string get_variable() { + int start = position_ - 1; + while (position_ < text_.length() && + (std::isalnum(text_[position_]) || '_' == text_[position_])) { + position_++; + } + return text_.substr(start, position_ - start); + } + + std::shared_ptr get_next_token() { + while (position_ < text_.length()) { + char current_char = text_[position_++]; + if (' ' == current_char) { + skip_space(); + } else if (std::isdigit(current_char)) { + cur_op_ = OPERAND; + std::string constant = get_constant(); + if (std::string::npos == constant.find('.')) { + return std::make_shared(calType::CONSTANT, + std::stoi(constant)); + } else { + return std::make_shared(calType::CONSTANT, + std::stof(constant)); + } + } else if (std::isalpha(current_char) || '_' == current_char) { + cur_op_ = OPERAND; + return std::make_shared(calType::VARIABLE, get_variable()); + } else if ('+' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::ADD, "add"); + } else if ('-' == current_char) { + if (cur_op_ == OPERAND) { + cur_op_ = OPERATOR; + return std::make_shared(calType::SUBTRACT, "subtract"); + } else { + cur_op_ = OPERAND; + std::string constant = get_constant(); + if (constant.length() <= 1) { + return nullptr; + } + if (std::string::npos == constant.find('.')) { + return std::make_shared(calType::CONSTANT, + std::stoi(constant)); + } else { + return std::make_shared(calType::CONSTANT, + std::stof(constant)); + } + } + } else if ('*' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::MULTIPLY, "multiply"); + } else if ('/' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::DIVIDE, "divide"); + } else if ('(' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::LPAREN, "("); + } else if (')' == current_char) { + cur_op_ = OPERAND; + return std::make_shared(calType::RPAREN, ")"); + } else if ('=' == current_char || '!' == current_char || + '>' == current_char || '<' == current_char) { + cur_op_ = OPERATOR; + auto type = get_comparison_type(); + if (calType::EQUAL == type) { + return std::make_shared(calType::EQUAL, "equal"); + } else if (calType::NOT_EQUAL == type) { + return std::make_shared(calType::NOT_EQUAL, "not_equal"); + } else if (calType::LESS == type) { + return std::make_shared(calType::LESS, "less"); + } else if (calType::LESS_EQUAL == type) { + return std::make_shared(calType::LESS_EQUAL, "less_equal"); + } else if (calType::GREATER == type) { + return std::make_shared(calType::GREATER, "greater"); + } else if (calType::GREATER_EQUAL == type) { + return std::make_shared(calType::GREATER_EQUAL, + "greater_equal"); + } + } else if ('&' == current_char) { + cur_op_ = OPERATOR; + if (position_ < text_.length() && '&' == text_[position_]) { + position_++; + return std::make_shared(calType::AND, "and"); + } + } else if ('|' == current_char) { + cur_op_ = OPERATOR; + if (position_ < text_.length() && '|' == text_[position_]) { + position_++; + return std::make_shared(calType::OR, "or"); + } + } + } + return nullptr; + } + + private: + OpType cur_op_; + int position_; + std::string text_; +}; + +class Parser { + public: + Parser(std::shared_ptr ptr) : lexer_ptr_(ptr) { + current_token_ = lexer_ptr_->get_next_token(); + } + + inline void update_current_token() { + current_token_ = lexer_ptr_->get_next_token(); + } + + // constant, variable, lparen + std::shared_ptr factor() { + if (!current_token_) { + return nullptr; + } + auto token = current_token_; + if (token->type_ == calType::CONSTANT) { + update_current_token(); + return std::make_shared(token, nullptr, nullptr); + } else if (token->type_ == calType::VARIABLE) { + update_current_token(); + return std::make_shared(token, nullptr, nullptr); + } else if (token->type_ == calType::LPAREN) { + update_current_token(); + auto node = logical(); + update_current_token(); + return node; + } + return nullptr; + } + + // multiply, divide + std::shared_ptr term() { + auto node = factor(); + while (current_token_ && (current_token_->type_ == calType::MULTIPLY || + current_token_->type_ == calType::DIVIDE)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, factor()); + } + return node; + } + + // add, subtract + std::shared_ptr expr() { + auto node = term(); + while (current_token_ && (current_token_->type_ == calType::ADD || + current_token_->type_ == calType::SUBTRACT)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, term()); + } + return node; + } + + // Comparison + std::shared_ptr comparison() { + auto node = expr(); + while (current_token_ && + (current_token_->type_ >= calType::EQUAL && + current_token_->type_ <= calType::GREATER_EQUAL)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, expr()); + } + return node; + } + + // Logical + std::shared_ptr logical() { + auto node = comparison(); + while (current_token_ && (current_token_->type_ == calType::AND || + current_token_->type_ == calType::OR)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, comparison()); + } + return node; + } + + private: + std::shared_ptr lexer_ptr_; + std::shared_ptr current_token_; +}; + +class Interpreter { + public: + Interpreter(std::shared_ptr parser) : parser_(parser) {} + arrow::compute::Expression visit(std::shared_ptr root) { + auto rt = root->token_; + auto rlt = root->left_->token_; + auto rrt = root->right_->token_; + if (rlt->type_ != calType::CONSTANT && rlt->type_ != calType::VARIABLE) { + visit(root->left_); + } + if (rrt->type_ != calType::CONSTANT && rrt->type_ != calType::VARIABLE) { + visit(root->right_); + } + + if (rt->type_ >= calType::ADD && rt->type_ <= calType::OR) { + rt->expression_ = + arrow::compute::call(rt->func_, {rlt->expression_, rrt->expression_}); + } + rt->type_ = calType::VARIABLE; + return rt->expression_; + } + + Status interpreter(std::shared_ptr& ASTree) { + auto root = parser_->logical(); + if (!root || !root->left_ || !root->right_ || + root->token_->type_ < calType::EQUAL || + root->token_->type_ > calType::OR) { + return errors::InvalidArgument( + "Your filter expression is not supported!"); + } + ASTree = root; + return Status::OK(); + } + + private: + std::shared_ptr parser_; +}; + +Status ParseExpression(const std::string& text, + arrow::compute::Expression& expr) { + auto lexer_ptr = std::make_shared(text); + auto parser_ptr = std::make_shared(lexer_ptr); + auto interpreter_ptr = std::make_shared(parser_ptr); + + std::shared_ptr ASTree; + auto status = interpreter_ptr->interpreter(ASTree); + if (!status.ok()) { + return status; + } + + expr = interpreter_ptr->visit(ASTree); + return Status::OK(); +} + } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.h b/tensorflow_io/core/kernels/arrow/arrow_util.h index 9dabc16c5..f78f98c48 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.h +++ b/tensorflow_io/core/kernels/arrow/arrow_util.h @@ -17,6 +17,9 @@ limitations under the License. #define TENSORFLOW_IO_CORE_KERNELS_ARROW_UTIL_H_ #include "arrow/api.h" +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/file_parquet.h" +#include "arrow/filesystem/s3fs.h" #include "arrow/ipc/api.h" #include "arrow/util/io_util.h" #include "tensorflow/core/framework/tensor.h" @@ -80,6 +83,10 @@ Status ParseEndpoint(std::string endpoint, std::string* endpoint_type, Status ParseHost(std::string host, std::string* host_address, std::string* host_port); +// Parse expr from string for scan filter +Status ParseExpression(const std::string& text, + arrow::compute::Expression& expr); + } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/ops/arrow_ops.cc b/tensorflow_io/core/ops/arrow_ops.cc index ceff0a7c7..dfd8f8196 100644 --- a/tensorflow_io/core/ops/arrow_ops.cc +++ b/tensorflow_io/core/ops/arrow_ops.cc @@ -188,4 +188,29 @@ REGISTER_OP("IO>ArrowReadableRead") return Status::OK(); }); +REGISTER_OP("IO>ArrowS3Dataset") + .Input("aws_access_key: string") + .Input("aws_secret_key: string") + .Input("aws_endpoint_override: string") + .Input("parquet_files: string") + .Input("column_names: string") + .Input("filter: string") + .Input("columns: int32") + .Input("batch_size: int64") + .Input("batch_mode: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset from s3 parqeut files + +aws_access_key: S3 access key. +aws_secret_key: S3 secret_key. +aws_endpoint_override: S3 endpoint override +parquet_files: One or more parqeut file path on s3 +column_names: Select columns to read by names +)doc"); + } // namespace tensorflow diff --git a/tensorflow_io/python/ops/arrow_dataset_ops.py b/tensorflow_io/python/ops/arrow_dataset_ops.py index e051c0b75..6f7e59461 100644 --- a/tensorflow_io/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/python/ops/arrow_dataset_ops.py @@ -651,7 +651,84 @@ def gen_record_batches(): ) +class ArrowS3Dataset(ArrowBaseDataset): + """An Arrow Dataset for reading record batches from an input stream. + Currently supported input streams are a socket client or stdin. + """ + + def __init__( + self, + aws_access_key, + aws_secret_key, + aws_endpoint_override, + parquet_files, + column_names, + columns, + output_types, + output_shapes=None, + batch_size=None, + batch_mode="keep_remainder", + filter="", + ): + """Create an ArrowDataset from an input stream. + + Args: + aws_access_key: S3 access key + aws_secret_key: S3 secret key + aws_endpoint_override: S3 endpoint override + parquet_files: A list of parquet files path on s3 + column_names: A list of column names to be used in the dataset + columns: A list of column indices to be used in the Dataset + output_types: Tensor dtypes of the output tensors + output_shapes: TensorShapes of the output tensors or None to + infer partial + batch_size: Batch size of output tensors, setting a batch size here + will create batched tensors from Arrow memory and can be more + efficient than using tf.data.Dataset.batch(). + NOTE: batch_size does not need to be set if batch_mode='auto' + batch_mode: Mode of batching, supported strings: + "keep_remainder" (default, keeps partial batch data), + "drop_remainder" (discard partial batch data), + "auto" (size to number of records in Arrow record batch) + filter : filter for reade row + """ + aws_access_key = tf.convert_to_tensor( + aws_access_key, dtype=dtypes.string, name="aws_access_key" + ) + aws_secret_key = tf.convert_to_tensor( + aws_secret_key, dtype=dtypes.string, name="aws_secret_key" + ) + aws_endpoint_override = tf.convert_to_tensor( + aws_endpoint_override, dtype=dtypes.string, name="aws_endpoint_override" + ) + parquet_files = tf.convert_to_tensor( + parquet_files, dtype=dtypes.string, name="parquet_files" + ) + column_names = tf.convert_to_tensor( + column_names, dtype=dtypes.string, name="column_names" + ) + filter = tf.convert_to_tensor(filter, dtype=dtypes.string, name="filter") + + super().__init__( + partial( + core_ops.io_arrow_s3_dataset, + aws_access_key, + aws_secret_key, + aws_endpoint_override, + parquet_files, + column_names, + filter, + ), + columns, + output_types, + output_shapes, + batch_size, + batch_mode, + ) + + def list_feather_columns(filename, **kwargs): + """list_feather_columns""" if not tf.executing_eagerly(): raise NotImplementedError("list_feather_columns only support eager mode") diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 4dbce8ede..84e3f5cd6 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -56,7 +56,12 @@ cc_library( [ "cpp/src/arrow/*.cc", "cpp/src/arrow/array/*.cc", + "cpp/src/arrow/compute/*.cc", + "cpp/src/arrow/compute/exec/*.cc", + "cpp/src/arrow/compute/kernels/*.cc", "cpp/src/arrow/csv/*.cc", + "cpp/src/arrow/dataset/*.cc", + "cpp/src/arrow/filesystem/*.cc", "cpp/src/arrow/io/*.cc", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/json/*.cc", @@ -66,6 +71,10 @@ cc_library( "cpp/src/arrow/vendored/optional.hpp", "cpp/src/arrow/vendored/string_view.hpp", "cpp/src/arrow/vendored/variant.hpp", + "cpp/src/arrow/vendored/base64.cpp", + "cpp/src/arrow/vendored/datetime/tz.cpp", + "cpp/src/arrow/vendored/uriparser/*.c", + "cpp/src/arrow/vendored/pcg/*.hpp", "cpp/src/arrow/**/*.h", "cpp/src/parquet/**/*.h", "cpp/src/parquet/**/*.cc", @@ -76,10 +85,11 @@ cc_library( "cpp/src/**/*_benchmark.cc", "cpp/src/**/*_main.cc", "cpp/src/**/*_nossl.cc", - "cpp/src/**/*_test.cc", - "cpp/src/**/test_*.cc", + "cpp/src/**/*test*.h", + "cpp/src/**/*test*.cc", "cpp/src/**/*hdfs*.cc", "cpp/src/**/*fuzz*.cc", + "cpp/src/**/*gcsfs*.cc", "cpp/src/**/file_to_stream.cc", "cpp/src/**/stream_to_file.cc", "cpp/src/arrow/util/bpacking_avx2.cc", @@ -94,6 +104,12 @@ cc_library( "cpp/src/parquet/parquet_version.h", ], copts = [], + linkopts = select({ + "@bazel_tools//src/conditions:windows": [ + "-DEFAULTLIB:Ole32.lib", + ], + "//conditions:default": [], + }), defines = [ "ARROW_WITH_BROTLI", "ARROW_WITH_SNAPPY", @@ -106,16 +122,21 @@ cc_library( "PARQUET_STATIC", "PARQUET_EXPORT=", "WIN32_LEAN_AND_MEAN", + "ARROW_DS_STATIC", + "URI_STATIC_BUILD", ], includes = [ "cpp/src", "cpp/src/arrow/vendored/xxhash", + "cpp/src/generated", ], textual_hdrs = [ "cpp/src/arrow/vendored/xxhash/xxhash.c", ], deps = [ ":arrow_format", + "@aws-sdk-cpp//:identity-management", + "@aws-sdk-cpp//:s3", "@boringssl//:crypto", "@brotli", "@bzip2", diff --git a/third_party/aws-sdk-cpp.BUILD b/third_party/aws-sdk-cpp.BUILD index ba7d90bcb..16e9cc9d1 100644 --- a/third_party/aws-sdk-cpp.BUILD +++ b/third_party/aws-sdk-cpp.BUILD @@ -163,6 +163,61 @@ cc_library( ], ) +cc_library( + name = "cognito-identity", + srcs = glob([ + "aws-cpp-sdk-cognito-identity/source/*.cpp", + "aws-cpp-sdk-cognito-identity/source/model/*.cpp", + ]), + hdrs = glob([ + "aws-cpp-sdk-cognito-identity/include/aws/cognito-identity/*.h", + "aws-cpp-sdk-cognito-identity/include/aws/cognito-identity/model/*.h", + ]), + includes = [ + "aws-cpp-sdk-cognito-identity/include", + ], + deps = [ + ":core", + ], +) + +cc_library( + name = "sts", + srcs = glob([ + "aws-cpp-sdk-sts/source/*.cpp", + "aws-cpp-sdk-sts/source/model/*.cpp", + ]), + hdrs = glob([ + "aws-cpp-sdk-sts/include/aws/sts/*.h", + "aws-cpp-sdk-sts/include/aws/sts/model/*.h", + ]), + includes = [ + "aws-cpp-sdk-sts/include", + ], + deps = [ + ":core", + ], +) + +cc_library( + name = "identity-management", + srcs = glob([ + "aws-cpp-sdk-identity-management/source/auth/*.cpp", + ]), + hdrs = glob([ + "aws-cpp-sdk-identity-management/include/aws/identity-management/*.h", + "aws-cpp-sdk-identity-management/include/aws/identity-management/auth/*.h", + ]), + includes = [ + "aws-cpp-sdk-identity-management/include", + ], + deps = [ + ":cognito-identity", + ":core", + ":sts", + ], +) + genrule( name = "SDKConfig_h", outs = [