diff --git a/lib/mfcdm/mfcdm/MediaFoundationCdm.cpp b/lib/mfcdm/mfcdm/MediaFoundationCdm.cpp index eda007493..aac1ff0f3 100644 --- a/lib/mfcdm/mfcdm/MediaFoundationCdm.cpp +++ b/lib/mfcdm/mfcdm/MediaFoundationCdm.cpp @@ -11,9 +11,10 @@ #include "MediaFoundationCdmFactory.h" #include "MediaFoundationCdmModule.h" #include "MediaFoundationCdmSession.h" +#include "utils/PMPHostWrapper.h" #include "Log.h" -#include "utils/PMPHostWrapper.h" +#include MediaFoundationCdm::MediaFoundationCdm() = default; MediaFoundationCdm::~MediaFoundationCdm() = default; @@ -92,34 +93,67 @@ bool MediaFoundationCdm::CreateSessionAndGenerateRequest(SessionType sessionType const std::vector& initData, SessionClient* client) { - auto session = std::make_unique(client); + auto session = std::make_shared(client); if (!session->Initialize(m_module.get(), sessionType)) { return false; } - int session_token = next_session_token_++; - if (!session->GenerateRequest(initDataType, initData)) + // when session id is identified, callback is ran. + // this meant to be able to access UpdateSession() + // inside MF callback because then session id is known. + int sessionToken = m_nextSessionToken++; + m_pendingSessions.emplace(sessionToken, session); + + if (!session->GenerateRequest(initDataType, initData, + std::bind(&MediaFoundationCdm::OnNewSessionId, this, sessionToken, std::placeholders::_1))) { return false; } - - m_cdm_sessions.emplace(session_token, std::move(session)); return true; } -void MediaFoundationCdm::LoadSession(SessionType session_type, - const std::string &session_id) +void MediaFoundationCdm::LoadSession(SessionType sessionType, const std::string& sessionId) { - + } -void MediaFoundationCdm::UpdateSession(const std::string &session_id) +bool MediaFoundationCdm::UpdateSession(const std::string& sessionId, + const std::vector& response) { + if (!m_module) + return false; + + auto* session = GetSession(sessionId); + if (!session) + { + Log(MFCDM::MFLOG_ERROR, "Couldn't find session in created sessions."); + return false; + } + return session->Update(response); } +void MediaFoundationCdm::OnNewSessionId(int sessionToken, std::string_view sessionId) +{ + auto itr = m_pendingSessions.find(sessionToken); + assert(itr != m_pendingSessions.end()); + + auto session = std::move(itr->second); + assert(session); + + m_pendingSessions.erase(itr); + m_sessions.emplace(sessionId, std::move(session)); +} + +MediaFoundationCdmSession* MediaFoundationCdm::GetSession(const std::string& sessionId) const +{ + auto itr = m_sessions.find(sessionId); + if (itr == m_sessions.end()) + return nullptr; + return itr->second.get(); +} diff --git a/lib/mfcdm/mfcdm/MediaFoundationCdm.h b/lib/mfcdm/mfcdm/MediaFoundationCdm.h index 8798f9df5..96b560a69 100644 --- a/lib/mfcdm/mfcdm/MediaFoundationCdm.h +++ b/lib/mfcdm/mfcdm/MediaFoundationCdm.h @@ -40,14 +40,17 @@ class MediaFoundationCdm { SessionClient* client); void LoadSession(SessionType session_type, const std::string& session_id); - void UpdateSession(const std::string& session_id); + bool UpdateSession(const std::string& session_id, const std::vector& response); private: void SetupPMPServer() const; + MediaFoundationCdmSession* GetSession(const std::string& sessionId) const; + void OnNewSessionId(int sessionToken, std::string_view sessionId); MediaFoundationSession m_session; std::unique_ptr m_module; - int next_session_token_{0}; - std::map> m_cdm_sessions; + int m_nextSessionToken = 0; + std::map> m_pendingSessions; + std::map> m_sessions; }; diff --git a/lib/mfcdm/mfcdm/MediaFoundationCdmFactory.cpp b/lib/mfcdm/mfcdm/MediaFoundationCdmFactory.cpp index 636ebf055..1901fe2da 100644 --- a/lib/mfcdm/mfcdm/MediaFoundationCdmFactory.cpp +++ b/lib/mfcdm/mfcdm/MediaFoundationCdmFactory.cpp @@ -34,7 +34,7 @@ bool MediaFoundationCdmFactory::Initialize() { const winrt::com_ptr classFactory = winrt::create_instance( CLSID_MFMediaEngineClassFactory, CLSCTX_INPROC_SERVER); - const std::wstring keySystemWide = ConvertUtf8ToWide(m_keySystem); + const std::wstring keySystemWide = WIDE::ConvertUtf8ToWide(m_keySystem); return SUCCEEDED(classFactory->CreateContentDecryptionModuleFactory( keySystemWide.c_str(), IID_PPV_ARGS(&m_cdmFactory))); @@ -42,7 +42,7 @@ bool MediaFoundationCdmFactory::Initialize() bool MediaFoundationCdmFactory::IsTypeSupported(std::string_view keySystem) const { - return m_cdmFactory->IsTypeSupported(ConvertUtf8ToWide(keySystem).c_str(), nullptr); + return m_cdmFactory->IsTypeSupported(WIDE::ConvertUtf8ToWide(keySystem).c_str(), nullptr); } /*! @@ -130,7 +130,7 @@ bool BuildCdmAccessConfigurations(const MediaFoundationCdmConfig& cdmConfig, // Persistent state ScopedPropVariant persisted_state; if (FAILED(InitPropVariantFromUInt32(cdmConfig.allow_persistent_state - ? MF_MEDIAKEYS_REQUIREMENT_REQUIRED + ? MF_MEDIAKEYS_REQUIREMENT_OPTIONAL : MF_MEDIAKEYS_REQUIREMENT_NOT_ALLOWED, persisted_state.ptr()))) { @@ -147,7 +147,7 @@ bool BuildCdmAccessConfigurations(const MediaFoundationCdmConfig& cdmConfig, // Distinctive ID ScopedPropVariant allow_distinctive_identifier; if (FAILED(InitPropVariantFromUInt32(cdmConfig.allow_distinctive_identifier - ? MF_MEDIAKEYS_REQUIREMENT_REQUIRED + ? MF_MEDIAKEYS_REQUIREMENT_OPTIONAL : MF_MEDIAKEYS_REQUIREMENT_NOT_ALLOWED, allow_distinctive_identifier.ptr()))) { @@ -193,7 +193,7 @@ bool MediaFoundationCdmFactory::CreateMfCdm(const MediaFoundationCdmConfig& cdmC const std::filesystem::path& cdmPath, std::unique_ptr& mfCdm) const { - const auto key_system_str = ConvertUtf8ToWide(m_keySystem); + const auto key_system_str = WIDE::ConvertUtf8ToWide(m_keySystem); if (!m_cdmFactory->IsTypeSupported(key_system_str.c_str(), nullptr)) { Log(MFCDM::MFLOG_ERROR, "%s is not supported by MF CdmFactory", m_keySystem); diff --git a/lib/mfcdm/mfcdm/MediaFoundationCdmSession.cpp b/lib/mfcdm/mfcdm/MediaFoundationCdmSession.cpp index dfc351a26..343ea2240 100644 --- a/lib/mfcdm/mfcdm/MediaFoundationCdmSession.cpp +++ b/lib/mfcdm/mfcdm/MediaFoundationCdmSession.cpp @@ -13,7 +13,6 @@ #include "utils/Wide.h" #include "Log.h" -#include #include #include @@ -62,13 +61,35 @@ KeyStatus ToCdmKeyStatus(MF_MEDIAKEY_STATUS status) return MFKeyUsable; case MF_MEDIAKEY_STATUS_EXPIRED: return MFKeyExpired; - // This is for legacy use and should not happen in normal cases. Map it to - // internal error in case it happens. - case MF_MEDIAKEY_STATUS_OUTPUT_NOT_ALLOWED: - return MFKeyError; + case MF_MEDIAKEY_STATUS_OUTPUT_DOWNSCALED: + return MFKeyDownScaled; case MF_MEDIAKEY_STATUS_INTERNAL_ERROR: + // Output not allowed is legacy use? Should not happen in normal cases + case MF_MEDIAKEY_STATUS_OUTPUT_NOT_ALLOWED: return MFKeyError; + case MF_MEDIAKEY_STATUS_STATUS_PENDING: + return MFKeyPending; + case MF_MEDIAKEY_STATUS_RELEASED: + return MFKeyReleased; + case MF_MEDIAKEY_STATUS_OUTPUT_RESTRICTED: + return MFKeyRestricted; + } +} + +std::vector> ToCdmKeysInfo(const MFMediaKeyStatus* key_statuses, + int count) +{ + std::vector> keys_info; + keys_info.reserve(count); + for (int i = 0; i < count; ++i) + { + const auto& key_status = key_statuses[i]; + keys_info.push_back(std::make_unique( + std::vector(key_status.pbKeyId, key_status.pbKeyId + key_status.cbKeyId), + ToCdmKeyStatus(key_status.eMediaKeyStatus)) + ); } + return keys_info; } class SessionCallbacks : public winrt::implements< @@ -77,28 +98,33 @@ class SessionCallbacks : public winrt::implements< public: using SessionMessage = std::function& message, std::string_view destinationUrl)>; + using KeyChanged = + std::function; - SessionCallbacks(SessionMessage sessionMessage) : m_sessionMessage(std::move(sessionMessage)){}; + SessionCallbacks(SessionMessage sessionMessage, KeyChanged keyChanged) + : m_sessionMessage(std::move(sessionMessage)), m_keyChanged(keyChanged){}; IFACEMETHODIMP KeyMessage(MF_MEDIAKEYSESSION_MESSAGETYPE message_type, const BYTE* message, DWORD message_size, LPCWSTR destination_url) final { - Log(MFCDM::MFLOG_DEBUG, "Message size: %d Destination Url: %S", + Log(MFCDM::MFLOG_DEBUG, "Message size: %i Destination Url: %S", message_size, destination_url); m_sessionMessage(std::vector(message, message + message_size), - ConvertWideToUTF8(destination_url)); + WIDE::ConvertWideToUTF8(destination_url)); return S_OK; } IFACEMETHODIMP KeyStatusChanged() final { - std::cout << "KeyStatusChanged" << std::endl; + Log(MFCDM::MFLOG_DEBUG, "KeyStatusChanged"); + m_keyChanged(); return S_OK; } private: SessionMessage m_sessionMessage; + KeyChanged m_keyChanged; }; MediaFoundationCdmSession::MediaFoundationCdmSession(SessionClient* client) @@ -111,7 +137,8 @@ bool MediaFoundationCdmSession::Initialize(MediaFoundationCdmModule* mfCdm, SessionType sessionType) { const auto session_callbacks = winrt::make( - std::bind(&MediaFoundationCdmSession::OnSessionMessage, this, std::placeholders::_1, std::placeholders::_2) + std::bind(&MediaFoundationCdmSession::OnSessionMessage, this, std::placeholders::_1, std::placeholders::_2), + std::bind(&MediaFoundationCdmSession::OnKeyChange, this) ); // |mf_cdm_session_| holds a ref count to |session_callbacks|. if (FAILED(mfCdm->CreateSession(ToMFSessionType(sessionType), session_callbacks.get(), @@ -124,8 +151,11 @@ bool MediaFoundationCdmSession::Initialize(MediaFoundationCdmModule* mfCdm, } bool MediaFoundationCdmSession::GenerateRequest(InitDataType initDataType, - const std::vector& initData) + const std::vector& initData, + SessionCreatedFunc created) { + m_sessionCreated = std::move(created); + if (FAILED(mfCdmSession->GenerateRequest(InitDataTypeToString(initDataType), initData.data(), static_cast(initData.size())))) { @@ -146,15 +176,49 @@ bool MediaFoundationCdmSession::Update(const std::vector& response) } void MediaFoundationCdmSession::OnSessionMessage(const std::vector& message, - std::string_view destinationUrl) const + std::string_view destinationUrl) { if (!m_client) return; + + if (m_sessionCreated) + { + m_sessionCreated(GetSessionId()); + m_sessionCreated = SessionCreatedFunc(); + } + m_client->OnSessionMessage(GetSessionId(), message, destinationUrl); } +void MediaFoundationCdmSession::OnKeyChange() const +{ + if (!m_client || !mfCdmSession) + return; + + ScopedCoMem keyStatuses; + + UINT count = 0; + if (FAILED(mfCdmSession->GetKeyStatuses(&keyStatuses, &count))) + { + Log(MFCDM::MFLOG_ERROR, "Failed to get key statuses."); + return; + } + + m_client->OnKeyChange(GetSessionId(), ToCdmKeysInfo(keyStatuses.get(), count)); + + for (UINT i = 0; i < count; ++i) + { + const auto& key_status = keyStatuses.get()[i]; + if (key_status.pbKeyId) + CoTaskMemFree(key_status.pbKeyId); + } +} + std::string MediaFoundationCdmSession::GetSessionId() const { + if (!mfCdmSession) + return ""; + ScopedCoMem sessionId; if (FAILED(mfCdmSession->GetSessionId(&sessionId))) @@ -163,5 +227,5 @@ std::string MediaFoundationCdmSession::GetSessionId() const return ""; } - return ConvertWideToUTF8(sessionId.get()); -} \ No newline at end of file + return WIDE::ConvertWideToUTF8(sessionId.get()); +} diff --git a/lib/mfcdm/mfcdm/MediaFoundationCdmSession.h b/lib/mfcdm/mfcdm/MediaFoundationCdmSession.h index 8cfd04c4e..c9628de8a 100644 --- a/lib/mfcdm/mfcdm/MediaFoundationCdmSession.h +++ b/lib/mfcdm/mfcdm/MediaFoundationCdmSession.h @@ -10,6 +10,8 @@ #include "MediaFoundationCdmTypes.h" +#include + #include #include @@ -20,19 +22,25 @@ class MediaFoundationCdmModule; class MediaFoundationCdmSession { public: + using SessionCreatedFunc = std::function; + MediaFoundationCdmSession(SessionClient* client); bool Initialize(MediaFoundationCdmModule* mfCdm, SessionType sessionType); - bool GenerateRequest(InitDataType initDataType, const std::vector& initData); + bool GenerateRequest(InitDataType initDataType, + const std::vector& initData, + SessionCreatedFunc created); bool Update(const std::vector& response); std::string GetSessionId() const; private: - void OnSessionMessage(const std::vector& message, std::string_view destinationUrl) const; + void OnSessionMessage(const std::vector& message, std::string_view destinationUrl); + void OnKeyChange() const; winrt::com_ptr mfCdmSession; SessionClient* m_client; + SessionCreatedFunc m_sessionCreated; }; diff --git a/lib/mfcdm/mfcdm/MediaFoundationCdmTypes.h b/lib/mfcdm/mfcdm/MediaFoundationCdmTypes.h index f978f8341..38490e638 100644 --- a/lib/mfcdm/mfcdm/MediaFoundationCdmTypes.h +++ b/lib/mfcdm/mfcdm/MediaFoundationCdmTypes.h @@ -6,12 +6,13 @@ * See LICENSES/README.md for more information. */ +#pragma once + #include +#include #include #include -#pragma once - enum SessionType : uint32_t { MFTemporary = 0, @@ -28,8 +29,26 @@ enum InitDataType : uint32_t enum KeyStatus : uint32_t { MFKeyUsable = 0, - MFKeyExpired = 1, - MFKeyError = 2 + MFKeyDownScaled = 1, + MFKeyPending = 2, + MFKeyExpired = 3, + MFKeyReleased = 4, + MFKeyRestricted = 5, + MFKeyError = 6 +}; + +struct KeyInfo +{ + KeyInfo(std::vector keyId, KeyStatus status) + : keyId(std::move(keyId)), + status(status) + { + + } + std::vector keyId; + KeyStatus status; + + bool operator==(KeyInfo const& other) const { return keyId == other.keyId; } }; class SessionClient @@ -37,7 +56,10 @@ class SessionClient public: virtual ~SessionClient() = default; - virtual void OnSessionMessage(std::string_view session, + virtual void OnSessionMessage(std::string_view sessionId, const std::vector& message, std::string_view destinationUrl) = 0; + + virtual void OnKeyChange(std::string_view sessionId, + std::vector> keys) = 0; }; diff --git a/lib/mfcdm/mfcdm/utils/Wide.h b/lib/mfcdm/mfcdm/utils/Wide.h index 20c016ef9..ad8866a13 100644 --- a/lib/mfcdm/mfcdm/utils/Wide.h +++ b/lib/mfcdm/mfcdm/utils/Wide.h @@ -8,11 +8,15 @@ #include -#include +#define NOGDI // Ignore useless header that creates useless macros +#include namespace UTILS { + namespace WIDE + { + static std::wstring ConvertUtf8ToWide(std::string_view str) { const int charCount = @@ -28,8 +32,8 @@ namespace UTILS static std::string ConvertWideToUTF8(std::wstring_view wstr) { - const int charCount = - WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.length()), nullptr, 0, nullptr, nullptr); + const int charCount = WideCharToMultiByte( + CP_UTF8, 0, wstr.data(), static_cast(wstr.length()), nullptr, 0, nullptr, nullptr); if (charCount <= 0) return {}; @@ -39,4 +43,6 @@ namespace UTILS return str; } + } //namespace WIDE + } //namespace UTILS diff --git a/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.cpp b/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.cpp index f8cd27cca..c563557f1 100644 --- a/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.cpp +++ b/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.cpp @@ -8,17 +8,18 @@ #include "MFCencSingleSampleDecrypter.h" -#include "../../utils/Base64Utils.h" -#include "../../utils/CurlUtils.h" -#include "../../utils/DigestMD5Utils.h" -#include "../../utils/FileUtils.h" -#include "../../utils/StringUtils.h" -#include "../../utils/Utils.h" -#include "../../utils/log.h" +#include "MFDecrypter.h" +#include "utils/Base64Utils.h" +#include "utils/CurlUtils.h" +#include "utils/DigestMD5Utils.h" +#include "utils/FileUtils.h" +#include "utils/StringUtils.h" +#include "utils/Utils.h" +#include "utils/log.h" +#include "utils/XMLUtils.h" #include "pugixml.hpp" -#include "MFDecrypter.h" -#include "mfcdm/MediaFoundationCdm.h" +#include #include #include @@ -51,12 +52,10 @@ CMFCencSingleSampleDecrypter::CMFCencSingleSampleDecrypter(CMFDecrypter& host, return; } - //m_wvCdmAdapter.insertssd(this); - if (m_host.IsDebugSaveLicense()) { - std::string debugFilePath = - FILESYS::PathCombine(m_host.GetProfilePath(), "9A04F079-9840-4286-AB92-E65BE0885F95.init"); + const std::string debugFilePath = + FILESYS::PathCombine(m_host.GetProfilePath(), "9A04F079-9840-4286-AB92-E65BE0885F95.init"); std::string data{reinterpret_cast(pssh.data()), pssh.size()}; FILESYS::SaveFile(debugFilePath, data, true); @@ -89,26 +88,21 @@ CMFCencSingleSampleDecrypter::CMFCencSingleSampleDecrypter(CMFDecrypter& host, m_host.GetCdm()->CreateSessionAndGenerateRequest(MFTemporary, MFCenc, m_pssh, this); - int retryCount = 0; - while (m_strSession.empty() && ++retryCount < 100) - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - if (m_strSession.empty()) + if (sessionId.empty()) { LOG::LogF(LOGERROR, "Cannot perform License update, no session available"); return; } - if (skipSessionMessage) - return; +} - while (m_challenge.GetDataSize() > 0 && SendSessionMessage()) - ; +CMFCencSingleSampleDecrypter::~CMFCencSingleSampleDecrypter() +{ } -void CMFCencSingleSampleDecrypter::OnSessionMessage(std::string_view session, - const std::vector& message, - std::string_view destinationUrl) +void CMFCencSingleSampleDecrypter::ParsePlayReadyMessage(const std::vector& message, + std::string& challenge, + std::map& headers) { xml_document doc; @@ -116,326 +110,154 @@ void CMFCencSingleSampleDecrypter::OnSessionMessage(std::string_view session, xml_parse_result parseRes = doc.load_buffer(message.data(), message.size()); if (parseRes.status != status_ok) { - LOG::LogF(LOGERROR, "Failed to parse PlayReady session message", parseRes.status); + LOG::LogF(LOGERROR, "Failed to parse PlayReady session message %i", parseRes.status); return; } if (m_host.IsDebugSaveLicense()) { - std::string debugFilePath = - FILESYS::PathCombine(m_host.GetProfilePath(), "9A04F079-9840-4286-AB92-E65BE0885F95.message"); - + const std::string debugFilePath = FILESYS::PathCombine( + m_host.GetProfilePath(), "9A04F079-9840-4286-AB92-E65BE0885F95.message"); + doc.save_file(debugFilePath.c_str()); } - xml_node nodeKeyMessage = doc.child("PlayReadyKeyMessage"); - if (!nodeKeyMessage) + xml_node nodeAcquisition = doc.first_element_by_path("PlayReadyKeyMessage/LicenseAcquisition"); + if (!nodeAcquisition) { - LOG::LogF(LOGERROR, "Failed to get Playready's tag element."); + LOG::LogF(LOGERROR, "Failed to get Playready's tag element."); return; } - std::lock_guard lock(m_renewalLock); - - m_strSession = session; - //m_challenge.SetData(message.data(), message.size()); - LOG::LogF(LOGDEBUG, "Opened playready session ID: %s", m_strSession.c_str()); -} - -CMFCencSingleSampleDecrypter::~CMFCencSingleSampleDecrypter() -{ - //m_wvCdmAdapter.removessd(this); -} - -void CMFCencSingleSampleDecrypter::GetCapabilities(std::string_view key, - uint32_t media, - IDecrypter::DecrypterCapabilites& caps) -{ - caps = {0, m_hdcpVersion, m_hdcpLimit}; - - if (m_strSession.empty()) + xml_node nodeChallenge = nodeAcquisition.child("Challenge"); + if (!nodeChallenge) { - LOG::LogF(LOGDEBUG, "Session empty"); + LOG::LogF(LOGERROR, "Failed to get Playready's tag element."); return; } - caps.flags = IDecrypter::DecrypterCapabilites::SSD_SUPPORTS_DECODING; - - if (m_keys.empty()) + std::string encodingType; + encodingType = XML::GetAttrib(nodeChallenge, "encoding"); + if (encodingType != "base64encoded") { - LOG::LogF(LOGDEBUG, "Keys empty"); + LOG::LogF(LOGERROR, "Unknown challenge encoding %s", encodingType); return; } - if (!caps.hdcpLimit) - caps.hdcpLimit = m_resolutionLimit; - -} + challenge = BASE64::DecodeToStr(nodeChallenge.child_value()); -const char* CMFCencSingleSampleDecrypter::GetSessionId() -{ - return m_strSession.empty() ? nullptr : m_strSession.c_str(); -} + LOG::LogF(LOGDEBUG, "Challenge: encoding %s size %i", encodingType, challenge.size()); -void CMFCencSingleSampleDecrypter::CloseSessionId() -{ - if (!m_strSession.empty()) + if (xml_node nodeHeaders = nodeAcquisition.child("HttpHeaders")) { - LOG::LogF(LOGDEBUG, "Closing MF session ID: %s", m_strSession.c_str()); - //m_wvCdmAdapter.GetCdmAdapter()->CloseSession(++m_promiseId, m_strSession.data(), - // m_strSession.size()); - - LOG::LogF(LOGDEBUG, "MF session ID %s closed", m_strSession.c_str()); - m_strSession.clear(); + for (xml_node nodeHeader : nodeHeaders.children("HttpHeader")) + { + std::string name = nodeHeader.child_value("name"); + std::string value = nodeHeader.child_value("value"); + headers.insert({name, value}); + } } -} -AP4_DataBuffer CMFCencSingleSampleDecrypter::GetChallengeData() -{ - return m_challenge; + LOG::LogF(LOGDEBUG, "HttpHeaders: size %i", headers.size()); } -void CMFCencSingleSampleDecrypter::CheckLicenseRenewal() +void CMFCencSingleSampleDecrypter::OnSessionMessage(std::string_view session, + const std::vector& message, + std::string_view messageDestinationUrl) { - { - std::lock_guard lock(m_renewalLock); - if (!m_challenge.GetDataSize()) - return; - } - SendSessionMessage(); -} + std::string challenge; + std::map playReadyHeaders; -bool CMFCencSingleSampleDecrypter::SendSessionMessage() -{ - // StringUtils::Split(m_wvCdmAdapter.GetLicenseURL(), '|') - std::vector blocks{}; + ParsePlayReadyMessage(message, challenge, + playReadyHeaders); - if (blocks.size() != 4) - { - LOG::LogF(LOGERROR, "Wrong \"|\" blocks in license URL. Four blocks (req | header | body | " - "response) are expected in license URL"); - return false; - } + sessionId = session; + m_challenge.SetData(reinterpret_cast(challenge.data()), + static_cast(challenge.size())); + + LOG::LogF(LOGDEBUG, "Playready message session ID: %s", sessionId.c_str()); if (m_host.IsDebugSaveLicense()) { std::string debugFilePath = FILESYS::PathCombine( m_host.GetProfilePath(), "9A04F079-9840-4286-AB92-E65BE0885F95.challenge"); - std::string data{reinterpret_cast(m_challenge.GetData()), - m_challenge.GetDataSize()}; - FILESYS::SaveFile(debugFilePath, data, true); + + FILESYS::SaveFile(debugFilePath, challenge, true); } - //Process placeholder in GET String - std::string::size_type insPos(blocks[0].find("{SSM}")); - if (insPos != std::string::npos) + std::vector blocks; + if (!m_host.GetLicenseKey().empty()) { - if (insPos > 0 && blocks[0][insPos - 1] == 'B') + blocks = StringUtils::Split(m_host.GetLicenseKey(), '|'); + if (blocks.size() != 4) { - std::string msgEncoded{BASE64::Encode(m_challenge.GetData(), m_challenge.GetDataSize())}; - msgEncoded = STRING::URLEncode(msgEncoded); - blocks[0].replace(insPos - 1, 6, msgEncoded); - } - else - { - LOG::Log(LOGERROR, "Unsupported License request template (command)"); - return false; + LOG::LogF(LOGERROR, "Wrong \"|\" blocks in license URL. Four blocks (req | header | body | " + "response) are expected in license URL"); + return; } } - insPos = blocks[0].find("{HASH}"); - if (insPos != std::string::npos) + std::string destinationUrl; + if (!blocks.empty()) + { + destinationUrl = blocks[0]; + } + else { - DIGEST::MD5 md5; - md5.Update(m_challenge.GetData(), m_challenge.GetDataSize()); - md5.Finalize(); - blocks[0].replace(insPos, 6, md5.HexDigest()); + destinationUrl = messageDestinationUrl; } - CURL::CUrl file{blocks[0].c_str()}; + CURL::CUrl file(destinationUrl); file.AddHeader("Expect", ""); - std::string response; - std::string resLimit; - std::string contentType; - char buf[2048]; - bool serverCertRequest; - - //Process headers - std::vector headers{StringUtils::Split(blocks[1], '&')}; - for (std::string& headerStr : headers) + for (const auto& header: playReadyHeaders) { - std::vector header{StringUtils::Split(headerStr, '=')}; - if (!header.empty()) - { - StringUtils::Trim(header[0]); - std::string value; - if (header.size() > 1) - { - StringUtils::Trim(header[1]); - value = STRING::URLDecode(header[1]); - } - file.AddHeader(header[0].c_str(), value.c_str()); - } + file.AddHeader(header.first, header.second); } - //Process body - if (!blocks[2].empty()) + //Process headers + if(!blocks.empty()) { - if (blocks[2][0] == '%') - blocks[2] = STRING::URLDecode(blocks[2]); - - insPos = blocks[2].find("{SSM}"); - if (insPos != std::string::npos) + std::vector headers{StringUtils::Split(blocks[1], '&')}; + for (std::string& headerStr : headers) { - std::string::size_type sidPos(blocks[2].find("{SID}")); - std::string::size_type kidPos(blocks[2].find("{KID}")); - - char fullDecode = 0; - if (insPos > 1 && sidPos > 1 && kidPos > 1 && (blocks[2][0] == 'b' || blocks[2][0] == 'B') && - blocks[2][1] == '{') + std::vector header{StringUtils::Split(headerStr, '=')}; + if (!header.empty()) { - fullDecode = blocks[2][0]; - blocks[2] = blocks[2].substr(2, blocks[2].size() - 3); - insPos -= 2; - if (kidPos != std::string::npos) - kidPos -= 2; - if (sidPos != std::string::npos) - sidPos -= 2; - } - - size_t size_written(0); - - if (insPos > 0) - { - if (blocks[2][insPos - 1] == 'B' || blocks[2][insPos - 1] == 'b') - { - std::string msgEncoded{BASE64::Encode(m_challenge.GetData(), m_challenge.GetDataSize())}; - if (blocks[2][insPos - 1] == 'B') - { - msgEncoded = STRING::URLEncode(msgEncoded); - } - blocks[2].replace(insPos - 1, 6, msgEncoded); - size_written = msgEncoded.size(); - } - else if (blocks[2][insPos - 1] == 'D') - { - std::string msgEncoded{ - STRING::ToDecimal(m_challenge.GetData(), m_challenge.GetDataSize())}; - blocks[2].replace(insPos - 1, 6, msgEncoded); - size_written = msgEncoded.size(); - } - else - { - blocks[2].replace(insPos - 1, 6, reinterpret_cast(m_challenge.GetData()), - m_challenge.GetDataSize()); - size_written = m_challenge.GetDataSize(); - } - } - else - { - LOG::Log(LOGERROR, "Unsupported License request template (body / ?{SSM})"); - return false; - } - - if (sidPos != std::string::npos && insPos < sidPos) - sidPos += size_written, sidPos -= 6; - - if (kidPos != std::string::npos && insPos < kidPos) - kidPos += size_written, kidPos -= 6; - - size_written = 0; - - if (sidPos != std::string::npos) - { - if (sidPos > 0) - { - if (blocks[2][sidPos - 1] == 'B' || blocks[2][sidPos - 1] == 'b') - { - std::string msgEncoded{BASE64::Encode(m_strSession)}; - - if (blocks[2][sidPos - 1] == 'B') - { - msgEncoded = STRING::URLEncode(msgEncoded); - } - - blocks[2].replace(sidPos - 1, 6, msgEncoded); - size_written = msgEncoded.size(); - } - else - { - blocks[2].replace(sidPos - 1, 6, m_strSession.data(), m_strSession.size()); - size_written = m_strSession.size(); - } - } - else - { - LOG::LogF(LOGERROR, "Unsupported License request template (body / ?{SID})"); - return false; - } - } - - if (kidPos != std::string::npos) - { - if (sidPos < kidPos) - kidPos += size_written, kidPos -= 6; - - if (blocks[2][kidPos - 1] == 'H') - { - std::string keyIdUUID{StringUtils::ToHexadecimal(m_defaultKeyId)}; - blocks[2].replace(kidPos - 1, 6, keyIdUUID.c_str(), 32); - } - else - { - std::string kidUUID{ConvertKIDtoUUID(m_defaultKeyId)}; - blocks[2].replace(kidPos, 5, kidUUID.c_str(), 36); - } - } - - if (fullDecode) - { - std::string msgEncoded{BASE64::Encode(blocks[2])}; - if (fullDecode == 'B') + StringUtils::Trim(header[0]); + std::string value; + if (header.size() > 1) { - msgEncoded = STRING::URLEncode(msgEncoded); + StringUtils::Trim(header[1]); + value = STRING::URLDecode(header[1]); } - blocks[2] = msgEncoded; + file.AddHeader(header[0].c_str(), value.c_str()); } - } - - std::string encData{BASE64::Encode(blocks[2])}; - file.AddHeader("postdata", encData.c_str()); + } } - serverCertRequest = m_challenge.GetDataSize() == 2; - m_challenge.SetDataSize(0); + std::string encData{BASE64::Encode(challenge)}; + file.AddHeader("postdata", encData); - if (!file.Open()) + int statusCode = file.Open(); + if (statusCode == -1 || statusCode >= 400) { LOG::Log(LOGERROR, "License server returned failure"); - return false; + return; } + std::string response; + CURL::ReadStatus downloadStatus = CURL::ReadStatus::CHUNK_READ; while (downloadStatus == CURL::ReadStatus::CHUNK_READ) { downloadStatus = file.Read(response); } - resLimit = file.GetResponseHeader("X-Limit-Video"); - contentType = file.GetResponseHeader("Content-Type"); - - if (!resLimit.empty()) - { - std::string::size_type posMax = resLimit.find("max="); // log/check this - if (posMax != std::string::npos) - m_resolutionLimit = std::atoi(resLimit.data() + (posMax + 4)); - } - if (downloadStatus == CURL::ReadStatus::ERROR) { LOG::LogF(LOGERROR, "Could not read full SessionMessage response"); - return false; + return; } if (m_host.IsDebugSaveLicense()) @@ -445,44 +267,79 @@ bool CMFCencSingleSampleDecrypter::SendSessionMessage() FILESYS::SaveFile(debugFilePath, response, true); } - if (serverCertRequest && contentType.find("application/octet-stream") == std::string::npos) - serverCertRequest = false; + m_host.GetCdm()->UpdateSession( + sessionId, std::vector(response.data(), response.data() + response.size())); +} + +void CMFCencSingleSampleDecrypter::OnKeyChange(std::string_view sessionId, + std::vector> keys) +{ + LOG::LogF(LOGDEBUG, "Received %i keys", keys.size()); + for (const auto& key : keys) + { + char buf[36]; + buf[32] = 0; + AP4_FormatHex(key->keyId.data(), key->keyId.size(), buf); + + LOG::LogF(LOGDEBUG, "Key: %s status: %i", buf, key->status); + } + m_keys = std::move(keys); +} + +void CMFCencSingleSampleDecrypter::GetCapabilities(std::string_view key, + uint32_t media, + IDecrypter::DecrypterCapabilites& caps) +{ + caps = {IDecrypter::DecrypterCapabilites::SSD_SECURE_PATH | + IDecrypter::DecrypterCapabilites::SSD_ANNEXB_REQUIRED, + 0, m_hdcpLimit}; - //m_wvCdmAdapter.GetCdmAdapter()->UpdateSession( - // ++m_promiseId, m_strSession.data(), m_strSession.size(), - // reinterpret_cast(response.data()), response.size()); + if (sessionId.empty()) + { + LOG::LogF(LOGDEBUG, "Session empty"); + return; + } if (m_keys.empty()) { - LOG::LogF(LOGERROR, "License update not successful (no keys)"); - CloseSessionId(); - return false; + LOG::LogF(LOGDEBUG, "Keys empty"); + return; } - LOG::Log(LOGDEBUG, "License update successful"); - return true; + if (!caps.hdcpLimit) + caps.hdcpLimit = m_resolutionLimit; +} + +const char* CMFCencSingleSampleDecrypter::GetSessionId() +{ + return sessionId.empty() ? nullptr : sessionId.c_str(); } -void CMFCencSingleSampleDecrypter::AddSessionKey(const uint8_t* data, - size_t dataSize, - uint32_t status) +void CMFCencSingleSampleDecrypter::CloseSessionId() { - WVSKEY key; - std::vector::iterator res; + if (!sessionId.empty()) + { + LOG::LogF(LOGDEBUG, "Closing MF session ID: %s", sessionId.c_str()); + //m_wvCdmAdapter.GetCdmAdapter()->CloseSession(++m_promiseId, sessionId.data(), + // sessionId.size()); - key.m_keyId = std::string((const char*)data, dataSize); - if ((res = std::find(m_keys.begin(), m_keys.end(), key)) == m_keys.end()) - res = m_keys.insert(res, key); - res->status = static_cast(status); + LOG::LogF(LOGDEBUG, "MF session ID %s closed", sessionId.c_str()); + sessionId.clear(); + } +} + +AP4_DataBuffer CMFCencSingleSampleDecrypter::GetChallengeData() +{ + return m_challenge; } -bool CMFCencSingleSampleDecrypter::HasKeyId(std::string_view keyid) +bool CMFCencSingleSampleDecrypter::HasKeyId(std::string_view keyId) { - if (!keyid.empty()) + if (!keyId.empty()) { - for (const WVSKEY& key : m_keys) + for (const std::unique_ptr& key : m_keys) { - if (key.m_keyId == keyid) + if (key->keyId == STRING::ToVecUint8(keyId)) return true; } } @@ -598,12 +455,11 @@ void CMFCencSingleSampleDecrypter::SetDefaultKeyId(std::string_view keyId) void CMFCencSingleSampleDecrypter::AddKeyId(std::string_view keyId) { - WVSKEY key; - key.m_keyId = keyId; - key.status = MFKeyUsable; + std::unique_ptr key = std::make_unique( + std::vector(keyId.data(), keyId.data() + keyId.size()), MFKeyUsable); if (std::find(m_keys.begin(), m_keys.end(), key) == m_keys.end()) { - m_keys.push_back(key); + m_keys.push_back(std::move(key)); } } diff --git a/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.h b/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.h index 472e90959..f19365eba 100644 --- a/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.h +++ b/src/decrypters/mediafoundation/MFCencSingleSampleDecrypter.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -45,9 +46,9 @@ class ATTR_DLL_LOCAL CMFCencSingleSampleDecrypter : public Adaptive_CencSingleSa void OnSessionMessage(std::string_view session, const std::vector& message, std::string_view destinationUrl) override; + void OnKeyChange(std::string_view sessionId, std::vector> keys) override; - void AddSessionKey(const uint8_t* data, size_t dataSize, uint32_t status); - bool HasKeyId(std::string_view keyid); + bool HasKeyId(std::string_view keyId); virtual AP4_Result SetFragmentInfo(AP4_UI32 poolId, const std::vector& keyId, @@ -86,22 +87,17 @@ class ATTR_DLL_LOCAL CMFCencSingleSampleDecrypter : public Adaptive_CencSingleSa void AddKeyId(std::string_view keyId) override; private: - void CheckLicenseRenewal(); - bool SendSessionMessage(); + void ParsePlayReadyMessage(const std::vector& message, + std::string& challenge, + std::map& headers); CMFDecrypter& m_host; - std::string m_strSession; + std::string sessionId; std::vector m_pssh; AP4_DataBuffer m_challenge; std::string m_defaultKeyId; - struct WVSKEY - { - bool operator==(WVSKEY const& other) const { return m_keyId == other.m_keyId; }; - std::string m_keyId; - KeyStatus status; - }; - std::vector m_keys; + std::vector> m_keys; AP4_UI16 m_hdcpVersion; int m_hdcpLimit; @@ -143,7 +139,6 @@ class ATTR_DLL_LOCAL CMFCencSingleSampleDecrypter : public Adaptive_CencSingleSa bool m_isDrained; //std::list m_videoFrames; - std::mutex m_renewalLock; CryptoMode m_EncryptionMode; //std::optional m_currentVideoDecConfig; diff --git a/src/decrypters/mediafoundation/MFDecrypter.cpp b/src/decrypters/mediafoundation/MFDecrypter.cpp index e1b6962c4..bad909cbb 100644 --- a/src/decrypters/mediafoundation/MFDecrypter.cpp +++ b/src/decrypters/mediafoundation/MFDecrypter.cpp @@ -78,6 +78,8 @@ bool CMFDecrypter::OpenDRMSystem(std::string_view licenseURL, return false; } + m_strLicenseKey = licenseURL; + return m_cdm->Initialize({true, true}, "com.microsoft.playready.recommendation", m_strProfilePath); } diff --git a/src/decrypters/mediafoundation/MFDecrypter.h b/src/decrypters/mediafoundation/MFDecrypter.h index 17071bc4d..0adc8f270 100644 --- a/src/decrypters/mediafoundation/MFDecrypter.h +++ b/src/decrypters/mediafoundation/MFDecrypter.h @@ -79,12 +79,13 @@ class ATTR_DLL_LOCAL CMFDecrypter : public IDecrypter MediaFoundationCdm* GetCdm() const { return m_cdm; } + std::string GetLicenseKey() const { return m_strLicenseKey; } + private: MediaFoundationCdm* m_cdm; CMFCencSingleSampleDecrypter* m_decodingDecrypter; - std::string m_strProfilePath; std::string m_strLibraryPath; - + std::string m_strLicenseKey; bool m_isDebugSaveLicense; };