Skip to content

Commit

Permalink
Use AMSI for archive malware scanning
Browse files Browse the repository at this point in the history
* Remove `#include <pure.h>` and add necessary includes for AMSI in `src/AppInstallerCommonCore/Archive.cpp`
* Initialize AMSI, create a session, scan the file, and handle results in `ScanZipFile` function
* Add tests for new archive formats in `src/AppInstallerCLITests/Archive.cpp`
  - Add test cases for 7z, Rar, TarGz, and TarBz2 archive formats
  - Verify extraction and scanning of these new archive formats
  • Loading branch information
pl4nty committed Oct 27, 2024
1 parent fe939b8 commit 104864f
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 113 deletions.
169 changes: 135 additions & 34 deletions src/AppInstallerCLITests/Archive.cpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,135 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "pch.h"
#include "TestCommon.h"
#include <winget/Archive.h>

using namespace AppInstaller::Archive;
using namespace TestCommon;

constexpr std::string_view s_ZipFile = "TestZip.zip";

TEST_CASE("Extract_ZipArchive", "[archive]")
{
TestCommon::TempDirectory tempDirectory("TempDirectory");
TestDataFile testZip(s_ZipFile);

const auto& testZipPath = testZip.GetPath();
const auto& tempDirectoryPath = tempDirectory.GetPath();

HRESULT hr = TryExtractArchive(testZipPath, tempDirectoryPath);

std::filesystem::path expectedPath = tempDirectoryPath / "test.txt";
REQUIRE(SUCCEEDED(hr));
REQUIRE(std::filesystem::exists(expectedPath));
}

TEST_CASE("Scan_ZipArchive", "[archive]")
{
TestDataFile testZip(s_ZipFile);

const auto& testZipPath = testZip.GetPath();
bool result = ScanZipFile(testZipPath);
REQUIRE(result);
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "pch.h"
#include "TestCommon.h"
#include <winget/Archive.h>

using namespace AppInstaller::Archive;
using namespace TestCommon;

constexpr std::string_view s_ZipFile = "TestZip.zip";
constexpr std::string_view s_7zFile = "Test7z.7z";
constexpr std::string_view s_RarFile = "TestRar.rar";
constexpr std::string_view s_TarGzFile = "TestTarGz.tar.gz";
constexpr std::string_view s_TarBz2File = "TestTarBz2.tar.bz2";

TEST_CASE("Extract_ZipArchive", "[archive]")
{
TestCommon::TempDirectory tempDirectory("TempDirectory");
TestDataFile testZip(s_ZipFile);

const auto& testZipPath = testZip.GetPath();
const auto& tempDirectoryPath = tempDirectory.GetPath();

HRESULT hr = TryExtractArchive(testZipPath, tempDirectoryPath);

std::filesystem::path expectedPath = tempDirectoryPath / "test.txt";
REQUIRE(SUCCEEDED(hr));
REQUIRE(std::filesystem::exists(expectedPath));
}

TEST_CASE("Scan_ZipArchive", "[archive]")
{
TestDataFile testZip(s_ZipFile);

const auto& testZipPath = testZip.GetPath();
bool result = ScanZipFile(testZipPath);
REQUIRE(result);
}

TEST_CASE("Extract_7zArchive", "[archive]")
{
TestCommon::TempDirectory tempDirectory("TempDirectory");
TestDataFile test7z(s_7zFile);

const auto& test7zPath = test7z.GetPath();
const auto& tempDirectoryPath = tempDirectory.GetPath();

HRESULT hr = TryExtractArchive(test7zPath, tempDirectoryPath);

std::filesystem::path expectedPath = tempDirectoryPath / "test.txt";
REQUIRE(SUCCEEDED(hr));
REQUIRE(std::filesystem::exists(expectedPath));
}

TEST_CASE("Scan_7zArchive", "[archive]")
{
TestDataFile test7z(s_7zFile);

const auto& test7zPath = test7z.GetPath();
bool result = ScanZipFile(test7zPath);
REQUIRE(result);
}

TEST_CASE("Extract_RarArchive", "[archive]")
{
TestCommon::TempDirectory tempDirectory("TempDirectory");
TestDataFile testRar(s_RarFile);

const auto& testRarPath = testRar.GetPath();
const auto& tempDirectoryPath = tempDirectory.GetPath();

HRESULT hr = TryExtractArchive(testRarPath, tempDirectoryPath);

std::filesystem::path expectedPath = tempDirectoryPath / "test.txt";
REQUIRE(SUCCEEDED(hr));
REQUIRE(std::filesystem::exists(expectedPath));
}

TEST_CASE("Scan_RarArchive", "[archive]")
{
TestDataFile testRar(s_RarFile);

const auto& testRarPath = testRar.GetPath();
bool result = ScanZipFile(testRarPath);
REQUIRE(result);
}

TEST_CASE("Extract_TarGzArchive", "[archive]")
{
TestCommon::TempDirectory tempDirectory("TempDirectory");
TestDataFile testTarGz(s_TarGzFile);

const auto& testTarGzPath = testTarGz.GetPath();
const auto& tempDirectoryPath = tempDirectory.GetPath();

HRESULT hr = TryExtractArchive(testTarGzPath, tempDirectoryPath);

std::filesystem::path expectedPath = tempDirectoryPath / "test.txt";
REQUIRE(SUCCEEDED(hr));
REQUIRE(std::filesystem::exists(expectedPath));
}

TEST_CASE("Scan_TarGzArchive", "[archive]")
{
TestDataFile testTarGz(s_TarGzFile);

const auto& testTarGzPath = testTarGz.GetPath();
bool result = ScanZipFile(testTarGzPath);
REQUIRE(result);
}

TEST_CASE("Extract_TarBz2Archive", "[archive]")
{
TestCommon::TempDirectory tempDirectory("TempDirectory");
TestDataFile testTarBz2(s_TarBz2File);

const auto& testTarBz2Path = testTarBz2.GetPath();
const auto& tempDirectoryPath = tempDirectory.GetPath();

HRESULT hr = TryExtractArchive(testTarBz2Path, tempDirectoryPath);

std::filesystem::path expectedPath = tempDirectoryPath / "test.txt";
REQUIRE(SUCCEEDED(hr));
REQUIRE(std::filesystem::exists(expectedPath));
}

TEST_CASE("Scan_TarBz2Archive", "[archive]")
{
TestDataFile testTarBz2(s_TarBz2File);

const auto& testTarBz2Path = testTarBz2.GetPath();
bool result = ScanZipFile(testTarBz2Path);
REQUIRE(result);
}

177 changes: 98 additions & 79 deletions src/AppInstallerCommonCore/Archive.cpp
Original file line number Diff line number Diff line change
@@ -1,79 +1,98 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "pch.h"
#include "Public/winget/Archive.h"

// TODO: Move include statement to pch.h and resolve build errors
#pragma warning( push )
#pragma warning ( disable : 4189 4244 26451 )
#include <pure.h>
#pragma warning ( pop )

namespace AppInstaller::Archive
{
using unique_pidlist_absolute = wil::unique_any<PIDLIST_ABSOLUTE, decltype(&::CoTaskMemFree), ::CoTaskMemFree>;
using unique_lpitemidlist = wil::unique_any<LPITEMIDLIST, decltype(&::CoTaskMemFree), ::CoTaskMemFree>;

HRESULT TryExtractArchive(const std::filesystem::path& archivePath, const std::filesystem::path& destPath)
{
wil::com_ptr<IFileOperation> pFileOperation;
RETURN_IF_FAILED(CoCreateInstance(CLSID_FileOperation, NULL, CLSCTX_ALL, IID_PPV_ARGS(&pFileOperation)));
RETURN_IF_FAILED(pFileOperation->SetOperationFlags(FOF_NO_UI));

wil::com_ptr<IShellItem> pShellItemTo;
RETURN_IF_FAILED(SHCreateItemFromParsingName(destPath.c_str(), NULL, IID_PPV_ARGS(&pShellItemTo)));

unique_pidlist_absolute pidlFull;
RETURN_IF_FAILED(SHParseDisplayName(archivePath.c_str(), NULL, &pidlFull, 0, NULL));

wil::com_ptr<IShellFolder> pArchiveShellFolder;
RETURN_IF_FAILED(SHBindToObject(NULL, pidlFull.get(), NULL, IID_PPV_ARGS(&pArchiveShellFolder)));

wil::com_ptr<IEnumIDList> pEnumIdList;
RETURN_IF_FAILED(pArchiveShellFolder->EnumObjects(nullptr, SHCONTF_FOLDERS | SHCONTF_NONFOLDERS, &pEnumIdList));

unique_lpitemidlist pidlChild;
ULONG nFetched;
while (pEnumIdList->Next(1, wil::out_param_ptr<LPITEMIDLIST*>(pidlChild), &nFetched) == S_OK && nFetched == 1)
{
wil::com_ptr<IShellItem> pShellItemFrom;
STRRET strFolderName;
WCHAR szFolderName[MAX_PATH];
RETURN_IF_FAILED(pArchiveShellFolder->GetDisplayNameOf(pidlChild.get(), SHGDN_INFOLDER | SHGDN_FORPARSING, &strFolderName));
RETURN_IF_FAILED(StrRetToBuf(&strFolderName, pidlChild.get(), szFolderName, MAX_PATH));
RETURN_IF_FAILED(SHCreateItemWithParent(pidlFull.get(), pArchiveShellFolder.get(), pidlChild.get(), IID_PPV_ARGS(&pShellItemFrom)));
RETURN_IF_FAILED(pFileOperation->CopyItem(pShellItemFrom.get(), pShellItemTo.get(), NULL, NULL));
}

RETURN_IF_FAILED(pFileOperation->PerformOperations());
return S_OK;
}

#ifndef AICLI_DISABLE_TEST_HOOKS
static bool* s_ScanArchiveResult_TestHook_Override = nullptr;

void TestHook_SetScanArchiveResult_Override(bool* status)
{
s_ScanArchiveResult_TestHook_Override = status;
}
#endif

bool ScanZipFile(const std::filesystem::path& zipPath)
{
#ifndef AICLI_DISABLE_TEST_HOOKS
if (s_ScanArchiveResult_TestHook_Override)
{
return *s_ScanArchiveResult_TestHook_Override;
}
#endif

std::ifstream instream{ zipPath, std::ios::in | std::ios::binary };
std::vector<uint8_t> data{ { std::istreambuf_iterator<char>{ instream } }, std::istreambuf_iterator<char>{} };

uint8_t* buffer = &data[0];
uint64_t flag = 0;
int scanResult = pure_zip(buffer, data.size(), flag);

return scanResult == 0;
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "pch.h"
#include "Public/winget/Archive.h"

#include <amsi.h>
#include <comdef.h>
#include <fstream>
#include <vector>

namespace AppInstaller::Archive
{
using unique_pidlist_absolute = wil::unique_any<PIDLIST_ABSOLUTE, decltype(&::CoTaskMemFree), ::CoTaskMemFree>;
using unique_lpitemidlist = wil::unique_any<LPITEMIDLIST, decltype(&::CoTaskMemFree), ::CoTaskMemFree>;

HRESULT TryExtractArchive(const std::filesystem::path& archivePath, const std::filesystem::path& destPath)
{
wil::com_ptr<IFileOperation> pFileOperation;
RETURN_IF_FAILED(CoCreateInstance(CLSID_FileOperation, NULL, CLSCTX_ALL, IID_PPV_ARGS(&pFileOperation)));
RETURN_IF_FAILED(pFileOperation->SetOperationFlags(FOF_NO_UI));

wil::com_ptr<IShellItem> pShellItemTo;
RETURN_IF_FAILED(SHCreateItemFromParsingName(destPath.c_str(), NULL, IID_PPV_ARGS(&pShellItemTo)));

unique_pidlist_absolute pidlFull;
RETURN_IF_FAILED(SHParseDisplayName(archivePath.c_str(), NULL, &pidlFull, 0, NULL));

wil::com_ptr<IShellFolder> pArchiveShellFolder;
RETURN_IF_FAILED(SHBindToObject(NULL, pidlFull.get(), NULL, IID_PPV_ARGS(&pArchiveShellFolder)));

wil::com_ptr<IEnumIDList> pEnumIdList;
RETURN_IF_FAILED(pArchiveShellFolder->EnumObjects(nullptr, SHCONTF_FOLDERS | SHCONTF_NONFOLDERS, &pEnumIdList));

unique_lpitemidlist pidlChild;
ULONG nFetched;
while (pEnumIdList->Next(1, wil::out_param_ptr<LPITEMIDLIST*>(pidlChild), &nFetched) == S_OK && nFetched == 1)
{
wil::com_ptr<IShellItem> pShellItemFrom;
STRRET strFolderName;
WCHAR szFolderName[MAX_PATH];
RETURN_IF_FAILED(pArchiveShellFolder->GetDisplayNameOf(pidlChild.get(), SHGDN_INFOLDER | SHGDN_FORPARSING, &strFolderName));
RETURN_IF_FAILED(StrRetToBuf(&strFolderName, pidlChild.get(), szFolderName, MAX_PATH));
RETURN_IF_FAILED(SHCreateItemWithParent(pidlFull.get(), pArchiveShellFolder.get(), pidlChild.get(), IID_PPV_ARGS(&pShellItemFrom)));
RETURN_IF_FAILED(pFileOperation->CopyItem(pShellItemFrom.get(), pShellItemTo.get(), NULL, NULL));
}

RETURN_IF_FAILED(pFileOperation->PerformOperations());
return S_OK;
}

#ifndef AICLI_DISABLE_TEST_HOOKS
static bool* s_ScanArchiveResult_TestHook_Override = nullptr;

void TestHook_SetScanArchiveResult_Override(bool* status)
{
s_ScanArchiveResult_TestHook_Override = status;
}
#endif

bool ScanZipFile(const std::filesystem::path& zipPath)
{
#ifndef AICLI_DISABLE_TEST_HOOKS
if (s_ScanArchiveResult_TestHook_Override)
{
return *s_ScanArchiveResult_TestHook_Override;
}
#endif

HRESULT hr = S_OK;
wil::com_ptr_nothrow<IUnknown> amsiContext;
wil::com_ptr_nothrow<IUnknown> amsiSession;

hr = AmsiInitialize(L"WinGet", &amsiContext);
if (FAILED(hr))
{
return false;
}

hr = AmsiOpenSession(amsiContext.get(), &amsiSession);
if (FAILED(hr))
{
AmsiUninitialize(amsiContext.get());
return false;
}

std::ifstream instream{ zipPath, std::ios::in | std::ios::binary };
std::vector<uint8_t> data{ { std::istreambuf_iterator<char>{ instream } }, std::istreambuf_iterator<char>{} };

AMSI_RESULT result = AMSI_RESULT_CLEAN;
hr = AmsiScanBuffer(amsiContext.get(), data.data(), data.size(), zipPath.filename().c_str(), amsiSession.get(), &result);

AmsiCloseSession(amsiContext.get(), amsiSession.get());
AmsiUninitialize(amsiContext.get());

return SUCCEEDED(hr) && (result == AMSI_RESULT_CLEAN || result == AMSI_RESULT_NOT_DETECTED);
}
}

0 comments on commit 104864f

Please sign in to comment.