Skip to content

Commit

Permalink
Allow classic COM interfaces with get_self (#1314)
Browse files Browse the repository at this point in the history
* Allow classic COM interfaces with get_self

Fixes #1312

* Fix mingw builds

---------

Co-authored-by: Kenny Kerr <[email protected]>
  • Loading branch information
sylveon and kennykerr authored Jun 27, 2023
1 parent d3bb275 commit 297454e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
6 changes: 6 additions & 0 deletions strings/base_implements.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ WINRT_EXPORT namespace winrt
return &static_cast<impl::produce<D, default_interface<I>>*>(get_abi(from))->shim();
}

template <typename D, typename I>
D* get_self(com_ptr<I> const& from) noexcept
{
return static_cast<D*>(static_cast<impl::producer<D, I>*>(from.get()));
}

template <typename D, typename I>
[[deprecated]] D* from_abi(I const& from) noexcept
{
Expand Down
3 changes: 3 additions & 0 deletions strings/base_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ WINRT_EXPORT namespace winrt
template <typename T>
struct com_ptr;

template <typename D, typename I>
D* get_self(com_ptr<I> const& from) noexcept;

namespace param
{
template <typename T>
Expand Down
49 changes: 49 additions & 0 deletions test/old_tests/UnitTests/interop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ using namespace Windows::Foundation;

namespace
{
struct IClassicComInterface : ::IUnknown {};

struct ClassicCom : implements<ClassicCom, IClassicComInterface> {};

struct Stringable : implements<Stringable, IStringable>
{
Stringable(std::wstring_view const& value = L"Stringable") : m_value(value)
Expand All @@ -30,8 +34,16 @@ namespace
object->AddRef();
return object->Release();
}

template <typename T>
uint32_t get_ref_count(com_ptr<T> const& object)
{
return get_ref_count(object.get());
}
}

template <> inline constexpr winrt::guid winrt::impl::guid_v<IClassicComInterface>{ 0xc136bb75, 0xbc03, 0x41a6, { 0xa5, 0xdc, 0x5e, 0xfa, 0x67, 0x92, 0x4e, 0xbf } };

TEST_CASE("interop")
{
uint32_t const before = get_module_lock();
Expand Down Expand Up @@ -108,6 +120,43 @@ TEST_CASE("self")
REQUIRE(get_ref_count(object) == 1);
object = nullptr;

strong = weak.get();
REQUIRE(!strong);
}

TEST_CASE("self_classic_com")
{
com_ptr<ClassicCom> strong = make_self<ClassicCom>();

REQUIRE(get_ref_count(strong.get()) == 1);

com_ptr<IClassicComInterface> object = strong.as<IClassicComInterface>();

REQUIRE(get_ref_count(strong.get()) == 2);

ClassicCom* ptr = get_self<ClassicCom>(object);
REQUIRE(ptr == strong.get());

REQUIRE(get_ref_count(strong.get()) == 2);
strong = nullptr;
REQUIRE(get_ref_count(object) == 1);

strong = get_self<ClassicCom>(object)->get_strong();
REQUIRE(get_ref_count(object) == 2);
strong = nullptr;
REQUIRE(get_ref_count(object) == 1);

weak_ref<ClassicCom> weak = get_self<ClassicCom>(object)->get_weak();
REQUIRE(get_ref_count(object) == 1); // <-- still just one!

strong = weak.get();
REQUIRE(strong);
REQUIRE(get_ref_count(object) == 2);

strong = nullptr;
REQUIRE(get_ref_count(object) == 1);
object = nullptr;

strong = weak.get();
REQUIRE(!strong);
}

0 comments on commit 297454e

Please sign in to comment.