Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds SmoothView #185

Merged
merged 4 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/tensorwrapper/detail_/polymorphic_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ class PolymorphicBase {
* @throw None No throw guarantee.
*/
bool are_equal(const_base_reference rhs) const noexcept {
// Downcast *this so it can be passed to are_equal_
const_base_reference plhs = static_cast<const_base_reference>(*this);
return are_equal_(rhs) && rhs.are_equal_(plhs);

// This line is necessary if are_equal_ is overriden in BaseType
const PolymorphicBase& rhs_upcast = rhs;

return are_equal_(rhs) && rhs_upcast.are_equal_(plhs);
}

/** @brief Determines if *this and @p rhs are polymorphically different.
Expand Down
61 changes: 61 additions & 0 deletions include/tensorwrapper/detail_/view_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <type_traits>

namespace tensorwrapper::detail_ {

/** @brief Is the cast from @p FromType to @p ToType just adding const?
*
* A common TMP pattern in implementing views is needing to convert mutable
* views to read-only views. This trait can be used to compare the template
* type parameters of two views (assuming the views are templated on what
* object they are acting like) in order to determine if they represent a
* conversion from @p FromType to @p ToType such that @p ToType is
* `const FromType`. If @p ToType is `const FromType` this template variable
* will be set to true, otherwise it will be set to false.
*
* @tparam FromType The type we are converting from.
* @tparam ToType The type we are converting to.
*/
template<typename FromType, typename ToType>
constexpr bool is_mutable_to_immutable_cast_v =
!std::is_const_v<FromType> && // FromType is NOT read-only
std::is_const_v<ToType> && // ToType is read-only
std::is_same_v<const FromType, ToType>; // They differ by const-ness

/** @brief Disables a templated function except when
* `is_mutable_to_immutable_cast_v<FromType, ToType>` evaluates to true.
*
* If `View` is a template class with template parameter type `T`, we want the
* implicit conversion from `View<T>` to `View<const T>` to exist. In practice,
* this leaves us with two options: partial specialization of `View` for
* const-qualified types or use of SFINAE to disable the conversion. We prefer
* the latter as the former requires us to duplicate the entirety of the
* class. This template type will disable the accompanying function via SFINAE
* if @p ToType is not `const FromType`.
*
* @tparam FromType The type we are converting from. Expected to be the
* template type parameter of the view we are casting from.
* @tparam ToType The type we are converting to. Expected to be the template
* type parameter of the view we are casting to.
*/
template<typename FromType, typename ToType>
using enable_if_mutable_to_immutable_cast_t =
std::enable_if_t<is_mutable_to_immutable_cast_v<FromType, ToType>>;

} // namespace tensorwrapper::detail_
57 changes: 52 additions & 5 deletions include/tensorwrapper/shape/shape_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <cstddef>
#include <memory>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/shape/shape_traits.hpp>
#include <tensorwrapper/shape/smooth_view.hpp>

namespace tensorwrapper::shape {

/** @brief Code factorization for the various types of shapes.
Expand All @@ -34,19 +37,29 @@ namespace tensorwrapper::shape {
* - get_rank_()
* - get_size_()
*/
class ShapeBase : public detail_::PolymorphicBase<ShapeBase> {
class ShapeBase : public tensorwrapper::detail_::PolymorphicBase<ShapeBase> {
private:
/// Type implementing the traits of this
using traits_type = ShapeTraits<ShapeBase>;

public:
/// Type all shapes inherit from
using shape_base = ShapeBase;
using shape_base = typename traits_type::shape_base;

/// Type of a pointer to the base of a shape object
using base_pointer = std::unique_ptr<shape_base>;
using base_pointer = typename traits_type::base_pointer;

/// Type used to hold the rank of a tensor
using rank_type = unsigned short;
using rank_type = typename traits_type::rank_type;

/// Type used to specify the number of elements in the shape
using size_type = std::size_t;
using size_type = typename traits_type::size_type;

/// Type of an object acting like a mutable reference to a Smooth shape
using smooth_reference = SmoothView<Smooth>;

/// Type of an object acting like a read-only reference to a Smooth shape
using const_smooth_reference = SmoothView<const Smooth>;

/// No-op for ShapeBase because ShapeBase has no state
ShapeBase() noexcept = default;
Expand Down Expand Up @@ -83,6 +96,34 @@ class ShapeBase : public detail_::PolymorphicBase<ShapeBase> {
*/
size_type size() const noexcept { return get_size_(); }

/** @brief Returns a view of *this as a Smooth object.
*
* It is possible to view any shape as a smooth shape. For more exotic
* shapes this may require flattening nestings and padding dimensions.
* This method ultimately dispatches to the as_smooth_ overload of the
* derived class to control how to smooth the shape out.
*
* @return A view of *this consistent with thinking of *this as a Smooth
* object.
*
* @throw std::bad_alloc if there is a problem allocating the view. Strong
* throw guarantee.
*/
smooth_reference as_smooth() { return as_smooth_(); }

/** @brief Returns a read-only view of *this as a Smooth object.
*
* This method works the same as the non-const version except that the
* resulting view is read-only.
*
* @return A read-only view of *this consistent with thinking of *this as
* a Smooth object.
*
* @throw std::bad_alloc if there is a problem allocating the view. Strong
* throw guarantee.
*/
const_smooth_reference as_smooth() const { return as_smooth_(); }

protected:
/** @brief Used to implement rank().
*
Expand All @@ -108,6 +149,12 @@ class ShapeBase : public detail_::PolymorphicBase<ShapeBase> {
* subject to a no-throw guarantee.
*/
virtual size_type get_size_() const noexcept = 0;

/// Derived class should override to be consistent with as_smooth()
virtual smooth_reference as_smooth_() = 0;

/// Derived class should override to be consistent with as_smooth() const
virtual const_smooth_reference as_smooth_() const = 0;
};

} // namespace tensorwrapper::shape
34 changes: 34 additions & 0 deletions include/tensorwrapper/shape/shape_fwd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace tensorwrapper::shape {
namespace detail_ {

template<typename SmoothType>
class SmoothViewPIMPL;

}

class ShapeBase;

class Smooth;

template<typename SmoothType>
class SmoothView;

} // namespace tensorwrapper::shape
72 changes: 72 additions & 0 deletions include/tensorwrapper/shape/shape_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <memory>
#include <tensorwrapper/shape/shape_fwd.hpp>

namespace tensorwrapper::shape {

template<typename ShapeType>
struct ShapeTraits;

template<>
struct ShapeTraits<ShapeBase> {
using shape_base = ShapeBase;
using base_pointer = std::unique_ptr<shape_base>;
using rank_type = unsigned short;
using size_type = std::size_t;
};

template<>
struct ShapeTraits<const ShapeBase> {
using shape_base = ShapeBase;
using base_pointer = std::unique_ptr<shape_base>;
using rank_type = unsigned short;
using size_type = std::size_t;
};

template<>
struct ShapeTraits<Smooth> : public ShapeTraits<ShapeBase> {
using value_type = Smooth;
using const_value_type = const value_type;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
};

template<>
struct ShapeTraits<const Smooth> : public ShapeTraits<const ShapeBase> {
using value_type = Smooth;
using const_value_type = const value_type;
using reference = const value_type&;
using const_reference = const value_type&;
using pointer = const value_type*;
using const_pointer = const value_type*;
};

template<typename T>
struct ShapeTraits<SmoothView<T>> {
using smooth_traits = ShapeTraits<T>;
using pimpl_type = detail_::SmoothViewPIMPL<T>;
using const_pimpl_type =
detail_::SmoothViewPIMPL<typename smooth_traits::const_value_type>;
using pimpl_pointer = std::unique_ptr<pimpl_type>;
using const_pimpl_pointer = std::unique_ptr<const_pimpl_type>;
};

} // namespace tensorwrapper::shape
6 changes: 6 additions & 0 deletions include/tensorwrapper/shape/smooth.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ class Smooth : public ShapeBase {
size_type(1), std::multiplies<size_type>());
}

smooth_reference as_smooth_() override { return smooth_reference(*this); }

virtual const_smooth_reference as_smooth_() const override {
return const_smooth_reference(*this);
}

/// Implements are_equal by calling ShapeBase::are_equal_impl_
bool are_equal_(const ShapeBase& rhs) const noexcept override {
return are_equal_impl_<Smooth>(rhs);
Expand Down
Loading
Loading