Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[*] improve log message for storage view content #1718

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/ctranslate2/storage_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ namespace ctranslate2 {

friend std::ostream& operator<<(std::ostream& os, const StorageView& storage);

template <typename T>
void print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const;

protected:
DataType _dtype = DataType::FLOAT32;
Device _device = Device::CPU;
Expand Down
19 changes: 17 additions & 2 deletions python/tests/test_storage_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,26 @@ def test_storageview_cpu(dtype, name):
with pytest.raises(AttributeError, match="CPU"):
s.__cuda_array_interface__

assert str(s) == " 1 1 1 ... 1 1 1\n[cpu:0 %s storage viewed as 2x4]" % name
expected_output = (
"Data (2D Matrix):"
"\n[[1, 1, 1, 1], "
"\n[1, 1, 1, 1]]"
"\n[device:{}:{}, dtype:{}, storage viewed as {}x{}]"
).format(s.device, s.device_index, name, s.shape[0], s.shape[1])

assert str(s) == expected_output

x[0][2] = 3
x[1][3] = 8
assert str(s) == " 1 1 3 ... 1 1 8\n[cpu:0 %s storage viewed as 2x4]" % name

expected_output = (
"Data (2D Matrix):"
"\n[[1, 1, 3, 1], "
"\n[1, 1, 1, 8]]"
"\n[device:{}:{}, dtype:{}, storage viewed as {}x{}]"
).format(s.device, s.device_index, name, s.shape[0], s.shape[1])

assert str(s) == expected_output

y = np.array(x)
assert test_utils.array_equal(x, y)
Expand Down
70 changes: 51 additions & 19 deletions src/storage_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,31 +441,35 @@ namespace ctranslate2 {
}

std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
// Create a printable copy of the storage
StorageView printable(storage.dtype());
printable.copy_from(storage);

// Check the data type and print accordingly
TYPE_DISPATCH(
printable.dtype(),
const auto* values = printable.data<T>();
if (printable.size() <= PRINT_MAX_VALUES) {
for (dim_t i = 0; i < printable.size(); ++i) {
os << ' ';
print_value(os, values[i]);
}
}
else {
for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) {
os << ' ';
print_value(os, values[i]);
}
os << " ...";
for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) {
os << ' ';
print_value(os, values[i]);
}
const auto& shape = printable.shape();

// Print tensor contents based on dimensionality
if (shape.empty()) { // Scalar case
os << "Data (Scalar): " << values[0] << std::endl;
} else {
os << "Data (" << shape.size() << "D ";
if (shape.size() == 1)
os << "Vector";
else if (shape.size() == 2)
os << "Matrix";
else
os << "Tensor";
os << "):" << std::endl;
printable.print_tensor(os, values, shape, 0, 0, 0);
os << std::endl;
}
os << std::endl);
os << "[" << device_to_str(storage.device(), storage.device_index())
<< " " << dtype_name(storage.dtype()) << " storage viewed as ";
);

os << "[device:" << device_to_str(storage.device(), storage.device_index())
<< ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as ";
if (storage.is_scalar())
os << "scalar";
else {
Expand All @@ -479,6 +483,34 @@ namespace ctranslate2 {
return os;
}

template <typename T>
void StorageView::print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const {
std::string indentation(indent, ' ');

os << indentation << "[";
bool is_last_dim = (dim == shape.size() - 1);
for (dim_t i = 0; i < shape[dim]; ++i) {
if (i > 0) {
os << ", ";
if (!is_last_dim) {
os << "\n" << std::string(indent, ' ');
}
}

if (i == PRINT_MAX_VALUES / 2 && shape[dim] > PRINT_MAX_VALUES) {
os << "...";
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
} else {
if (is_last_dim) {
os << +data[offset + i];
} else {
print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent);
}
}
}
os << "]";
}

#define DECLARE_IMPL(T) \
template \
StorageView::StorageView(Shape shape, T init, Device device); \
Expand Down
Loading