From a4dbc742e1a5e32b74b374d1c31bba5c2bd5f1d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Andr=C3=A9=20Moreau?= Date: Thu, 30 May 2024 17:02:39 -0400 Subject: [PATCH] Start exposing in-process DVC client APIs --- CMakeLists.txt | 1 + channels/CMakeLists.txt | 5 + channels/DvcServer.cpp | 173 +++++++++ dll/CMakeLists.txt | 2 + dll/MsRdpEx.cpp | 18 +- dll/RdpDvcClient.cpp | 409 ++++++++++++++++++++++ dll/RdpDvcClient.h | 87 +++++ dll/RdpInstance.cpp | 13 + dll/RdpSettings.cpp | 7 + dotnet/Devolutions.MsRdpEx/Bindings.cs | 4 + dotnet/Devolutions.MsRdpEx/RdpInstance.cs | 6 + dotnet/MsRdpEx_App/MainDlg.cs | 12 +- dotnet/MsRdpEx_App/MsRdpEx_App.csproj | 2 + dotnet/MsRdpEx_App/RdpChannel.cs | 193 ++++++++++ include/MsRdpEx/RdpInstance.h | 2 + include/MsRdpEx/RdpSettings.h | 1 + 16 files changed, 932 insertions(+), 3 deletions(-) create mode 100644 channels/CMakeLists.txt create mode 100644 channels/DvcServer.cpp create mode 100644 dll/RdpDvcClient.cpp create mode 100644 dll/RdpDvcClient.h create mode 100644 dotnet/MsRdpEx_App/RdpChannel.cs diff --git a/CMakeLists.txt b/CMakeLists.txt index 8709768..a42c6aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,7 @@ include_directories("${CMAKE_SOURCE_DIR}/com") if(WITH_NATIVE) add_subdirectory(dll) add_subdirectory(exe) + add_subdirectory(channels) endif() if(WITH_DOTNET) diff --git a/channels/CMakeLists.txt b/channels/CMakeLists.txt new file mode 100644 index 0000000..06f6e9a --- /dev/null +++ b/channels/CMakeLists.txt @@ -0,0 +1,5 @@ + +add_executable(DvcServer + DvcServer.cpp) + +target_link_libraries(DvcServer wtsapi32.lib) diff --git a/channels/DvcServer.cpp b/channels/DvcServer.cpp new file mode 100644 index 0000000..484fd14 --- /dev/null +++ b/channels/DvcServer.cpp @@ -0,0 +1,173 @@ +#include +#include +#include +#include +#include + +DWORD OpenVirtualChannel(const char* channelName, HANDLE* phFile) +{ + HANDLE hWTSHandle = NULL; + HANDLE hWTSFileHandle; + PVOID vcFileHandlePtr = NULL; + DWORD len; + DWORD rc = ERROR_SUCCESS; + + hWTSHandle = WTSVirtualChannelOpenEx(WTS_CURRENT_SESSION, (LPSTR)channelName, WTS_CHANNEL_OPTION_DYNAMIC); + + if (!hWTSHandle) + { + rc = GetLastError(); + printf("WTSVirtualChannelOpenEx API Call Failed: GetLastError() = %d\n", GetLastError()); + goto exitpt; + } + + BOOL bSuccess = WTSVirtualChannelQuery(hWTSHandle, WTSVirtualFileHandle, &vcFileHandlePtr, &len); + + if (!bSuccess) + { + rc = GetLastError(); + goto exitpt; + } + + if (len != sizeof(HANDLE)) + { + rc = ERROR_INVALID_PARAMETER; + goto exitpt; + } + + hWTSFileHandle = *(HANDLE*)vcFileHandlePtr; + + bSuccess = DuplicateHandle(GetCurrentProcess(), + hWTSFileHandle, GetCurrentProcess(), phFile, 0, FALSE, DUPLICATE_SAME_ACCESS); + + if (!bSuccess) + { + rc = GetLastError(); + goto exitpt; + } + + rc = ERROR_SUCCESS; +exitpt: + if (vcFileHandlePtr) + { + WTSFreeMemory(vcFileHandlePtr); + } + if (hWTSHandle) + { + WTSVirtualChannelClose(hWTSHandle); + } + return rc; +} + +DWORD WriteVirtualChannelMessage(HANDLE hFile, ULONG cbSize, BYTE* pBuffer) +{ + BYTE WriteBuffer[1024]; + DWORD dwWritten; + BOOL bSuccess; + HANDLE hEvent; + + hEvent = CreateEvent(NULL, FALSE, FALSE, NULL); + + OVERLAPPED overlapped = { 0 }; + overlapped.hEvent = hEvent; + + bSuccess = WriteFile(hFile, pBuffer, cbSize, &dwWritten, &overlapped); + + if (!bSuccess) + { + if (GetLastError() == ERROR_IO_PENDING) + { + DWORD dwStatus = WaitForSingleObject(overlapped.hEvent, 10000); + bSuccess = GetOverlappedResult(hFile, &overlapped, &dwWritten, FALSE); + } + } + + if (!bSuccess) + { + DWORD error = GetLastError(); + return error; + } + + return 0; +} + +DWORD HandleVirtualChannel(HANDLE hFile) +{ + BYTE ReadBuffer[CHANNEL_PDU_LENGTH]; + DWORD dwRead; + BYTE b = 0; + CHANNEL_PDU_HEADER* pHdr = (CHANNEL_PDU_HEADER*)ReadBuffer; + BOOL bSuccess; + HANDLE hEvent; + + const char* cmd = "whoami"; + ULONG cbSize = strlen(cmd) + 1; + BYTE* pBuffer = (BYTE*)cmd; + WriteVirtualChannelMessage(hFile, cbSize, pBuffer); + + hEvent = CreateEvent(NULL, FALSE, FALSE, NULL); + + do + { + OVERLAPPED overlapped = { 0 }; + DWORD TotalRead = 0; + + do { + overlapped.hEvent = hEvent; + bSuccess = ReadFile(hFile, ReadBuffer, sizeof(ReadBuffer), &dwRead, &overlapped); + + if (!bSuccess) + { + if (GetLastError() == ERROR_IO_PENDING) + { + DWORD dwStatus = WaitForSingleObject(overlapped.hEvent, INFINITE); + bSuccess = GetOverlappedResult(hFile, &overlapped, &dwRead, FALSE); + } + } + + if (!bSuccess) + { + DWORD error = GetLastError(); + return error; + } + + printf("read %d bytes\n", dwRead); + + ULONG packetSize = dwRead - sizeof(*pHdr); + TotalRead += packetSize; + PBYTE pData = (PBYTE)(pHdr + 1); + + printf(">> %s\n", (const char*)pData); + + } while (0 == (pHdr->flags & CHANNEL_FLAG_LAST)); + + } while (true); + + return 0; +} + +INT _cdecl wmain(INT argc, __in_ecount(argc) WCHAR** argv) +{ + DWORD rc; + HANDLE hFile; + const char* channelName = "DvcSample"; + + printf("Opening %s dynamic virtual channel\n", channelName); + rc = OpenVirtualChannel(channelName, &hFile); + + if (ERROR_SUCCESS != rc) + { + printf("Failed to open %s dynamic virtual channel\n", channelName); + return 0; + } + else + { + printf("%s dynamic virtual channel is opened\n", channelName); + } + + HandleVirtualChannel(hFile); + + CloseHandle(hFile); + + return 0; +} diff --git a/dll/CMakeLists.txt b/dll/CMakeLists.txt index d8aa8eb..6a04ff3 100644 --- a/dll/CMakeLists.txt +++ b/dll/CMakeLists.txt @@ -61,6 +61,8 @@ set(MSRDPEX_SOURCES RdpProcess.cpp RdpInstance.cpp RdpSettings.cpp + RdpDvcClient.cpp + RdpDvcClient.h TSObjects.cpp TSObjects.h MsRdpEx.cpp diff --git a/dll/MsRdpEx.cpp b/dll/MsRdpEx.cpp index 3f517f2..176ea20 100644 --- a/dll/MsRdpEx.cpp +++ b/dll/MsRdpEx.cpp @@ -5,11 +5,15 @@ #include +#include + #include #include #include +#include "RdpDvcClient.h" + static HMODULE g_hModule = NULL; static bool g_AxHookEnabled = true; @@ -30,9 +34,19 @@ HRESULT STDAPICALLTYPE DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID* p HRESULT hr = E_UNEXPECTED; char clsid[MSRDPEX_GUID_STRING_SIZE]; char iid[MSRDPEX_GUID_STRING_SIZE]; + const GUID* pclsid = reinterpret_cast(&rclsid); + const GUID* piid = reinterpret_cast(&riid); + + MsRdpEx_GuidBinToStr(pclsid, clsid, 0); + MsRdpEx_GuidBinToStr(piid, iid, 0); - MsRdpEx_GuidBinToStr((GUID*)&rclsid, clsid, 0); - MsRdpEx_GuidBinToStr((GUID*)&riid, iid, 0); + CMsRdpExInstance* instance = MsRdpEx_InstanceManager_FindBySessionId((GUID*) pclsid); + + if (instance) { + hr = DllGetClassObject_DvcPlugin(rclsid, riid, ppv, (void*) instance); + MsRdpEx_LogPrint(DEBUG, "DllGetClassObject_DvcPlugin(%s, %s) with instance %p, hr = 0x%08X", clsid, iid, hr, instance); + return hr; + } if (g_IsClientProcess) { if (g_IsOOBClient) { diff --git a/dll/RdpDvcClient.cpp b/dll/RdpDvcClient.cpp new file mode 100644 index 0000000..6257d65 --- /dev/null +++ b/dll/RdpDvcClient.cpp @@ -0,0 +1,409 @@ + +#include "RdpDvcClient.h" + +#include "MsRdpEx.h" + +#include + +#include + +// +// CRdpDvcClient class +// + +// IUnknown methods + +STDMETHODIMP CRdpDvcClient::QueryInterface(REFIID riid, void** ppv) +{ + HRESULT hr = S_OK; + + char iid[MSRDPEX_GUID_STRING_SIZE]; + MsRdpEx_GuidBinToStr((GUID*)&riid, iid, 0); + + if (!ppv) + return E_INVALIDARG; + + *ppv = NULL; + + if (riid == (REFIID) IID_IUnknown) { + *ppv = this; + } + else if (riid == (REFIID) IID_IWTSVirtualChannelCallback) { + MsRdpEx_LogPrint(DEBUG, "CRdpDvcClient::QueryInterface(IID_IWTSVirtualChannelCallback)"); + *ppv = static_cast(this); + } + + if (nullptr != *ppv) { + ((IUnknown*)*ppv)->AddRef(); + } + else { + hr = E_NOINTERFACE; + } + + MsRdpEx_LogPrint(DEBUG, "CRdpDvcClient::QueryInterface(%s), hr = 0x%08X", iid, hr); + + return hr; +} + +STDMETHODIMP_(ULONG) CRdpDvcClient::AddRef(void) +{ + return InterlockedIncrement(&m_refCount); +} + +STDMETHODIMP_(ULONG) CRdpDvcClient::Release(void) +{ + ULONG refCount = InterlockedDecrement(&m_refCount); + + if (refCount != 0) { + return refCount; + } + + delete this; + return 0; +} + +// IWTSVirtualChannelCallback methods + +HRESULT STDMETHODCALLTYPE CRdpDvcClient::OnDataReceived(ULONG cbSize, BYTE* pBuffer) +{ + MsRdpEx_LogPrint(DEBUG, "CRdpDvcClient::OnDataReceived(%s)", (const char*)pBuffer); + HRESULT hr = m_pChannel->Write(cbSize, pBuffer, NULL); + return S_OK; +} + +HRESULT STDMETHODCALLTYPE CRdpDvcClient::OnClose(void) +{ + MsRdpEx_LogPrint(DEBUG, "CRdpDvcClient::OnClose()"); + return S_OK; +} + +// Additional methods + +void CRdpDvcClient::SetChannel(IWTSVirtualChannel* pChannel) +{ + m_pChannel = pChannel; +} + +void CRdpDvcClient::SetListener(IWTSListener* pListener) +{ + m_pListener = pListener; +} + +// Additional methods + +CRdpDvcClient::CRdpDvcClient(void) +{ + +} + +CRdpDvcClient::~CRdpDvcClient() +{ + +} + +// +// CRdpDvcListener class +// + +// IUnknown methods + +STDMETHODIMP CRdpDvcListener::QueryInterface(REFIID riid, void** ppv) +{ + HRESULT hr = S_OK; + + char iid[MSRDPEX_GUID_STRING_SIZE]; + MsRdpEx_GuidBinToStr((GUID*)&riid, iid, 0); + + if (!ppv) + return E_INVALIDARG; + + *ppv = NULL; + + if (riid == IID_IUnknown) { + *ppv = this; + } + else if (riid == IID_IWTSListenerCallback) { + *ppv = static_cast(this); + } + + if (nullptr != *ppv) { + ((IUnknown*)*ppv)->AddRef(); + } + else { + hr = E_NOINTERFACE; + } + + MsRdpEx_LogPrint(DEBUG, "CRdpDvcListener::QueryInterface(%s), hr = 0x%08X", iid, hr); + + return hr; +} + +STDMETHODIMP_(ULONG) CRdpDvcListener::AddRef(void) +{ + return InterlockedIncrement(&m_refCount); +} + +STDMETHODIMP_(ULONG) CRdpDvcListener::Release(void) +{ + ULONG refCount = InterlockedDecrement(&m_refCount); + + if (refCount != 0) { + return refCount; + } + + delete this; + return 0; +} + +// IWTSListenerCallback methods + +HRESULT STDMETHODCALLTYPE CRdpDvcListener::OnNewChannelConnection(IWTSVirtualChannel* pChannel, + BSTR data, BOOL* pbAccept, IWTSVirtualChannelCallback** ppCallback) +{ + HRESULT hr = S_OK; + CRdpDvcClient* dvcClient = new CRdpDvcClient(); + IWTSVirtualChannelCallback* pIWTSVirtualChannelCallback = NULL; + + MsRdpEx_LogPrint(DEBUG, "CRdpDvcListener::OnNewChannelConnection()"); + + hr = dvcClient->QueryInterface(IID_IWTSVirtualChannelCallback, (void**)&pIWTSVirtualChannelCallback); + + if (FAILED(hr)) { + return hr; + } + + dvcClient->SetChannel(pChannel); + + *pbAccept = TRUE; + *ppCallback = pIWTSVirtualChannelCallback; + + return hr; +} + +// Additional methods + +CRdpDvcListener::CRdpDvcListener(void) +{ + +} + +CRdpDvcListener::~CRdpDvcListener() +{ + +} + +// +// CRdpDvcPlugin class +// + +// IUnknown methods + +STDMETHODIMP CRdpDvcPlugin::QueryInterface(REFIID riid, void** ppv) +{ + HRESULT hr = S_OK; + + char iid[MSRDPEX_GUID_STRING_SIZE]; + MsRdpEx_GuidBinToStr((GUID*)&riid, iid, 0); + + MsRdpEx_LogPrint(DEBUG, "CRdpDvcPlugin::QueryInterface(%s)", iid); + + if (!ppv) + return E_INVALIDARG; + + *ppv = NULL; + + if (riid == IID_IUnknown) { + *ppv = this; + } + else if (riid == IID_IWTSPlugin) { + *ppv = static_cast(this); + } + + if (nullptr != *ppv) { + ((IUnknown*)*ppv)->AddRef(); + } + else { + hr = E_NOINTERFACE; + } + + return hr; +} + +STDMETHODIMP_(ULONG) CRdpDvcPlugin::AddRef(void) +{ + return InterlockedIncrement(&m_refCount); +} + +STDMETHODIMP_(ULONG) CRdpDvcPlugin::Release(void) +{ + ULONG refCount = InterlockedDecrement(&m_refCount); + + if (refCount != 0) { + return refCount; + } + + delete this; + return 0; +} + +// IWTSPlugin methods + +HRESULT STDMETHODCALLTYPE CRdpDvcPlugin::Initialize(IWTSVirtualChannelManager* pChannelMgr) +{ + HRESULT hr = S_OK; + CRdpDvcListener* dvcListener = new CRdpDvcListener(); + IWTSListener* pIWTSListener = NULL; + IWTSListenerCallback* pIWTSListenerCallback = NULL; + + hr = dvcListener->QueryInterface(IID_IWTSListenerCallback, (void**)&pIWTSListenerCallback); + + if (FAILED(hr)) { + return hr; + } + + hr = pChannelMgr->CreateListener("DvcSample", 0, pIWTSListenerCallback, &pIWTSListener); + + return hr; +} + +HRESULT STDMETHODCALLTYPE CRdpDvcPlugin::Connected(void) +{ + MsRdpEx_LogPrint(DEBUG, "CRdpDvcPlugin::Connected()"); + return S_OK; +} + +HRESULT STDMETHODCALLTYPE CRdpDvcPlugin::Disconnected(DWORD dwDisconnectCode) +{ + MsRdpEx_LogPrint(DEBUG, "CRdpDvcPlugin::Disconnected()"); + return S_OK; +} + +HRESULT STDMETHODCALLTYPE CRdpDvcPlugin::Terminated(void) +{ + MsRdpEx_LogPrint(DEBUG, "CRdpDvcPlugin::Terminated()"); + return S_OK; +} + +// Additional methods + +CRdpDvcPlugin::CRdpDvcPlugin(void) +{ + +} + +CRdpDvcPlugin::~CRdpDvcPlugin() +{ + +} + +// CDvcPluginClassFactory class + +class CDvcPluginClassFactory : IClassFactory +{ +public: + CDvcPluginClassFactory(CMsRdpExInstance* instance) + { + m_instance = instance; + } + + ~CDvcPluginClassFactory() + { + + } + + // IUnknown interface +public: + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, LPVOID* ppvObject) + { + HRESULT hr = E_NOINTERFACE; + + char iid[MSRDPEX_GUID_STRING_SIZE]; + MsRdpEx_GuidBinToStr((GUID*)&riid, iid, 0); + + MsRdpEx_LogPrint(DEBUG, "CDvcPluginClassFactory::QueryInterface(%s)", iid); + + if (riid == IID_IUnknown) { + *ppvObject = (LPVOID)((IUnknown*)this); + m_refCount++; + return S_OK; + } + if (riid == IID_IClassFactory) { + *ppvObject = (LPVOID)((IClassFactory*)this); + m_refCount++; + return S_OK; + } + + return hr; + } + + ULONG STDMETHODCALLTYPE AddRef() + { + return InterlockedIncrement(&m_refCount); + } + + ULONG STDMETHODCALLTYPE Release() + { + ULONG refCount = InterlockedDecrement(&m_refCount); + + if (refCount != 0) { + return refCount; + } + + delete this; + return 0; + } + + // IClassFactory interface +public: + HRESULT STDMETHODCALLTYPE CreateInstance(IUnknown* pUnkOuter, REFIID riid, LPVOID* ppvObject) + { + HRESULT hr = E_NOINTERFACE; + + char iid[MSRDPEX_GUID_STRING_SIZE]; + MsRdpEx_GuidBinToStr((GUID*)&riid, iid, 0); + + if (riid == IID_IWTSPlugin) { + IUnknown* wtsPlugin = NULL; + IMsRdpExInstance* rdpInstance = (IMsRdpExInstance*)m_instance; + rdpInstance->GetWTSPluginObject((void**)&wtsPlugin); + + if (wtsPlugin) { + MsRdpEx_LogPrint(DEBUG, "CDvcPluginClassFactory using registered WTSPlugin"); + hr = wtsPlugin->QueryInterface(riid, ppvObject); + } + else { + MsRdpEx_LogPrint(DEBUG, "CDvcPluginClassFactory using built-in WTSPlugin"); + CRdpDvcPlugin* dvcPlugin = new CRdpDvcPlugin(); + hr = dvcPlugin->QueryInterface(riid, ppvObject); + } + } + + MsRdpEx_LogPrint(DEBUG, "CDvcPluginClassFactory::CreateInstance(%s), hr = 0x%08X", iid, hr); + + return hr; + } + + HRESULT STDMETHODCALLTYPE LockServer(BOOL fLock) + { + MsRdpEx_LogPrint(DEBUG, "CDvcPluginClassFactory::LockServer"); + return S_OK; + } + +private: + ULONG m_refCount = 0; + CMsRdpExInstance* m_instance = NULL; +}; + +HRESULT STDAPICALLTYPE DllGetClassObject_DvcPlugin(REFCLSID rclsid, REFIID riid, LPVOID* ppv, void* instance) +{ + HRESULT hr = E_NOINTERFACE; + + if (riid == (REFIID) IID_IClassFactory) + { + CDvcPluginClassFactory* classFactory = new CDvcPluginClassFactory((CMsRdpExInstance*) instance); + *ppv = (LPVOID) classFactory; + hr = S_OK; + } + + return hr; +} diff --git a/dll/RdpDvcClient.h b/dll/RdpDvcClient.h new file mode 100644 index 0000000..472321a --- /dev/null +++ b/dll/RdpDvcClient.h @@ -0,0 +1,87 @@ +#ifndef MSRDPEX_DVC_CLIENT_H +#define MSRDPEX_DVC_CLIENT_H + +#include "MsRdpEx.h" + +#include +#include +#include +#include + +#include + +class CRdpDvcClient : + public IWTSVirtualChannelCallback +{ +public: + // IUnknown methods + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppv) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + // IWTSVirtualChannelCallback methods + HRESULT STDMETHODCALLTYPE OnDataReceived(ULONG cbSize, BYTE* pBuffer) override; + HRESULT STDMETHODCALLTYPE OnClose(void) override; + + // Additional methods + void SetChannel(IWTSVirtualChannel* pChannel); + void SetListener(IWTSListener* pListener); + + CRdpDvcClient(void); + virtual ~CRdpDvcClient(); + +private: + ULONG m_refCount = 0; + IWTSVirtualChannel* m_pChannel = NULL; + IWTSListener* m_pListener = NULL; +}; + +class CRdpDvcListener : + public IWTSListenerCallback +{ +public: + // IUnknown methods + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppv) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + // IWTSListenerCallback methods + HRESULT STDMETHODCALLTYPE OnNewChannelConnection(IWTSVirtualChannel* pChannel, + BSTR data, BOOL* pbAccept, IWTSVirtualChannelCallback** ppCallback) override; + + // Additional methods + CRdpDvcListener(void); + virtual ~CRdpDvcListener(); +private: + ULONG m_refCount = 0; +}; + +class CRdpDvcPlugin : + public IWTSPlugin +{ +public: + // IUnknown methods + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppv) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + // IWTSPlugin methods + HRESULT STDMETHODCALLTYPE Initialize(IWTSVirtualChannelManager* pChannelMgr) override; + HRESULT STDMETHODCALLTYPE Connected() override; + HRESULT STDMETHODCALLTYPE Disconnected(DWORD dwDisconnectCode) override; + HRESULT STDMETHODCALLTYPE Terminated() override; + + // Additional methods + CRdpDvcPlugin(void); + virtual ~CRdpDvcPlugin(); + +private: + ULONG m_refCount = 0; + IWTSVirtualChannel* m_pChannel = NULL; +}; + +extern "C" const GUID CLSID_IMsRdpExDVCPlugin; + +HRESULT STDAPICALLTYPE DllGetClassObject_DvcPlugin(REFCLSID rclsid, REFIID riid, LPVOID* ppv, void* instance); + +#endif /* MSRDPEX_DVC_CLIENT_H */ diff --git a/dll/RdpInstance.cpp b/dll/RdpInstance.cpp index 1dcdc1b..d7b9b02 100644 --- a/dll/RdpInstance.cpp +++ b/dll/RdpInstance.cpp @@ -243,6 +243,18 @@ class CMsRdpExInstance : public IMsRdpExInstance m_LastMousePosY = posY; } + HRESULT STDMETHODCALLTYPE GetWTSPluginObject(LPVOID* ppvObject) + { + *ppvObject = m_WTSPlugin; + return S_OK; + } + + HRESULT STDMETHODCALLTYPE SetWTSPluginObject(LPVOID pvObject) + { + m_WTSPlugin = (IUnknown*)pvObject; + return S_OK; + } + public: GUID m_sessionId; ULONG m_refCount = NULL; @@ -257,6 +269,7 @@ class CMsRdpExInstance : public IMsRdpExInstance CMsRdpExtendedSettings* m_pMsRdpExtendedSettings = NULL; int32_t m_LastMousePosX = 0; int32_t m_LastMousePosY = 0; + IUnknown* m_WTSPlugin = NULL; }; CMsRdpExInstance* CMsRdpExInstance_New(CMsRdpClient* pMsRdpClient) diff --git a/dll/RdpSettings.cpp b/dll/RdpSettings.cpp index 42fbf36..b9d6d56 100644 --- a/dll/RdpSettings.cpp +++ b/dll/RdpSettings.cpp @@ -604,6 +604,13 @@ HRESULT __stdcall CMsRdpExtendedSettings::get_Property(BSTR bstrPropertyName, VA pValue->intVal = (INT) m_MouseJigglerMethod; hr = S_OK; } + else if (MsRdpEx_StringEquals(propName, "MsRdpEx_SessionId")) { + pValue->vt = VT_BSTR; + char sessionId[MSRDPEX_GUID_STRING_SIZE]; + MsRdpEx_GuidBinToStr((GUID*)&m_sessionId, sessionId, 0); + pValue->bstrVal = _com_util::ConvertStringToBSTR(sessionId); + hr = S_OK; + } else { hr = m_pMsRdpExtendedSettings->get_Property(bstrPropertyName, pValue); } diff --git a/dotnet/Devolutions.MsRdpEx/Bindings.cs b/dotnet/Devolutions.MsRdpEx/Bindings.cs index 8e4aa82..5d3812c 100644 --- a/dotnet/Devolutions.MsRdpEx/Bindings.cs +++ b/dotnet/Devolutions.MsRdpEx/Bindings.cs @@ -143,6 +143,10 @@ bool GetShadowBitmap(ref IntPtr phDC, ref IntPtr phBitmap, ref IntPtr pBitmapDat [MethodImpl(MethodImplOptions.PreserveSig)] void SetLastMousePosition(Int32 posX, Int32 posY); + + void GetWTSPluginObject(out IntPtr plugin); + + void SetWTSPluginObject(IntPtr plugin); } public static class Bindings diff --git a/dotnet/Devolutions.MsRdpEx/RdpInstance.cs b/dotnet/Devolutions.MsRdpEx/RdpInstance.cs index b59a4f1..08a508f 100644 --- a/dotnet/Devolutions.MsRdpEx/RdpInstance.cs +++ b/dotnet/Devolutions.MsRdpEx/RdpInstance.cs @@ -1,4 +1,5 @@ using System; +using System.Runtime.InteropServices; namespace MsRdpEx { @@ -38,5 +39,10 @@ public bool GetShadowBitmap(ref IntPtr phDC, ref IntPtr phBitmap, ref IntPtr pBi return iface.GetShadowBitmap(ref phDC, ref phBitmap, ref pBitmapData, ref pBitmapWidth, ref pBitmapHeight, ref pBitmapStep); } + + public object WTSPlugin + { + set { iface.SetWTSPluginObject(Marshal.GetIUnknownForObject(value)); } + } } } diff --git a/dotnet/MsRdpEx_App/MainDlg.cs b/dotnet/MsRdpEx_App/MainDlg.cs index cd823c6..fba2f0a 100644 --- a/dotnet/MsRdpEx_App/MainDlg.cs +++ b/dotnet/MsRdpEx_App/MainDlg.cs @@ -439,13 +439,17 @@ private void btnConnect_Click(object sender, EventArgs e) AxMSTSCLib.AxMsRdpClient9NotSafeForScripting rdp = rdpView.rdpClient; + Guid sessionId = Guid.Empty; + RdpDvcPlugin wtsPlugin = new RdpDvcPlugin(); + if (axHookEnabled) { RdpInstance rdpInstance = new RdpInstance((IMsRdpExInstance)rdp.GetOcx()); rdpInstance.OutputMirrorEnabled = false; rdpInstance.VideoRecordingEnabled = false; + rdpInstance.WTSPlugin = wtsPlugin; - Guid sessionId = rdpInstance.SessionId; + sessionId = rdpInstance.SessionId; Debug.WriteLine("SessionId: {0}", sessionId); } @@ -463,6 +467,12 @@ private void btnConnect_Click(object sender, EventArgs e) rdpView.ClientSize = DesktopSize; rdpView.Text = String.Format("{0} ({1})", rdp.Server, axName); + // https://learn.microsoft.com/en-us/windows/win32/termserv/dvc-plug-in-registration + string pluginCLSID = "7009F103-4B7E-48E2-81BC-46AB3FC1B64C"; + pluginCLSID = sessionId.ToString("D"); + string pluginDlls = String.Format("{0}:{{{1}}}", rdpExDll, pluginCLSID); + rdp.AdvancedSettings.PluginDlls = pluginDlls; + try { object RequestUseNewOutputPresenter = true; extendedSettings.set_Property("RequestUseNewOutputPresenter", ref RequestUseNewOutputPresenter); diff --git a/dotnet/MsRdpEx_App/MsRdpEx_App.csproj b/dotnet/MsRdpEx_App/MsRdpEx_App.csproj index 922dc09..4ec387e 100644 --- a/dotnet/MsRdpEx_App/MsRdpEx_App.csproj +++ b/dotnet/MsRdpEx_App/MsRdpEx_App.csproj @@ -9,6 +9,7 @@ True $(CMakeOutputPath) True + 11 @@ -17,6 +18,7 @@ 1.0.0 1.0.0.0 1.0.0.0 + Always diff --git a/dotnet/MsRdpEx_App/RdpChannel.cs b/dotnet/MsRdpEx_App/RdpChannel.cs new file mode 100644 index 0000000..cfb1afc --- /dev/null +++ b/dotnet/MsRdpEx_App/RdpChannel.cs @@ -0,0 +1,193 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; + +namespace MsRdpEx_App +{ + [ComImport] + [Guid("a1230207-d6a7-11d8-b9fd-000bdbd1f198")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + public interface IWTSVirtualChannel + { + void Write(uint cbSize, IntPtr pBuffer, + [MarshalAs(UnmanagedType.IUnknown)] object pReserved); + void Close(); + } + + [ComImport] + [Guid("a1230204-d6a7-11d8-b9fd-000bdbd1f198")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + public interface IWTSVirtualChannelCallback + { + void OnDataReceived(uint cbSize, IntPtr pBuffer); + void OnClose(); + } + + [ComImport] + [Guid("a1230206-9a39-4d58-8674-cdb4dff4e73b")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + public interface IWTSListener + { + void GetConfiguration( + [MarshalAs(UnmanagedType.IUnknown)] out object ppPropertyBag); + } + + [ComImport] + [Guid("a1230203-d6a7-11d8-b9fd-000bdbd1f198")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + public interface IWTSListenerCallback + { + void OnNewChannelConnection( + IWTSVirtualChannel pChannel, + [MarshalAs(UnmanagedType.BStr)] string data, + [MarshalAs(UnmanagedType.Bool)] out bool pAccept, + out IWTSVirtualChannelCallback pCallback); + } + + [ComImport] + [Guid("a1230205-d6a7-11d8-b9fd-000bdbd1f198")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + public interface IWTSVirtualChannelManager + { + void CreateListener( + [MarshalAs(UnmanagedType.LPStr)] string pszChannelName, + uint uFlags, + [MarshalAs(UnmanagedType.Interface)] IWTSListenerCallback pListenerCallback, + [MarshalAs(UnmanagedType.Interface)] out IWTSListener ppListener); + } + + [ComImport] + [Guid("a1230201-1439-4e62-a414-190d0ac3d40e")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + public interface IWTSPlugin + { + void Initialize([MarshalAs(UnmanagedType.Interface)] IWTSVirtualChannelManager pChannelMgr); + void Connected(); + void Disconnected(uint dwDisconnectCode); + void Terminated(); + } + + public class RdpDvcClient : IWTSVirtualChannelCallback + { + public string channelName; + private IWTSVirtualChannel wtsChannel; + + public event EventHandler OnChannelClose; + + public RdpDvcClient(string name, IWTSVirtualChannel wtsChannel) + { + this.channelName = name; + this.wtsChannel = wtsChannel; + } + + public void SendRawBuffer(uint cbSize, IntPtr pBuffer) + { + wtsChannel?.Write(cbSize, pBuffer, null); + } + + void IWTSVirtualChannelCallback.OnDataReceived(uint cbSize, IntPtr pBuffer) + { + + } + + void IWTSVirtualChannelCallback.OnClose() + { + OnChannelClose?.Invoke(this, EventArgs.Empty); + } + } + + public class RdpDvcListener : IWTSListenerCallback + { + public int maxCount; + public string channelName; + public IWTSListener wtsListener; + + private List clients = new List(); + public List Clients => clients; + + public RdpDvcListener(string name, int maxCount) + { + this.channelName = name; + this.maxCount = maxCount; + } + + void IWTSListenerCallback.OnNewChannelConnection(IWTSVirtualChannel pChannel, + [MarshalAs(UnmanagedType.BStr)] string data, + [MarshalAs(UnmanagedType.Bool)] out bool pAccept, + out IWTSVirtualChannelCallback pCallback) + { + if ((maxCount != -1) && (clients.Count >= maxCount)) + { + pAccept = false; + pCallback = null; + return; + } + + RdpDvcClient client = new RdpDvcClient(channelName, pChannel); + pAccept = true; + pCallback = client; + + client.OnChannelClose += OnChannelClose; + clients.Add(client); + } + + private void OnChannelClose(object sender, EventArgs e) + { + RdpDvcClient client = sender as RdpDvcClient; + clients.Remove(client); + } + + public void OnConnected(object sender, EventArgs e) + { + + } + + public void OnDisconnected(object sender, EventArgs e) + { + + } + + public void OnTerminated(object sender, EventArgs e) + { + + } + } + + public class RdpDvcPlugin : IWTSPlugin + { + public event EventHandler OnConnected; + public event EventHandler OnDisconnected; + public event EventHandler OnTerminated; + + private Dictionary listeners = new Dictionary(); + + void IWTSPlugin.Initialize(IWTSVirtualChannelManager pChannelMgr) + { + string channelName = "DvcSample"; + RdpDvcListener listener = new RdpDvcListener(channelName, -1); + + pChannelMgr.CreateListener(listener.channelName, + 0, listener, out listener.wtsListener); + + listeners[listener.channelName] = listener; + this.OnConnected += listener.OnConnected; + this.OnDisconnected += listener.OnDisconnected; + this.OnTerminated += listener.OnTerminated; + } + + void IWTSPlugin.Connected() + { + OnConnected?.Invoke(this, EventArgs.Empty); + } + + void IWTSPlugin.Disconnected(uint disconnectCode) + { + OnDisconnected?.Invoke(this, EventArgs.Empty); + } + + void IWTSPlugin.Terminated() + { + OnTerminated?.Invoke(this, EventArgs.Empty); + } + } +} diff --git a/include/MsRdpEx/RdpInstance.h b/include/MsRdpEx/RdpInstance.h index e98d80b..9800c7a 100644 --- a/include/MsRdpEx/RdpInstance.h +++ b/include/MsRdpEx/RdpInstance.h @@ -33,6 +33,8 @@ struct __declspec(novtable) virtual void __stdcall UnlockShadowBitmap() = 0; virtual void __stdcall GetLastMousePosition(int32_t* posX, int32_t* posY) = 0; virtual void __stdcall SetLastMousePosition(int32_t posX, int32_t posY) = 0; + virtual HRESULT __stdcall GetWTSPluginObject(LPVOID* ppvObject) = 0; + virtual HRESULT __stdcall SetWTSPluginObject(LPVOID pvObject) = 0; }; class CMsRdpExInstance; diff --git a/include/MsRdpEx/RdpSettings.h b/include/MsRdpEx/RdpSettings.h index c5fbd61..699415d 100644 --- a/include/MsRdpEx/RdpSettings.h +++ b/include/MsRdpEx/RdpSettings.h @@ -68,6 +68,7 @@ class CMsRdpExtendedSettings : public IMsRdpExtendedSettings bool m_MouseJigglerEnabled = false; uint32_t m_MouseJigglerInterval = 60; uint32_t m_MouseJigglerMethod = 0; + IUnknown* m_pWTSPlugin = NULL; }; #ifdef __cplusplus