Skip to content

Commit

Permalink
feat: client returns the desired data directly
Browse files Browse the repository at this point in the history
And our error type stores all error information

Signed-off-by: Richard Zak <[email protected]>
  • Loading branch information
rjzak committed Dec 25, 2023
1 parent e78fbb8 commit cf3fb7f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 56 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ repository = "https://github.com/malwaredb/vt-client"
exclude = ["testdata"]

[dependencies]
anyhow = { version = "1.0", features = ["std"] }
chrono = { version = "0.4", features = ["clock", "serde"], default-features = false }
reqwest = { version = "0.11", features = ["multipart", "rustls-tls"], default-features = false }
serde = { version = "1.0", features = ["derive"] }
Expand Down
11 changes: 4 additions & 7 deletions src/filereport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ pub struct SigmaAnalysisStats {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Context;
use rstest::rstest;

#[rstest]
Expand All @@ -328,9 +327,8 @@ mod tests {
#[case(include_str!("../../testdata/ddecc35aa198f401948c73a0d53fd93c4ecb770198ad7db308de026745c56b71.json"), "Win32 EXE")]
#[case(include_str!("../../testdata/de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740.json"), "ELF")]
fn deserialize_valid_report(#[case] report: &str, #[case] file_type: &str) {
let report: FileReportRequestResponse = serde_json::from_str(report)
.context("failed to deserialize VT report")
.unwrap();
let report: FileReportRequestResponse =
serde_json::from_str(report).expect("failed to deserialize VT report");

if let FileReportRequestResponse::Data(data) = report {
if file_type == "Mach-O" {
Expand All @@ -354,9 +352,8 @@ mod tests {
#[case(include_str!("../../testdata/not_found.json"))]
#[case(include_str!("../../testdata/wrong_key.json"))]
fn deserialize_errors(#[case] contents: &str) {
let report: FileReportRequestResponse = serde_json::from_str(contents)
.context("failed to deserialize VT error response")
.unwrap();
let report: FileReportRequestResponse =
serde_json::from_str(contents).expect("failed to deserialize VT error response");

match report {
FileReportRequestResponse::Data(_) => panic!("Should have been an error type!"),
Expand Down
6 changes: 2 additions & 4 deletions src/filerescan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ pub struct FileRescanRequestData {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Context;

#[test]
fn deserialize_valid_response() {
const RESPONSE: &str = include_str!("../../testdata/rescan.json");

let rescan: FileRescanRequestResponse = serde_json::from_str(RESPONSE)
.context("failed to deserialize VT rescan")
.unwrap();
let rescan: FileRescanRequestResponse =
serde_json::from_str(RESPONSE).expect("failed to deserialize VT rescan");

if let FileRescanRequestResponse::Data(data) = rescan {
assert_eq!(data.rescan_type, "analysis");
Expand Down
130 changes: 86 additions & 44 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
pub mod filereport;
pub mod filerescan;

use crate::filereport::FileReportRequestResponse;
use crate::filerescan::FileRescanRequestResponse;
use crate::filereport::{FileReportData, FileReportRequestResponse};
use crate::filerescan::{FileRescanRequestData, FileRescanRequestResponse};

use std::fmt::{Display, Formatter};
use std::str::FromStr;
use std::string::FromUtf8Error;

use anyhow::{Context, Result};
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::multipart::Form;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;

/// Capture the error from VirusTotal, plus parsing or networking errors along the way
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VirusTotalError {
pub message: String,
Expand All @@ -26,6 +28,39 @@ impl Display for VirusTotalError {

impl std::error::Error for VirusTotalError {}

impl From<reqwest::Error> for VirusTotalError {
fn from(err: reqwest::Error) -> Self {
let url = if let Some(url) = err.url() {
format!(" loading {}", url.as_str())
} else {
"".into()
};
Self {
message: "Http error".into(),
code: format!("Error {url} {}", err),
}
}
}

impl From<serde_json::Error> for VirusTotalError {
fn from(err: serde_json::Error) -> Self {
Self {
message: "Json error".into(),
code: format!("Json error at line {}: {}", err.line(), err),
}
}
}

impl From<FromUtf8Error> for VirusTotalError {
fn from(err: FromUtf8Error) -> Self {
Self {
message: "UTF-8 decoding error error".into(),
code: err.to_string(),
}
}
}

/// VirusTotal client object
#[derive(Clone)]
pub struct VirusTotalClient {
/// The API key used to interact with VirusTotal
Expand All @@ -34,7 +69,9 @@ pub struct VirusTotalClient {

impl VirusTotalClient {
const API_KEY: &'static str = "x-apikey";
const KEY_LEN: usize = 64;

/// New VirusTotal client given an API key, assuming it's valid
pub fn new(key: &str) -> Self {
Self {
key: Zeroizing::new(key.to_string()),
Expand All @@ -50,7 +87,8 @@ impl VirusTotalClient {
headers
}

pub async fn get_report(&self, file_hash: &str) -> Result<FileReportRequestResponse> {
/// Get a file report from VirusTotal for an MD5, SHA-1, or SHA-256 hash. It's assumed to be valid.
pub async fn get_report(&self, file_hash: &str) -> Result<FileReportData, VirusTotalError> {
let client = reqwest::Client::new();
let body = client
.get(format!(
Expand All @@ -62,15 +100,20 @@ impl VirusTotalClient {
.bytes()
.await?;

let json_response = String::from_utf8(body.to_ascii_lowercase())
.context("failed to convert response to string")?;
let report: FileReportRequestResponse =
serde_json::from_str(&json_response).context("failed to deserialize VT report")?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileReportRequestResponse = serde_json::from_str(&json_response)?;

Ok(report)
match report {
FileReportRequestResponse::Data(data) => Ok(data),
FileReportRequestResponse::Error(error) => Err(error),
}
}

pub async fn request_rescan(&self, file_hash: &str) -> Result<FileRescanRequestResponse> {
/// Request VirusTotal rescan a file for an MD5, SHA-1, or SHA-256 hash. It's assumed to be valid.
pub async fn request_rescan(
&self,
file_hash: &str,
) -> Result<FileRescanRequestData, VirusTotalError> {
let client = reqwest::Client::new();
let body = client
.post(format!(
Expand All @@ -83,34 +126,33 @@ impl VirusTotalClient {
.bytes()
.await?;

let json_response = String::from_utf8(body.to_ascii_lowercase())
.context("failed to convert response to string")?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)
.context("failed to deserialize VT rescan request")?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;

Ok(report)
match report {
FileRescanRequestResponse::Data(data) => Ok(data),
FileRescanRequestResponse::Error(error) => Err(error),
}
}

/// Submit a file to VirusTotal.
pub async fn submit(
&self,
data: Vec<u8>,
name: Option<String>,
) -> Result<FileRescanRequestResponse> {
) -> Result<FileRescanRequestData, VirusTotalError> {
let client = reqwest::Client::new();
let form = if let Some(file_name) = name {
Form::new().part(
"file",
reqwest::multipart::Part::bytes(data)
.file_name(file_name)
.mime_str("application/octet-stream")
.context("failed to set mime type")?,
.mime_str("application/octet-stream")?,
)
} else {
Form::new().part(
"file",
reqwest::multipart::Part::bytes(data)
.mime_str("application/octet-stream")
.context("failed to set mime type")?,
reqwest::multipart::Part::bytes(data).mime_str("application/octet-stream")?,
)
};

Expand All @@ -124,12 +166,28 @@ impl VirusTotalClient {
.await?
.bytes()
.await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())
.context("failed to convert response to string")?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)
.context("failed to deserialize VT rescan request")?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;

Ok(report)
match report {
FileRescanRequestResponse::Data(data) => Ok(data),
FileRescanRequestResponse::Error(error) => Err(error),
}
}
}

/// Get a VirusTotal client from a key, checking that the key is the expected length.
impl FromStr for VirusTotalClient {
type Err = &'static str;

fn from_str(key: &str) -> Result<Self, Self::Err> {
if key.len() != VirusTotalClient::KEY_LEN {
Err("Invalid API key length")
} else {
Ok(Self {
key: Zeroizing::new(key.to_string()),
})
}
}
}

Expand All @@ -149,29 +207,13 @@ mod test {
.get_report(HASH)
.await
.expect("failed to get or parse VT scan report");

match report {
FileReportRequestResponse::Data(data) => {
assert!(data.attributes.last_analysis_results.len() > 10);
}
FileReportRequestResponse::Error(error) => {
panic!("VT Report Error {error}");
}
}
assert!(report.attributes.last_analysis_results.len() > 10);

let rescan = client
.request_rescan(HASH)
.await
.expect("failed to get or parse VT rescan response");

match rescan {
FileRescanRequestResponse::Data(data) => {
assert_eq!(data.rescan_type, "analysis");
}
FileRescanRequestResponse::Error(error) => {
panic!("VT Rescan Error {error}");
}
}
assert_eq!(rescan.rescan_type, "analysis");

const ELF: &[u8] = include_bytes!("../testdata/elf_haiku_x86");
client
Expand Down

0 comments on commit cf3fb7f

Please sign in to comment.