Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured zarr dtype as bytes #176

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4aa8cae
Implemented concept of "open as byte array" for any dtype
brian-michell Jun 20, 2024
ac6ba5c
Added open option for byte array
brian-michell Jun 21, 2024
6b43911
Removed open option for byte array
brian-michell Jun 21, 2024
c4beb11
Allow modification of ParseDTypeOptions
brian-michell Jun 21, 2024
43cc7c8
Update default behavior
brian-michell Jun 21, 2024
a76363a
Emplace instead of replace
brian-michell Jun 21, 2024
6342417
Update field selection to use emplaced field
brian-michell Jun 21, 2024
f684f17
Stopped synthetic dimension from getting seralized
brian-michell Jun 24, 2024
93e92ce
Pushback synthetic field rather than make it the only field
brian-michell Jun 24, 2024
874d581
Skip the synthetic field as part of the number of bytes
brian-michell Jun 24, 2024
c7408d4
Fix const qualifier
brian-michell Jun 24, 2024
83acfa6
Properly modify const object
brian-michell Jun 24, 2024
7ac53a4
Properly modify has_fields logic
brian-michell Jun 24, 2024
2c0bdc3
Pulled check outside loop and fixed logic error
brian-michell Jun 24, 2024
442daa0
Explicit initilization of values. Ensured has_fields for structured d…
brian-michell Jun 24, 2024
92145fc
Implemented proper field_shape for synthetic field
brian-michell Jun 24, 2024
77ce1ca
Added valid case for an empty named field
brian-michell Jun 24, 2024
1113ec7
Disabled setting has_fields to false
brian-michell Jun 24, 2024
9d9a555
Fix encoding
brian-michell Jun 24, 2024
1d597ad
Fixed chunkgrid to use the true size
brian-michell Jun 25, 2024
ea1ae5c
Temporary drop of synthetic field for metadata validation
brian-michell Jun 25, 2024
5355201
Fix single field dtype without name incorrectly passing
brian-michell Jun 25, 2024
5c9bd16
Fix multiple synthetic fields getting added
brian-michell Jun 25, 2024
c4f0178
Removed extra argument
brian-michell Jun 25, 2024
288b175
Catch case where there was no field due to being a struct array of si…
brian-michell Jun 25, 2024
de830ef
Ignore the synthetic dimension when chunking
brian-michell Jun 25, 2024
619a03a
Remove the synthetic dimension when it is not needed
brian-michell Jun 25, 2024
bb68abf
Update tests to reflect DType without going through the spec driver p…
brian-michell Jun 25, 2024
e40dfdd
Update test to reflect metadata containing synthetic dimension. (Hand…
brian-michell Jun 25, 2024
8fc4fa6
Remove check that is no-longer true
brian-michell Jun 25, 2024
9912b32
Remove test that is no longer true
brian-michell Jun 25, 2024
6015c31
Only use the actual fields for fill values
brian-michell Jun 25, 2024
5ae3d36
Cleaning up convention
brian-michell Jun 25, 2024
8ca6c74
Fixed incorrect assertion
brian-michell Jun 25, 2024
211beab
Add beginning of new test for field-free opening
brian-michell Jun 26, 2024
a4d7e73
Fix num_bytes not getting set
brian-michell Jun 26, 2024
e51b3fe
Allow the metadata to be one size higher in the event of structured d…
brian-michell Jun 26, 2024
f30948c
Properly specify the chunk grid for synthetic data field
brian-michell Jun 26, 2024
824d5bc
Added fallback for zarr synthetic field
brian-michell Jun 26, 2024
364a569
Fix test to reflect num_bytes getting properly set
brian-michell Jun 26, 2024
620af12
Disabled ignoring 'true size' for chunk layout
brian-michell Jul 1, 2024
60060a7
Fixed metadata caching causing extra dimension getting added
brian-michell Jul 1, 2024
5b56a98
Fix hardcoded case for opening existing zarr without field
brian-michell Jul 5, 2024
08fbd48
Implement naieve fix for fill_value assertion failure on no field str…
brian-michell Jul 5, 2024
c1c0ccd
Remove incorrect decrement of struct data
brian-michell Jul 5, 2024
96144b7
Added test for opening existing store as struct array
brian-michell Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion tensorstore/driver/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,17 @@ Future<Driver::Handle> OpenDriver(TransformedDriverSpec bound_spec,
if (composed_transform.ok()) {
handle->transform = std::move(composed_transform).value();
} else {
status = composed_transform.status();
// Fallback for Zarr driver opening without field specified.
if ((handle->transform.domain().rank() + 1) == bound_spec.transform.domain().rank()) {
// TODO: Make this a safer fallback. There may be a way to do it at the Zarr driver level.
// Just use the spec's transform twice... What's the worst that could happen!?
composed_transform = tensorstore::ComposeTransforms(std::move(bound_spec.transform), std::move(bound_spec.transform));
if (composed_transform.ok()) {
handle->transform = std::move(composed_transform).value();
}
} else {
status = composed_transform.status();
}
}
}

Expand Down
56 changes: 46 additions & 10 deletions tensorstore/driver/zarr/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ Result<MetadataCache::MetadataPtr> MetadataCache::DecodeMetadata(

Result<absl::Cord> MetadataCache::EncodeMetadata(std::string_view entry_key,
const void* metadata) {
return absl::Cord(
::nlohmann::json(*static_cast<const ZarrMetadata*>(metadata)).dump());
auto meta = ::nlohmann::json(*static_cast<const ZarrMetadata*>(metadata));
if (meta["dtype"].is_array() && meta["dtype"].back()[0] == "") {
meta["dtype"].erase(meta["dtype"].size() - 1);
}

return absl::Cord(meta.dump());
}

absl::Status ZarrDriverSpec::ApplyOptions(SpecOptions&& options) {
Expand Down Expand Up @@ -264,7 +268,8 @@ void DataCache::GetChunkGridBounds(const void* metadata_ptr,
DimensionSet& implicit_lower_bounds,
DimensionSet& implicit_upper_bounds) {
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);
assert(bounds.rank() == static_cast<DimensionIndex>(metadata.shape.size()));
assert((bounds.rank() == static_cast<DimensionIndex>(metadata.shape.size())) ||
(bounds.rank()+1 == static_cast<DimensionIndex>(metadata.shape.size())));
std::fill(bounds.origin().begin(), bounds.origin().end(), Index(0));
std::copy(metadata.shape.begin(), metadata.shape.end(),
bounds.shape().begin());
Expand All @@ -291,15 +296,32 @@ Result<std::shared_ptr<const void>> DataCache::GetResizedMetadata(

internal::ChunkGridSpecification DataCache::GetChunkGridSpecification(
const ZarrMetadata& metadata) {
bool flag = false;

if (metadata.shape.size() - 1 == metadata.rank) {
flag = true;
const_cast<ZarrMetadata&>(metadata).shape.pop_back();
const_cast<ZarrMetadata&>(metadata).chunks.pop_back();
}
size_t true_size = metadata.dtype.fields.size();
internal::ChunkGridSpecification::ComponentList components;
components.reserve(metadata.dtype.fields.size());
components.reserve(true_size);
std::vector<DimensionIndex> chunked_to_cell_dimensions(
metadata.chunks.size());
std::iota(chunked_to_cell_dimensions.begin(),
chunked_to_cell_dimensions.end(), static_cast<DimensionIndex>(0));
for (size_t field_i = 0; field_i < metadata.dtype.fields.size(); ++field_i) {
for (size_t field_i = 0; field_i < true_size; ++field_i) {
const auto& field = metadata.dtype.fields[field_i];

if (field.name.empty() && true_size > 1 && !flag) {
// Fix the synthetic field
const_cast<ZarrMetadata&>(metadata).chunk_layout.fields[field_i].decoded_chunk_layout = metadata.chunk_layout.fields[0].decoded_chunk_layout;
// We need to "add" a dimension or there will be an illegal transform
const_cast<ZarrMetadata&>(metadata).shape.push_back(metadata.dtype.fields.back().num_bytes);
const_cast<ZarrMetadata&>(metadata).chunks.push_back(0); // No chunking in the synthetic dimension
}
const auto& field_layout = metadata.chunk_layout.fields[field_i];

auto fill_value = metadata.fill_value[field_i];
const bool fill_value_specified = fill_value.valid();
if (!fill_value.valid()) {
Expand All @@ -323,7 +345,13 @@ internal::ChunkGridSpecification DataCache::GetChunkGridSpecification(
for (DimensionIndex cell_dim = fill_value_start_dim; cell_dim < cell_rank;
++cell_dim) {
const Index size = field_layout.full_chunk_shape()[cell_dim];
assert(fill_value.shape()[cell_dim - fill_value_start_dim] == size);

if(field.name.empty() && true_size > 1 && field_i+1 == true_size) {
// TODO: Figure out how this case should be properly handled
} else {
assert(fill_value.shape()[cell_dim - fill_value_start_dim] == size);
}

chunk_fill_value.shape()[cell_dim] = size;
chunk_fill_value.byte_strides()[cell_dim] =
fill_value.byte_strides()[cell_dim - fill_value_start_dim];
Expand Down Expand Up @@ -504,13 +532,21 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {

Result<size_t> GetComponentIndex(const void* metadata_ptr,
OpenMode open_mode) override {
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);
const auto& const_metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);

ZarrMetadata metadata = const_metadata;
// Modify temporary metadata objects for validation. There is probably a better way!
if (!spec().selected_field.empty() && metadata.dtype.fields.back().name.empty()) {
metadata.dtype.fields.pop_back();
}
// We're only going to use our modified variables up to here.
TENSORSTORE_RETURN_IF_ERROR(
ValidateMetadata(metadata, spec().partial_metadata));
ValidateMetadata(metadata, spec().partial_metadata));

TENSORSTORE_ASSIGN_OR_RETURN(
auto field_index, GetFieldIndex(metadata.dtype, spec().selected_field));
auto field_index, GetFieldIndex(const_metadata.dtype, spec().selected_field));
TENSORSTORE_RETURN_IF_ERROR(
ValidateMetadataSchema(metadata, field_index, spec().schema));
ValidateMetadataSchema(metadata, field_index, spec().schema));
return field_index;
}
};
Expand Down
64 changes: 64 additions & 0 deletions tensorstore/driver/zarr/driver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,70 @@ TEST(ZarrDriverTest, CreateLittleEndianUnaligned) {
}));
}

TEST(ZarrDriverTest, OpenWithoutField) {
::nlohmann::json json_spec{
{"driver", "zarr"},
{"kvstore",
{
{"driver", "memory"},
{"path", "prefix/"},
}},
{"metadata",
{
{"compressor", {{"id", "blosc"}}},
{"dtype", ::nlohmann::json::array_t{{"x", "|b1"}, {"y", "<i2"}}},
{"shape", {100, 100}},
{"chunks", {3, 2}},
}},
};
auto context = Context::Default();

auto store_res = tensorstore::Open(json_spec, context, tensorstore::OpenMode::create,
tensorstore::ReadWriteMode::read_write);
ASSERT_TRUE(store_res.status().ok()) << store_res.status();
}

TEST(ZarrDriverTest, OpenExistingWithoutField) {
::nlohmann::json json_spec{
{"driver", "zarr"},
{"kvstore",
{
{"driver", "memory"},
{"path", "prefix/"},
}},
{"metadata",
{
{"compressor", {{"id", "blosc"}}},
// I changed the standard test here to ensure that it was getting opened as a byte array
{"dtype", ::nlohmann::json::array_t{{"y", "<i2"}, {"x", "<f4"}}},
{"shape", {100, 100}},
{"chunks", {3, 2}},
}},
};
auto context = Context::Default();

auto store_res = tensorstore::Open(json_spec, context, tensorstore::OpenMode::create,
tensorstore::ReadWriteMode::read_write);
ASSERT_TRUE(store_res.status().ok()) << store_res.status();

auto creation_store = store_res.value();
auto creation_dtype = creation_store.dtype();
auto creation_rank = creation_store.rank();
auto creation_domain = creation_store.domain();

json_spec.erase("metadata");

auto store_res_open = tensorstore::Open(json_spec, context, tensorstore::OpenMode::open,
tensorstore::ReadWriteMode::read_write);

ASSERT_TRUE(store_res_open.status().ok()) << store_res_open.status();
auto open_store = store_res_open.value();

EXPECT_EQ(creation_store.dtype(), open_store.dtype());
EXPECT_EQ(creation_store.rank(), open_store.rank());
EXPECT_EQ(creation_store.domain(), open_store.domain());
}

TEST(ZarrDriverTest, CreateComplexWithFillValue) {
::nlohmann::json json_spec{
{"driver", "zarr"},
Expand Down
46 changes: 42 additions & 4 deletions tensorstore/driver/zarr/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,10 @@ namespace {
/// \param value The zarr metadata "dtype" JSON specification.
/// \param out[out] Must be non-null. Filled with the parsed dtype on success.
/// \error `absl::StatusCode::kInvalidArgument' if `value` is invalid.
Result<ZarrDType> ParseDTypeNoDerived(const nlohmann::json& value) {
Result<ZarrDType> ParseDTypeNoDerived(const nlohmann::json& value, const ParseDTypeOptions& options) {
ZarrDType out;
if (value.is_string()) {
const_cast<ParseDTypeOptions&>(options).treat_struct_as_byte_array = false;
// Single field.
out.has_fields = false;
out.fields.resize(1);
Expand Down Expand Up @@ -247,6 +248,10 @@ Result<ZarrDType> ParseDTypeNoDerived(const nlohmann::json& value) {
switch (i) {
case 0:
if (internal_json::JsonRequireValueAs(v, &field.name).ok()) {
// This SHOULD only be the case if a field was not provided
if (field_i > 0 && field_i == (out.fields.size() - 1) && field.name.empty()) {
return absl::OkStatus();
}
if (!field.name.empty()) return absl::OkStatus();
}
return absl::InvalidArgumentError(tensorstore::StrCat(
Expand Down Expand Up @@ -279,14 +284,39 @@ Result<ZarrDType> ParseDTypeNoDerived(const nlohmann::json& value) {
});
});
if (!parse_result.ok()) return parse_result;

if (options.treat_struct_as_byte_array) {

// Check if we've already added a synthetic field
if (!out.fields.back().name.empty()) {
// Convert struct dtype to a single byte array dtype.
ZarrDType::Field byte_array_field;
byte_array_field.name = "";
byte_array_field.dtype = dtype_v<std::byte>;
byte_array_field.endian = endian::native;
byte_array_field.encoded_dtype = "|V";
byte_array_field.flexible_shape = {out.bytes_per_outer_element};
byte_array_field.num_inner_elements = 0; // I don't think I need to set this explicitly
byte_array_field.byte_offset = 0; // I don't think I need to set this explicitly
byte_array_field.num_bytes = 0; // This will get set properly elsewhere, but let's init it to 0 anyway
out.fields.push_back({byte_array_field});
}
}
return out;
}

} // namespace

absl::Status ValidateDType(ZarrDType& dtype) {
dtype.bytes_per_outer_element = 0;
for (size_t field_i = 0; field_i < dtype.fields.size(); ++field_i) {
size_t num_fields = dtype.fields.size();
// TODO: Implement better logic and name
bool flag = false;
if (dtype.fields.back().name.empty() && dtype.fields.size() > 1) {
--num_fields;
flag = true;
}
for (size_t field_i = 0; field_i < num_fields; ++field_i) {
auto& field = dtype.fields[field_i];
if (std::any_of(
dtype.fields.begin(), dtype.fields.begin() + field_i,
Expand Down Expand Up @@ -317,11 +347,19 @@ absl::Status ValidateDType(ZarrDType& dtype) {
"Total number of bytes per outer array element is too large");
}
}
if (flag) {
dtype.fields[num_fields].field_shape = {dtype.bytes_per_outer_element};
// Check that we haven't already added the size to the encoding
if (dtype.fields[num_fields].encoded_dtype.size() == 2) {
dtype.fields[num_fields].encoded_dtype += std::to_string(dtype.bytes_per_outer_element);
dtype.fields[num_fields].num_bytes = dtype.bytes_per_outer_element;
}
}
return absl::OkStatus();
}

Result<ZarrDType> ParseDType(const nlohmann::json& value) {
TENSORSTORE_ASSIGN_OR_RETURN(ZarrDType dtype, ParseDTypeNoDerived(value));
Result<ZarrDType> ParseDType(const nlohmann::json& value, const ParseDTypeOptions& options) {
TENSORSTORE_ASSIGN_OR_RETURN(ZarrDType dtype, ParseDTypeNoDerived(value, options));
TENSORSTORE_RETURN_IF_ERROR(ValidateDType(dtype));
return dtype;
}
Expand Down
5 changes: 4 additions & 1 deletion tensorstore/driver/zarr/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,13 @@ struct ZarrDType {
const ZarrDType& dtype);
};

struct ParseDTypeOptions {
bool treat_struct_as_byte_array = true;
};
/// Parses a zarr metadata "dtype" JSON specification.
///
/// \error `absl::StatusCode::kInvalidArgument` if `value` is not valid.
Result<ZarrDType> ParseDType(const ::nlohmann::json& value);
Result<ZarrDType> ParseDType(const ::nlohmann::json& value, const ParseDTypeOptions& options = {});

/// Validates `dtype and computes derived values.
///
Expand Down
34 changes: 32 additions & 2 deletions tensorstore/driver/zarr/dtype_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,13 @@ void CheckDType(const ::nlohmann::json& json, const ZarrDType& expected) {
TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto dtype, ParseDType(json));
EXPECT_EQ(expected, dtype);
// Check round trip.
EXPECT_EQ(json, ::nlohmann::json(dtype));
auto dtype_json = ::nlohmann::json(dtype);
if (json != dtype_json) {
ASSERT_TRUE(dtype_json.is_array() && dtype_json.back().is_array() && dtype_json.back()[0].is_string());
ASSERT_TRUE(dtype_json.back()[0].get<std::string>().empty());
dtype_json.erase(dtype_json.end() - 1); // Remove the last element
ASSERT_EQ(json, dtype_json);
}
}

TEST(ParseDType, SimpleStringBool) {
Expand Down Expand Up @@ -213,8 +219,20 @@ TEST(ParseDType, SingleNamedFieldChar) {
/*.num_inner_elements=*/10,
/*.byte_offset=*/0,
/*.num_bytes=*/10},
{{
/*.encoded_dtype=*/"|V10",
/*.dtype=*/dtype_v<std::byte>,
/*.endian=*/endian::native,
/*.flexible_shape=*/{10},
},
/*.outer_shape=*/{},
/*.name=*/"",
/*.field_shape=*/{10},
/*.num_inner_elements=*/0,/*Not set yet*/
/*.byte_offset=*/0,/*Not set yet*/
/*.num_bytes=*/10},
},
/*.bytes_per_outer_element=*/10,
/*.bytes_per_outer_element=*/10/*Won't change*/,
});
}

Expand Down Expand Up @@ -250,6 +268,18 @@ TEST(ParseDType, TwoNamedFieldsCharAndInt) {
/*.num_inner_elements=*/5,
/*.byte_offset=*/10 * 2 * 3,
/*.num_bytes=*/2 * 5},
{{
/*.encoded_dtype=*/"|V70",
/*.dtype=*/dtype_v<std::byte>,
/*.endian=*/endian::native,
/*.flexible_shape=*/{70},
},
/*.outer_shape=*/{},
/*.name=*/"",
/*.field_shape=*/{70},
/*.num_inner_elements=*/0,/*Not set yet*/
/*.byte_offset=*/0,/*Not set yet*/
/*.num_bytes=*/70},
},
/*.bytes_per_outer_element=*/10 * 2 * 3 + 2 * 5,
});
Expand Down
Loading