From 104864f50156ff38bcf1551a775a0f3f44307ea8 Mon Sep 17 00:00:00 2001 From: Tom Plant Date: Sun, 27 Oct 2024 12:27:06 +1100 Subject: [PATCH] Use AMSI for archive malware scanning * Remove `#include ` and add necessary includes for AMSI in `src/AppInstallerCommonCore/Archive.cpp` * Initialize AMSI, create a session, scan the file, and handle results in `ScanZipFile` function * Add tests for new archive formats in `src/AppInstallerCLITests/Archive.cpp` - Add test cases for 7z, Rar, TarGz, and TarBz2 archive formats - Verify extraction and scanning of these new archive formats --- src/AppInstallerCLITests/Archive.cpp | 169 ++++++++++++++++++----- src/AppInstallerCommonCore/Archive.cpp | 177 ++++++++++++++----------- 2 files changed, 233 insertions(+), 113 deletions(-) diff --git a/src/AppInstallerCLITests/Archive.cpp b/src/AppInstallerCLITests/Archive.cpp index 4d415761f1..32d826314f 100644 --- a/src/AppInstallerCLITests/Archive.cpp +++ b/src/AppInstallerCLITests/Archive.cpp @@ -1,34 +1,135 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -#include "pch.h" -#include "TestCommon.h" -#include - -using namespace AppInstaller::Archive; -using namespace TestCommon; - -constexpr std::string_view s_ZipFile = "TestZip.zip"; - -TEST_CASE("Extract_ZipArchive", "[archive]") -{ - TestCommon::TempDirectory tempDirectory("TempDirectory"); - TestDataFile testZip(s_ZipFile); - - const auto& testZipPath = testZip.GetPath(); - const auto& tempDirectoryPath = tempDirectory.GetPath(); - - HRESULT hr = TryExtractArchive(testZipPath, tempDirectoryPath); - - std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; - REQUIRE(SUCCEEDED(hr)); - REQUIRE(std::filesystem::exists(expectedPath)); -} - -TEST_CASE("Scan_ZipArchive", "[archive]") -{ - TestDataFile testZip(s_ZipFile); - - const auto& testZipPath = testZip.GetPath(); - bool result = ScanZipFile(testZipPath); - REQUIRE(result); -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include "pch.h" +#include "TestCommon.h" +#include + +using namespace AppInstaller::Archive; +using namespace TestCommon; + +constexpr std::string_view s_ZipFile = "TestZip.zip"; +constexpr std::string_view s_7zFile = "Test7z.7z"; +constexpr std::string_view s_RarFile = "TestRar.rar"; +constexpr std::string_view s_TarGzFile = "TestTarGz.tar.gz"; +constexpr std::string_view s_TarBz2File = "TestTarBz2.tar.bz2"; + +TEST_CASE("Extract_ZipArchive", "[archive]") +{ + TestCommon::TempDirectory tempDirectory("TempDirectory"); + TestDataFile testZip(s_ZipFile); + + const auto& testZipPath = testZip.GetPath(); + const auto& tempDirectoryPath = tempDirectory.GetPath(); + + HRESULT hr = TryExtractArchive(testZipPath, tempDirectoryPath); + + std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; + REQUIRE(SUCCEEDED(hr)); + REQUIRE(std::filesystem::exists(expectedPath)); +} + +TEST_CASE("Scan_ZipArchive", "[archive]") +{ + TestDataFile testZip(s_ZipFile); + + const auto& testZipPath = testZip.GetPath(); + bool result = ScanZipFile(testZipPath); + REQUIRE(result); +} + +TEST_CASE("Extract_7zArchive", "[archive]") +{ + TestCommon::TempDirectory tempDirectory("TempDirectory"); + TestDataFile test7z(s_7zFile); + + const auto& test7zPath = test7z.GetPath(); + const auto& tempDirectoryPath = tempDirectory.GetPath(); + + HRESULT hr = TryExtractArchive(test7zPath, tempDirectoryPath); + + std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; + REQUIRE(SUCCEEDED(hr)); + REQUIRE(std::filesystem::exists(expectedPath)); +} + +TEST_CASE("Scan_7zArchive", "[archive]") +{ + TestDataFile test7z(s_7zFile); + + const auto& test7zPath = test7z.GetPath(); + bool result = ScanZipFile(test7zPath); + REQUIRE(result); +} + +TEST_CASE("Extract_RarArchive", "[archive]") +{ + TestCommon::TempDirectory tempDirectory("TempDirectory"); + TestDataFile testRar(s_RarFile); + + const auto& testRarPath = testRar.GetPath(); + const auto& tempDirectoryPath = tempDirectory.GetPath(); + + HRESULT hr = TryExtractArchive(testRarPath, tempDirectoryPath); + + std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; + REQUIRE(SUCCEEDED(hr)); + REQUIRE(std::filesystem::exists(expectedPath)); +} + +TEST_CASE("Scan_RarArchive", "[archive]") +{ + TestDataFile testRar(s_RarFile); + + const auto& testRarPath = testRar.GetPath(); + bool result = ScanZipFile(testRarPath); + REQUIRE(result); +} + +TEST_CASE("Extract_TarGzArchive", "[archive]") +{ + TestCommon::TempDirectory tempDirectory("TempDirectory"); + TestDataFile testTarGz(s_TarGzFile); + + const auto& testTarGzPath = testTarGz.GetPath(); + const auto& tempDirectoryPath = tempDirectory.GetPath(); + + HRESULT hr = TryExtractArchive(testTarGzPath, tempDirectoryPath); + + std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; + REQUIRE(SUCCEEDED(hr)); + REQUIRE(std::filesystem::exists(expectedPath)); +} + +TEST_CASE("Scan_TarGzArchive", "[archive]") +{ + TestDataFile testTarGz(s_TarGzFile); + + const auto& testTarGzPath = testTarGz.GetPath(); + bool result = ScanZipFile(testTarGzPath); + REQUIRE(result); +} + +TEST_CASE("Extract_TarBz2Archive", "[archive]") +{ + TestCommon::TempDirectory tempDirectory("TempDirectory"); + TestDataFile testTarBz2(s_TarBz2File); + + const auto& testTarBz2Path = testTarBz2.GetPath(); + const auto& tempDirectoryPath = tempDirectory.GetPath(); + + HRESULT hr = TryExtractArchive(testTarBz2Path, tempDirectoryPath); + + std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; + REQUIRE(SUCCEEDED(hr)); + REQUIRE(std::filesystem::exists(expectedPath)); +} + +TEST_CASE("Scan_TarBz2Archive", "[archive]") +{ + TestDataFile testTarBz2(s_TarBz2File); + + const auto& testTarBz2Path = testTarBz2.GetPath(); + bool result = ScanZipFile(testTarBz2Path); + REQUIRE(result); +} + diff --git a/src/AppInstallerCommonCore/Archive.cpp b/src/AppInstallerCommonCore/Archive.cpp index 95766519e9..cbf19b142e 100644 --- a/src/AppInstallerCommonCore/Archive.cpp +++ b/src/AppInstallerCommonCore/Archive.cpp @@ -1,79 +1,98 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -#include "pch.h" -#include "Public/winget/Archive.h" - -// TODO: Move include statement to pch.h and resolve build errors -#pragma warning( push ) -#pragma warning ( disable : 4189 4244 26451 ) -#include -#pragma warning ( pop ) - -namespace AppInstaller::Archive -{ - using unique_pidlist_absolute = wil::unique_any; - using unique_lpitemidlist = wil::unique_any; - - HRESULT TryExtractArchive(const std::filesystem::path& archivePath, const std::filesystem::path& destPath) - { - wil::com_ptr pFileOperation; - RETURN_IF_FAILED(CoCreateInstance(CLSID_FileOperation, NULL, CLSCTX_ALL, IID_PPV_ARGS(&pFileOperation))); - RETURN_IF_FAILED(pFileOperation->SetOperationFlags(FOF_NO_UI)); - - wil::com_ptr pShellItemTo; - RETURN_IF_FAILED(SHCreateItemFromParsingName(destPath.c_str(), NULL, IID_PPV_ARGS(&pShellItemTo))); - - unique_pidlist_absolute pidlFull; - RETURN_IF_FAILED(SHParseDisplayName(archivePath.c_str(), NULL, &pidlFull, 0, NULL)); - - wil::com_ptr pArchiveShellFolder; - RETURN_IF_FAILED(SHBindToObject(NULL, pidlFull.get(), NULL, IID_PPV_ARGS(&pArchiveShellFolder))); - - wil::com_ptr pEnumIdList; - RETURN_IF_FAILED(pArchiveShellFolder->EnumObjects(nullptr, SHCONTF_FOLDERS | SHCONTF_NONFOLDERS, &pEnumIdList)); - - unique_lpitemidlist pidlChild; - ULONG nFetched; - while (pEnumIdList->Next(1, wil::out_param_ptr(pidlChild), &nFetched) == S_OK && nFetched == 1) - { - wil::com_ptr pShellItemFrom; - STRRET strFolderName; - WCHAR szFolderName[MAX_PATH]; - RETURN_IF_FAILED(pArchiveShellFolder->GetDisplayNameOf(pidlChild.get(), SHGDN_INFOLDER | SHGDN_FORPARSING, &strFolderName)); - RETURN_IF_FAILED(StrRetToBuf(&strFolderName, pidlChild.get(), szFolderName, MAX_PATH)); - RETURN_IF_FAILED(SHCreateItemWithParent(pidlFull.get(), pArchiveShellFolder.get(), pidlChild.get(), IID_PPV_ARGS(&pShellItemFrom))); - RETURN_IF_FAILED(pFileOperation->CopyItem(pShellItemFrom.get(), pShellItemTo.get(), NULL, NULL)); - } - - RETURN_IF_FAILED(pFileOperation->PerformOperations()); - return S_OK; - } - -#ifndef AICLI_DISABLE_TEST_HOOKS - static bool* s_ScanArchiveResult_TestHook_Override = nullptr; - - void TestHook_SetScanArchiveResult_Override(bool* status) - { - s_ScanArchiveResult_TestHook_Override = status; - } -#endif - - bool ScanZipFile(const std::filesystem::path& zipPath) - { -#ifndef AICLI_DISABLE_TEST_HOOKS - if (s_ScanArchiveResult_TestHook_Override) - { - return *s_ScanArchiveResult_TestHook_Override; - } -#endif - - std::ifstream instream{ zipPath, std::ios::in | std::ios::binary }; - std::vector data{ { std::istreambuf_iterator{ instream } }, std::istreambuf_iterator{} }; - - uint8_t* buffer = &data[0]; - uint64_t flag = 0; - int scanResult = pure_zip(buffer, data.size(), flag); - - return scanResult == 0; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include "pch.h" +#include "Public/winget/Archive.h" + +#include +#include +#include +#include + +namespace AppInstaller::Archive +{ + using unique_pidlist_absolute = wil::unique_any; + using unique_lpitemidlist = wil::unique_any; + + HRESULT TryExtractArchive(const std::filesystem::path& archivePath, const std::filesystem::path& destPath) + { + wil::com_ptr pFileOperation; + RETURN_IF_FAILED(CoCreateInstance(CLSID_FileOperation, NULL, CLSCTX_ALL, IID_PPV_ARGS(&pFileOperation))); + RETURN_IF_FAILED(pFileOperation->SetOperationFlags(FOF_NO_UI)); + + wil::com_ptr pShellItemTo; + RETURN_IF_FAILED(SHCreateItemFromParsingName(destPath.c_str(), NULL, IID_PPV_ARGS(&pShellItemTo))); + + unique_pidlist_absolute pidlFull; + RETURN_IF_FAILED(SHParseDisplayName(archivePath.c_str(), NULL, &pidlFull, 0, NULL)); + + wil::com_ptr pArchiveShellFolder; + RETURN_IF_FAILED(SHBindToObject(NULL, pidlFull.get(), NULL, IID_PPV_ARGS(&pArchiveShellFolder))); + + wil::com_ptr pEnumIdList; + RETURN_IF_FAILED(pArchiveShellFolder->EnumObjects(nullptr, SHCONTF_FOLDERS | SHCONTF_NONFOLDERS, &pEnumIdList)); + + unique_lpitemidlist pidlChild; + ULONG nFetched; + while (pEnumIdList->Next(1, wil::out_param_ptr(pidlChild), &nFetched) == S_OK && nFetched == 1) + { + wil::com_ptr pShellItemFrom; + STRRET strFolderName; + WCHAR szFolderName[MAX_PATH]; + RETURN_IF_FAILED(pArchiveShellFolder->GetDisplayNameOf(pidlChild.get(), SHGDN_INFOLDER | SHGDN_FORPARSING, &strFolderName)); + RETURN_IF_FAILED(StrRetToBuf(&strFolderName, pidlChild.get(), szFolderName, MAX_PATH)); + RETURN_IF_FAILED(SHCreateItemWithParent(pidlFull.get(), pArchiveShellFolder.get(), pidlChild.get(), IID_PPV_ARGS(&pShellItemFrom))); + RETURN_IF_FAILED(pFileOperation->CopyItem(pShellItemFrom.get(), pShellItemTo.get(), NULL, NULL)); + } + + RETURN_IF_FAILED(pFileOperation->PerformOperations()); + return S_OK; + } + +#ifndef AICLI_DISABLE_TEST_HOOKS + static bool* s_ScanArchiveResult_TestHook_Override = nullptr; + + void TestHook_SetScanArchiveResult_Override(bool* status) + { + s_ScanArchiveResult_TestHook_Override = status; + } +#endif + + bool ScanZipFile(const std::filesystem::path& zipPath) + { +#ifndef AICLI_DISABLE_TEST_HOOKS + if (s_ScanArchiveResult_TestHook_Override) + { + return *s_ScanArchiveResult_TestHook_Override; + } +#endif + + HRESULT hr = S_OK; + wil::com_ptr_nothrow amsiContext; + wil::com_ptr_nothrow amsiSession; + + hr = AmsiInitialize(L"WinGet", &amsiContext); + if (FAILED(hr)) + { + return false; + } + + hr = AmsiOpenSession(amsiContext.get(), &amsiSession); + if (FAILED(hr)) + { + AmsiUninitialize(amsiContext.get()); + return false; + } + + std::ifstream instream{ zipPath, std::ios::in | std::ios::binary }; + std::vector data{ { std::istreambuf_iterator{ instream } }, std::istreambuf_iterator{} }; + + AMSI_RESULT result = AMSI_RESULT_CLEAN; + hr = AmsiScanBuffer(amsiContext.get(), data.data(), data.size(), zipPath.filename().c_str(), amsiSession.get(), &result); + + AmsiCloseSession(amsiContext.get(), amsiSession.get()); + AmsiUninitialize(amsiContext.get()); + + return SUCCEEDED(hr) && (result == AMSI_RESULT_CLEAN || result == AMSI_RESULT_NOT_DETECTED); + } +} +