-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use AMSI for archive malware scanning
* Remove `#include <pure.h>` and add necessary includes for AMSI in `src/AppInstallerCommonCore/Archive.cpp` * Initialize AMSI, create a session, scan the file, and handle results in `ScanZipFile` function * Add tests for new archive formats in `src/AppInstallerCLITests/Archive.cpp` - Add test cases for 7z, Rar, TarGz, and TarBz2 archive formats - Verify extraction and scanning of these new archive formats
- Loading branch information
Showing
2 changed files
with
233 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,135 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
#include "pch.h" | ||
#include "TestCommon.h" | ||
#include <winget/Archive.h> | ||
|
||
using namespace AppInstaller::Archive; | ||
using namespace TestCommon; | ||
|
||
constexpr std::string_view s_ZipFile = "TestZip.zip"; | ||
|
||
TEST_CASE("Extract_ZipArchive", "[archive]") | ||
{ | ||
TestCommon::TempDirectory tempDirectory("TempDirectory"); | ||
TestDataFile testZip(s_ZipFile); | ||
|
||
const auto& testZipPath = testZip.GetPath(); | ||
const auto& tempDirectoryPath = tempDirectory.GetPath(); | ||
|
||
HRESULT hr = TryExtractArchive(testZipPath, tempDirectoryPath); | ||
|
||
std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; | ||
REQUIRE(SUCCEEDED(hr)); | ||
REQUIRE(std::filesystem::exists(expectedPath)); | ||
} | ||
|
||
TEST_CASE("Scan_ZipArchive", "[archive]") | ||
{ | ||
TestDataFile testZip(s_ZipFile); | ||
|
||
const auto& testZipPath = testZip.GetPath(); | ||
bool result = ScanZipFile(testZipPath); | ||
REQUIRE(result); | ||
} | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
#include "pch.h" | ||
#include "TestCommon.h" | ||
#include <winget/Archive.h> | ||
|
||
using namespace AppInstaller::Archive; | ||
using namespace TestCommon; | ||
|
||
constexpr std::string_view s_ZipFile = "TestZip.zip"; | ||
constexpr std::string_view s_7zFile = "Test7z.7z"; | ||
constexpr std::string_view s_RarFile = "TestRar.rar"; | ||
constexpr std::string_view s_TarGzFile = "TestTarGz.tar.gz"; | ||
constexpr std::string_view s_TarBz2File = "TestTarBz2.tar.bz2"; | ||
|
||
TEST_CASE("Extract_ZipArchive", "[archive]") | ||
{ | ||
TestCommon::TempDirectory tempDirectory("TempDirectory"); | ||
TestDataFile testZip(s_ZipFile); | ||
|
||
const auto& testZipPath = testZip.GetPath(); | ||
const auto& tempDirectoryPath = tempDirectory.GetPath(); | ||
|
||
HRESULT hr = TryExtractArchive(testZipPath, tempDirectoryPath); | ||
|
||
std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; | ||
REQUIRE(SUCCEEDED(hr)); | ||
REQUIRE(std::filesystem::exists(expectedPath)); | ||
} | ||
|
||
TEST_CASE("Scan_ZipArchive", "[archive]") | ||
{ | ||
TestDataFile testZip(s_ZipFile); | ||
|
||
const auto& testZipPath = testZip.GetPath(); | ||
bool result = ScanZipFile(testZipPath); | ||
REQUIRE(result); | ||
} | ||
|
||
TEST_CASE("Extract_7zArchive", "[archive]") | ||
{ | ||
TestCommon::TempDirectory tempDirectory("TempDirectory"); | ||
TestDataFile test7z(s_7zFile); | ||
|
||
const auto& test7zPath = test7z.GetPath(); | ||
const auto& tempDirectoryPath = tempDirectory.GetPath(); | ||
|
||
HRESULT hr = TryExtractArchive(test7zPath, tempDirectoryPath); | ||
|
||
std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; | ||
REQUIRE(SUCCEEDED(hr)); | ||
REQUIRE(std::filesystem::exists(expectedPath)); | ||
} | ||
|
||
TEST_CASE("Scan_7zArchive", "[archive]") | ||
{ | ||
TestDataFile test7z(s_7zFile); | ||
|
||
const auto& test7zPath = test7z.GetPath(); | ||
bool result = ScanZipFile(test7zPath); | ||
REQUIRE(result); | ||
} | ||
|
||
TEST_CASE("Extract_RarArchive", "[archive]") | ||
{ | ||
TestCommon::TempDirectory tempDirectory("TempDirectory"); | ||
TestDataFile testRar(s_RarFile); | ||
|
||
const auto& testRarPath = testRar.GetPath(); | ||
const auto& tempDirectoryPath = tempDirectory.GetPath(); | ||
|
||
HRESULT hr = TryExtractArchive(testRarPath, tempDirectoryPath); | ||
|
||
std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; | ||
REQUIRE(SUCCEEDED(hr)); | ||
REQUIRE(std::filesystem::exists(expectedPath)); | ||
} | ||
|
||
TEST_CASE("Scan_RarArchive", "[archive]") | ||
{ | ||
TestDataFile testRar(s_RarFile); | ||
|
||
const auto& testRarPath = testRar.GetPath(); | ||
bool result = ScanZipFile(testRarPath); | ||
REQUIRE(result); | ||
} | ||
|
||
TEST_CASE("Extract_TarGzArchive", "[archive]") | ||
{ | ||
TestCommon::TempDirectory tempDirectory("TempDirectory"); | ||
TestDataFile testTarGz(s_TarGzFile); | ||
|
||
const auto& testTarGzPath = testTarGz.GetPath(); | ||
const auto& tempDirectoryPath = tempDirectory.GetPath(); | ||
|
||
HRESULT hr = TryExtractArchive(testTarGzPath, tempDirectoryPath); | ||
|
||
std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; | ||
REQUIRE(SUCCEEDED(hr)); | ||
REQUIRE(std::filesystem::exists(expectedPath)); | ||
} | ||
|
||
TEST_CASE("Scan_TarGzArchive", "[archive]") | ||
{ | ||
TestDataFile testTarGz(s_TarGzFile); | ||
|
||
const auto& testTarGzPath = testTarGz.GetPath(); | ||
bool result = ScanZipFile(testTarGzPath); | ||
REQUIRE(result); | ||
} | ||
|
||
TEST_CASE("Extract_TarBz2Archive", "[archive]") | ||
{ | ||
TestCommon::TempDirectory tempDirectory("TempDirectory"); | ||
TestDataFile testTarBz2(s_TarBz2File); | ||
|
||
const auto& testTarBz2Path = testTarBz2.GetPath(); | ||
const auto& tempDirectoryPath = tempDirectory.GetPath(); | ||
|
||
HRESULT hr = TryExtractArchive(testTarBz2Path, tempDirectoryPath); | ||
|
||
std::filesystem::path expectedPath = tempDirectoryPath / "test.txt"; | ||
REQUIRE(SUCCEEDED(hr)); | ||
REQUIRE(std::filesystem::exists(expectedPath)); | ||
} | ||
|
||
TEST_CASE("Scan_TarBz2Archive", "[archive]") | ||
{ | ||
TestDataFile testTarBz2(s_TarBz2File); | ||
|
||
const auto& testTarBz2Path = testTarBz2.GetPath(); | ||
bool result = ScanZipFile(testTarBz2Path); | ||
REQUIRE(result); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,98 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
#include "pch.h" | ||
#include "Public/winget/Archive.h" | ||
|
||
// TODO: Move include statement to pch.h and resolve build errors | ||
#pragma warning( push ) | ||
#pragma warning ( disable : 4189 4244 26451 ) | ||
#include <pure.h> | ||
#pragma warning ( pop ) | ||
|
||
namespace AppInstaller::Archive | ||
{ | ||
using unique_pidlist_absolute = wil::unique_any<PIDLIST_ABSOLUTE, decltype(&::CoTaskMemFree), ::CoTaskMemFree>; | ||
using unique_lpitemidlist = wil::unique_any<LPITEMIDLIST, decltype(&::CoTaskMemFree), ::CoTaskMemFree>; | ||
|
||
HRESULT TryExtractArchive(const std::filesystem::path& archivePath, const std::filesystem::path& destPath) | ||
{ | ||
wil::com_ptr<IFileOperation> pFileOperation; | ||
RETURN_IF_FAILED(CoCreateInstance(CLSID_FileOperation, NULL, CLSCTX_ALL, IID_PPV_ARGS(&pFileOperation))); | ||
RETURN_IF_FAILED(pFileOperation->SetOperationFlags(FOF_NO_UI)); | ||
|
||
wil::com_ptr<IShellItem> pShellItemTo; | ||
RETURN_IF_FAILED(SHCreateItemFromParsingName(destPath.c_str(), NULL, IID_PPV_ARGS(&pShellItemTo))); | ||
|
||
unique_pidlist_absolute pidlFull; | ||
RETURN_IF_FAILED(SHParseDisplayName(archivePath.c_str(), NULL, &pidlFull, 0, NULL)); | ||
|
||
wil::com_ptr<IShellFolder> pArchiveShellFolder; | ||
RETURN_IF_FAILED(SHBindToObject(NULL, pidlFull.get(), NULL, IID_PPV_ARGS(&pArchiveShellFolder))); | ||
|
||
wil::com_ptr<IEnumIDList> pEnumIdList; | ||
RETURN_IF_FAILED(pArchiveShellFolder->EnumObjects(nullptr, SHCONTF_FOLDERS | SHCONTF_NONFOLDERS, &pEnumIdList)); | ||
|
||
unique_lpitemidlist pidlChild; | ||
ULONG nFetched; | ||
while (pEnumIdList->Next(1, wil::out_param_ptr<LPITEMIDLIST*>(pidlChild), &nFetched) == S_OK && nFetched == 1) | ||
{ | ||
wil::com_ptr<IShellItem> pShellItemFrom; | ||
STRRET strFolderName; | ||
WCHAR szFolderName[MAX_PATH]; | ||
RETURN_IF_FAILED(pArchiveShellFolder->GetDisplayNameOf(pidlChild.get(), SHGDN_INFOLDER | SHGDN_FORPARSING, &strFolderName)); | ||
RETURN_IF_FAILED(StrRetToBuf(&strFolderName, pidlChild.get(), szFolderName, MAX_PATH)); | ||
RETURN_IF_FAILED(SHCreateItemWithParent(pidlFull.get(), pArchiveShellFolder.get(), pidlChild.get(), IID_PPV_ARGS(&pShellItemFrom))); | ||
RETURN_IF_FAILED(pFileOperation->CopyItem(pShellItemFrom.get(), pShellItemTo.get(), NULL, NULL)); | ||
} | ||
|
||
RETURN_IF_FAILED(pFileOperation->PerformOperations()); | ||
return S_OK; | ||
} | ||
|
||
#ifndef AICLI_DISABLE_TEST_HOOKS | ||
static bool* s_ScanArchiveResult_TestHook_Override = nullptr; | ||
|
||
void TestHook_SetScanArchiveResult_Override(bool* status) | ||
{ | ||
s_ScanArchiveResult_TestHook_Override = status; | ||
} | ||
#endif | ||
|
||
bool ScanZipFile(const std::filesystem::path& zipPath) | ||
{ | ||
#ifndef AICLI_DISABLE_TEST_HOOKS | ||
if (s_ScanArchiveResult_TestHook_Override) | ||
{ | ||
return *s_ScanArchiveResult_TestHook_Override; | ||
} | ||
#endif | ||
|
||
std::ifstream instream{ zipPath, std::ios::in | std::ios::binary }; | ||
std::vector<uint8_t> data{ { std::istreambuf_iterator<char>{ instream } }, std::istreambuf_iterator<char>{} }; | ||
|
||
uint8_t* buffer = &data[0]; | ||
uint64_t flag = 0; | ||
int scanResult = pure_zip(buffer, data.size(), flag); | ||
|
||
return scanResult == 0; | ||
} | ||
} | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
#include "pch.h" | ||
#include "Public/winget/Archive.h" | ||
|
||
#include <amsi.h> | ||
#include <comdef.h> | ||
#include <fstream> | ||
#include <vector> | ||
|
||
namespace AppInstaller::Archive | ||
{ | ||
using unique_pidlist_absolute = wil::unique_any<PIDLIST_ABSOLUTE, decltype(&::CoTaskMemFree), ::CoTaskMemFree>; | ||
using unique_lpitemidlist = wil::unique_any<LPITEMIDLIST, decltype(&::CoTaskMemFree), ::CoTaskMemFree>; | ||
|
||
HRESULT TryExtractArchive(const std::filesystem::path& archivePath, const std::filesystem::path& destPath) | ||
{ | ||
wil::com_ptr<IFileOperation> pFileOperation; | ||
RETURN_IF_FAILED(CoCreateInstance(CLSID_FileOperation, NULL, CLSCTX_ALL, IID_PPV_ARGS(&pFileOperation))); | ||
RETURN_IF_FAILED(pFileOperation->SetOperationFlags(FOF_NO_UI)); | ||
|
||
wil::com_ptr<IShellItem> pShellItemTo; | ||
RETURN_IF_FAILED(SHCreateItemFromParsingName(destPath.c_str(), NULL, IID_PPV_ARGS(&pShellItemTo))); | ||
|
||
unique_pidlist_absolute pidlFull; | ||
RETURN_IF_FAILED(SHParseDisplayName(archivePath.c_str(), NULL, &pidlFull, 0, NULL)); | ||
|
||
wil::com_ptr<IShellFolder> pArchiveShellFolder; | ||
RETURN_IF_FAILED(SHBindToObject(NULL, pidlFull.get(), NULL, IID_PPV_ARGS(&pArchiveShellFolder))); | ||
|
||
wil::com_ptr<IEnumIDList> pEnumIdList; | ||
RETURN_IF_FAILED(pArchiveShellFolder->EnumObjects(nullptr, SHCONTF_FOLDERS | SHCONTF_NONFOLDERS, &pEnumIdList)); | ||
|
||
unique_lpitemidlist pidlChild; | ||
ULONG nFetched; | ||
while (pEnumIdList->Next(1, wil::out_param_ptr<LPITEMIDLIST*>(pidlChild), &nFetched) == S_OK && nFetched == 1) | ||
{ | ||
wil::com_ptr<IShellItem> pShellItemFrom; | ||
STRRET strFolderName; | ||
WCHAR szFolderName[MAX_PATH]; | ||
RETURN_IF_FAILED(pArchiveShellFolder->GetDisplayNameOf(pidlChild.get(), SHGDN_INFOLDER | SHGDN_FORPARSING, &strFolderName)); | ||
RETURN_IF_FAILED(StrRetToBuf(&strFolderName, pidlChild.get(), szFolderName, MAX_PATH)); | ||
RETURN_IF_FAILED(SHCreateItemWithParent(pidlFull.get(), pArchiveShellFolder.get(), pidlChild.get(), IID_PPV_ARGS(&pShellItemFrom))); | ||
RETURN_IF_FAILED(pFileOperation->CopyItem(pShellItemFrom.get(), pShellItemTo.get(), NULL, NULL)); | ||
} | ||
|
||
RETURN_IF_FAILED(pFileOperation->PerformOperations()); | ||
return S_OK; | ||
} | ||
|
||
#ifndef AICLI_DISABLE_TEST_HOOKS | ||
static bool* s_ScanArchiveResult_TestHook_Override = nullptr; | ||
|
||
void TestHook_SetScanArchiveResult_Override(bool* status) | ||
{ | ||
s_ScanArchiveResult_TestHook_Override = status; | ||
} | ||
#endif | ||
|
||
bool ScanZipFile(const std::filesystem::path& zipPath) | ||
{ | ||
#ifndef AICLI_DISABLE_TEST_HOOKS | ||
if (s_ScanArchiveResult_TestHook_Override) | ||
{ | ||
return *s_ScanArchiveResult_TestHook_Override; | ||
} | ||
#endif | ||
|
||
HRESULT hr = S_OK; | ||
wil::com_ptr_nothrow<IUnknown> amsiContext; | ||
wil::com_ptr_nothrow<IUnknown> amsiSession; | ||
|
||
hr = AmsiInitialize(L"WinGet", &amsiContext); | ||
if (FAILED(hr)) | ||
{ | ||
return false; | ||
} | ||
|
||
hr = AmsiOpenSession(amsiContext.get(), &amsiSession); | ||
if (FAILED(hr)) | ||
{ | ||
AmsiUninitialize(amsiContext.get()); | ||
return false; | ||
} | ||
|
||
std::ifstream instream{ zipPath, std::ios::in | std::ios::binary }; | ||
std::vector<uint8_t> data{ { std::istreambuf_iterator<char>{ instream } }, std::istreambuf_iterator<char>{} }; | ||
|
||
AMSI_RESULT result = AMSI_RESULT_CLEAN; | ||
hr = AmsiScanBuffer(amsiContext.get(), data.data(), data.size(), zipPath.filename().c_str(), amsiSession.get(), &result); | ||
|
||
AmsiCloseSession(amsiContext.get(), amsiSession.get()); | ||
AmsiUninitialize(amsiContext.get()); | ||
|
||
return SUCCEEDED(hr) && (result == AMSI_RESULT_CLEAN || result == AMSI_RESULT_NOT_DETECTED); | ||
} | ||
} | ||
|