Skip to content

Commit

Permalink
TFRecord to support S3 index URIs (NVIDIA#5515)
Browse files Browse the repository at this point in the history
TFRecord index file didn't use the FileStream:Open abstraction, therefore it did not support S3 storage like advertised.
Added a FileStreamBuf abstraction (useful to use with std::istream)

Signed-off-by: Joaquin Anton <[email protected]>
  • Loading branch information
jantonguirao authored Jun 11, 2024
1 parent 9dfdeac commit 9420fb8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
15 changes: 12 additions & 3 deletions dali/operators/reader/loader/indexed_file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "dali/util/uri.h"
#include "dali/util/file.h"
#include "dali/util/odirect_file.h"
#include "dali/core/call_at_exit.h"

namespace dali {

Expand Down Expand Up @@ -192,13 +193,21 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>, true> {
DALI_ENFORCE(index_uris.size() == paths_.size(),
"Number of index files needs to match the number of data files");
for (size_t i = 0; i < index_uris.size(); ++i) {
std::ifstream fin(index_uris[i]);
DALI_ENFORCE(fin.good(), "Failed to open file " + index_uris[i]);
const auto& path = index_uris[i];
auto uri = URI::Parse(path);
auto index_file = FileStream::Open(path);
auto index_file_cleanup = AtScopeExit([&index_file] {
if (index_file)
index_file->Close();
});

FileStreamBuf<> stream_buf(index_file.get());
std::istream fin(&stream_buf);
DALI_ENFORCE(fin.good(), "Failed to open file " + path);
int64 pos, size;
while (fin >> pos >> size) {
indices_.emplace_back(pos, size, i);
}
fin.close();
}
}

Expand Down
28 changes: 28 additions & 0 deletions dali/util/file.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define DALI_UTIL_FILE_H_

#include <cstdio>
#include <streambuf>
#include <memory>
#include <string>
#include <optional>
Expand Down Expand Up @@ -107,6 +108,33 @@ class DLL_PUBLIC FileStream : public InputStream {
std::string path_;
};

/**
* @brief Custom streambuf implementation that reads from FileStream.
* @remarks It is useful to be used together with std::istream
*/
template <size_t BufferSize = (1 << 10)>
class FileStreamBuf : public std::streambuf {
public:
explicit FileStreamBuf(FileStream *reader) : reader_(reader) {
setg(buffer_, buffer_, buffer_); // Initialize get area pointers
}

protected:
int_type underflow() override {
if (gptr() == egptr()) { // get area is exhausted
size_t nbytes = reader_->Read(buffer_, BufferSize);
if (nbytes == 0)
return traits_type::eof();
setg(buffer_, buffer_, buffer_ + nbytes);
}
return traits_type::to_int_type(*gptr());
}

private:
FileStream *reader_;
char buffer_[BufferSize];
};

} // namespace dali

#endif // DALI_UTIL_FILE_H_

0 comments on commit 9420fb8

Please sign in to comment.