diff --git a/dll/ApiHooks.cpp b/dll/ApiHooks.cpp index 081b834..5cf693f 100644 --- a/dll/ApiHooks.cpp +++ b/dll/ApiHooks.cpp @@ -717,6 +717,86 @@ ATOM Hook_RegisterClassW(WNDCLASSW* wndClass) return wndClassAtom; } +#define ID_MAINDLG_COMPUTER_TEXTBOX 5012 + +static DLGPROC Real_MainDlgProc = NULL; + +HWND MsRdpEx_GetMsgWindowHandle() +{ + char buffer[32]; + HWND hWndMsg = NULL; + + if (GetEnvironmentVariableA("MSRDPEX_HWNDMSG", buffer, sizeof(buffer)) > 0) + { + hWndMsg = (HWND)_strtoui64(buffer, NULL, 0); + } + + return hWndMsg; +} + +INT_PTR CALLBACK Hook_MainDlgProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam) +{ + INT_PTR dlgResult = FALSE; + bool interceptedCall = false; + + HWND hWndMsg = MsRdpEx_GetMsgWindowHandle(); + + if (hWndMsg) + { + if (uMsg == WM_DESTROY) + { + MsRdpEx_LogPrint(DEBUG, "MainDlgProc WM_DESTROY"); + PostMessage(hWndMsg, 0x0401, NULL, NULL); + } + else if (uMsg == WM_COMMAND) + { + MsRdpEx_LogPrint(DEBUG, "MainDlgProc WM_COMMAND: %d", (int)wParam); + + if (wParam == IDOK) // "Connect" button + { + char connectionStringA[256]; + GetDlgItemTextA(hWnd, ID_MAINDLG_COMPUTER_TEXTBOX, connectionStringA, 256); + MsRdpEx_LogPrint(DEBUG, "MainDlgProc: Connect(%s)", connectionStringA); + + PostMessage(hWndMsg, 0x0402, NULL, NULL); + interceptedCall = true; + dlgResult = TRUE; + } + } + } + + if (!interceptedCall) { + dlgResult = Real_MainDlgProc(hWnd, uMsg, wParam, lParam); + } + + return dlgResult; +} + +HWND (WINAPI* Real_CreateDialogParamW)(HINSTANCE hInstance, + LPCWSTR lpTemplateName, HWND hWndParent, + DLGPROC lpDialogFunc, LPARAM dwInitParam) = CreateDialogParamW; + +HWND Hook_CreateDialogParamW(HINSTANCE hInstance, + LPCWSTR lpTemplateName, HWND hWndParent, + DLGPROC lpDialogFunc, LPARAM dwInitParam) +{ + HWND hWndDialog; + HWND hWndMsg = MsRdpEx_GetMsgWindowHandle(); + + if (hWndMsg) + { + if (lpTemplateName == MAKEINTRESOURCE(15001)) { + MsRdpEx_LogPrint(DEBUG, "CreateDialogParamW: CMainDlg"); + Real_MainDlgProc = lpDialogFunc; + lpDialogFunc = Hook_MainDlgProc; + } + } + + hWndDialog = Real_CreateDialogParamW(hInstance, lpTemplateName, hWndParent, lpDialogFunc, dwInitParam); + + return hWndDialog; +} + BOOL (WINAPI * Real_CredReadW)(LPCWSTR TargetName, DWORD Type, DWORD Flags, PCREDENTIALW* Credential) = CredReadW; BOOL Hook_CredReadW(LPCWSTR TargetName, DWORD Type, DWORD Flags, PCREDENTIALW* Credential) @@ -1040,6 +1120,7 @@ LONG MsRdpEx_AttachHooks() MSRDPEX_DETOUR_ATTACH(Real_StretchBlt, Hook_StretchBlt); MSRDPEX_DETOUR_ATTACH(Real_RegisterClassExW, Hook_RegisterClassExW); MSRDPEX_DETOUR_ATTACH(Real_RegisterClassW, Hook_RegisterClassW); + MSRDPEX_DETOUR_ATTACH(Real_CreateDialogParamW, Hook_CreateDialogParamW); MSRDPEX_DETOUR_ATTACH(Real_CredReadW, Hook_CredReadW); //MSRDPEX_DETOUR_ATTACH(Real_CryptProtectMemory, Hook_CryptProtectMemory); @@ -1093,6 +1174,7 @@ LONG MsRdpEx_DetachHooks() MSRDPEX_DETOUR_DETACH(Real_StretchBlt, Hook_StretchBlt); MSRDPEX_DETOUR_DETACH(Real_RegisterClassExW, Hook_RegisterClassExW); MSRDPEX_DETOUR_DETACH(Real_RegisterClassW, Hook_RegisterClassW); + MSRDPEX_DETOUR_DETACH(Real_CreateDialogParamW, Hook_CreateDialogParamW); MSRDPEX_DETOUR_DETACH(Real_CredReadW, Hook_CredReadW); //MSRDPEX_DETOUR_DETACH(Real_CryptProtectMemory, Hook_CryptProtectMemory); diff --git a/dll/CMakeLists.txt b/dll/CMakeLists.txt index c9e0734..1da2e98 100644 --- a/dll/CMakeLists.txt +++ b/dll/CMakeLists.txt @@ -48,7 +48,6 @@ set(MSRDPEX_SOURCES String.c MsRdpClient.cpp MsRdpClient.h - AxDll.cpp ApiHooks.cpp OutputMirror.c VideoRecorder.c @@ -69,9 +68,20 @@ add_library(MsRdpEx_Dll SHARED ${MSRDPEX_SOURCES} ${MSRDPEX_HEADERS} ${MSRDPEX_RESOURCES} + AxDll.cpp MsRdpEx.def) -set_target_properties(MsRdpEx_Dll PROPERTIES OUTPUT_NAME "MsRdpEx") +set(MSRDPEX_LIBS + detours + version.lib + userenv.lib + user32.lib + rpcrt4.lib + ws2_32.lib + secur32.lib + credui.lib + advapi32.lib) -target_link_libraries(MsRdpEx_Dll detours - version.lib userenv.lib user32.lib rpcrt4.lib ws2_32.lib secur32.lib credui.lib advapi32.lib) +target_link_libraries(MsRdpEx_Dll ${MSRDPEX_LIBS}) + +set_target_properties(MsRdpEx_Dll PROPERTIES OUTPUT_NAME "MsRdpEx") diff --git a/dll/MsRdpEx.def b/dll/MsRdpEx.def index 21bae31..7257e31 100644 --- a/dll/MsRdpEx.def +++ b/dll/MsRdpEx.def @@ -13,10 +13,15 @@ EXPORTS DllCancelAuthentication DllDeleteSavedCreds DllPreCleanUp - MsRdpEx_CreateInstance - MsRdpEx_LaunchProcess MsRdpEx_GetClaimsToken MsRdpEx_LogoffClaimsToken MsRdpEx_CancelAuthentication MsRdpEx_DeleteSavedCreds MsRdpEx_PreCleanUp + MsRdpEx_InitPaths + MsRdpEx_GetPath + MsRdpEx_CreateInstance + MsRdpEx_LaunchProcess + MsRdpExProcess_CreateInstance + MsRdpEx_GetArgumentVector + MsRdpEx_FreeArgumentVector \ No newline at end of file diff --git a/dll/Paths.c b/dll/Paths.c index 446f556..229b8d5 100644 --- a/dll/Paths.c +++ b/dll/Paths.c @@ -17,6 +17,8 @@ static char g_MSTSCAX_DLL_PATH[MSRDPEX_MAX_PATH] = { 0 }; static char g_MSRDC_EXE_PATH[MSRDPEX_MAX_PATH] = { 0 }; static char g_RDCLIENTAX_DLL_PATH[MSRDPEX_MAX_PATH] = { 0 }; +static char g_DEFAULT_RDP_PATH[MSRDPEX_MAX_PATH] = { 0 }; + bool MsRdpEx_PathCchRenameExtension(char* pszPath, size_t cchPath, const char* pszExt) { size_t length = strlen(pszPath); @@ -122,6 +124,10 @@ bool MsRdpEx_InitPaths(uint32_t pathIds) } } + if (pathIds & MSRDPEX_DEFAULT_RDP_PATH) { + ExpandEnvironmentStringsA("%UserProfile%\\Documents\\Default.rdp", g_DEFAULT_RDP_PATH, MSRDPEX_MAX_PATH); + } + return true; } @@ -166,6 +172,10 @@ const char* MsRdpEx_GetPath(uint32_t pathId) case MSRDPEX_RDCLIENTAX_DLL_PATH: path = (const char*) g_RDCLIENTAX_DLL_PATH; break; + + case MSRDPEX_DEFAULT_RDP_PATH: + path = (const char*) g_DEFAULT_RDP_PATH; + break; } return path; diff --git a/dll/RdpProcess.cpp b/dll/RdpProcess.cpp index e81bd41..59147ef 100644 --- a/dll/RdpProcess.cpp +++ b/dll/RdpProcess.cpp @@ -9,23 +9,6 @@ extern "C" const GUID IID_IMsRdpExProcess; -struct __declspec(novtable) - IMsRdpExProcess : public IUnknown -{ -public: - virtual void __stdcall SetFileName(const char* filename) = 0; - virtual void __stdcall SetArguments(const char* arguments) = 0; - virtual void __stdcall SetArgumentBlock(const char* argumentBlock) = 0; - virtual void __stdcall SetEnvironmentBlock(const char* environmentBlock) = 0; - virtual void __stdcall SetWorkingDirectory(const char* workingDirectory) = 0; - virtual HRESULT __stdcall StartWithInfo() = 0; - virtual HRESULT __stdcall Start(int argc, char** argv, const char* appName, const char* axName) = 0; - virtual HRESULT __stdcall Stop(uint32_t exitCode) = 0; - virtual HRESULT __stdcall Wait(uint32_t milliseconds) = 0; - virtual uint32_t __stdcall GetProcessId() = 0; - virtual uint32_t __stdcall GetExitCode() = 0; -}; - class CMsRdpExProcess : public IMsRdpExProcess { public: @@ -33,6 +16,7 @@ class CMsRdpExProcess : public IMsRdpExProcess { m_refCount = 0; m_exitCode = 0; + m_hasExited = false; m_filename = NULL; m_arguments = NULL; m_argumentBlock = NULL; diff --git a/exe/msrdcex/CMakeLists.txt b/exe/msrdcex/CMakeLists.txt index ed648e8..a6e0f30 100644 --- a/exe/msrdcex/CMakeLists.txt +++ b/exe/msrdcex/CMakeLists.txt @@ -11,7 +11,7 @@ windows_rc_generate_version_info( source_group("Resources" FILES msrdcex.rc) add_executable(msrdcex WIN32 - msrdcex.c + msrdcex.cpp msrdcex.rc) target_link_libraries(msrdcex MsRdpEx_Dll) diff --git a/exe/msrdcex/msrdcex.cpp b/exe/msrdcex/msrdcex.cpp new file mode 100644 index 0000000..f2cb82e --- /dev/null +++ b/exe/msrdcex/msrdcex.cpp @@ -0,0 +1,120 @@ +#include + +LRESULT CALLBACK WrapperMsgWindowProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam) +{ + switch (uMsg) + { + case 0x401: // WM_QUIT + PostQuitMessage(uMsg); + break; + + case 0x402: // Connect button + PostQuitMessage(uMsg); + break; + + case WM_DESTROY: + PostQuitMessage(0); + return 0; + } + + return DefWindowProc(hWnd, uMsg, wParam, lParam); +} + +HWND CreateWrapperMsgWindow(HINSTANCE hInstance) +{ + WNDCLASS wc = { 0 }; + wc.lpfnWndProc = WrapperMsgWindowProc; + wc.hInstance = hInstance; + wc.lpszClassName = L"MsRdpEx_WrapperMsgWindow"; + + if (!RegisterClass(&wc)) { + return NULL; + } + + HWND hWndMsg = CreateWindowEx(0, + L"MsRdpEx_WrapperMsgWindow", + L"MsRdpEx_WrapperMsgWindow", + 0, 0, 0, 0, 0, HWND_MESSAGE, + NULL, hInstance, NULL); + + return hWndMsg; +} + +BOOL MsRdpEx_SetMsgWindowHandle(HWND hWndMsg) +{ + char buffer[32]; + _ui64toa((unsigned long long)hWndMsg, buffer, 10); + return SetEnvironmentVariableA("MSRDPEX_HWNDMSG", buffer) > 0; +} + +int WINAPI wWinMain( + _In_ HINSTANCE hInstance, + _In_opt_ HINSTANCE hPrevInstance, + _In_ LPWSTR lpCmdLine, + _In_ int nShowCmd) +{ + HRESULT hr; + char mstsc_args[2048]; + char msrdc_args[2048]; + IMsRdpExProcess* mstsc = NULL; + IMsRdpExProcess* msrdc = NULL; + + if (__argc >= 2) + { + // we launched msrdc with command-line arguments + hr = MsRdpEx_LaunchProcess(-1, NULL, NULL, "msrdc"); + return 0; + } + + MsRdpEx_InitPaths(MSRDPEX_ALL_PATHS); + + HWND hWndMsg = CreateWrapperMsgWindow(hInstance); + + MsRdpEx_SetMsgWindowHandle(hWndMsg); + + const char* mstsc_exe = MsRdpEx_GetPath(MSRDPEX_MSTSC_EXE_PATH); + const char* msrdc_exe = MsRdpEx_GetPath(MSRDPEX_MSRDC_EXE_PATH); + const char* default_rdp = MsRdpEx_GetPath(MSRDPEX_DEFAULT_RDP_PATH); + + hr = MsRdpExProcess_CreateInstance((LPVOID*)&mstsc); + + sprintf_s(mstsc_args, sizeof(mstsc_args) - 1, "\"%s\"", + mstsc_exe); + + mstsc->AddRef(); + mstsc->SetFileName(mstsc_exe); + mstsc->SetArguments(mstsc_args); + hr = mstsc->StartWithInfo(); + + MSG msg = { 0 }; + + while (GetMessage(&msg, hWndMsg, 0, 0)) + { + TranslateMessage(&msg); + DispatchMessage(&msg); + } + + int exitCode = (int)msg.wParam; + + mstsc->Stop(0); + mstsc->Release(); + mstsc = NULL; + + if (exitCode == 0x402) + { + hr = MsRdpExProcess_CreateInstance((LPVOID*)&msrdc); + + sprintf_s(msrdc_args, sizeof(msrdc_args) - 1, "\"%s\" \"%s\"", + msrdc_exe, default_rdp); + + msrdc->AddRef(); + msrdc->SetFileName(msrdc_exe); + msrdc->SetArguments(msrdc_args); + msrdc->StartWithInfo(); + msrdc->Wait(INFINITE); + msrdc->Release(); + msrdc = NULL; + } + + return 0; +} diff --git a/exe/mstscex/CMakeLists.txt b/exe/mstscex/CMakeLists.txt index cf6a009..c521dab 100644 --- a/exe/mstscex/CMakeLists.txt +++ b/exe/mstscex/CMakeLists.txt @@ -11,7 +11,7 @@ windows_rc_generate_version_info( source_group("Resources" FILES mstscex.rc) add_executable(mstscex WIN32 - mstscex.c + mstscex.cpp mstscex.rc) target_link_libraries(mstscex MsRdpEx_Dll) diff --git a/exe/mstscex/mstscex.c b/exe/mstscex/mstscex.c deleted file mode 100644 index c0ebe30..0000000 --- a/exe/mstscex/mstscex.c +++ /dev/null @@ -1,15 +0,0 @@ -#include - -int WINAPI wWinMain( - _In_ HINSTANCE hInstance, - _In_opt_ HINSTANCE hPrevInstance, - _In_ LPWSTR lpCmdLine, - _In_ int nShowCmd) -{ - HRESULT hr; - const char* axName = "mstsc"; - - hr = MsRdpEx_LaunchProcess(-1, NULL, NULL, axName); - - return 0; -} diff --git a/exe/msrdcex/msrdcex.c b/exe/mstscex/mstscex.cpp similarity index 69% rename from exe/msrdcex/msrdcex.c rename to exe/mstscex/mstscex.cpp index 3a44de5..1a09964 100644 --- a/exe/msrdcex/msrdcex.c +++ b/exe/mstscex/mstscex.cpp @@ -7,9 +7,8 @@ int WINAPI wWinMain( _In_ int nShowCmd) { HRESULT hr; - const char* axName = "msrdc"; - hr = MsRdpEx_LaunchProcess(-1, NULL, NULL, axName); + hr = MsRdpEx_LaunchProcess(-1, NULL, NULL, "mstsc"); return 0; } diff --git a/include/MsRdpEx/MsRdpEx.h b/include/MsRdpEx/MsRdpEx.h index 8b7e6fd..44c6cef 100644 --- a/include/MsRdpEx/MsRdpEx.h +++ b/include/MsRdpEx/MsRdpEx.h @@ -64,6 +64,7 @@ HMODULE MsRdpEx_LoadLibrary(const char* filename); #define MSRDPEX_MSTSCAX_DLL_PATH 0x00000200 #define MSRDPEX_MSRDC_EXE_PATH 0x00000400 #define MSRDPEX_RDCLIENTAX_DLL_PATH 0x00000800 +#define MSRDPEX_DEFAULT_RDP_PATH 0x00001000 #define MSRDPEX_ALL_PATHS 0xFFFFFFFF bool MsRdpEx_InitPaths(uint32_t pathIds); diff --git a/include/MsRdpEx/RdpProcess.h b/include/MsRdpEx/RdpProcess.h index 29aa419..750bf11 100644 --- a/include/MsRdpEx/RdpProcess.h +++ b/include/MsRdpEx/RdpProcess.h @@ -3,6 +3,25 @@ #include +#include + +struct __declspec(novtable) + IMsRdpExProcess : public IUnknown +{ +public: + virtual void __stdcall SetFileName(const char* filename) = 0; + virtual void __stdcall SetArguments(const char* arguments) = 0; + virtual void __stdcall SetArgumentBlock(const char* argumentBlock) = 0; + virtual void __stdcall SetEnvironmentBlock(const char* environmentBlock) = 0; + virtual void __stdcall SetWorkingDirectory(const char* workingDirectory) = 0; + virtual HRESULT __stdcall StartWithInfo() = 0; + virtual HRESULT __stdcall Start(int argc, char** argv, const char* appName, const char* axName) = 0; + virtual HRESULT __stdcall Stop(uint32_t exitCode) = 0; + virtual HRESULT __stdcall Wait(uint32_t milliseconds) = 0; + virtual uint32_t __stdcall GetProcessId() = 0; + virtual uint32_t __stdcall GetExitCode() = 0; +}; + #ifdef __cplusplus extern "C" { #endif