Skip to content

Commit

Permalink
feat: incredibly hacky solution to native GHA cache support for artif…
Browse files Browse the repository at this point in the history
…acts
  • Loading branch information
erikreinert committed Dec 26, 2024
1 parent f5c6a26 commit dcee476
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 9 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/vorpal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ jobs:
- run: ./dist/vorpal keys generate

- env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
./dist/vorpal start \
--registry-backend "s3" \
--registry-backend-s3-bucket "altf4llc-vorpal-registry" \
- uses: actions/github-script@v6
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- run: |
./dist/vorpal start --registry-backend "gha" \
> worker_output.log 2>&1 &
WORKER_PID=$(echo $!)
echo "WORKER_PID=$WORKER_PID" >> $GITHUB_ENV
Expand Down
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ async fn main() -> Result<()> {

if services.contains("registry") {
let backend = match registry_backend.as_str() {
"gha" => RegistryServerBackend::GHA,
"local" => RegistryServerBackend::Local,
"s3" => RegistryServerBackend::S3,
_ => RegistryServerBackend::Unknown,
Expand Down
3 changes: 3 additions & 0 deletions registry/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ edition = "2021"
anyhow = { default-features = false, version = "1" }
aws-config = { default-features = false, features = ["behavior-version-latest", "rt-tokio", "rustls", "sso"], version = "1" }
aws-sdk-s3 = { default-features = false, version = "1" }
reqwest = { default-features = false, version = "0", features = ["json", "rustls-tls"] }
rsa = { default-features = false, version = "0" }
serde = { version = "1.0", features = ["derive"] }
sha2 = "0.10"
tokio = { default-features = false, features = ["process", "rt-multi-thread"], version = "1" }
tokio-stream = { default-features = false, features = ["io-util"], version = "0" }
tonic = { default-features = false, version = "0" }
Expand Down
222 changes: 222 additions & 0 deletions registry/src/gha.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use anyhow::{anyhow, Context, Result};
use reqwest::{
header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_RANGE, CONTENT_TYPE},
Client, StatusCode,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{
fs::File,
io::{Read, Seek, SeekFrom},
path::Path,
sync::Arc,
};
use tokio::sync::Semaphore;
use tracing::info;

const VERSION_SALT: &str = "1.0";
const API_VERSION: &str = "6.0-preview.1";

#[derive(Debug, Serialize, Deserialize)]
pub struct ArtifactCacheEntry {
pub archive_location: String,
pub cache_key: String,
pub cache_version: String,
pub scope: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ReserveCacheRequest {
pub key: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_size: Option<u64>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ReserveCacheResponse {
pub cache_id: u64,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct CommitCacheRequest {
pub size: u64,
}

#[derive(Debug)]
pub struct CacheClient {
client: Client,
base_url: String,
}

impl CacheClient {
pub fn new() -> Result<Self> {
let token = std::env::var("ACTIONS_RUNTIME_TOKEN")
.context("ACTIONS_RUNTIME_TOKEN environment variable not found")?;
let base_url = std::env::var("ACTIONS_CACHE_URL")
.context("ACTIONS_CACHE_URL environment variable not found")?;

let mut headers = HeaderMap::new();
headers.insert(
ACCEPT,
HeaderValue::from_str(&format!("application/json;api-version={API_VERSION}"))?,
);
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {token}"))?,
);

let client = Client::builder()
.user_agent("rust/github-actions-cache")
.default_headers(headers)
.build()?;

Ok(Self { client, base_url })
}

pub async fn get_cache_entry(
&self,
keys: &[String],
paths: &[String],
compression_method: Option<String>,
enable_cross_os_archive: bool,
) -> Result<Option<ArtifactCacheEntry>> {
let version = get_cache_version(paths, compression_method, enable_cross_os_archive)?;
let keys_str = keys.join(",");
let url = format!(
"{}/_apis/artifactcache/cache?keys={}&version={}",
self.base_url, keys_str, version
);

let response = self.client.get(&url).send().await?;

match response.status() {
StatusCode::NO_CONTENT => Ok(None),
StatusCode::OK => {
let entry = response.json::<ArtifactCacheEntry>().await?;
Ok(Some(entry))
}
status => Err(anyhow!("Unexpected status code: {}", status)),
}
}

pub async fn reserve_cache(
&self,
key: &str,
paths: &[String],
compression_method: Option<String>,
enable_cross_os_archive: bool,
cache_size: Option<u64>,
) -> Result<ReserveCacheResponse> {
let version = get_cache_version(paths, compression_method, enable_cross_os_archive)?;
let url = format!("{}/_apis/artifactcache/caches", self.base_url);

let request = ReserveCacheRequest {
cache_size,
key: key.to_string(),
version,
};

let response = self
.client
.post(&url)
.json(&request)
.send()
.await?
.error_for_status()?;

Ok(response.json().await?)
}

pub async fn save_cache(
&self,
cache_id: u64,
archive_path: &Path,
concurrency: usize,
chunk_size: usize,
) -> Result<()> {
let file = File::open(archive_path)?;
let file_size = file.metadata()?.len();
let url = format!("{}/_apis/artifactcache/caches/{}", self.base_url, cache_id);

info!("Uploading cache file with size: {} bytes", file_size);

// Create a semaphore to limit concurrent uploads
let semaphore = Arc::new(Semaphore::new(concurrency));
let mut tasks = Vec::new();
let file = Arc::new(tokio::sync::Mutex::new(file));

for chunk_start in (0..file_size).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size as u64 - 1).min(file_size - 1);
let permit = semaphore.clone().acquire_owned().await?;
let client = self.client.clone();
let url = url.clone();
let file = file.clone();

let task = tokio::spawn(async move {
let _permit = permit; // Keep permit alive for the duration of the upload
let mut file = file.lock().await;
file.seek(SeekFrom::Start(chunk_start))?;

let mut buffer = vec![0; (chunk_end - chunk_start + 1) as usize];
file.read_exact(&mut buffer)?;

let range = format!("bytes {}-{}/{}", chunk_start, chunk_end, file_size);
let response = client
.patch(&url)
.header(CONTENT_TYPE, "application/octet-stream")
.header(CONTENT_RANGE, &range)
.body(buffer)
.send()
.await?
.error_for_status()?;

info!("Uploaded chunk response: {}", response.status());

Result::<()>::Ok(())
});

tasks.push(task);
}

// Wait for all upload tasks to complete
for task in tasks {
task.await??;
}

// Commit the cache
info!("Committing cache");
let commit_request = CommitCacheRequest { size: file_size };
self.client
.post(&url)
.json(&commit_request)
.send()
.await?
.error_for_status()?;

info!("Cache saved successfully");
Ok(())
}
}

fn get_cache_version(
paths: &[String],
compression_method: Option<String>,
enable_cross_os_archive: bool,
) -> Result<String> {
let mut components = paths.to_vec();

if let Some(method) = compression_method {
components.push(method);
}

if cfg!(windows) && !enable_cross_os_archive {
components.push("windows-only".to_string());
}

components.push(VERSION_SALT.to_string());

let mut hasher = Sha256::new();
hasher.update(components.join("|"));
Ok(format!("{:x}", hasher.finalize()))
}
Loading

0 comments on commit dcee476

Please sign in to comment.