Skip to content

Commit

Permalink
License request implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
TheDaChicken committed Oct 22, 2023
1 parent 8568474 commit 0c186ad
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 357 deletions.
54 changes: 44 additions & 10 deletions lib/mfcdm/mfcdm/MediaFoundationCdm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
#include "MediaFoundationCdmFactory.h"
#include "MediaFoundationCdmModule.h"
#include "MediaFoundationCdmSession.h"
#include "utils/PMPHostWrapper.h"
#include "Log.h"

#include "utils/PMPHostWrapper.h"
#include <functional>

MediaFoundationCdm::MediaFoundationCdm() = default;
MediaFoundationCdm::~MediaFoundationCdm() = default;
Expand Down Expand Up @@ -92,34 +93,67 @@ bool MediaFoundationCdm::CreateSessionAndGenerateRequest(SessionType sessionType
const std::vector<uint8_t>& initData,
SessionClient* client)
{
auto session = std::make_unique<MediaFoundationCdmSession>(client);
auto session = std::make_shared<MediaFoundationCdmSession>(client);

if (!session->Initialize(m_module.get(), sessionType))
{
return false;
}

int session_token = next_session_token_++;
if (!session->GenerateRequest(initDataType, initData))
// when session id is identified, callback is ran.
// this meant to be able to access UpdateSession()
// inside MF callback because then session id is known.
int sessionToken = m_nextSessionToken++;
m_pendingSessions.emplace(sessionToken, session);

if (!session->GenerateRequest(initDataType, initData,
std::bind(&MediaFoundationCdm::OnNewSessionId, this, sessionToken, std::placeholders::_1)))
{
return false;
}

m_cdm_sessions.emplace(session_token, std::move(session));
return true;
}

void MediaFoundationCdm::LoadSession(SessionType session_type,
const std::string &session_id)
void MediaFoundationCdm::LoadSession(SessionType sessionType, const std::string& sessionId)
{

}

void MediaFoundationCdm::UpdateSession(const std::string &session_id)
bool MediaFoundationCdm::UpdateSession(const std::string& sessionId,
const std::vector<uint8_t>& response)
{
if (!m_module)
return false;

auto* session = GetSession(sessionId);
if (!session)
{
Log(MFCDM::MFLOG_ERROR, "Couldn't find session in created sessions.");
return false;
}

return session->Update(response);
}

void MediaFoundationCdm::OnNewSessionId(int sessionToken, std::string_view sessionId)
{
auto itr = m_pendingSessions.find(sessionToken);
assert(itr != m_pendingSessions.end());

auto session = std::move(itr->second);
assert(session);

m_pendingSessions.erase(itr);

m_sessions.emplace(sessionId, std::move(session));
}

MediaFoundationCdmSession* MediaFoundationCdm::GetSession(const std::string& sessionId) const
{
auto itr = m_sessions.find(sessionId);
if (itr == m_sessions.end())
return nullptr;

return itr->second.get();
}

9 changes: 6 additions & 3 deletions lib/mfcdm/mfcdm/MediaFoundationCdm.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ class MediaFoundationCdm {
SessionClient* client);

void LoadSession(SessionType session_type, const std::string& session_id);
void UpdateSession(const std::string& session_id);
bool UpdateSession(const std::string& session_id, const std::vector<uint8_t>& response);

private:
void SetupPMPServer() const;
MediaFoundationCdmSession* GetSession(const std::string& sessionId) const;
void OnNewSessionId(int sessionToken, std::string_view sessionId);

MediaFoundationSession m_session;
std::unique_ptr<MediaFoundationCdmModule> m_module;

int next_session_token_{0};
std::map<int, std::unique_ptr<MediaFoundationCdmSession>> m_cdm_sessions;
int m_nextSessionToken = 0;
std::map<int, std::shared_ptr<MediaFoundationCdmSession>> m_pendingSessions;
std::map<std::string, std::shared_ptr<MediaFoundationCdmSession>> m_sessions;
};
10 changes: 5 additions & 5 deletions lib/mfcdm/mfcdm/MediaFoundationCdmFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ bool MediaFoundationCdmFactory::Initialize()
{
const winrt::com_ptr<IMFMediaEngineClassFactory4> classFactory = winrt::create_instance<IMFMediaEngineClassFactory4>(
CLSID_MFMediaEngineClassFactory, CLSCTX_INPROC_SERVER);
const std::wstring keySystemWide = ConvertUtf8ToWide(m_keySystem);
const std::wstring keySystemWide = WIDE::ConvertUtf8ToWide(m_keySystem);

return SUCCEEDED(classFactory->CreateContentDecryptionModuleFactory(
keySystemWide.c_str(), IID_PPV_ARGS(&m_cdmFactory)));
}

bool MediaFoundationCdmFactory::IsTypeSupported(std::string_view keySystem) const
{
return m_cdmFactory->IsTypeSupported(ConvertUtf8ToWide(keySystem).c_str(), nullptr);
return m_cdmFactory->IsTypeSupported(WIDE::ConvertUtf8ToWide(keySystem).c_str(), nullptr);
}

/*!
Expand Down Expand Up @@ -130,7 +130,7 @@ bool BuildCdmAccessConfigurations(const MediaFoundationCdmConfig& cdmConfig,
// Persistent state
ScopedPropVariant persisted_state;
if (FAILED(InitPropVariantFromUInt32(cdmConfig.allow_persistent_state
? MF_MEDIAKEYS_REQUIREMENT_REQUIRED
? MF_MEDIAKEYS_REQUIREMENT_OPTIONAL
: MF_MEDIAKEYS_REQUIREMENT_NOT_ALLOWED,
persisted_state.ptr())))
{
Expand All @@ -147,7 +147,7 @@ bool BuildCdmAccessConfigurations(const MediaFoundationCdmConfig& cdmConfig,
// Distinctive ID
ScopedPropVariant allow_distinctive_identifier;
if (FAILED(InitPropVariantFromUInt32(cdmConfig.allow_distinctive_identifier
? MF_MEDIAKEYS_REQUIREMENT_REQUIRED
? MF_MEDIAKEYS_REQUIREMENT_OPTIONAL
: MF_MEDIAKEYS_REQUIREMENT_NOT_ALLOWED,
allow_distinctive_identifier.ptr())))
{
Expand Down Expand Up @@ -193,7 +193,7 @@ bool MediaFoundationCdmFactory::CreateMfCdm(const MediaFoundationCdmConfig& cdmC
const std::filesystem::path& cdmPath,
std::unique_ptr<MediaFoundationCdmModule>& mfCdm) const
{
const auto key_system_str = ConvertUtf8ToWide(m_keySystem);
const auto key_system_str = WIDE::ConvertUtf8ToWide(m_keySystem);
if (!m_cdmFactory->IsTypeSupported(key_system_str.c_str(), nullptr))
{
Log(MFCDM::MFLOG_ERROR, "%s is not supported by MF CdmFactory", m_keySystem);
Expand Down
92 changes: 78 additions & 14 deletions lib/mfcdm/mfcdm/MediaFoundationCdmSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "utils/Wide.h"
#include "Log.h"

#include <functional>
#include <iostream>
#include <ostream>

Expand Down Expand Up @@ -62,13 +61,35 @@ KeyStatus ToCdmKeyStatus(MF_MEDIAKEY_STATUS status)
return MFKeyUsable;
case MF_MEDIAKEY_STATUS_EXPIRED:
return MFKeyExpired;
// This is for legacy use and should not happen in normal cases. Map it to
// internal error in case it happens.
case MF_MEDIAKEY_STATUS_OUTPUT_NOT_ALLOWED:
return MFKeyError;
case MF_MEDIAKEY_STATUS_OUTPUT_DOWNSCALED:
return MFKeyDownScaled;
case MF_MEDIAKEY_STATUS_INTERNAL_ERROR:
// Output not allowed is legacy use? Should not happen in normal cases
case MF_MEDIAKEY_STATUS_OUTPUT_NOT_ALLOWED:
return MFKeyError;
case MF_MEDIAKEY_STATUS_STATUS_PENDING:
return MFKeyPending;
case MF_MEDIAKEY_STATUS_RELEASED:
return MFKeyReleased;
case MF_MEDIAKEY_STATUS_OUTPUT_RESTRICTED:
return MFKeyRestricted;
}
}

std::vector<std::unique_ptr<KeyInfo>> ToCdmKeysInfo(const MFMediaKeyStatus* key_statuses,
int count)
{
std::vector<std::unique_ptr<KeyInfo>> keys_info;
keys_info.reserve(count);
for (int i = 0; i < count; ++i)
{
const auto& key_status = key_statuses[i];
keys_info.push_back(std::make_unique<KeyInfo>(
std::vector(key_status.pbKeyId, key_status.pbKeyId + key_status.cbKeyId),
ToCdmKeyStatus(key_status.eMediaKeyStatus))
);
}
return keys_info;
}

class SessionCallbacks : public winrt::implements<
Expand All @@ -77,28 +98,33 @@ class SessionCallbacks : public winrt::implements<
public:
using SessionMessage =
std::function<void(const std::vector<uint8_t>& message, std::string_view destinationUrl)>;
using KeyChanged =
std::function<void()>;

SessionCallbacks(SessionMessage sessionMessage) : m_sessionMessage(std::move(sessionMessage)){};
SessionCallbacks(SessionMessage sessionMessage, KeyChanged keyChanged)
: m_sessionMessage(std::move(sessionMessage)), m_keyChanged(keyChanged){};

IFACEMETHODIMP KeyMessage(MF_MEDIAKEYSESSION_MESSAGETYPE message_type,
const BYTE* message,
DWORD message_size,
LPCWSTR destination_url) final
{
Log(MFCDM::MFLOG_DEBUG, "Message size: %d Destination Url: %S",
Log(MFCDM::MFLOG_DEBUG, "Message size: %i Destination Url: %S",
message_size, destination_url);
m_sessionMessage(std::vector(message, message + message_size),
ConvertWideToUTF8(destination_url));
WIDE::ConvertWideToUTF8(destination_url));
return S_OK;
}

IFACEMETHODIMP KeyStatusChanged() final
{
std::cout << "KeyStatusChanged" << std::endl;
Log(MFCDM::MFLOG_DEBUG, "KeyStatusChanged");
m_keyChanged();
return S_OK;
}
private:
SessionMessage m_sessionMessage;
KeyChanged m_keyChanged;
};

MediaFoundationCdmSession::MediaFoundationCdmSession(SessionClient* client)
Expand All @@ -111,7 +137,8 @@ bool MediaFoundationCdmSession::Initialize(MediaFoundationCdmModule* mfCdm,
SessionType sessionType)
{
const auto session_callbacks = winrt::make<SessionCallbacks>(
std::bind(&MediaFoundationCdmSession::OnSessionMessage, this, std::placeholders::_1, std::placeholders::_2)
std::bind(&MediaFoundationCdmSession::OnSessionMessage, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&MediaFoundationCdmSession::OnKeyChange, this)
);
// |mf_cdm_session_| holds a ref count to |session_callbacks|.
if (FAILED(mfCdm->CreateSession(ToMFSessionType(sessionType), session_callbacks.get(),
Expand All @@ -124,8 +151,11 @@ bool MediaFoundationCdmSession::Initialize(MediaFoundationCdmModule* mfCdm,
}

bool MediaFoundationCdmSession::GenerateRequest(InitDataType initDataType,
const std::vector<uint8_t>& initData)
const std::vector<uint8_t>& initData,
SessionCreatedFunc created)
{
m_sessionCreated = std::move(created);

if (FAILED(mfCdmSession->GenerateRequest(InitDataTypeToString(initDataType), initData.data(),
static_cast<DWORD>(initData.size()))))
{
Expand All @@ -146,15 +176,49 @@ bool MediaFoundationCdmSession::Update(const std::vector<uint8_t>& response)
}

void MediaFoundationCdmSession::OnSessionMessage(const std::vector<uint8_t>& message,
std::string_view destinationUrl) const
std::string_view destinationUrl)
{
if (!m_client)
return;

if (m_sessionCreated)
{
m_sessionCreated(GetSessionId());
m_sessionCreated = SessionCreatedFunc();
}

m_client->OnSessionMessage(GetSessionId(), message, destinationUrl);
}

void MediaFoundationCdmSession::OnKeyChange() const
{
if (!m_client || !mfCdmSession)
return;

ScopedCoMem<MFMediaKeyStatus> keyStatuses;

UINT count = 0;
if (FAILED(mfCdmSession->GetKeyStatuses(&keyStatuses, &count)))
{
Log(MFCDM::MFLOG_ERROR, "Failed to get key statuses.");
return;
}

m_client->OnKeyChange(GetSessionId(), ToCdmKeysInfo(keyStatuses.get(), count));

for (UINT i = 0; i < count; ++i)
{
const auto& key_status = keyStatuses.get()[i];
if (key_status.pbKeyId)
CoTaskMemFree(key_status.pbKeyId);
}
}

std::string MediaFoundationCdmSession::GetSessionId() const
{
if (!mfCdmSession)
return "";

ScopedCoMem<wchar_t> sessionId;

if (FAILED(mfCdmSession->GetSessionId(&sessionId)))
Expand All @@ -163,5 +227,5 @@ std::string MediaFoundationCdmSession::GetSessionId() const
return "";
}

return ConvertWideToUTF8(sessionId.get());
}
return WIDE::ConvertWideToUTF8(sessionId.get());
}
12 changes: 10 additions & 2 deletions lib/mfcdm/mfcdm/MediaFoundationCdmSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include "MediaFoundationCdmTypes.h"

#include <functional>

#include <unknwn.h>
#include <winrt/base.h>

Expand All @@ -20,19 +22,25 @@ class MediaFoundationCdmModule;

class MediaFoundationCdmSession {
public:
using SessionCreatedFunc = std::function<void(std::string_view sessionId)>;

MediaFoundationCdmSession(SessionClient* client);

bool Initialize(MediaFoundationCdmModule* mfCdm, SessionType sessionType);

bool GenerateRequest(InitDataType initDataType, const std::vector<uint8_t>& initData);
bool GenerateRequest(InitDataType initDataType,
const std::vector<uint8_t>& initData,
SessionCreatedFunc created);
bool Update(const std::vector<uint8_t>& response);

std::string GetSessionId() const;

private:

void OnSessionMessage(const std::vector<uint8_t>& message, std::string_view destinationUrl) const;
void OnSessionMessage(const std::vector<uint8_t>& message, std::string_view destinationUrl);
void OnKeyChange() const;

winrt::com_ptr<IMFContentDecryptionModuleSession> mfCdmSession;
SessionClient* m_client;
SessionCreatedFunc m_sessionCreated;
};
Loading

0 comments on commit 0c186ad

Please sign in to comment.