diff --git a/.github/workflows/cloud_provider.yml b/.github/workflows/cloud_provider.yml index 165eed2..9e62344 100644 --- a/.github/workflows/cloud_provider.yml +++ b/.github/workflows/cloud_provider.yml @@ -1,10 +1,10 @@ name: Cloud Provider Tests -on: - push: - branches: [ main ] +on: pull_request: - branches: [ "**" ] + branches: [main] + push: + branches: [main] permissions: contents: read @@ -133,6 +133,16 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - name: Upload Java coverage to Codecov + uses: codecov/codecov-action@v5.4.3 + with: + files: ./java-coverage.xml + flags: java,${{ matrix.name }} + fail_ci_if_error: true + name: ${{ matrix.name }}-java-coverage + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - name: Clean up remote directory (Make) if: always() env: diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 075c614..fef1ab8 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -4,7 +4,7 @@ on: push: branches: [ main ] pull_request: - branches: [ "**" ] + branches: [ main ] permissions: # added using https://github.com/step-security/secure-workflows contents: read diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 36acbef..c79ee87 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -16,7 +16,7 @@ on: branches: [ main ] pull_request: # The branches below must be a subset of the branches above - branches: [ "**" ] + branches: [ main ] schedule: - cron: '20 3 * * 4' diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index ad5ca58..a0c4d1e 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,7 +2,7 @@ on: push: branches: [ main ] pull_request: - branches: [ "**" ] + branches: [ main ] name: Go Unit Tests permissions: # added using https://github.com/step-security/secure-workflows diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 1f590ba..97e48a0 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -2,7 +2,7 @@ on: push: branches: [ main ] pull_request: - branches: [ "**" ] + branches: [ main ] name: golangci-lint permissions: # added using https://github.com/step-security/secure-workflows diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml new file mode 100644 index 0000000..92954f6 --- /dev/null +++ b/.github/workflows/java.yml @@ -0,0 +1,48 @@ +name: java + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build-test-coverage: + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: '21' + cache: maven + + - name: Lint / formatting check + run: make lint-java + + - name: Run tests with coverage (mvn verify) + working-directory: java + run: mvn -q -DskipTests=false verify + + - name: Show coverage summary + if: always() + run: | + if [ -f java/target/site/jacoco/index.html ]; then + echo "JaCoCo report generated" + fi + if [ -f java/target/site/jacoco/jacoco.xml ]; then + grep -q '/dev/null | grep -v '_test.go' | grep -v '/testhelp/' || true); \ + if [ -n "$$violations" ]; then \ + echo 'ERROR: S2IAM_TEST_ variables found in library (non-test) source:'; \ + echo "$$violations"; \ + exit 1; \ + fi + +test-local: test-local-patterns test-local-go test-local-python test-local-java @echo "✓ All local tests passed" -test-go-local: +test-local-go: @echo "Running Go local tests..." cd go && go test -v ./... -test-python-local: +test-local-python: @echo "Running Python local tests..." - cd python && python3 -m pytest tests/ -v + cd python && python3 -m venv test-venv && \ + PIP_CACHE_DIR=$(HOME)/.cache/pip-test ./test-venv/bin/pip install -e '.[dev]' && \ + ./test-venv/bin/python -m pytest tests/ -v; \ + EXIT_CODE=$$?; \ + rm -rf test-venv; \ + exit $$EXIT_CODE + +test-local-java: + @echo "Running Java local tests..." + cd java && mvn -q -DskipTests=false test check-cloud-env: @if [ -z "$$S2IAM_TEST_CLOUD_PROVIDER" ] && [ -z "$$S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE" ] && [ -z "$$S2IAM_TEST_ASSUME_ROLE" ]; then \ @@ -102,10 +135,10 @@ ifndef HOST endif on-remote-completed: - @echo "✓ All tests completed successfully" + @echo "ALL_TESTS_COMPLETED_OK" # Cloud test targets (designed to run ON cloud VMs) -on-remote-test: check-cloud-env on-remote-test-go on-remote-test-python +on-remote-test: check-cloud-env on-remote-test-java on-remote-test-go on-remote-test-python on-remote-test-go: check-cloud-env @echo "=== Running Go cloud tests ===" @@ -122,10 +155,19 @@ on-remote-test-python: check-cloud-env # Add src to PYTHONPATH so tests can import s2iam without installation cd python && PYTHONPATH=src python3 -m pytest tests/ -v --tb=short --cov=src/s2iam --cov-report=xml:coverage.xml --cov-report=html:htmlcov -dev-setup-ubuntu: dev-setup-ubuntu-go dev-setup-ubuntu-python +on-remote-test-java: check-cloud-env + @echo "=== Running Java cloud tests ===" + @echo "Environment: S2IAM_TEST_CLOUD_PROVIDER=$${S2IAM_TEST_CLOUD_PROVIDER:-}" + @echo "Environment: S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE=$${S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE:-}" + @echo "Environment: S2IAM_TEST_ASSUME_ROLE=$${S2IAM_TEST_ASSUME_ROLE:-}" + cd java && mvn -q -DskipTests=false verify + # Copy JaCoCo XML up one level for remote retrieval naming consistency + @if [ -f java/target/site/jacoco/jacoco.xml ]; then cp java/target/site/jacoco/jacoco.xml java-coverage.xml || true; fi + +dev-setup-ubuntu: dev-setup-ubuntu-go dev-setup-ubuntu-python dev-setup-ubuntu-java @echo "✓ Full Ubuntu/Debian development environment ready" -dev-setup-macos: dev-setup-macos-go dev-setup-macos-python +dev-setup-macos: dev-setup-macos-go dev-setup-macos-python dev-setup-macos-java @echo "✓ Full macOS development environment ready" dev-setup-common: @@ -161,6 +203,27 @@ dev-setup-ubuntu-python: dev-setup-common cd python && pip install -e .[dev] @echo "✓ Ubuntu Python development environment ready (no virtualenv)" +dev-setup-ubuntu-java: + @echo "Installing Java toolchain (OpenJDK 11 + Maven)..." + sudo apt update + sudo apt install -y openjdk-11-jdk maven + @echo "Priming Maven dependency cache (offline build support)..." + cd java && mvn -q -DskipTests dependency:go-offline || { echo "Maven dependency prefetch failed"; exit 1; } + @echo "✓ Java development environment ready" + +dev-setup-macos-java: + @if ! command -v brew >/dev/null 2>&1; then \ + echo "ERROR: Homebrew not found. Install from https://brew.sh first."; \ + exit 1; \ + fi + @echo "Installing Java toolchain (Temurin 11 + Maven + Spotless deps)..." + brew install openjdk@11 maven || { echo "Failed to install Java tooling"; exit 1; } + # Ensure JAVA_HOME is set for current shell usage note + @echo "Add to shell profile if not present: export JAVA_HOME=\`/usr/libexec/java_home -v 11\`" + @echo "Priming Maven dependency cache (offline build support)..." + cd java && mvn -q -DskipTests dependency:go-offline || { echo "Maven dependency prefetch failed"; exit 1; } + @echo "✓ macOS Java development environment ready" + dev-setup-macos-python: @if ! command -v brew >/dev/null 2>&1; then \ echo "ERROR: Homebrew not found. Install from https://brew.sh first."; \ @@ -184,7 +247,7 @@ dev-setup-azure: dev-setup-gcp: @echo "GCP dependencies installed via python3-google-auth and python3-google-auth-oauthlib" -lint: lint-go lint-python +lint: lint-go lint-python lint-java lint-go: @echo "Running Go linters..." @@ -212,7 +275,15 @@ lint-python: cd python && python3 -m black --check src/ tests/ cd python && python3 -m isort --check-only src tests -format: format-go format-python +format: format-go format-python format-java + +lint-java: + @echo "Running Java formatter check (Spotless)..." + cd java && mvn -q spotless:check || { echo "Java formatting issues found (run make format-java)"; exit 1; } + +format-java: + @echo "Formatting Java code (Spotless)..." + cd java && mvn -q spotless:apply format-go: @echo "Formatting Go code..." @@ -254,7 +325,7 @@ ssh-run-remote-tests: check-host ssh $(SSH_OPTS) $(HOST) \ "cd $(REMOTE_BASE_DIR)/$(UNIQUE_DIR) && env $(ENV_VARS) make $(TEST_TARGET) on-remote-completed" \ 2>&1 | tee $(HOST)-log - @if grep -q "✓ All tests completed successfully" $(HOST)-log; then \ + @if grep -q "ALL_TESTS_COMPLETED_OK" $(HOST)-log; then \ echo "✓ Remote tests passed on $(HOST)"; \ else \ echo "✗ Remote tests failed on $(HOST) - check $(HOST)-log"; \ @@ -263,7 +334,7 @@ ssh-run-remote-tests: check-host # Generic function to download coverage files # CI target - download coverage files from remote host -ssh-download-coverage: ssh-download-coverage-go ssh-download-coverage-python +ssh-download-coverage: ssh-download-coverage-go ssh-download-coverage-python ssh-download-coverage-java @echo "✓ All coverage files downloaded" # CI target - download Go coverage from remote host @@ -282,6 +353,13 @@ ssh-download-coverage-python: check-host if [ ! -s ./python-coverage-$$TIMESTAMP.xml ]; then echo "Python coverage file empty or missing"; exit 1; fi; \ cp ./python-coverage-$$TIMESTAMP.xml python-coverage.xml +ssh-download-coverage-java: check-host + @echo "Downloading Java coverage from $(HOST)..." + TIMESTAMP=$$(date +%Y%m%d-%H%M%S); \ + scp $(SSH_OPTS) $(HOST):$(REMOTE_BASE_DIR)/$(UNIQUE_DIR)/java/java-coverage.xml ./java-coverage-$$TIMESTAMP.xml || scp $(SSH_OPTS) $(HOST):$(REMOTE_BASE_DIR)/$(UNIQUE_DIR)/java/target/site/jacoco/jacoco.xml ./java-coverage-$$TIMESTAMP.xml; \ + if [ ! -s ./java-coverage-$$TIMESTAMP.xml ]; then echo "Java coverage file empty or missing"; exit 1; fi; \ + cp ./java-coverage-$$TIMESTAMP.xml java-coverage.xml + # Generic function to cleanup remote directory # CI target - cleanup remote directory ssh-cleanup-remote: check-host diff --git a/README.md b/README.md index 48d17e9..cf6fd08 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,11 @@ This repository contains tools for the SingleStore IAM authentication system. [![Go report card](https://goreportcard.com/badge/github.com/singlestore-labs/singlestore-auth-iam/go)](https://goreportcard.com/report/github.com/singlestore-labs/singlestore-auth-iam/go) [![codecov](https://codecov.io/gh/singlestore-labs/singlestore-auth-iam/branch/main/graph/badge.svg)](https://codecov.io/gh/singlestore-labs/singlestore-auth-iam) -## Current status +## Current Status -This service is not yet available. This library may be updated before the service becomes available. +JWTs for engine access are ready for testing. +JWTs for the management API are not yet available. +APIs and language bindings may change before this is considered generally avaialble. ## Overview @@ -20,17 +22,14 @@ The `singlestore-auth-iam` library provides a seamless way to authenticate with ### Key Features -- **Multi-language support**: Go and Python libraries with identical functionality +- **Multi-language support**: Go (reference), Python, and Java implementations with converging functionality - **Automatic detection**: Discovers cloud provider and obtains credentials automatically - **Role assumption**: Assume different roles/service accounts for enhanced security - **Command-line tool**: Standalone CLI for scripts and CI/CD pipelines ### Future Plans -- Additional language support: Java, Node.js, and C++ (coming soon) - -## Current Status +- Additional language support: Node.js and C++ (planned) -This service is not yet available. This library may be updated before the service becomes available. ## Installation @@ -85,6 +84,47 @@ api_jwt = await s2iam.get_jwt_api() **[📖 Full Python Documentation →](python/README.md)** +### Java Library + +Add the Maven dependency (snapshot until first release): + +```xml + + com.singlestore + s2iam + 0.0.1-SNAPSHOT + +``` + +Basic usage: + +```java +import com.singlestore.s2iam.S2IAM; + +// Detect provider & get database JWT +String jwt = S2IAM.getDatabaseJWT("workspace-group-id"); + +// Get API JWT +String apiJwt = S2IAM.getAPIJWT(); +``` + +**Note:** Until GA, groupId/artifactId/version may change; pin exact versions and review release notes when updating. + +Advanced (Builder API & Assume Role): + +```java +import com.singlestore.s2iam.*; + +String jwt = S2IAMRequest.newRequest() + .databaseWorkspaceGroup("workspace-group-id") // or .api() + .assumeRole("arn:aws:iam::123456789012:role/AppRole") // AWS, or service account email (GCP), or Azure client ID + .audience("https://authsvc.singlestore.com") // GCP ONLY; throws if non-GCP + .timeout(java.time.Duration.ofSeconds(5)) + .get(); +``` + +Audience (GCP ONLY): Supplying an audience when not on GCP raises an exception (renamed from withGcpAudience to withAudience and now enforced). + ### Command Line Tool #### Installation @@ -150,6 +190,7 @@ The libraries automatically detect the cloud provider and obtain appropriate cre - **[Go Library Documentation](go/README.md)** - Complete Go API reference and examples - **[Python Library Documentation](python/README.md)** - Complete Python API reference and examples +- **Java**: See inline Javadoc and `java/README.md` (implementation evolving pre-GA) ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. diff --git a/java/.eclipse-java-formatter.xml b/java/.eclipse-java-formatter.xml new file mode 100644 index 0000000..952cc6f --- /dev/null +++ b/java/.eclipse-java-formatter.xml @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + diff --git a/java/README.md b/java/README.md new file mode 100644 index 0000000..9853aa8 --- /dev/null +++ b/java/README.md @@ -0,0 +1,95 @@ +SingleStore Auth IAM - Java +=========================== + +Status: ACTIVE DEVELOPMENT (parity tracking the Go reference). Breaking changes may still occur before GA. + +Overview +-------- +This Java library obtains short‑lived JWTs for SingleStore database (workspace group) or Management API access using native cloud provider identities (AWS / GCP / Azure). It auto‑detects the runtime cloud provider in seconds (target parity with Go implementation) and sends signed identity headers to the auth service which returns a JWT. + +Quick Start +----------- +```java +import com.singlestore.s2iam.S2IAM; + +// Database JWT (workspace group required) +String dbJwt = S2IAM.getDatabaseJWT("my-workspace-group-id"); + +// Management API JWT +String apiJwt = S2IAM.getAPIJWT(); +``` + +Fluent Builder API +------------------ +For advanced composition (assume role, custom timeout, explicit provider, custom server URL, GCP audience) use the builder: + +```java +import com.singlestore.s2iam.*; + +String jwt = S2IAMRequest.newRequest() + .databaseWorkspaceGroup("my-workspace-group-id") // or .api() + .assumeRole("arn:aws:iam::123456789012:role/AppRole") // AWS ARN, GCP service account email, or Azure client ID + .timeout(java.time.Duration.ofSeconds(5)) + .audience("https://authsvc.singlestore.com") // GCP ONLY (see below) + .get(); +``` + +GCP Audience (GCP ONLY) +----------------------- +Use `.audience()` (builder) or `Options.withAudience()` (static API) ONLY when the detected (or explicitly provided) provider is GCP. The audience parameter tunes the GCP identity token audience. If you specify an audience and the provider is not GCP, the library throws `S2IAMException` immediately. (Older name `withGcpAudience` was renamed to `withAudience` and now enforces this validation.) + +Assume Role / Impersonation +--------------------------- +- AWS: Provide an IAM role ARN (e.g., `arn:aws:iam::ACCOUNT:role/RoleName`). Session duration fixed to 3600s (parity with Go). Session name prefix: `SingleStoreAuth-`. +- GCP: Provide a service account email for impersonation. +- Azure: Provide a managed identity client (object) ID (UUID format). + +Validation is strict; malformed identifiers raise `S2IAMException` before network calls. + +Functional Options (Static API) +------------------------------- +```java +import com.singlestore.s2iam.options.Options; + +String apiJwt = S2IAM.getAPIJWT( + Options.withTimeout(Duration.ofSeconds(4)), + Options.withAudience("https://authsvc.singlestore.com") // only if running on GCP +); +``` + +Detection & Performance +----------------------- +Detection proceeds in two phases: +1. Fast phase (serial) – very quick heuristics per provider. +2. Full phase (concurrent with 5s default timeout) – parallel deeper probes. + +The first positive result short‑circuits. Typical success latency on real cloud instances is under a second (target parity with Go). + +Operational Notes +----------------- +All outbound requests include `User-Agent: s2iam-java/`. The library is fail-fast—any unexpected condition raises an exception rather than logging a warning. + +API Summary +----------- +Core static methods: +- `S2IAM.getDatabaseJWT(workspaceGroupId, JwtOption...)` +- `S2IAM.getAPIJWT(JwtOption...)` +- `S2IAM.detectProvider()` + +Builder: +- `S2IAMRequest.newRequest().databaseWorkspaceGroup(id)|api().assumeRole(id).audience(aud).timeout(d).provider(explicitProvider).serverUrl(url).get()` + +Selected Options helpers: +- `Options.withTimeout(Duration)` +- `Options.withAudience(String)` (GCP only) +- `Options.withAssumeRole(String)` +- `Options.withServerUrl(String)` +- `Options.withProvider(CloudProviderClient)` (explicit injection / test) + +Timeouts +-------- +Default detection + HTTP call timeout: 5s (aligned to Go reference). Override with `Options.withTimeout` or builder `.timeout()`. + +License +------- +MIT (see root LICENSE file). diff --git a/java/pom.xml b/java/pom.xml new file mode 100644 index 0000000..cc743fa --- /dev/null +++ b/java/pom.xml @@ -0,0 +1,87 @@ + + + 4.0.0 + com.singlestore + s2iam + 0.0.1-SNAPSHOT + SingleStore Auth IAM Java + SingleStore Auth IAM Java library (experimental) + + 11 + 11 + UTF-8 + 5.10.2 + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + com.fasterxml.jackson.core + jackson-databind + 2.17.1 + + + software.amazon.awssdk + sts + 2.25.20 + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.2.5 + + true + + + + org.jacoco + jacoco-maven-plugin + 0.8.11 + + + + com/singlestore/** + + + + + + prepare-agent + + + + report + verify + + report + + + + + + com.diffplug.spotless + spotless-maven-plugin + 2.43.0 + + + + + ${project.basedir}/.eclipse-java-formatter.xml + + + + + + + + + diff --git a/java/src/main/java/com/singlestore/s2iam/CloudIdentity.java b/java/src/main/java/com/singlestore/s2iam/CloudIdentity.java new file mode 100644 index 0000000..104986d --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/CloudIdentity.java @@ -0,0 +1,46 @@ +package com.singlestore.s2iam; + +import java.util.Map; + +public class CloudIdentity { + private final CloudProviderType provider; + private final String identifier; + private final String accountId; + private final String region; + private final String resourceType; + private final Map additionalClaims; + + public CloudIdentity(CloudProviderType provider, String identifier, String accountId, + String region, String resourceType, Map additionalClaims) { + this.provider = provider; + this.identifier = identifier; + this.accountId = accountId; + this.region = region; + this.resourceType = resourceType; + this.additionalClaims = additionalClaims; + } + + public CloudProviderType getProvider() { + return provider; + } + + public String getIdentifier() { + return identifier; + } + + public String getAccountId() { + return accountId; + } + + public String getRegion() { + return region; + } + + public String getResourceType() { + return resourceType; + } + + public Map getAdditionalClaims() { + return additionalClaims; + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/CloudProviderClient.java b/java/src/main/java/com/singlestore/s2iam/CloudProviderClient.java new file mode 100644 index 0000000..21ad1ab --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/CloudProviderClient.java @@ -0,0 +1,29 @@ +package com.singlestore.s2iam; + +import java.util.Map; + +public interface CloudProviderClient { + Exception detect(); // Full detection (may perform network); returns null on success or exception. + + Exception fastDetect(); // Fast in-process detection only; returns null if detected; else an + // exception. + + CloudProviderType getType(); // Provider type. + + CloudProviderClient assumeRole(String roleIdentifier); // Returns new client with assumed role. + + IdentityHeadersResult getIdentityHeaders(Map additionalParams); + + class IdentityHeadersResult { + public final Map headers; + public final CloudIdentity identity; + public final Exception error; + + public IdentityHeadersResult(Map headers, CloudIdentity identity, + Exception error) { + this.headers = headers; + this.identity = identity; + this.error = error; + } + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/CloudProviderType.java b/java/src/main/java/com/singlestore/s2iam/CloudProviderType.java new file mode 100644 index 0000000..27418a5 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/CloudProviderType.java @@ -0,0 +1,5 @@ +package com.singlestore.s2iam; + +public enum CloudProviderType { + aws, gcp, azure; +} diff --git a/java/src/main/java/com/singlestore/s2iam/Logger.java b/java/src/main/java/com/singlestore/s2iam/Logger.java new file mode 100644 index 0000000..262fede --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/Logger.java @@ -0,0 +1,8 @@ +package com.singlestore.s2iam; + +@FunctionalInterface +public interface Logger { + void logf(String format, Object... args); + + Logger STDOUT = (fmt, args) -> System.out.printf(fmt + "%n", args); +} diff --git a/java/src/main/java/com/singlestore/s2iam/ProviderContext.java b/java/src/main/java/com/singlestore/s2iam/ProviderContext.java new file mode 100644 index 0000000..4188dfd --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/ProviderContext.java @@ -0,0 +1,23 @@ +package com.singlestore.s2iam; + +import com.singlestore.s2iam.options.ProviderOptions; + +/** + * Internal thread-local context allowing builder to pass provider options (e.g. + * timeout) without widening public method signatures broadly. This is + * intentionally minimal and not part of the public documented API. + */ +final class ProviderContext { + private static final ThreadLocal CURRENT = new ThreadLocal<>(); + static void set(ProviderOptions po) { + CURRENT.set(po); + } + static ProviderOptions get() { + return CURRENT.get(); + } + static void clear() { + CURRENT.remove(); + } + private ProviderContext() { + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/S2IAM.java b/java/src/main/java/com/singlestore/s2iam/S2IAM.java new file mode 100644 index 0000000..7c76c61 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/S2IAM.java @@ -0,0 +1,399 @@ +package com.singlestore.s2iam; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import com.singlestore.s2iam.exceptions.S2IAMException; +import com.singlestore.s2iam.options.*; +import com.singlestore.s2iam.providers.aws.AWSClient; +import com.singlestore.s2iam.providers.azure.AzureClient; +import com.singlestore.s2iam.providers.gcp.GCPClient; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.*; +import java.util.concurrent.*; + +public final class S2IAM { + private S2IAM() { + } + + private static final String DEFAULT_SERVER = "https://authsvc.singlestore.com/auth/iam/:jwtType"; + private static final String LIB_NAME = "s2iam-java"; + private static final String LIB_VERSION = Optional + .ofNullable(S2IAM.class.getPackage().getImplementationVersion()).orElse("dev"); + private static final String USER_AGENT = LIB_NAME + "/" + LIB_VERSION; // derived dynamically + private static final ObjectMapper MAPPER = new ObjectMapper(); + + // Convenience API (database) + public static String getDatabaseJWT(String workspaceGroupId, JwtOption... opts) + throws S2IAMException { + if (workspaceGroupId == null || workspaceGroupId.isEmpty()) { + throw new S2IAMException("workspaceGroupId is required for database JWT"); + } + JwtOptions o = new JwtOptions(); + o.jwtType = JwtOptions.JWTType.database; + o.workspaceGroupId = workspaceGroupId; + o.serverUrl = DEFAULT_SERVER; + applyJwtOptions(o, opts); + return getJWT(o); + } + + // Synonym for getDatabaseJWT matching other language naming style + public static String getJwtDatabase(String workspaceGroupId, JwtOption... opts) + throws S2IAMException { + return getDatabaseJWT(workspaceGroupId, opts); + } + + // Convenience API (api) + public static String getAPIJWT(JwtOption... opts) throws S2IAMException { + JwtOptions o = new JwtOptions(); + o.jwtType = JwtOptions.JWTType.api; + o.serverUrl = DEFAULT_SERVER; + applyJwtOptions(o, opts); + return getJWT(o); + } + + // Synonym for getAPIJWT matching other language naming style + public static String getJwtApi(JwtOption... opts) throws S2IAMException { + return getAPIJWT(opts); + } + + // Overloads supporting provider options (primarily for builder convenience) + public static String getDatabaseJWT(String workspaceGroupId, JwtOption[] jwtOpts, + ProviderOption[] providerOpts) throws S2IAMException { + if (providerOpts != null && providerOpts.length > 0) { + // apply provider options globally before invoking regular path + ProviderOptions po = new ProviderOptions(); + for (ProviderOption p : providerOpts) + p.apply(po); + // Currently only timeout/logger/clients are meaningful; we can't thread through + // directly without refactor, so store once for detectProvider static use if + // needed. + // For minimal risk, just pass timeout via a thread-local. + ProviderContext.set(po); + try { + return getDatabaseJWT(workspaceGroupId, jwtOpts); + } finally { + ProviderContext.clear(); + } + } + return getDatabaseJWT(workspaceGroupId, jwtOpts); + } + + public static String getAPIJWT(JwtOption[] jwtOpts, ProviderOption[] providerOpts) + throws S2IAMException { + if (providerOpts != null && providerOpts.length > 0) { + ProviderOptions po = new ProviderOptions(); + for (ProviderOption p : providerOpts) + p.apply(po); + ProviderContext.set(po); + try { + return getAPIJWT(jwtOpts); + } finally { + ProviderContext.clear(); + } + } + return getAPIJWT(jwtOpts); + } + + // Provider detection + public static CloudProviderClient detectProvider(ProviderOption... opts) + throws NoCloudProviderDetectedException { + ProviderOptions po = new ProviderOptions(); + for (ProviderOption opt : opts) + opt.apply(po); + // Merge thread-local provider context (builder) if present and explicit opts + // didn't set. + ProviderOptions ctx = ProviderContext.get(); + if (ctx != null) { + if (po.timeout == null) + po.timeout = ctx.timeout; + if (po.logger == null) + po.logger = ctx.logger; + if (po.clients == null) + po.clients = ctx.clients; + } + if (po.logger == null && "true".equals(System.getenv("S2IAM_DEBUGGING"))) { + po.logger = Logger.STDOUT; + } + if (po.clients == null) { + po.clients = List.of(new AWSClient(po.logger), new GCPClient(po.logger), + new AzureClient(po.logger)); + } + boolean debug = "true".equals(System.getenv("S2IAM_DEBUGGING")); + // Fast detect first + long fastStart = System.nanoTime(); + if (debug && po.logger != null) { + po.logger.logf("detectProvider: starting fastDetect phase over providers=%d", + po.clients.size()); + } + Map fastErrors = new LinkedHashMap<>(); + for (CloudProviderClient c : po.clients) { + long s = System.nanoTime(); + Exception fe = c.fastDetect(); + long durMs = (System.nanoTime() - s) / 1_000_000L; + if (fe == null) { + if (debug && po.logger != null) { + po.logger.logf("detectProvider: fastDetect SUCCESS provider=%s totalFastPhaseMs=%d", + c.getClass().getSimpleName(), (System.nanoTime() - fastStart) / 1_000_000L); + } + return c; + } else if (debug && po.logger != null) { + po.logger.logf("detectProvider: fastDetect FAIL provider=%s err=%s durationMs=%d", + c.getClass().getSimpleName(), fe.getMessage(), durMs); + } + fastErrors.put(c.getClass().getSimpleName(), fe.getMessage()); + } + Duration timeout = po.timeout == null ? Duration.ofSeconds(5) : po.timeout; // align with Go + // default + if (debug && po.logger != null) { + po.logger.logf("detectProvider: entering concurrent detect phase timeoutMs=%d", + timeout.toMillis()); + } + ExecutorService exec = Executors.newFixedThreadPool(po.clients.size()); + CompletionService cs = new ExecutorCompletionService<>(exec); + List> futures = new ArrayList<>(); + for (CloudProviderClient c : po.clients) { + futures.add(cs.submit(() -> { + long s = System.nanoTime(); + Exception err = c.detect(); + long durMs = (System.nanoTime() - s) / 1_000_000L; + if (debug && po.logger != null) { + po.logger.logf("detectProvider: detect provider=%s result=%s durationMs=%d thread=%s", + c.getClass().getSimpleName(), err == null ? "SUCCESS" : ("ERR:" + err.getMessage()), + durMs, Thread.currentThread().getName()); + } + if (err == null) + return c; + else + throw err; + })); + } + exec.shutdown(); + long deadline = System.nanoTime() + timeout.toNanos(); + List errors = new ArrayList<>(); + Map detectErrors = new LinkedHashMap<>(); + for (int i = 0; i < futures.size(); i++) { + long remainingMs = (deadline - System.nanoTime()) / 1_000_000L; + if (remainingMs <= 0) + break; + try { + Future f = cs.poll(remainingMs, TimeUnit.MILLISECONDS); + if (f == null) + break; // timeout + CloudProviderClient found = f.get(); + for (Future other : futures) + if (!other.isDone()) + other.cancel(true); + return found; + } catch (ExecutionException ee) { + Throwable cause = ee.getCause(); + errors.add(cause); + if (debug && po.logger != null) { + po.logger.logf("detectProvider: provider failed error=%s remainingMs=%d", + cause == null ? "" : cause.getMessage(), remainingMs); + } + String key = cause == null ? "unknown" : cause.getClass().getSimpleName(); + detectErrors.put(key + "@" + i, cause == null ? "null" : cause.getMessage()); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + errors.add(ie); + break; + } + } + for (Future other : futures) + if (!other.isDone()) + other.cancel(true); + if (debug && po.logger != null) { + po.logger.logf("detectProvider: no provider detected errors=%d firstError=%s", errors.size(), + errors.isEmpty() || errors.get(0) == null ? "" : errors.get(0).getMessage()); + } + throw new NoCloudProviderDetectedException( + buildAggregateDetectMessage(fastErrors, detectErrors)); + } + + private static String safeTrunc(String s) { + if (s == null) + return ""; + if (s.length() > 60) + return s.substring(0, 57) + "..."; + return s; + } + + private static String buildAggregateDetectMessage(Map fastErrors, + Map detectErrors) { + List parts = new ArrayList<>(); + String fast = formatSection(fastErrors, 3); + if (!fast.isEmpty()) + parts.add("fast=" + fast); + String detect = formatSection(detectErrors, 3); + if (!detect.isEmpty()) + parts.add("detect=" + detect); + if (parts.isEmpty()) + return "no cloud provider detected"; + return "no cloud provider detected; " + String.join("; ", parts); + } + + private static void appendSection(StringBuilder sb, String label, Map src, + int max) { + if (src.isEmpty()) + return; + sb.append("; ").append(label).append('='); + int n = 0; + for (var e : src.entrySet()) { + if (n++ > 0) + sb.append(','); + sb.append(e.getKey()).append(':').append(safeTrunc(e.getValue())); + if (n >= max && src.size() > max) { + sb.append("+" + (src.size() - max) + "more"); + break; + } + } + } + + private static String formatSection(Map src, int max) { + if (src.isEmpty()) + return ""; + StringBuilder sb = new StringBuilder(); + int n = 0; + for (var e : src.entrySet()) { + if (n++ > 0) + sb.append(','); + sb.append(e.getKey()).append(':').append(safeTrunc(e.getValue())); + if (n >= max && src.size() > max) { + sb.append("+" + (src.size() - max) + "more"); + break; + } + } + return sb.toString(); + } + + private static void applyJwtOptions(JwtOptions o, JwtOption... opts) { + for (JwtOption opt : opts) + opt.apply(o); + if (o.timeout == null) + o.timeout = Duration.ofSeconds(5); + if (o.serverUrl == null || o.serverUrl.isEmpty()) + o.serverUrl = DEFAULT_SERVER; + } + + private static String getJWT(JwtOptions o) throws S2IAMException { + if (o.serverUrl == null || o.serverUrl.isEmpty()) { + throw new S2IAMException("server URL is required"); + } + if (o.provider == null) { + try { + o.provider = detectProvider(); + } catch (NoCloudProviderDetectedException e) { + throw new S2IAMException("failed to detect cloud provider", e); + } + } + // Enforce that audience param (if present) only used for GCP + if (o.additionalParams != null && o.additionalParams.containsKey("audience")) { + CloudProviderType t = o.provider.getType(); + if (t != CloudProviderType.gcp) { + throw new S2IAMException( + "audience parameter is only supported for GCP provider (detected=" + t + ")"); + } + } + boolean debug = "true".equals(System.getenv("S2IAM_DEBUGGING")); + CloudProviderClient provider = o.provider; + if (o.assumeRoleIdentifier != null && !o.assumeRoleIdentifier.isEmpty()) { + String id = o.assumeRoleIdentifier; + switch (provider.getType()) { + case aws: { + if (!id.startsWith("arn:")) + throw new S2IAMException("invalid AWS assumeRoleIdentifier (must start with 'arn:')"); + String[] arnParts = id.split(":"); + if (arnParts.length < 6 || arnParts[2].isEmpty() || arnParts[5].isEmpty()) + throw new S2IAMException("invalid AWS ARN format for assumeRoleIdentifier"); + if (!arnParts[2].equals("iam") && !arnParts[2].equals("sts")) + throw new S2IAMException("AWS assumeRoleIdentifier service must be iam or sts"); + break; + } + case gcp: { + if (!id.contains("@") || !id.endsWith(".gserviceaccount.com")) + throw new S2IAMException( + "invalid GCP assumeRoleIdentifier (expected service account email)"); + break; + } + case azure: { + String s = id.trim(); + if (s.length() != 36 || s.chars().filter(ch -> ch == '-').count() != 4) + throw new S2IAMException( + "invalid Azure assumeRoleIdentifier (expected GUID format xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)"); + try { + java.util.UUID.fromString(s); + } catch (IllegalArgumentException iae) { + throw new S2IAMException("invalid Azure assumeRoleIdentifier (not a valid UUID)"); + } + break; + } + default : + throw new S2IAMException("assumeRoleIdentifier validation not implemented for provider: " + + provider.getType()); + } + provider = provider.assumeRole(id); + } + CloudProviderClient.IdentityHeadersResult res = provider.getIdentityHeaders(o.additionalParams); + if (res.error != null) + throw new S2IAMException("failed to get identity headers", res.error); + CloudIdentity identity = res.identity; + if (identity == null) + throw new S2IAMException("no identity returned by provider"); + + String url = o.serverUrl.replace(":cloudProvider", identity.getProvider().name()) + .replace(":jwtType", o.jwtType.name()); + String query = ""; + if (o.jwtType == JwtOptions.JWTType.database && o.workspaceGroupId != null + && !o.workspaceGroupId.isEmpty()) { + try { + query = "?workspaceGroupID=" + java.net.URLEncoder.encode(o.workspaceGroupId, + java.nio.charset.StandardCharsets.UTF_8); + } catch (Exception e) { + throw new S2IAMException("failed to URL encode workspaceGroupId", e); + } + } + Duration httpTimeout = o.timeout != null ? o.timeout : Timeouts.IDENTITY; // apply option + // timeout + HttpRequest.Builder rb = HttpRequest.newBuilder(URI.create(url + query)).timeout(httpTimeout) + .POST(HttpRequest.BodyPublishers.noBody()).header("User-Agent", USER_AGENT); + for (Map.Entry e : res.headers.entrySet()) + rb.header(e.getKey(), e.getValue()); + + if (debug && identity.getProvider() != null) { + Logger log = Logger.STDOUT; // simple fallback + log.logf("getJWT: requesting jwtType=%s provider=%s url=%s timeoutMs=%d", o.jwtType, + identity.getProvider(), url, httpTimeout.toMillis()); + } + + HttpClient client = HttpClient.newBuilder().connectTimeout(httpTimeout).build(); + HttpResponse response; + try { + response = client.send(rb.build(), HttpResponse.BodyHandlers.ofString()); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new S2IAMException("error calling authentication server (interrupted)", ie); + } catch (IOException ioe) { + throw new S2IAMException("error calling authentication server", ioe); + } + int sc = response.statusCode(); + if (sc != 200) { + throw new S2IAMException( + "authentication server returned status " + sc + ": " + response.body()); + } + try { + JsonNode node = MAPPER.readTree(response.body()); + String jwt = node.path("jwt").asText(); + if (jwt == null || jwt.isEmpty()) + throw new S2IAMException("received empty JWT from server"); + return jwt; + } catch (IOException e) { + throw new S2IAMException("cannot parse response", e); + } + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/S2IAMRequest.java b/java/src/main/java/com/singlestore/s2iam/S2IAMRequest.java new file mode 100644 index 0000000..bcf0b05 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/S2IAMRequest.java @@ -0,0 +1,144 @@ +package com.singlestore.s2iam; + +import com.singlestore.s2iam.exceptions.S2IAMException; +import com.singlestore.s2iam.options.JwtOption; +import com.singlestore.s2iam.options.Options; +import com.singlestore.s2iam.options.ProviderOption; +import java.time.Duration; +import java.util.*; + +/** + * Fluent builder for obtaining JWTs (database or API) with a more idiomatic + * Java experience than the varargs functional option style. It is purely a + * convenience layer over the existing static S2IAM methods and Options helpers. + * + * Usage example: + * + *
{@code
+ * String jwt = S2IAMRequest.newRequest().databaseWorkspaceGroup("wg-123")
+ *     .assumeRole("arn:aws:iam::123456789012:role/MyRole").timeout(Duration.ofSeconds(5)).get();
+ * }
+ */ +public final class S2IAMRequest { + private boolean apiMode = false; + private String workspaceGroupId; + private String assumeRoleId; + private Duration timeout; + private String serverUrl; + private final Map additionalParams = new LinkedHashMap<>(); + private CloudProviderClient provider; // optional explicit provider (skips detection) + + private S2IAMRequest() { + } + + public static S2IAMRequest newRequest() { + return new S2IAMRequest(); + } + + /** Select API JWT mode (no workspace group id). */ + public S2IAMRequest api() { + this.apiMode = true; + this.workspaceGroupId = null; + return this; + } + + /** Select database JWT mode and set the workspace group id. */ + public S2IAMRequest databaseWorkspaceGroup(String workspaceGroupId) { + this.apiMode = false; + this.workspaceGroupId = workspaceGroupId; + return this; + } + + /** + * Optional assume role identifier (provider specific: AWS role ARN, GCP service + * account email, Azure object id). + */ + public S2IAMRequest assumeRole(String assumeRoleId) { + this.assumeRoleId = assumeRoleId; + return this; + } + + /** Overall timeout applied to detection + identity HTTP calls. */ + public S2IAMRequest timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + /** Override the authentication server base URL (e.g., test server). */ + public S2IAMRequest serverUrl(String serverUrl) { + this.serverUrl = serverUrl; + return this; + } + + /** + * Provide explicit provider (e.g., FakeProvider in tests) to skip detection. + */ + public S2IAMRequest provider(CloudProviderClient provider) { + this.provider = provider; + return this; + } + + /** + * Add a raw additional parameter (forwarded as query parameter when supported). + */ + public S2IAMRequest param(String key, String value) { + if (key != null && value != null) + additionalParams.put(key, value); + return this; + } + + /** + * Set audience parameter (currently GCP-only). Using this when the provider is + * not GCP will cause an error during execution for explicit clarity. + */ + public S2IAMRequest audience(String audience) { + if (audience != null) + additionalParams.put("audience", audience); + return this; + } + + /** Execute the request and return the JWT string. */ + public String get() throws S2IAMException { + return execute(); + } + + private String execute() throws S2IAMException { + List jwtOpts = new ArrayList<>(); + List providerOpts = new ArrayList<>(); + if (assumeRoleId != null) + jwtOpts.add(Options.withAssumeRole(assumeRoleId)); + if (serverUrl != null) + jwtOpts.add(Options.withServerUrl(serverUrl)); + if (timeout != null) + providerOpts.add(Options.withTimeout(timeout)); + // Map additional params to existing explicit helpers (currently only audience) + if (additionalParams.containsKey("audience")) { + // Validate provider type if provider already set (explicit builder provider) or + // later after detection + if (provider != null && provider.getType() != CloudProviderType.gcp) { + throw new S2IAMException( + "audience is GCP-only and cannot be used with provider=" + provider.getType()); + } + jwtOpts.add(Options.withAudience(additionalParams.get("audience"))); + } + JwtOption[] jwtArr = jwtOpts.toArray(new JwtOption[0]); + ProviderOption[] providerArr = providerOpts.toArray(new ProviderOption[0]); + if (provider != null) { + jwtOpts.add(com.singlestore.s2iam.options.Options.withProvider(provider)); + jwtArr = jwtOpts.toArray(new JwtOption[0]); + } + if (apiMode) { + if (provider != null && additionalParams.containsKey("audience") + && provider.getType() != CloudProviderType.gcp) { + throw new S2IAMException( + "audience is GCP-only and cannot be used with provider=" + provider.getType()); + } + return S2IAM.getAPIJWT(jwtArr, providerArr); + } + if (workspaceGroupId == null || workspaceGroupId.isEmpty()) { + throw new S2IAMException( + "workspace group id required for database JWT (call api() for API JWT)"); + } + return S2IAM.getDatabaseJWT(workspaceGroupId, jwtArr, providerArr); + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/Timeouts.java b/java/src/main/java/com/singlestore/s2iam/Timeouts.java new file mode 100644 index 0000000..d4f19de --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/Timeouts.java @@ -0,0 +1,25 @@ +package com.singlestore.s2iam; + +import java.time.Duration; + +/** + * Centralized timeout constants to keep provider client behavior consistent and + * enable quick tuning. Values chosen to keep overall test runtime low while + * allowing a modest network RTT on real cloud VMs. + */ +public final class Timeouts { + private Timeouts() { + } + + // Metadata detection (allow slower clouds / transient slowness) + public static final Duration DETECT = Duration.ofSeconds(5); + + // Identity / token retrieval baseline (metadata tokens, STS, MI, etc.) + public static final Duration IDENTITY = Duration.ofSeconds(10); + + // Secondary / follow-up metadata probes (instance details, subscription, etc.) + public static final Duration SECONDARY = Duration.ofSeconds(5); + + // Extended identity operations (future impersonation / long STS chains) + public static final Duration IDENTITY_EXTENDED = Duration.ofSeconds(15); +} diff --git a/java/src/main/java/com/singlestore/s2iam/exceptions/NoCloudProviderDetectedException.java b/java/src/main/java/com/singlestore/s2iam/exceptions/NoCloudProviderDetectedException.java new file mode 100644 index 0000000..2c0dfc1 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/exceptions/NoCloudProviderDetectedException.java @@ -0,0 +1,11 @@ +package com.singlestore.s2iam.exceptions; + +public class NoCloudProviderDetectedException extends S2IAMException { + public NoCloudProviderDetectedException(String msg) { + super(msg); + } + + public NoCloudProviderDetectedException(String msg, Throwable cause) { + super(msg, cause); + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/exceptions/S2IAMException.java b/java/src/main/java/com/singlestore/s2iam/exceptions/S2IAMException.java new file mode 100644 index 0000000..89af4ba --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/exceptions/S2IAMException.java @@ -0,0 +1,11 @@ +package com.singlestore.s2iam.exceptions; + +public class S2IAMException extends Exception { + public S2IAMException(String message) { + super(message); + } + + public S2IAMException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/options/JwtOption.java b/java/src/main/java/com/singlestore/s2iam/options/JwtOption.java new file mode 100644 index 0000000..1f050a9 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/options/JwtOption.java @@ -0,0 +1,5 @@ +package com.singlestore.s2iam.options; + +public interface JwtOption { + void apply(JwtOptions o); +} diff --git a/java/src/main/java/com/singlestore/s2iam/options/JwtOptions.java b/java/src/main/java/com/singlestore/s2iam/options/JwtOptions.java new file mode 100644 index 0000000..dfbc4b7 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/options/JwtOptions.java @@ -0,0 +1,18 @@ +package com.singlestore.s2iam.options; + +import com.singlestore.s2iam.CloudProviderClient; +import java.util.HashMap; +import java.util.Map; + +public class JwtOptions extends ProviderOptions { + public enum JWTType { + database, api + } + + public JWTType jwtType; + public String workspaceGroupId; + public String serverUrl; + public CloudProviderClient provider; + public Map additionalParams = new HashMap<>(); + public String assumeRoleIdentifier; +} diff --git a/java/src/main/java/com/singlestore/s2iam/options/Options.java b/java/src/main/java/com/singlestore/s2iam/options/Options.java new file mode 100644 index 0000000..b9cfaeb --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/options/Options.java @@ -0,0 +1,38 @@ +package com.singlestore.s2iam.options; + +import com.singlestore.s2iam.CloudProviderClient; +import java.time.Duration; + +public final class Options { + private Options() { + } + + public static JwtOption withServerUrl(String url) { + return o -> o.serverUrl = url; + } + + public static JwtOption withProvider(CloudProviderClient provider) { + return o -> o.provider = provider; + } + + public static JwtOption withAudience(String aud) { + return o -> o.additionalParams.put("audience", aud); + } + + public static JwtOption withAssumeRole(String role) { + return o -> o.assumeRoleIdentifier = role; + } + + // Re-export provider options for convenience + public static ProviderOption withTimeout(Duration d) { + return ProviderOption.withTimeout(d); + } + + public static ProviderOption withLogger(com.singlestore.s2iam.Logger l) { + return ProviderOption.withLogger(l); + } + + public static ProviderOption withClients(java.util.List c) { + return ProviderOption.withClients(c); + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/options/ProviderOption.java b/java/src/main/java/com/singlestore/s2iam/options/ProviderOption.java new file mode 100644 index 0000000..cd9430c --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/options/ProviderOption.java @@ -0,0 +1,22 @@ +package com.singlestore.s2iam.options; + +import com.singlestore.s2iam.CloudProviderClient; +import com.singlestore.s2iam.Logger; +import java.time.Duration; +import java.util.List; + +public interface ProviderOption { + void apply(ProviderOptions o); + + static ProviderOption withLogger(Logger logger) { + return o -> o.logger = logger; + } + + static ProviderOption withClients(List clients) { + return o -> o.clients = clients; + } + + static ProviderOption withTimeout(Duration timeout) { + return o -> o.timeout = timeout; + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/options/ProviderOptions.java b/java/src/main/java/com/singlestore/s2iam/options/ProviderOptions.java new file mode 100644 index 0000000..bd87c9b --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/options/ProviderOptions.java @@ -0,0 +1,12 @@ +package com.singlestore.s2iam.options; + +import com.singlestore.s2iam.CloudProviderClient; +import com.singlestore.s2iam.Logger; +import java.time.Duration; +import java.util.List; + +public class ProviderOptions { + public Logger logger; + public List clients; + public Duration timeout = Duration.ofSeconds(15); +} diff --git a/java/src/main/java/com/singlestore/s2iam/options/ServerUrlOption.java b/java/src/main/java/com/singlestore/s2iam/options/ServerUrlOption.java new file mode 100644 index 0000000..cf9cbd5 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/options/ServerUrlOption.java @@ -0,0 +1,18 @@ +package com.singlestore.s2iam.options; + +public class ServerUrlOption implements JwtOption { + private final String url; + + public ServerUrlOption(String url) { + this.url = url; + } + + @Override + public void apply(JwtOptions o) { + o.serverUrl = url; + } + + public static ServerUrlOption of(String url) { + return new ServerUrlOption(url); + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/options/WithProviders.java b/java/src/main/java/com/singlestore/s2iam/options/WithProviders.java new file mode 100644 index 0000000..aefd546 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/options/WithProviders.java @@ -0,0 +1,22 @@ +package com.singlestore.s2iam.options; + +import com.singlestore.s2iam.CloudProviderClient; +import java.util.Arrays; +import java.util.List; + +public class WithProviders implements ProviderOption { + private final List clients; + + public WithProviders(CloudProviderClient... c) { + this.clients = Arrays.asList(c); + } + + @Override + public void apply(ProviderOptions o) { + o.clients = clients; + } + + public static WithProviders of(CloudProviderClient... c) { + return new WithProviders(c); + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/providers/AbstractBaseClient.java b/java/src/main/java/com/singlestore/s2iam/providers/AbstractBaseClient.java new file mode 100644 index 0000000..52fb736 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/providers/AbstractBaseClient.java @@ -0,0 +1,33 @@ +package com.singlestore.s2iam.providers; + +import com.singlestore.s2iam.*; +import java.util.Map; + +/** Base client with default unsupported identity header retrieval. */ +public abstract class AbstractBaseClient implements CloudProviderClient { + protected final Logger logger; + protected final String assumedRole; + + protected AbstractBaseClient(Logger logger, String assumedRole) { + this.logger = logger; + this.assumedRole = assumedRole; + } + + @Override + public CloudProviderClient assumeRole(String roleIdentifier) { + return newInstance(logger, roleIdentifier); + } + + protected abstract CloudProviderClient newInstance(Logger logger, String assumedRole); + + @Override + public IdentityHeadersResult getIdentityHeaders(Map additionalParams) { + return new IdentityHeadersResult(null, null, + new IllegalStateException("identity retrieval not implemented")); + } + + @Override + public Exception fastDetect() { + return new IllegalStateException("not detected"); + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/providers/aws/AWSClient.java b/java/src/main/java/com/singlestore/s2iam/providers/aws/AWSClient.java new file mode 100644 index 0000000..852cac7 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/providers/aws/AWSClient.java @@ -0,0 +1,204 @@ +package com.singlestore.s2iam.providers.aws; + +import com.singlestore.s2iam.*; +import com.singlestore.s2iam.providers.AbstractBaseClient; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.HashMap; +import java.util.Map; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; +import software.amazon.awssdk.services.sts.model.AssumeRoleResponse; +import software.amazon.awssdk.services.sts.model.GetCallerIdentityRequest; +import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse; + +public class AWSClient extends AbstractBaseClient { + // Detect order: (1) environment hints (fast), (2) IMDSv2 token endpoint, (3) + // legacy metadata path. + // Identity headers always reflect either the base credentials or an assumed + // role (if provided). + private static final String METADATA_BASE = System.getenv() + .getOrDefault("S2IAM_AWS_METADATA_BASE", "http://169.254.169.254"); + private volatile StsClient sts; + private AwsCredentialsProvider baseProvider; + + public AWSClient(Logger logger) { + super(logger, null); + } + private AWSClient(Logger logger, String assumed) { + super(logger, assumed); + } + + @Override + protected CloudProviderClient newInstance(Logger logger, String assumedRole) { + return new AWSClient(logger, assumedRole); + } + @Override + public CloudProviderType getType() { + return CloudProviderType.aws; + } + + @Override + public Exception detect() { + String[] envs = {"AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_EXECUTION_ENV", + "AWS_REGION", "AWS_DEFAULT_REGION", "AWS_LAMBDA_FUNCTION_NAME"}; + for (String e : envs) + if (System.getenv(e) != null && !System.getenv(e).isEmpty()) + return null; + HttpClient client = HttpClient.newBuilder().connectTimeout(Timeouts.DETECT).build(); + boolean debug = "true".equals(System.getenv("S2IAM_DEBUGGING")) && logger != null; + try { + HttpRequest tokenReq = HttpRequest.newBuilder(URI.create(METADATA_BASE + "/latest/api/token")) + .timeout(Timeouts.DETECT).header("X-aws-ec2-metadata-token-ttl-seconds", "60") + .method("PUT", HttpRequest.BodyPublishers.noBody()).build(); + HttpResponse tokenResp = client.send(tokenReq, HttpResponse.BodyHandlers.ofString()); + if (tokenResp.statusCode() == 200) + return null; + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + return ie; + } catch (Exception ignored) { + if (debug) + logger.logf("AWSClient.detect: token endpoint error class=%s msg=%s", + ignored.getClass().getSimpleName(), ignored.getMessage()); + } + try { + HttpRequest req = HttpRequest.newBuilder(URI.create(METADATA_BASE + "/latest/meta-data/")) + .timeout(Timeouts.DETECT).GET().build(); + HttpResponse resp = client.send(req, HttpResponse.BodyHandlers.discarding()); + if (resp.statusCode() == 200) + return null; + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + return ie; + } catch (IOException e) { + if (debug) + logger.logf("AWSClient.detect: metadata path IO error class=%s msg=%s", + e.getClass().getSimpleName(), e.getMessage()); + return e; + } + return new IllegalStateException("not running on AWS"); + } + + @Override + public Exception fastDetect() { + String prop = System.getProperty("s2iam.test.awsFast", ""); + if (!prop.isEmpty()) + return null; + String[] envs = {"AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_EXECUTION_ENV", + "AWS_REGION", "AWS_DEFAULT_REGION", "AWS_LAMBDA_FUNCTION_NAME"}; + for (String e : envs) { + String v = System.getenv(e); + if (v != null && !v.isEmpty()) + return null; + } + return new Exception("no aws fast path"); + } + + @Override + public IdentityHeadersResult getIdentityHeaders(Map additionalParams) { + try { + ensureSTS(); + GetCallerIdentityResponse who = sts + .getCallerIdentity(GetCallerIdentityRequest.builder().build()); + AwsCredentials baseCreds = baseProvider.resolveCredentials(); + Map headers = new HashMap<>(); + headers.put("X-AWS-Access-Key-ID", baseCreds.accessKeyId()); + if (baseCreds.secretAccessKey() != null) + headers.put("X-AWS-Secret-Access-Key", baseCreds.secretAccessKey()); + if (baseCreds instanceof AwsSessionCredentials) { + String token = ((AwsSessionCredentials) baseCreds).sessionToken(); + if (token != null && !token.isEmpty()) + headers.put("X-AWS-Session-Token", token); + } + String arn; + String account; + String resourceType; + String region; + if (assumedRole != null && !assumedRole.isEmpty()) { + AssumeRoleResponse assume = sts.assumeRole(AssumeRoleRequest.builder().roleArn(assumedRole) + .roleSessionName("SingleStoreAuth-" + (System.currentTimeMillis() / 1000L)) + .durationSeconds(3600).build()); + headers.put("X-AWS-Access-Key-ID", assume.credentials().accessKeyId()); + headers.put("X-AWS-Secret-Access-Key", assume.credentials().secretAccessKey()); + headers.put("X-AWS-Session-Token", assume.credentials().sessionToken()); + StsClient temp = StsClient.builder().region(sts.serviceClientConfiguration().region()) + .credentialsProvider( + () -> AwsSessionCredentials.create(assume.credentials().accessKeyId(), + assume.credentials().secretAccessKey(), assume.credentials().sessionToken())) + .build(); + GetCallerIdentityResponse assumedIdentity = temp + .getCallerIdentity(GetCallerIdentityRequest.builder().build()); + account = assumedIdentity.account(); + region = deriveRegion(assumedIdentity.arn()); + resourceType = deriveResourceTypeDetailed(assumedIdentity.arn()); + arn = assumedRole; + } else { + arn = who.arn(); + account = who.account(); + region = deriveRegion(arn); + resourceType = deriveResourceTypeDetailed(arn); + if (!headers.containsKey("X-AWS-Session-Token") + && System.getenv("AWS_SESSION_TOKEN") != null) { + headers.put("X-AWS-Session-Token", System.getenv("AWS_SESSION_TOKEN")); + } + } + Map extra = new HashMap<>(); + extra.put("account", account); + if (who.userId() != null && !who.userId().isEmpty()) + extra.put("userId", who.userId()); + CloudIdentity identity = new CloudIdentity(CloudProviderType.aws, arn, account, region, + resourceType, extra); + return new IdentityHeadersResult(headers, identity, null); + } catch (Exception e) { + return new IdentityHeadersResult(null, null, e); + } + } + + private void ensureSTS() { + if (sts != null) + return; + synchronized (this) { + if (sts == null) { + String region = System.getenv().getOrDefault("AWS_REGION", + System.getenv().getOrDefault("AWS_DEFAULT_REGION", "us-east-1")); + baseProvider = DefaultCredentialsProvider.create(); + sts = StsClient.builder().region(Region.of(region)).credentialsProvider(baseProvider) + .build(); + } + } + } + private static String deriveRegion(String arn) { + String[] parts = arn.split(":"); + return parts.length > 3 ? parts[3] : ""; + } + private static String deriveResourceTypeDetailed(String arn) { + if (arn.contains(":instance/")) + return "ec2"; + if (arn.contains(":assumed-role/")) + return "assumed-role"; + if (arn.contains(":role/")) + return "role"; + if (arn.contains(":user/")) + return "user"; + if (arn.contains(":lambda:")) + return "lambda"; + if (arn.contains(":task/")) + return "ecs-task"; + if (arn.contains(":cluster/")) + return "ecs-cluster"; + if (arn.contains(":function:")) + return "lambda"; + if (arn.contains(":iam::")) + return "iam"; + return "aws"; + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/providers/azure/AzureClient.java b/java/src/main/java/com/singlestore/s2iam/providers/azure/AzureClient.java new file mode 100644 index 0000000..3a3fe54 --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/providers/azure/AzureClient.java @@ -0,0 +1,205 @@ +package com.singlestore.s2iam.providers.azure; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.singlestore.s2iam.*; +import com.singlestore.s2iam.providers.AbstractBaseClient; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Pattern; + +public class AzureClient extends AbstractBaseClient { + public AzureClient(Logger logger) { + super(logger, null); + } + private AzureClient(Logger logger, String assumed) { + super(logger, assumed); + } + @Override + protected CloudProviderClient newInstance(Logger logger, String assumedRole) { + return new AzureClient(logger, assumedRole); + } + @Override + public CloudProviderType getType() { + return CloudProviderType.azure; + } + @Override + public Exception detect() { + if (System.getenv("AZURE_FEDERATED_TOKEN_FILE") != null || System.getenv("MSI_ENDPOINT") != null + || System.getenv("IDENTITY_ENDPOINT") != null) + return null; + boolean debug = "true".equals(System.getenv("S2IAM_DEBUGGING")) && logger != null; + HttpClient client = HttpClient.newBuilder().connectTimeout(Timeouts.DETECT).build(); + try { + HttpRequest req = HttpRequest + .newBuilder(URI.create("http://169.254.169.254/metadata/instance?api-version=2021-02-01")) + .header("Metadata", "true").timeout(Timeouts.DETECT).GET().build(); + HttpResponse resp = client.send(req, HttpResponse.BodyHandlers.discarding()); + if (resp.statusCode() == 200) { + try { + HttpRequest miReq = HttpRequest.newBuilder(URI.create( + "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://management.azure.com/")) + .header("Metadata", "true").timeout(Timeouts.DETECT).GET().build(); + HttpResponse miResp = client.send(miReq, HttpResponse.BodyHandlers.discarding()); + int sc = miResp.statusCode(); + if (sc == 200) { + if (debug) + logger.logf("AzureClient.detect: classification=azure-with-mi status=%d", sc); + return null; + } + if (sc == 400 || sc == 403 || sc == 404) { + if (debug) + logger.logf("AzureClient.detect: classification=azure-no-role miStatus=%d", sc); + return null; + } + if (debug) + logger.logf("AzureClient.detect: classification=azure-mi-other status=%d (accepted)", + sc); + return null; + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + if (debug) + logger.logf("AzureClient.detect: classification=azure-mi-interrupted err=%s", + ie.getMessage()); + return null; + } catch (IOException e) { + if (debug) + logger.logf("AzureClient.detect: classification=azure-mi-io err=%s", e.getMessage()); + return null; + } + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + return ie; + } catch (IOException e) { + return e; + } + return new IllegalStateException("not running on Azure"); + } + @Override + public Exception fastDetect() { + String prop = System.getProperty("s2iam.test.azureFast", ""); + if (!prop.isEmpty()) + return null; + String tokenFile = System.getenv("AZURE_FEDERATED_TOKEN_FILE"); + String cid = System.getenv("AZURE_CLIENT_ID"); + String tid = System.getenv("AZURE_TENANT_ID"); + if (tokenFile != null && !tokenFile.isEmpty() && cid != null && !cid.isEmpty() && tid != null + && !tid.isEmpty()) + return null; + return new Exception("no azure fast path"); + } + @Override + public IdentityHeadersResult getIdentityHeaders(Map additionalParams) { + String resource = additionalParams.getOrDefault("azure_resource", + "https://management.azure.com/"); + String url = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=" + + resource; + HttpClient client = HttpClient.newBuilder().connectTimeout(Timeouts.IDENTITY).build(); + try { + HttpRequest req = HttpRequest.newBuilder(URI.create(url)).header("Metadata", "true") + .timeout(Timeouts.IDENTITY).GET().build(); + HttpResponse resp = client.send(req, HttpResponse.BodyHandlers.ofString()); + if (resp.statusCode() != 200 || resp.body().isEmpty()) { + return new IdentityHeadersResult(null, null, + new IllegalStateException("failed to get Azure MI token status=" + resp.statusCode())); + } + String body = resp.body(); + ObjectMapper om = new ObjectMapper(); + JsonNode node = om.readTree(body); + String token = node.path("access_token").asText(); + if (token == null || token.isEmpty()) + return new IdentityHeadersResult(null, null, new IllegalStateException("no access_token")); + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + token); + String accessToken = token; + String[] parts = accessToken.split("\\."); + String tenantId = ""; + String principalId = ""; + String clientId = node.path("client_id").asText(); + String subscriptionId = ""; + String region = ""; + String resourceType = "unknown"; + Map extra = new HashMap<>(); + ObjectMapper payloadMapper = om; + if (parts.length >= 2) { + try { + String payloadJson = new String(java.util.Base64.getUrlDecoder().decode(parts[1])); + JsonNode payload = payloadMapper.readTree(payloadJson); + if (payload.has("oid")) + principalId = payload.get("oid").asText(""); + else if (payload.has("sub")) + principalId = payload.get("sub").asText(""); + else if (payload.has("appid")) + principalId = payload.get("appid").asText(""); + if (payload.has("iss")) { + String iss = payload.get("iss").asText(""); + extra.put("iss", iss); + String[] segs = iss.split("/"); + for (int i = 0; i < segs.length; i++) + if ("tokens".equals(segs[i]) && i > 0) { + tenantId = segs[i - 1]; + break; + } + } + if (payload.has("xms_mirid")) { + String mirid = payload.get("xms_mirid").asText(""); + extra.put("xms_mirid", mirid); + String[] p = mirid.split("/"); + for (int i = 0; i < p.length; i++) { + if ("subscriptions".equals(p[i]) && i + 1 < p.length) + subscriptionId = p[i + 1]; + if ("providers".equals(p[i]) && i + 1 < p.length) + resourceType = p[i + 1]; + } + } + } catch (Exception ignored) { + } + } + if (principalId.isEmpty()) + principalId = clientId; + if (subscriptionId.isEmpty()) { + try { + HttpRequest instReq = HttpRequest + .newBuilder( + URI.create("http://169.254.169.254/metadata/instance?api-version=2021-02-01")) + .header("Metadata", "true").timeout(Timeouts.SECONDARY).GET().build(); + HttpResponse instResp = client.send(instReq, + HttpResponse.BodyHandlers.ofString()); + if (instResp.statusCode() == 200) { + try { + JsonNode inst = om.readTree(instResp.body()); + JsonNode compute = inst.path("compute"); + if (subscriptionId.isEmpty()) + subscriptionId = compute.path("subscriptionId").asText(""); + if (region.isEmpty()) + region = compute.path("location").asText(region); + } catch (Exception ignored) { + } + } + } catch (Exception ignored) { + } + } + Pattern guidPattern = Pattern.compile("^[0-9a-fA-F-]{32,36}$"); + if (principalId.isEmpty() || !guidPattern.matcher(principalId).find()) { + return new IdentityHeadersResult(headers, null, + new IllegalStateException("invalid principalId")); + } + if (!subscriptionId.isEmpty()) + extra.put("subscriptionId", subscriptionId); + if (!clientId.isEmpty()) + extra.put("clientId", clientId); + extra.put("principalId", principalId); + CloudIdentity identity = new CloudIdentity(CloudProviderType.azure, principalId, tenantId, + region, resourceType, extra); + return new IdentityHeadersResult(headers, identity, null); + } catch (Exception e) { + return new IdentityHeadersResult(null, null, e); + } + } +} diff --git a/java/src/main/java/com/singlestore/s2iam/providers/gcp/GCPClient.java b/java/src/main/java/com/singlestore/s2iam/providers/gcp/GCPClient.java new file mode 100644 index 0000000..9198acb --- /dev/null +++ b/java/src/main/java/com/singlestore/s2iam/providers/gcp/GCPClient.java @@ -0,0 +1,205 @@ +package com.singlestore.s2iam.providers.gcp; + +import com.singlestore.s2iam.*; +import com.singlestore.s2iam.providers.AbstractBaseClient; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class GCPClient extends AbstractBaseClient { + public GCPClient(Logger logger) { + super(logger, null); + } + private GCPClient(Logger logger, String assumed) { + super(logger, assumed); + } + @Override + protected CloudProviderClient newInstance(Logger logger, String assumedRole) { + return new GCPClient(logger, assumedRole); + } + @Override + public CloudProviderType getType() { + return CloudProviderType.gcp; + } + @Override + public Exception detect() { + if (System.getenv("GOOGLE_APPLICATION_CREDENTIALS") != null + || System.getenv("GCE_METADATA_HOST") != null) + return null; + boolean debug = "true".equals(System.getenv("S2IAM_DEBUGGING")) && logger != null; + String ep = "http://metadata.google.internal/computeMetadata/v1/instance/id"; + HttpClient client = HttpClient.newBuilder().connectTimeout(Timeouts.DETECT).build(); + try { + HttpRequest req = HttpRequest.newBuilder(URI.create(ep)).header("Metadata-Flavor", "Google") + .timeout(Timeouts.DETECT).GET().build(); + long start = System.nanoTime(); + HttpResponse resp = client.send(req, HttpResponse.BodyHandlers.discarding()); + long durMs = (System.nanoTime() - start) / 1_000_000L; + if (resp.statusCode() == 200) { + if (debug) + logger.logf("GCPClient.detect: success endpoint=%s status=%d durationMs=%d", ep, + resp.statusCode(), durMs); + return null; + } else { + if (debug) + logger.logf("GCPClient.detect: non-200 endpoint=%s status=%d durationMs=%d", ep, + resp.statusCode(), durMs); + return new IllegalStateException("metadata status=" + resp.statusCode()); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + if (debug) + logger.logf("GCPClient.detect: interrupted endpoint=%s err=%s", ep, ie.getMessage()); + return ie; + } catch (IOException ioe) { + if (debug) + logger.logf("GCPClient.detect: io error endpoint=%s err=%s class=%s", ep, ioe.getMessage(), + ioe.getClass().getSimpleName()); + return ioe; + } catch (Exception e) { + if (debug) + logger.logf("GCPClient.detect: other error endpoint=%s err=%s class=%s", ep, e.getMessage(), + e.getClass().getSimpleName()); + return e; + } + } + @Override + public Exception fastDetect() { + String prop = System.getProperty("s2iam.test.gcpFast", ""); + if (!prop.isEmpty()) + return null; + String creds = System.getenv("GOOGLE_APPLICATION_CREDENTIALS"); + if (creds != null && creds.endsWith(".json")) + return null; + String host = System.getenv("GCE_METADATA_HOST"); + if (host != null && !host.isEmpty()) + return null; + return new IllegalStateException("fast detect: not gcp"); + } + @Override + public IdentityHeadersResult getIdentityHeaders(Map additionalParams) { + String audience = additionalParams.getOrDefault("audience", "https://authsvc.singlestore.com/"); + if (audience.endsWith("/")) + audience = audience.substring(0, audience.length() - 1); + HttpClient client = HttpClient.newBuilder().connectTimeout(Timeouts.IDENTITY).build(); + try { + String url = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity?audience=" + + audience + "&format=full"; + HttpRequest req = HttpRequest.newBuilder(URI.create(url)).header("Metadata-Flavor", "Google") + .timeout(Timeouts.IDENTITY).GET().build(); + HttpResponse resp = client.send(req, HttpResponse.BodyHandlers.ofString()); + if (resp.statusCode() != 200 || resp.body().isEmpty()) { + if (resp.statusCode() == 404) { + return new IdentityHeadersResult(null, null, + new IllegalStateException("gcp-no-role-identity-unavailable-404")); + } + return new IdentityHeadersResult(null, null, new IllegalStateException( + "failed to get GCP identity token status=" + resp.statusCode())); + } + Map headers = new HashMap<>(); + String token = resp.body(); + headers.put("Authorization", "Bearer " + token); + CloudIdentity identity = parseGCPIdentity(token); + return new IdentityHeadersResult(headers, identity, null); + } catch (Exception e) { + return new IdentityHeadersResult(null, null, e); + } + } + private CloudIdentity parseGCPIdentity(String jwt) { + try { + String[] parts = jwt.split("\\."); + if (parts.length < 2) { + return new CloudIdentity(CloudProviderType.gcp, "", "", "", "", Map.of()); + } + String jsonStr = new String(Base64.getUrlDecoder().decode(parts[1])); + JsonNode root = OM.readTree(jsonStr); + String sub = optText(root, "sub"); + String email = optText(root, "email"); + String identifier = sub == null ? "" : sub; + if (email != null && !email.isEmpty()) { + String emailVerified = optText(root, "email_verified"); + if ("true".equalsIgnoreCase(emailVerified)) { + identifier = email; + } + } + String resourceType = "instance"; + String region = ""; + JsonNode ce = root.get("google"); + if (ce == null) { + ce = root.get("compute_engine"); + } + if (ce != null && ce.isObject()) { + JsonNode zoneNode = ce.get("zone"); + if (zoneNode == null) { + JsonNode ce2 = root.get("compute_engine"); + if (ce2 != null && ce2.get("zone") != null) + zoneNode = ce2.get("zone"); + } + if (zoneNode != null && zoneNode.isTextual()) { + region = deriveRegionFromZone(zoneNode.asText()); + } + if (ce.get("instance_id") != null) + resourceType = "instance"; + } else { + int idx = jsonStr.indexOf("\"zone\":\""); + if (idx >= 0) { + int s = idx + 8; + int e = jsonStr.indexOf('"', s); + if (e > s) { + region = deriveRegionFromZone(jsonStr.substring(s, e)); + } + } + } + Map extra = new HashMap<>(); + for (String key : new String[]{"sub", "email", "aud", "iss", "azp", "kid", "project_number", + "project_id"}) { + String v = optText(root, key); + if (v != null && !v.isEmpty()) + extra.put(key, v); + } + JsonNode ceNode = root.get("google"); + if (ceNode != null && ceNode.get("compute_engine") != null) + ceNode = ceNode.get("compute_engine"); + if (ceNode == null) + ceNode = root.get("compute_engine"); + if (ceNode != null && ceNode.isObject()) { + copyIfText(extra, ceNode, "instance_id"); + copyIfText(extra, ceNode, "project_id"); + copyIfText(extra, ceNode, "zone"); + } + return new CloudIdentity(CloudProviderType.gcp, identifier, sub == null ? identifier : sub, + region, resourceType, extra); + } catch (Exception e) { + return new CloudIdentity(CloudProviderType.gcp, "", "", "", "", Map.of()); + } + } + private static String optText(JsonNode n, String field) { + JsonNode c = n.get(field); + return c != null && !c.isNull() ? c.asText() : null; + } + private static void copyIfText(Map dest, JsonNode node, String field) { + JsonNode v = node.get(field); + if (v != null && v.isTextual()) + dest.put(field, v.asText()); + } + private static String deriveRegionFromZone(String zoneVal) { + String zone = zoneVal; + if (zone.contains("/")) { + String[] parts = zone.split("/"); + zone = parts[parts.length - 1]; + } + String[] segs = zone.split("-"); + if (segs.length >= 3) { + return String.join("-", java.util.Arrays.copyOf(segs, segs.length - 1)); + } + return ""; + } + private static final ObjectMapper OM = new ObjectMapper(); +} diff --git a/java/src/test/java/com/singlestore/s2iam/FastPathDetectionTest.java b/java/src/test/java/com/singlestore/s2iam/FastPathDetectionTest.java new file mode 100644 index 0000000..90e86c9 --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/FastPathDetectionTest.java @@ -0,0 +1,66 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; + +/** Parity fast-path detection tests (local only, skipped on real cloud). */ +public class FastPathDetectionTest { + + private boolean isCloudEnv() { + return env("S2IAM_TEST_CLOUD_PROVIDER") || env("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") + || env("S2IAM_TEST_ASSUME_ROLE"); + } + + private boolean env(String k) { + return System.getenv(k) != null && !System.getenv(k).isEmpty(); + } + + @Test + void fastPathAWSViaEnv() throws Exception { + Assumptions.assumeFalse(isCloudEnv(), "local-only fast path test"); + // simulate fast path via dedicated system property hook + System.setProperty("s2iam.test.awsFast", "true"); + try { + CloudProviderClient c = S2IAM.detectProvider(); + assertEquals("aws", c.getType().name()); + } catch (NoCloudProviderDetectedException e) { + fail("expected fast path AWS detection"); + } finally { + System.clearProperty("s2iam.test.awsFast"); + } + } + + @Test + void fastPathGCPViaCredentials() throws Exception { + Assumptions.assumeFalse(isCloudEnv(), "local-only fast path test"); + System.setProperty("s2iam.test.gcpFast", "true"); + try { + CloudProviderClient c = S2IAM.detectProvider(); + assertEquals("gcp", c.getType().name()); + } catch (NoCloudProviderDetectedException e) { + fail("expected fast path GCP detection"); + } finally { + System.clearProperty("s2iam.test.gcpFast"); + } + } + + @Test + void fastPathAzureViaFederatedToken() throws Exception { + Assumptions.assumeFalse(isCloudEnv(), "local-only fast path test"); + System.setProperty("s2iam.test.azureFast", "true"); + try { + CloudProviderClient c = S2IAM.detectProvider(); + assertEquals("azure", c.getType().name()); + } catch (NoCloudProviderDetectedException e) { + fail("expected fast path Azure detection"); + } finally { + System.clearProperty("s2iam.test.azureFast"); + } + } + + // Local fast-path tests rely only on dedicated system properties + // (s2iam.test.*Fast). +} diff --git a/java/src/test/java/com/singlestore/s2iam/GoTestServer.java b/java/src/test/java/com/singlestore/s2iam/GoTestServer.java new file mode 100644 index 0000000..6f86c9e --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/GoTestServer.java @@ -0,0 +1,213 @@ +package com.singlestore.s2iam; + +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.util.*; +import java.util.concurrent.TimeUnit; + +/** + * Lightweight manager to build and run the Go test server for integration tests + * (info-file based startup). + */ +class GoTestServer { + private Process process; + private int port = -1; + private final Path goDir; + private final List flags; + private Path infoFile; // path to JSON info file + private Map endpoints = new HashMap<>(); + + GoTestServer(Path repoRoot, String... extraFlags) { + this.goDir = resolveGoDir(repoRoot); + this.flags = new ArrayList<>(); + for (String f : extraFlags) + flags.add(f); + } + + private Path resolveGoDir(Path start) { + // If passed path has go/ then use it, else walk up a few levels + Path p = start.toAbsolutePath(); + for (int i = 0; i < 4; i++) { + if (Files.exists(p.resolve("go").resolve("go.mod"))) + return p.resolve("go"); + p = p.getParent(); + if (p == null) + break; + } + return start.resolve("go"); + } + + int getPort() { + return port; + } + + String getBaseURL() { + return "http://localhost:" + port; + } + + Map getEndpoints() { + return endpoints; + } + + void start() throws Exception { + if (process != null) + return; + if (!Files.exists(goDir.resolve("go.mod"))) { + throw new IllegalStateException("go dir missing go.mod: " + goDir); + } + // Build server binary with size-reducing flags (-s -w) and trimpath. Retry once + // if ENOSPC or + // cache corruption (.partial leftover) is detected. Avoid unconditional 'go + // clean' because it + // slows builds and can introduce transient races creating partially-downloaded + // modules when + // multiple servers build in quick succession. + IllegalStateException firstFailure = null; + try { + run(new ProcessBuilder("go", "build", "-trimpath", "-ldflags", "-s -w", "-o", + "s2iam_test_server", "./cmd/s2iam_test_server").directory(goDir.toFile())); + } catch (IllegalStateException e) { + firstFailure = e; + String msg = e.getMessage() == null ? "" : e.getMessage(); + if (msg.contains("no space left") || msg.contains(".partial")) { + // Targeted cleanup then force full rebuild of all packages + try { + run(new ProcessBuilder("go", "clean", "-cache", "-modcache").directory(goDir.toFile())); + } catch (Exception ignored) { + } + run(new ProcessBuilder("go", "build", "-a", "-trimpath", "-ldflags", "-s -w", "-o", + "s2iam_test_server", "./cmd/s2iam_test_server").directory(goDir.toFile())); + } else { + throw e; // Non-space issue: propagate immediately + } + } + if (!Files.exists(goDir.resolve("s2iam_test_server"))) { + // Provide context from first failure if available + if (firstFailure != null) + throw firstFailure; + throw new IllegalStateException("build failed - no binary (unknown reason)"); + } + // Prepare info file path inside goDir (avoids needing temp outside repo for + // simplicity) + // Use a unique temp info file per server instance to avoid cross-test + // contention + infoFile = Files.createTempFile(goDir, "s2iam_test_server_info", ".json"); + + List cmd = new ArrayList<>(); + cmd.add("./s2iam_test_server"); + cmd.add("--port"); + cmd.add("0"); + cmd.add("--info-file"); + cmd.add(infoFile.toString()); + cmd.add("--allowed-audiences"); + cmd.add("https://authsvc.singlestore.com,https://test.example.com"); + cmd.add("--timeout"); + cmd.add("2m"); + cmd.addAll(flags); + + ProcessBuilder pb = new ProcessBuilder(cmd).directory(goDir.toFile()); + // We intentionally drop stdout; errors to stderr for visibility in failures + pb.redirectOutput(ProcessBuilder.Redirect.DISCARD); + pb.redirectError(ProcessBuilder.Redirect.INHERIT); + process = pb.start(); + + // Poll info file for up to 10s (cloud VMs can be a bit slower, especially under + // load) + long deadline = System.currentTimeMillis() + Duration.ofSeconds(10).toMillis(); + Exception lastErr = null; + while (System.currentTimeMillis() < deadline) { + if (!process.isAlive()) { + throw new IllegalStateException("server exited early before writing info file"); + } + if (Files.exists(infoFile)) { + // Only attempt parse if file has non-zero size (avoid transient empty-file EOF) + try { + if (Files.size(infoFile) == 0) { + Thread.sleep(50); + continue; + } + } catch (IOException ignore) { + } + try { + parseInfo(); + if (port > 0) + return; // success + } catch (Exception e) { + String msg = e.getMessage() == null ? "" : e.getMessage(); + if (msg.contains("No content to map due to end-of-input")) { + // transient; ignore + } else { + lastErr = e; + } + } + } + Thread.sleep(100); + } + throw new IllegalStateException("timeout waiting for server info file: " + + (lastErr == null ? "unknown" : lastErr.getMessage())); + } + + void stop() { + if (process != null) { + process.destroy(); + try { + process.waitFor(1, TimeUnit.SECONDS); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + } + if (process.isAlive()) + process.destroyForcibly(); + } + } + + private void run(ProcessBuilder pb) throws Exception { + // Capture both stdout and stderr so failures surface original command output + // (fail-fast rule) + pb.redirectErrorStream(true); + Process p = pb.start(); + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + try (InputStream in = p.getInputStream()) { + byte[] buf = new byte[8192]; + int r; + while ((r = in.read(buf)) != -1) { + bout.write(buf, 0, r); + } + } + int code = p.waitFor(); + if (code != 0) { + String out = bout.toString(); + throw new IllegalStateException("command failed (" + code + "):\n" + out); + } + } + + private void parseInfo() throws IOException { + if (infoFile == null) + return; + ObjectMapper mapper = new ObjectMapper(); + try (Reader r = Files.newBufferedReader(infoFile)) { + InfoFile info = mapper.readValue(r, InfoFile.class); + if (info != null && info.server_info != null) { + this.port = info.server_info.port; + if (info.server_info.endpoints != null) { + this.endpoints.clear(); + this.endpoints.putAll(info.server_info.endpoints); + } + } + } + } + + // POJOs matching test server info-file structure + @JsonIgnoreProperties(ignoreUnknown = true) + private static class InfoFile { + public ServerInfo server_info; // snake case matches JSON + } + @JsonIgnoreProperties(ignoreUnknown = true) + private static class ServerInfo { + public int port; + public java.util.Map endpoints; + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/LocalMockProvider.java b/java/src/test/java/com/singlestore/s2iam/LocalMockProvider.java new file mode 100644 index 0000000..116de2e --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/LocalMockProvider.java @@ -0,0 +1,40 @@ +package com.singlestore.s2iam; + +import java.util.Map; + +/** + * Local mock provider used only for Java integration tests with Go server when + * no cloud metadata present. + */ +class LocalMockProvider implements CloudProviderClient { + @Override + public Exception detect() { + return null; + } + + @Override + public Exception fastDetect() { + return null; + } + + @Override + public CloudProviderType getType() { + return CloudProviderType.aws; + } + + @Override + public CloudProviderClient assumeRole(String roleIdentifier) { + return this; + } + + @Override + public IdentityHeadersResult getIdentityHeaders(Map additionalParams) { + // Provide minimal headers that Go verifier will accept for AWS path: access key + // + secret. + Map h = Map.of("X-AWS-Access-Key-ID", "TESTACCESSKEY", + "X-AWS-Secret-Access-Key", "TESTSECRET"); + CloudIdentity id = new CloudIdentity(CloudProviderType.aws, + "arn:aws:iam::000000000000:user/Test", "000000000000", "us-east-1", "iam-user", Map.of()); + return new IdentityHeadersResult(h, id, null); + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/S2IAMAssumeRoleTest.java b/java/src/test/java/com/singlestore/s2iam/S2IAMAssumeRoleTest.java new file mode 100644 index 0000000..02d9873 --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/S2IAMAssumeRoleTest.java @@ -0,0 +1,41 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; + +/** + * When S2IAM_TEST_ASSUME_ROLE is set we must detect the cloud provider and + * successfully assume the specified role (fail-fast; do not skip). This + * validates the direct assumeRole() path (identity headers only) separate from + * JWT issuance tests. + */ +public class S2IAMAssumeRoleTest { + + @Test + void testAssumeRoleMustSucceedWhenEnvSet() { + String roleArn = System.getenv("S2IAM_TEST_ASSUME_ROLE"); + if (roleArn == null || roleArn.isEmpty()) { + Assumptions.abort("S2IAM_TEST_ASSUME_ROLE not set - skipping assumeRole test"); + } + CloudProviderClient base; + try { + base = S2IAM.detectProvider(); + } catch (NoCloudProviderDetectedException e) { + fail("Expected provider detection to succeed when S2IAM_TEST_ASSUME_ROLE set: " + + e.getMessage()); + return; // unreachable + } + assertNotNull(base, "provider must be detected"); + CloudProviderClient assumed = base.assumeRole(roleArn); + assertNotNull(assumed, "assumeRole returned null client"); + CloudProviderClient.IdentityHeadersResult res = assumed.getIdentityHeaders(java.util.Map.of()); + assertNull(res.error, "assumeRole identity retrieval failed: " + + (res.error == null ? "" : res.error.getMessage())); + assertNotNull(res.identity, "identity missing after assumeRole"); + assertEquals(roleArn, res.identity.getIdentifier(), + "identity identifier should match requested role ARN"); + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/S2IAMDetectionTest.java b/java/src/test/java/com/singlestore/s2iam/S2IAMDetectionTest.java new file mode 100644 index 0000000..f1417df --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/S2IAMDetectionTest.java @@ -0,0 +1,38 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; + +public class S2IAMDetectionTest { + + @Test + void testDetectionSkipOrFail() { + String expectProvider = System.getenv("S2IAM_TEST_CLOUD_PROVIDER"); + boolean assumeRole = false; + if (expectProvider == null) { + String role = System.getenv("S2IAM_TEST_ASSUME_ROLE"); + if (role != null && !role.isEmpty()) { + assumeRole = true; + expectProvider = "aws"; // assume role only currently supported for AWS + } + } + if (expectProvider == null) + expectProvider = System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE"); + try { + CloudProviderClient client = S2IAM.detectProvider(); + assertNotNull(client, "provider should not be null when detected"); + if (expectProvider != null) { + assertEquals(expectProvider.toLowerCase(), client.getType().name().toLowerCase(), + "detected provider mismatch"); + } + } catch (NoCloudProviderDetectedException e) { + if (expectProvider != null) { + fail("Cloud provider detection failed - expected to detect provider in test environment"); + } + Assumptions.abort("No cloud provider detected - not running in cloud environment"); + } + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/S2IAMIdentityShapeTest.java b/java/src/test/java/com/singlestore/s2iam/S2IAMIdentityShapeTest.java new file mode 100644 index 0000000..b0f6f8f --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/S2IAMIdentityShapeTest.java @@ -0,0 +1,137 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import com.singlestore.s2iam.options.JwtOption; +import com.singlestore.s2iam.options.Options; +import com.singlestore.s2iam.options.ServerUrlOption; +import java.net.URI; +import java.net.http.*; +import java.util.*; +import java.util.Base64; +import java.util.regex.Pattern; +import org.junit.jupiter.api.*; + +/** + * Parity: provider-specific identity shape assertions (mirrors logic embedded + * in Go happy path test). + */ +public class S2IAMIdentityShapeTest { + static GoTestServer server; + static final ObjectMapper M = new ObjectMapper(); + + @BeforeAll + static void start() throws Exception { + server = new GoTestServer(java.nio.file.Path.of(".").toAbsolutePath()); + server.start(); + } + + @AfterAll + static void stop() { + if (server != null) + server.stop(); + } + + @Test + void identityShape() throws Exception { + CloudProviderClient provider; + try { + provider = S2IAM.detectProvider(); + } catch (NoCloudProviderDetectedException e) { + Assumptions.abort("no cloud provider"); + return; + } + + if (System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null) { + if (provider.getType() == CloudProviderType.gcp) { + TestSkipUtil.skipIfNoRoleProbe(provider, + Map.of("audience", "https://authsvc.singlestore.com")); + } else { + TestSkipUtil.skipIfNoRoleProbe(provider); + } + } + + // Preflight identity retrieval so we can apply skip logic (NO_ROLE / Azure MI + // unavailable) + // before invoking S2IAM.getDatabaseJWT (which would otherwise fail with an MI + // 400 on + // generic Azure-hosted runners lacking managed identity). Mirrors logic in + // S2IAMJwtHappyPathTest. + Map addl = new HashMap<>(); + boolean expectCloud = System.getenv("S2IAM_TEST_CLOUD_PROVIDER") != null + || System.getenv("S2IAM_TEST_ASSUME_ROLE") != null + || System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null; + if (provider.getType() == CloudProviderType.gcp && expectCloud) { + addl.put("audience", "https://authsvc.singlestore.com"); + } + CloudProviderClient.IdentityHeadersResult preflight = provider.getIdentityHeaders(addl); + TestSkipUtil.skipIfNoRole(provider, preflight); + TestSkipUtil.skipIfAzureMIUnavailable(provider, preflight); + assertNull(preflight.error, "identity header retrieval failed: " + + (preflight.error == null ? "" : preflight.error.getMessage())); + + List opts = new ArrayList<>(); + opts.add(ServerUrlOption.of( + server.getEndpoints().getOrDefault("auth", server.getBaseURL() + "/auth/iam/:jwtType"))); + if (provider.getType() == CloudProviderType.gcp) + opts.add(Options.withAudience("https://authsvc.singlestore.com")); + String jwt = S2IAM.getDatabaseJWT("test-workspace", opts.toArray(new JwtOption[0])); + assertNotNull(jwt); + String sub = decodeSub(jwt); + JsonNode req = fetchLastRequest(); + JsonNode id = req.path("identity"); + String identifier = id.path("identifier").asText(); + assertEquals(identifier, sub, "sub must equal identifier"); + + switch (provider.getType()) { + case aws: + assertTrue(identifier.startsWith("arn:aws:"), "AWS identifier should be ARN"); + String accountID = id.path("accountID").asText(); + assertTrue(accountID.matches("[0-9]{12}"), "AWS accountID should be 12 digits"); + break; + case gcp: + assertTrue(Pattern.compile("^[A-Za-z0-9_-]+@[A-Za-z0-9_-]+\\.iam\\.gserviceaccount\\.com$") + .matcher(identifier).find(), "GCP identifier should be service account email"); + String accountIDG = id.path("accountID").asText(); + assertTrue(accountIDG.matches("[0-9]{10,}"), "GCP accountID numeric"); + break; + case azure: + String tenant = id.path("accountID").asText(); + assertTrue(identifier.matches("[0-9a-fA-F-]{32,36}"), "Azure principal ID GUID format"); + // Tenant may be empty in some flows; if present ensure GUID-looking format + if (!tenant.isEmpty()) { + assertTrue(tenant.matches("[0-9a-fA-F-]{32,36}"), "Azure tenant ID GUID format"); + } + break; + default : + fail("Unknown provider type"); + } + } + + private static JsonNode fetchLastRequest() throws Exception { + HttpClient c = HttpClient.newHttpClient(); + HttpResponse resp = c.send( + HttpRequest.newBuilder(URI.create(server.getBaseURL() + "/info/requests")).GET().build(), + HttpResponse.BodyHandlers.ofString()); + if (resp.statusCode() != 200) + return null; + JsonNode arr = M.readTree(resp.body()); + if (!arr.isArray() || arr.size() == 0) + return null; + return arr.get(arr.size() - 1); + } + + private static String decodeSub(String jwt) throws Exception { + String[] p = jwt.split("\\."); + if (p.length < 2) + return null; + String pay = p[1]; + int r = pay.length() % 4; + if (r > 0) + pay += "====".substring(r); + return M.readTree(Base64.getUrlDecoder().decode(pay)).path("sub").asText(); + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/S2IAMJwtAssumeRoleTest.java b/java/src/test/java/com/singlestore/s2iam/S2IAMJwtAssumeRoleTest.java new file mode 100644 index 0000000..825c5cc --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/S2IAMJwtAssumeRoleTest.java @@ -0,0 +1,106 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import com.singlestore.s2iam.options.JwtOption; +import com.singlestore.s2iam.options.Options; +import com.singlestore.s2iam.options.ServerUrlOption; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.*; +import java.util.Base64; +import org.junit.jupiter.api.*; + +/** + * Validates assume-role changes identity in JWT and server log (mirrors Go + * test). + */ +public class S2IAMJwtAssumeRoleTest { + static GoTestServer server; + static final ObjectMapper M = new ObjectMapper(); + + @BeforeAll + static void start() throws Exception { + server = new GoTestServer(java.nio.file.Path.of(".").toAbsolutePath()); + server.start(); + } + + @AfterAll + static void stop() { + if (server != null) + server.stop(); + } + + @Test + void assumeRoleDatabaseJWT() throws Exception { + String role = System.getenv("S2IAM_TEST_ASSUME_ROLE"); + if (role == null || role.isEmpty()) + Assumptions.abort("S2IAM_TEST_ASSUME_ROLE not set"); + + CloudProviderClient base; + try { + base = S2IAM.detectProvider(); + } catch (NoCloudProviderDetectedException e) { + fail("Provider detection must succeed when S2IAM_TEST_ASSUME_ROLE set"); + return; + } + + // Original identity + JWT + List opts = new ArrayList<>(); + opts.add(ServerUrlOption.of( + server.getEndpoints().getOrDefault("auth", server.getBaseURL() + "/auth/iam/:jwtType"))); + if (base.getType() == CloudProviderType.gcp) + opts.add(Options.withAudience("https://authsvc.singlestore.com")); + String originalJwt = S2IAM.getDatabaseJWT("test-workspace", opts.toArray(new JwtOption[0])); + String originalSub = decodeSub(originalJwt); + JsonNode originalReq = fetchLastRequest(); + String originalIdentifier = originalReq.path("identity").path("identifier").asText(); + assertEquals(originalIdentifier, originalSub, "pre-assume sub mismatch"); + + // Assume role path + List assumeOpts = new ArrayList<>(opts); + assumeOpts.add(Options.withAssumeRole(role)); + String assumedJwt = S2IAM.getDatabaseJWT("test-workspace", + assumeOpts.toArray(new JwtOption[0])); + String assumedSub = decodeSub(assumedJwt); + JsonNode assumedReq = fetchLastRequest(); + String assumedIdentifier = assumedReq.path("identity").path("identifier").asText(); + + assertNotEquals(originalIdentifier, assumedIdentifier, + "identity should change when assuming role"); + assertEquals(assumedIdentifier, assumedSub, "assumed JWT sub mismatch"); + String roleNameFragment = role.contains("/") ? role.substring(role.lastIndexOf('/') + 1) : role; + assertTrue(assumedIdentifier.contains(roleNameFragment), + "assumed identifier should contain role fragment"); + } + + private static JsonNode fetchLastRequest() throws Exception { + HttpClient c = HttpClient.newHttpClient(); + HttpResponse resp = c.send( + HttpRequest.newBuilder(URI.create(server.getBaseURL() + "/info/requests")).GET().build(), + HttpResponse.BodyHandlers.ofString()); + if (resp.statusCode() != 200) + return null; + JsonNode arr = M.readTree(resp.body()); + if (!arr.isArray() || arr.size() == 0) + return null; + return arr.get(arr.size() - 1); + } + + private static String decodeSub(String jwt) throws Exception { + String[] parts = jwt.split("\\."); + if (parts.length < 2) + return null; + String payload = parts[1]; + int rem = payload.length() % 4; + if (rem > 0) + payload += "====".substring(rem); + byte[] dec = Base64.getUrlDecoder().decode(payload); + return M.readTree(dec).path("sub").asText(); + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/S2IAMJwtErrorCasesTest.java b/java/src/test/java/com/singlestore/s2iam/S2IAMJwtErrorCasesTest.java new file mode 100644 index 0000000..d3e1fdb --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/S2IAMJwtErrorCasesTest.java @@ -0,0 +1,95 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import com.singlestore.s2iam.exceptions.S2IAMException; +import com.singlestore.s2iam.options.ServerUrlOption; +import java.nio.file.Path; +import org.junit.jupiter.api.*; + +public class S2IAMJwtErrorCasesTest { + GoTestServer base; + + @BeforeEach + void start() throws Exception { + base = new GoTestServer(Path.of(".").toAbsolutePath()); + base.start(); + } + + @AfterEach + void stop() { + if (base != null) + base.stop(); + } + + private String url() { + return base.getEndpoints().getOrDefault("auth", base.getBaseURL() + "/auth/iam/:jwtType"); + } + + @Test + void serverReturnsEmptyJWT() throws Exception { + assumeOrSkip(); + CloudProviderClient provider = S2IAM.detectProvider(); + if (System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null) { + if (provider.getType() == CloudProviderType.gcp) { + TestSkipUtil.skipIfNoRoleProbe(provider, + java.util.Map.of("audience", "https://authsvc.singlestore.com")); + } else { + TestSkipUtil.skipIfNoRoleProbe(provider); + } + } + // For generic non-cloud CI runners (Azure host w/o MI) abort identity-bearing + // tests. + TestSkipUtil.skipIfAzureMIUnavailable(provider, + provider.getIdentityHeaders(java.util.Collections.emptyMap())); + // Start dedicated server with flag --return-empty-jwt + base.stop(); + base = new GoTestServer(Path.of(".").toAbsolutePath(), "-return-empty-jwt"); + base.start(); + S2IAMException ex = assertThrows(S2IAMException.class, + () -> S2IAM.getDatabaseJWT("wg", ServerUrlOption.of(url()))); + assertTrue(ex.getMessage().contains("empty")); + } + + @Test + void serverReturnsError() throws Exception { + assumeOrSkip(); + CloudProviderClient provider = S2IAM.detectProvider(); + if (System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null) { + if (provider.getType() == CloudProviderType.gcp) { + TestSkipUtil.skipIfNoRoleProbe(provider, + java.util.Map.of("audience", "https://authsvc.singlestore.com")); + } else { + TestSkipUtil.skipIfNoRoleProbe(provider); + } + } + TestSkipUtil.skipIfAzureMIUnavailable(provider, + provider.getIdentityHeaders(java.util.Collections.emptyMap())); + base.stop(); + base = new GoTestServer(Path.of(".").toAbsolutePath(), "-return-error", "-error-code", "500"); + base.start(); + S2IAMException ex = assertThrows(S2IAMException.class, + () -> S2IAM.getAPIJWT(ServerUrlOption.of(url()))); + assertTrue(ex.getMessage().contains("500")); + } + + private void assumeOrSkip() throws Exception { + boolean expectCloud = expectCloud(); + try { + S2IAM.detectProvider(); + } catch (NoCloudProviderDetectedException e) { + if (expectCloud) { + fail("Cloud provider detection failed - expected to detect provider in test environment"); + } else { + Assumptions.abort("No cloud provider detected - skipping"); + } + } + } + + private boolean expectCloud() { + return System.getenv("S2IAM_TEST_CLOUD_PROVIDER") != null + || System.getenv("S2IAM_TEST_ASSUME_ROLE") != null + || System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null; + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/S2IAMJwtHappyPathTest.java b/java/src/test/java/com/singlestore/s2iam/S2IAMJwtHappyPathTest.java new file mode 100644 index 0000000..3a28ab6 --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/S2IAMJwtHappyPathTest.java @@ -0,0 +1,185 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.*; +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import com.singlestore.s2iam.exceptions.S2IAMException; +import com.singlestore.s2iam.options.JwtOption; +import com.singlestore.s2iam.options.Options; // for withAudience +import com.singlestore.s2iam.options.ServerUrlOption; +import java.net.*; +import java.net.http.*; +import java.nio.file.Path; +import java.util.*; +import java.util.Base64; +import org.junit.jupiter.api.*; + +/** Happy path JWT acquisition tests via Go test server. */ +public class S2IAMJwtHappyPathTest { + static GoTestServer server; + static final ObjectMapper M = new ObjectMapper(); + + @BeforeAll + static void startServer() throws Exception { + // Skip entirely if running in cloud provider detection only environments + // lacking local build + // tools? assume go present. + Path here = Path.of(".").toAbsolutePath(); + server = new GoTestServer(here); + server.start(); + } + + @AfterAll + static void stopServer() { + if (server != null) + server.stop(); + } + + @Test + void getDatabaseJWT() throws Exception { + assumeOrSkip(); + CloudProviderClient provider = S2IAM.detectProvider(); + Map addl = new HashMap<>(); + boolean realCloud = expectCloud(); + if (provider.getType() == CloudProviderType.gcp && realCloud) { + addl.put("audience", "https://authsvc.singlestore.com"); + } + CloudProviderClient.IdentityHeadersResult idRes = provider.getIdentityHeaders(addl); + TestSkipUtil.skipIfNoRole(provider, idRes); + TestSkipUtil.skipIfAzureMIUnavailable(provider, idRes); + assertNull(idRes.error, "identity header retrieval failed: " + + (idRes.error == null ? "" : idRes.error.getMessage())); + CloudIdentity cid = idRes.identity; + assertNotNull(cid, "client identity null"); + + java.util.List opts = new java.util.ArrayList<>(); + // Use dynamic auth endpoint from server info (endpoints map) which already + // includes :jwtType + opts.add(ServerUrlOption.of( + server.getEndpoints().getOrDefault("auth", server.getBaseURL() + "/auth/iam/:jwtType"))); + if (provider.getType() == CloudProviderType.gcp && realCloud) + opts.add(Options.withAudience("https://authsvc.singlestore.com")); + String jwt = S2IAM.getDatabaseJWT("wg-test", opts.toArray(new JwtOption[0])); + assertNotNull(jwt); + assertFalse(jwt.isEmpty()); + assertTrue(jwt.split("\\.").length >= 2, "looks like a JWT"); + + // Fetch server request log and verify identity parity + JsonNode lastReq = fetchLastRequest(); + assertNotNull(lastReq, "server request log empty"); + JsonNode identity = lastReq.path("identity"); + assertEquals(cid.getIdentifier(), identity.path("identifier").asText(), + "client/server identifier mismatch"); + assertEquals(cid.getProvider().name(), identity.path("provider").asText()); + + // Decode JWT payload (no signature verification – parity check for 'sub') + String sub = decodeSub(jwt); + assertEquals(cid.getIdentifier(), sub, "JWT sub mismatch"); + } + + @Test + void getAPIJWT() throws Exception { + assumeOrSkip(); + CloudProviderClient provider = S2IAM.detectProvider(); + Map addl = new HashMap<>(); + boolean realCloud = expectCloud(); + if (provider.getType() == CloudProviderType.gcp && realCloud) { + addl.put("audience", "https://authsvc.singlestore.com"); + } + CloudProviderClient.IdentityHeadersResult idRes = provider.getIdentityHeaders(addl); + TestSkipUtil.skipIfNoRole(provider, idRes); + TestSkipUtil.skipIfAzureMIUnavailable(provider, idRes); + assertNull(idRes.error); + CloudIdentity cid = idRes.identity; + java.util.List opts = new java.util.ArrayList<>(); + opts.add(ServerUrlOption.of( + server.getEndpoints().getOrDefault("auth", server.getBaseURL() + "/auth/iam/:jwtType"))); + if (provider.getType() == CloudProviderType.gcp && realCloud) + opts.add(Options.withAudience("https://authsvc.singlestore.com")); + String jwt = S2IAM.getAPIJWT(opts.toArray(new JwtOption[0])); + assertNotNull(jwt); + assertFalse(jwt.isEmpty()); + JsonNode lastReq = fetchLastRequest(); + assertEquals(cid.getIdentifier(), lastReq.path("identity").path("identifier").asText()); + assertEquals(cid.getProvider().name(), lastReq.path("identity").path("provider").asText()); + assertEquals(cid.getIdentifier(), decodeSub(jwt)); + } + + @Test + void getDatabaseJWT_GcpAudienceCustomLocal() throws Exception { + assumeOrSkip(); + CloudProviderClient provider = S2IAM.detectProvider(); + if (provider.getType() != CloudProviderType.gcp) { + Assumptions.abort("not GCP"); + } + if (System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null) { + TestSkipUtil.skipIfNoRoleProbe(provider, + Map.of("audience", "https://authsvc.singlestore.com")); + } + String audience = expectCloud() + ? "https://authsvc.singlestore.com" + : "https://test.example.com"; + String jwt = S2IAM.getDatabaseJWT("wg-test", + ServerUrlOption.of( + server.getEndpoints().getOrDefault("auth", server.getBaseURL() + "/auth/iam/:jwtType")), + Options.withAudience(audience)); + assertNotNull(jwt); + assertFalse(jwt.isEmpty()); + } + + @Test + void missingWorkspaceGroupId() { + S2IAMException ex = assertThrows(S2IAMException.class, () -> S2IAM.getDatabaseJWT("", + ServerUrlOption.of(server.getBaseURL() + "/auth/iam/:jwtType"))); + assertTrue(ex.getMessage().contains("workspaceGroupId")); + } + + private void assumeOrSkip() throws Exception { + boolean expectCloud = expectCloud(); + try { + // Quick detection attempt; if it fails and not expecting cloud, abort test via + // Assumptions + S2IAM.detectProvider(); + } catch (NoCloudProviderDetectedException e) { + if (expectCloud) { + fail("Cloud provider detection failed - expected to detect provider in test environment"); + } else { + Assumptions.abort("No cloud provider detected - skipping"); + } + } + } + + private boolean expectCloud() { + return System.getenv("S2IAM_TEST_CLOUD_PROVIDER") != null + || System.getenv("S2IAM_TEST_ASSUME_ROLE") != null + || System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null; + } + + private JsonNode fetchLastRequest() throws Exception { + String url = server.getBaseURL() + "/info/requests"; + HttpClient c = HttpClient.newHttpClient(); + HttpRequest r = HttpRequest.newBuilder(URI.create(url)).GET().build(); + HttpResponse resp = c.send(r, HttpResponse.BodyHandlers.ofString()); + if (resp.statusCode() != 200) + return null; + JsonNode arr = M.readTree(resp.body()); + if (!arr.isArray() || arr.size() == 0) + return null; + return arr.get(arr.size() - 1); + } + + private String decodeSub(String jwt) throws Exception { + String[] parts = jwt.split("\\."); + if (parts.length < 2) + return null; + String payload = parts[1]; + // Pad base64url if needed + int rem = payload.length() % 4; + if (rem > 0) + payload += "====".substring(rem); + byte[] decoded = Base64.getUrlDecoder().decode(payload); + JsonNode node = M.readTree(decoded); + return node.path("sub").asText(); + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/S2IAMRequestBuilderTest.java b/java/src/test/java/com/singlestore/s2iam/S2IAMRequestBuilderTest.java new file mode 100644 index 0000000..41fd9df --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/S2IAMRequestBuilderTest.java @@ -0,0 +1,149 @@ +package com.singlestore.s2iam; + +import static org.junit.jupiter.api.Assertions.*; + +import com.singlestore.s2iam.options.Options; +import com.singlestore.s2iam.exceptions.S2IAMException; +import com.singlestore.s2iam.exceptions.NoCloudProviderDetectedException; +import org.junit.jupiter.api.Assumptions; +import java.nio.file.Path; +import org.junit.jupiter.api.*; + +/** + * Basic tests for the S2IAMRequest fluent builder using the local Go test + * server. + */ +public class S2IAMRequestBuilderTest { + GoTestServer server; + + @BeforeEach + void start() throws Exception { + server = new GoTestServer(Path.of(".").toAbsolutePath()); + server.start(); + } + + @AfterEach + void stop() { + if (server != null) + server.stop(); + } + + private String url() { + return server.getEndpoints().getOrDefault("auth", server.getBaseURL() + "/auth/iam/:jwtType"); + } + + @Test + void databaseJwtViaBuilder() throws Exception { + CloudProviderClient provider = detectOrSkip(); + S2IAMRequest req = S2IAMRequest.newRequest().databaseWorkspaceGroup("wg-test").serverUrl(url()) + .timeout(java.time.Duration.ofSeconds(3)); + boolean realCloud = expectCloud(); + if (provider.getType() == CloudProviderType.gcp && realCloud) { + req.audience("https://authsvc.singlestore.com"); + } + // Inject provider to avoid second detection. + req.provider(provider); + String jwt = req.get(); + assertNotNull(jwt); + assertFalse(jwt.isEmpty()); + } + + @Test + void apiJwtViaBuilder() throws Exception { + CloudProviderClient provider = detectOrSkip(); + S2IAMRequest req = S2IAMRequest.newRequest().api().serverUrl(url()); + boolean realCloud = expectCloud(); + if (provider.getType() == CloudProviderType.gcp && realCloud) { + req.audience("https://authsvc.singlestore.com"); + } + req.provider(provider); + String jwt = req.get(); + assertNotNull(jwt); + assertFalse(jwt.isEmpty()); + } + + @Test + void missingWorkspaceGroupFails() { + S2IAMException ex = assertThrows(S2IAMException.class, () -> S2IAMRequest.newRequest().get()); + assertTrue(ex.getMessage().contains("workspace group id")); + } + + private CloudProviderClient detectOrSkip() throws Exception { + boolean expect = expectCloud(); + try { + CloudProviderClient p = S2IAM.detectProvider(); + if (!expect) { + Assumptions.abort("Cloud provider not explicitly requested - skipping"); + } + // NO_ROLE environments: skip early (identity tests would abort later anyway) + if (System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null + && System.getenv("S2IAM_TEST_CLOUD_PROVIDER") == null + && System.getenv("S2IAM_TEST_ASSUME_ROLE") == null) { + Assumptions.abort("No-role environment - skipping JWT builder tests"); + } + return p; + } catch (NoCloudProviderDetectedException e) { + if (expect) { + fail("Cloud provider detection failed - expected to detect provider in test environment"); + } + Assumptions.abort("No cloud provider detected - skipping"); + return null; // unreachable + } + } + + private boolean expectCloud() { + return System.getenv("S2IAM_TEST_CLOUD_PROVIDER") != null + || System.getenv("S2IAM_TEST_ASSUME_ROLE") != null + || System.getenv("S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE") != null; + } + + @Test + void audience_only_allowed_for_gcp_builder_validation() { + S2IAMRequest r = S2IAMRequest.newRequest().api() + .provider(new FakeProviderLocal(CloudProviderType.aws)).audience("foo"); + assertThrows(S2IAMException.class, r::get, "Audience on non-GCP provider should error"); + } + + @Test + void audience_option_static_api_rejected_for_non_gcp_provider() { + FakeProviderLocal awsLike = new FakeProviderLocal(CloudProviderType.aws); + S2IAMException ex = assertThrows(S2IAMException.class, + () -> S2IAM.getAPIJWT(Options.withProvider(awsLike), Options.withAudience("notgcp"))); + assertTrue(ex.getMessage().toLowerCase().contains("gcp")); + } + + // Minimal local fake provider for audience validation tests (kept local so + // static + // analysis does not require cross-file lookup in test sources). + static class FakeProviderLocal implements CloudProviderClient { + private final CloudProviderType type; + FakeProviderLocal(CloudProviderType type) { + this.type = type; + } + @Override + public CloudProviderType getType() { + return type; + } + @Override + public CloudProviderClient assumeRole(String roleIdentifier) { + return this; + } + @Override + public Exception fastDetect() { + return null; + } + @Override + public Exception detect() { + return null; + } + @Override + public IdentityHeadersResult getIdentityHeaders( + java.util.Map additionalParams) { + java.util.Map headers = new java.util.HashMap<>(); + headers.put("X-Test", "ok"); + CloudIdentity id = new CloudIdentity(type, "local", null, null, null, + java.util.Collections.emptyMap()); + return new IdentityHeadersResult(headers, id, null); + } + } +} diff --git a/java/src/test/java/com/singlestore/s2iam/TestSkipUtil.java b/java/src/test/java/com/singlestore/s2iam/TestSkipUtil.java new file mode 100644 index 0000000..8937682 --- /dev/null +++ b/java/src/test/java/com/singlestore/s2iam/TestSkipUtil.java @@ -0,0 +1,95 @@ +package com.singlestore.s2iam; + +import java.util.Collections; +import java.util.Map; +import org.junit.jupiter.api.Assumptions; + +/** + * Centralized helper for skipping tests on *NO_ROLE* hosts where cloud + * identity/credentials are intentionally unavailable. Mirrors Go/Python + * semantics: when S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE is set, identity retrieval + * failures for the designated provider are expected and tests that depend on a + * working identity should be skipped (aborted) rather than failed. + */ +final class TestSkipUtil { + private TestSkipUtil() { + } + + private static final String ENV_NO_ROLE = "S2IAM_TEST_CLOUD_PROVIDER_NO_ROLE"; + private static final String ENV_EXPECT = "S2IAM_TEST_CLOUD_PROVIDER"; + private static final String ENV_ASSUME = "S2IAM_TEST_ASSUME_ROLE"; + + static void skipIfNoRole(CloudProviderClient provider, + CloudProviderClient.IdentityHeadersResult res) { + if (System.getenv(ENV_NO_ROLE) == null) + return; // Not a NO_ROLE run + if (res == null || res.error == null) + return; // Nothing to evaluate + String msg = String.valueOf(res.error.getMessage()); + switch (provider.getType()) { + case gcp: + if (containsAny(msg, "gcp-no-role-identity-unavailable-404", + "failed to get GCP identity token status=404")) { + Assumptions.abort("GCP NO_ROLE host: identity unavailable (expected)"); + } + break; + case aws: + if (containsAny(msg, "Unable to load credentials from any of the providers", + "Failed to load credentials from IMDS")) { + Assumptions.abort("AWS NO_ROLE host: credentials unavailable (expected)"); + } + break; + case azure: + if (containsAny(msg, "failed to get Azure MI token status=400")) { + Assumptions.abort("Azure NO_ROLE host: managed identity unavailable (expected)"); + } + break; + default : + // future providers: fall through + } + } + + /** + * Skip when running on an Azure host without managed identity (MI 400/403/404) + * in a job that is NOT explicitly a cloud test (no expectation env vars). This + * occurs on generic GitHub-hosted runners (Azure VM without MI). We treat this + * as equivalent to "no cloud provider detected" for identity-bearing tests. + * Real cloud test jobs always set one of the expectation env vars and therefore + * won't skip here; they should provision MI or use *_NO_ROLE env to trigger the + * other skip path. + */ + static void skipIfAzureMIUnavailable(CloudProviderClient provider, + CloudProviderClient.IdentityHeadersResult res) { + if (provider.getType() != CloudProviderType.azure) + return; + if (System.getenv(ENV_EXPECT) != null || System.getenv(ENV_ASSUME) != null + || System.getenv(ENV_NO_ROLE) != null) { + return; // Cloud test run; let normal logic handle failures / skips + } + if (res == null || res.error == null) + return; + String msg = String.valueOf(res.error.getMessage()); + if (containsAny(msg, "failed to get Azure MI token status=400", + "failed to get Azure MI token status=403", "failed to get Azure MI token status=404")) { + Assumptions.abort("Azure MI unavailable on shared runner (treat as no identity)"); + } + } + + static void skipIfNoRoleProbe(CloudProviderClient provider) { + skipIfNoRoleProbe(provider, Collections.emptyMap()); + } + + static void skipIfNoRoleProbe(CloudProviderClient provider, Map params) { + if (System.getenv(ENV_NO_ROLE) == null) + return; + CloudProviderClient.IdentityHeadersResult res = provider.getIdentityHeaders(params); + skipIfNoRole(provider, res); + } + + private static boolean containsAny(String haystack, String... needles) { + for (String n : needles) + if (haystack.contains(n)) + return true; + return false; + } +} diff --git a/python/src/s2iam/api.py b/python/src/s2iam/api.py index 1a68230..a58ebd5 100644 --- a/python/src/s2iam/api.py +++ b/python/src/s2iam/api.py @@ -6,7 +6,8 @@ import os import queue import threading -from typing import Optional +import time +from typing import Any, Dict, List, NoReturn, Optional from .aws import new_client as new_aws_client from .azure import new_client as new_azure_client @@ -17,8 +18,13 @@ Logger, ) -DETECT_PROVIDER_DEFAULT_TIMEOUT: float = 5.0 -"""Default timeout (seconds) for provider detection (mirrors Go implementation).""" +DETECT_PROVIDER_DEFAULT_TIMEOUT: float = 10.0 +"""Default timeout (seconds) for provider detection. + +Rationale: Reliability over negative‑path speed. A larger ceiling avoids false +negatives on resource‑constrained CI VMs while early success short‑circuits so +real cloud latency remains low. Mirrors project policy decision (see PR notes). +""" class DefaultLogger: @@ -48,8 +54,12 @@ async def detect_provider( Raises: CloudProviderNotFound: If no provider can be detected """ - # Set up logger if debugging is enabled - if logger is None and os.environ.get("S2IAM_DEBUGGING") == "true": + # Set up logger only if explicit debugging flag is set; production code must not branch + # on test harness-only environment variables. Rich diagnostics are instead + # surfaced via aggregated exception messages below. + debugging = os.environ.get("S2IAM_DEBUGGING") == "true" + debug_timing = os.environ.get("S2IAM_DEBUG_TIMING") == "true" + if logger is None and debugging: logger = DefaultLogger() # Create default clients if none provided @@ -79,29 +89,67 @@ async def detect_provider( stop_event = threading.Event() all_errors: list[str] = [] errors_lock = threading.Lock() + # Structured per-provider status for enhanced error reporting. + # Each element: {provider, status=success|error|timeout|skipped, elapsed_ms?, error?} + provider_status: List[Dict[str, Any]] = [] + status_lock = threading.Lock() + + def record_status(entry: Dict[str, Any]) -> None: + with status_lock: + provider_status.append(entry) + + # Track per-thread event loops so we can cancel/stop them on global timeout + provider_loops: list[asyncio.AbstractEventLoop] = [] + loops_lock = threading.Lock() def test_provider_sync(client: CloudProviderClient) -> None: """Test a provider in a thread (like Go goroutine).""" if stop_event.is_set(): + record_status({"provider": client.get_type().value, "status": "skipped"}) return - + thread_start = time.monotonic() + if logger and (debugging or debug_timing): + logger.log(f"DETECT_THREAD_START provider={client.get_type().value} outer_timeout_s={timeout}") try: # Run the async detect() in this thread's event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + with loops_lock: + provider_loops.append(loop) try: - loop.run_until_complete(client.detect()) + detect_coro = client.detect() + # Run detect; if global stop_event is set while running, we attempt loop.stop() + loop.run_until_complete(detect_coro) # Success - put result in queue if we're first if not stop_event.is_set(): result_queue.put(client) stop_event.set() # Signal other threads to stop + elapsed_ms = int((time.monotonic() - thread_start) * 1000) + record_status( + { + "provider": client.get_type().value, + "status": "success", + "elapsed_ms": elapsed_ms, + } + ) + if logger and (debugging or debug_timing): + logger.log(f"DETECT_THREAD_SUCCESS provider={client.get_type().value} elapsed_ms={elapsed_ms}") finally: loop.close() except Exception as e: with errors_lock: all_errors.append(f"Provider {client.get_type().value} detection failed: {e}") - if logger: - logger.log(f"Provider {client.get_type().value} detection failed: {e}") + elapsed_ms = int((time.monotonic() - thread_start) * 1000) + record_status( + { + "provider": client.get_type().value, + "status": "error", + "elapsed_ms": elapsed_ms, + "error": str(e)[:400], + } + ) + if logger and (debugging or debug_timing): + logger.log(f"DETECT_THREAD_ERROR provider={client.get_type().value} elapsed_ms={elapsed_ms} error={e}") # Start threads for each provider (like Go goroutines) threads = [] @@ -112,17 +160,143 @@ def test_provider_sync(client: CloudProviderClient) -> None: threads.append(thread) # Wait for first result or timeout (like Go select) + detection_start = time.monotonic() + # Track how many threads have finished (success or error) to allow early exit when all done. + total_clients = len(clients) try: - result: CloudProviderClient = result_queue.get(timeout=timeout) + # Poll loop instead of single blocking get so we can detect early-failure condition. + remaining = timeout + interval = 0.05 # 50ms poll granularity + while remaining > 0: + start_poll = time.monotonic() + try: + result: CloudProviderClient = result_queue.get(timeout=min(interval, remaining)) + stop_event.set() # Ensure all threads stop + total_elapsed_ms = int((time.monotonic() - detection_start) * 1000) + if logger: + if debugging or debug_timing: + logger.log( + ( + "DETECT_COMPLETE status=success " + f"provider={result.get_type().value} " + f"total_elapsed_ms={total_elapsed_ms}" + ) + ) + else: + logger.log(f"Detected provider: {result.get_type().value}") + return result + except queue.Empty: + pass + # Early failure: if every thread has produced a terminal status (success/error/timeout/skipped) + with status_lock: + finished = sum( + 1 for ps in provider_status if ps["status"] in {"success", "error", "timeout", "skipped"} + ) + if finished >= total_clients and not any(ps["status"] == "success" for ps in provider_status): + # All threads ended without success -> raise immediately (no need to wait remaining timeout) + break + remaining -= time.monotonic() - start_poll + else: + # Loop ended naturally (remaining <= 0) without a success result + if logger and (debugging or debug_timing): + logger.log( + "DETECT_LOOP_COMPLETE no-success reason=timeout-before-result " + f"elapsed_ms={int((time.monotonic()-detection_start)*1000)}" + ) + + # No provider detected within timeout or all failed fast. stop_event.set() # Ensure all threads stop + except queue.Empty: # pragma: no cover - retained for compatibility; main loop handles logic + stop_event.set() + _raise_detection_timeout( + timeout=timeout, + detection_start=detection_start, + clients=clients, + threads=threads, + provider_status=provider_status, + status_lock=status_lock, + all_errors=all_errors, + logger=logger, + debugging=debugging, + debug_timing=debug_timing, + stop_event=stop_event, + provider_loops=provider_loops, + loops_lock=loops_lock, + ) - if logger: - logger.log(f"Detected provider: {result.get_type().value}") - return result + # If we broke out of loop without returning (early failure or timeout), raise composed error. + _raise_detection_timeout( + timeout=timeout, + detection_start=detection_start, + clients=clients, + threads=threads, + provider_status=provider_status, + status_lock=status_lock, + all_errors=all_errors, + logger=logger, + debugging=debugging, + debug_timing=debug_timing, + stop_event=stop_event, + provider_loops=provider_loops, + loops_lock=loops_lock, + ) - except queue.Empty: - # Timeout occurred; signal threads to stop and join briefly. - stop_event.set() - for thread in threads: - thread.join(timeout=0.05) - raise CloudProviderNotFound(f"Provider detection timed out after {timeout}s") + +def _raise_detection_timeout( + *, + timeout: float, + detection_start: float, + clients: list[CloudProviderClient], + threads: list[threading.Thread], + provider_status: list[dict[str, Any]], + status_lock: threading.Lock, + all_errors: list[str], + logger: Optional[Logger], + debugging: bool, + debug_timing: bool, + stop_event: threading.Event, + provider_loops: list[asyncio.AbstractEventLoop], + loops_lock: threading.Lock, +) -> NoReturn: + """Compose and raise CloudProviderNotFound for a detection timeout. + + Isolated to keep the main detect_provider flow skimmable and ease future + experiments (e.g., per-provider granular timeouts or retry policy integration). + """ + stop_event.set() + # Attempt to stop any active provider event loops to prevent post-timeout drift + with loops_lock: + for loop in provider_loops: + if loop.is_running(): + try: + loop.call_soon_threadsafe(loop.stop) + except Exception: # noqa: BLE001 - best effort cancellation + pass + for thread in threads: + thread.join(timeout=0.05) + total_elapsed_ms = int((time.monotonic() - detection_start) * 1000) + with status_lock: + known = {p["provider"] for p in provider_status} + for c in clients: + name = c.get_type().value + if name not in known: + provider_status.append({"provider": name, "status": "timeout"}) + if logger and (debugging or debug_timing): + joined_errors_dbg = " | ".join(all_errors)[:800] + logger.log( + ( + "DETECT_COMPLETE status=timeout " + f"total_elapsed_ms={total_elapsed_ms} timeout_s={timeout} " + f"errors='{joined_errors_dbg}'" + ) + ) + summary = ", ".join( + f"{ps['provider']}:{ps['status']}{('@'+str(ps['elapsed_ms'])+'ms') if 'elapsed_ms' in ps else ''}" + for ps in provider_status + ) + errors_str = " | ".join(all_errors)[:800] if all_errors else "" + raise CloudProviderNotFound( + "Provider detection timed out: " + f"timeout_s={timeout} total_elapsed_ms={total_elapsed_ms} providers={len(clients)} " + f"error_count={len(all_errors)} provider_status=[{summary}] errors=[{errors_str}]" + ) diff --git a/python/src/s2iam/azure/__init__.py b/python/src/s2iam/azure/__init__.py index 1ea5b8c..6e0a7bd 100644 --- a/python/src/s2iam/azure/__init__.py +++ b/python/src/s2iam/azure/__init__.py @@ -319,22 +319,87 @@ async def _extract_principal_id_and_claims(self, access_token: str) -> tuple[str return "unknown", {} async def _get_managed_identity_token(self, resource: str, client_id: Optional[str] = None) -> dict[str, str]: - """Get token from Azure managed identity endpoint.""" + """Get token from Azure managed identity endpoint with bounded exponential backoff. + + Retry policy rationale: + - 429 (throttling) and transient 5xx responses are retriable per Azure MSI guidance. + - 400/404 (e.g. identity not configured) are treated as hard failures (no retries) so we surface + absence quickly without extending detection / header acquisition latency. + - Network exceptions (connection reset, timeout) are treated as transient and retried. + - Default attempts: 6 (≈ < 3.2s worst-case added latency with 50ms base, capped delay 1.6s) mirroring + detect() logic. Environment overrides respected: S2IAM_AZURE_MI_RETRIES / S2IAM_AZURE_MI_BACKOFF_MS. + + This function is on the identity acquisition path (after successful provider detection) and + therefore can afford limited retries for robustness without materially impacting overall + provider classification speed (classification already done).""" url = "http://169.254.169.254/metadata/identity/oauth2/token" params = {"api-version": "2018-02-01", "resource": resource} - if client_id: params["client_id"] = client_id - async with aiohttp.ClientSession() as session: - async with session.get(url, params=params, headers={"Metadata": "true"}) as response: - if response.status == 200: - data = await response.json() - # Azure MI returns a JSON object with string fields like access_token, expires_in, etc. - return {str(k): str(v) for k, v in data.items() if isinstance(k, str)} - else: - text = await response.text() - raise Exception(f"Failed to get managed identity token: {response.status} - {text}") + # Read retry configuration (reuse detection env vars for consistency) + max_attempts_env = os.environ.get("S2IAM_AZURE_MI_RETRIES", "6") + base_backoff_ms_env = os.environ.get("S2IAM_AZURE_MI_BACKOFF_MS", "50") + try: + max_attempts = max(1, min(20, int(max_attempts_env))) + except ValueError: + max_attempts = 6 + try: + base_backoff_ms = max(1, min(5000, int(base_backoff_ms_env))) + except ValueError: + base_backoff_ms = 50 + + last_error: Optional[str] = None + for attempt in range(1, max_attempts + 1): + try: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=2)) as session: + async with session.get(url, params=params, headers={"Metadata": "true"}) as response: + if response.status == 200: + data = await response.json() + token = {str(k): str(v) for k, v in data.items() if isinstance(k, str)} + if self._logger: + self._logger.log( + f"Azure: Managed identity token success (attempt {attempt}/{max_attempts})" + ) + return token + body_text = await response.text() + # Hard non-retriable statuses (identity absent / misconfiguration) + if response.status in (400, 404): + raise Exception( + f"Failed to get managed identity token: {response.status} - {body_text[:180]}" + ) + # Retriable statuses + if response.status == 429 or 500 <= response.status < 600: + last_error = f"status={response.status} body={body_text[:180]}" + if self._logger: + self._logger.log( + "Azure: Managed identity token transient error " + f"(attempt {attempt}/{max_attempts}) {last_error}" + ) + else: + # Non-retriable other status; surface immediately + raise Exception( + f"Failed to get managed identity token: {response.status} - {body_text[:180]}" + ) + except Exception as e: # noqa: BLE001 + # Network or other transient exception; decide to retry unless last attempt + last_error = f"exception={e}" + if attempt == max_attempts: + raise Exception(f"Failed to get managed identity token after retries: {last_error}") + if self._logger: + self._logger.log(f"Azure: Managed identity token exception (attempt {attempt}/{max_attempts}) {e}") + + # Backoff before next attempt if not returned / raised + if attempt < max_attempts: + delay = (base_backoff_ms / 1000.0) * (2 ** (attempt - 1)) + # Cap single delay to 1.6s here (shorter than detect() cap) to bound identity latency + delay = min(delay, 1.6) + await asyncio.sleep(delay) + + # If loop exits without returning, raise aggregated last error + raise Exception( + f"Failed to get managed identity token after {max_attempts} attempts: {last_error or 'unknown-error'}" + ) async def _get_instance_metadata(self) -> dict[str, Any]: """Get Azure instance metadata.""" diff --git a/python/src/s2iam/gcp/__init__.py b/python/src/s2iam/gcp/__init__.py index 2d273aa..07c68bd 100644 --- a/python/src/s2iam/gcp/__init__.py +++ b/python/src/s2iam/gcp/__init__.py @@ -1,6 +1,4 @@ -""" -Google Cloud Platform provider client implementation. -""" +"""Google Cloud Platform provider client implementation.""" import asyncio import os @@ -20,93 +18,132 @@ class GCPClient(CloudProviderClient): - """GCP implementation of CloudProviderClient.""" - - def __init__(self, logger: Optional[Logger] = None) -> None: + def __init__(self, logger: Optional[Logger] = None): self._logger = logger self._detected = False self._service_account_email: Optional[str] = None - self._identity: Optional[CloudIdentity] = None def _log(self, message: str) -> None: - """Log a message if logger is available.""" if self._logger: self._logger.log(f"GCP: {message}") - async def detect(self) -> None: - """Detect if running on GCP (full phase).""" - self._log("Starting GCP detection (full phase)") + async def detect(self) -> None: # noqa: D401 if self._detected: return - - # IMPORTANT (future maintainer / future-me): Do NOT add retries here unless you can - # produce a reproducible, provider-side behavioral change that: (a) manifests as - # a transient failure on the very first metadata probe AND (b) becomes a success - # within <1s WITHOUT any configuration / environment change. Historical context: - # A GCP detection timeout once occurred and a retry loop was briefly added. Root - # cause analysis showed the failure was due to logic (raising before queue publish), - # not actual transient unavailability of the metadata endpoint. Adding retries - # masked the underlying bug and only injected latency + flakiness surface area. - # - # Why single attempt is correct for GCP: - # 1. GCP metadata service is either immediately reachable or definitively absent. - # (Contrast: Azure IMDS managed identity can 429/throttle legitimately, hence - # bounded exponential retry ONLY on Azure identity acquisition.) - # 2. Fast failing keeps cross‑provider race tight and test suite duration low. - # 3. Retries make real configuration errors (firewall / network namespace / wrong - # cloud) slower to surface and harder to differentiate from genuine detection. - # 4. Every added retry path previously obscured a logic bug rather than fixing a - # platform instability. - # - # If you believe you need a retry, first capture: - # - exact wall clock timings - # - packet trace or tcpdump showing SYN/SYN-ACK delay OR DNS resolution latency - # - evidence that a second attempt (without any delay you inserted) would have - # succeeded (e.g., manual immediate second curl succeeds while first failed) - # and document that evidence in a linked issue. Without that, DO NOT ADD RETRIES. - # - # This comment is intentionally dry and procedural to discourage casual edits. - # Removing it or ignoring its instructions without evidence is a regression. - # Single bounded metadata probe (no retry). GCP metadata is either reachable promptly - # or not present; retries add latency and can mask a real negative signal. - self._log("Trying metadata service (link-local IP, single attempt)") - - # Use link-local IP (169.254.169.254) directly to avoid DNS resolution stalls that - # previously caused a thread to hang beyond the orchestrator timeout (leading to an - # Empty queue and overall detection timeout). Single attempt with explicit total timeout. + # Single metadata probe (no retries). If this times out or errors, treat as + # definitive negative (fast fail mirrors Go implementation). + self._log("Metadata probe (single attempt, link-local IP)") loop = asyncio.get_event_loop() start = loop.time() metadata_url = "http://169.254.169.254/computeMetadata/v1/instance/id" + # Single attempt wall clock budget. Increased to 10s (was 3s) to favor + # reliability on constrained CI VMs; early success returns immediately + # so typical latency stays low. + per_attempt_timeout = 10 + debugging = os.environ.get("S2IAM_DEBUGGING") == "true" + env_hint = "GCE_METADATA_HOST=set" if os.environ.get("GCE_METADATA_HOST") else "GCE_METADATA_HOST=unset" + cred_hint = ( + "GOOGLE_APPLICATION_CREDENTIALS=external_account" + if ( + os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") + and os.path.isfile(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "")) + ) + else "GOOGLE_APPLICATION_CREDENTIALS=unset_or_non_external" + ) + + def classify(msg: str) -> str: + lower = msg.lower() + if any(p in lower for p in ("name or service not known", "temporary failure", "not known")): + return "dns" + if any(p in lower for p in ("timed out", "timeout")): + return "timeout" + if any(p in lower for p in ("refused", "connection reset")): + return "connect" + return "other" + try: - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3)) as session: - async with session.get(metadata_url, headers={"Metadata-Flavor": "Google"}) as response: # noqa: S310 - if response.status != 200: - raise Exception(f"metadata status {response.status}") + trace_configs = [] + # Enable trace collection either when debugging explicitly OR if we later hit a timeout + want_trace = debugging + tc: Optional[aiohttp.TraceConfig] = None + if want_trace: + tc = aiohttp.TraceConfig() + + from typing import Any as _Any + + async def _trace_dns_start( + session: aiohttp.ClientSession, context: _Any, params: _Any + ) -> None: # noqa: D401 + self._log("TRACE dns_start") + + async def _trace_dns_end( + session: aiohttp.ClientSession, context: _Any, params: _Any + ) -> None: # noqa: D401 + self._log("TRACE dns_end") + + async def _trace_conn_start( + session: aiohttp.ClientSession, context: _Any, params: _Any + ) -> None: # noqa: D401 + self._log("TRACE connection_create_start") + + async def _trace_conn_end( + session: aiohttp.ClientSession, context: _Any, params: _Any + ) -> None: # noqa: D401 + self._log("TRACE connection_create_end") + + async def _trace_request_start( + session: aiohttp.ClientSession, context: _Any, params: _Any + ) -> None: # noqa: D401 + self._log("TRACE request_start") + + async def _trace_request_end( + session: aiohttp.ClientSession, context: _Any, params: _Any + ) -> None: # noqa: D401 + self._log("TRACE request_end") + + tc.on_dns_resolvehost_start.append(_trace_dns_start) + tc.on_dns_resolvehost_end.append(_trace_dns_end) + tc.on_connection_create_start.append(_trace_conn_start) + tc.on_connection_create_end.append(_trace_conn_end) + tc.on_request_start.append(_trace_request_start) + tc.on_request_end.append(_trace_request_end) + trace_configs.append(tc) + + async def _probe() -> None: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=per_attempt_timeout), trace_configs=trace_configs or None + ) as session: + async with session.get( + metadata_url, headers={"Metadata-Flavor": "Google"} + ) as response: # noqa: S310 + if response.status != 200: + raise Exception(f"metadata status {response.status}") + + try: + await asyncio.wait_for(_probe(), timeout=per_attempt_timeout + 0.25) # small guard margin + except asyncio.TimeoutError as te: # normalize to TimeoutError for classification + # If we timed out but did not previously enable traces, we cannot retroactively + # gather aiohttp phase hooks. Emit a concise marker for diagnostics. + if not debugging: + self._log("TRACE timeout_without_phase_detail") + raise TimeoutError("metadata probe wait_for timeout") from te elapsed_ms = int((loop.time() - start) * 1000) - self._log(f"Detected GCP metadata (elapsed={elapsed_ms}ms)") + self._log(f"Detected metadata (elapsed={elapsed_ms}ms)") self._detected = True return except Exception as e: # noqa: BLE001 elapsed_ms = int((loop.time() - start) * 1000) msg = str(e) or type(e).__name__ - lower = msg.lower() - if any(p in lower for p in ("name or service not known", "temporary failure", "not known")): - category = "dns" - elif any(p in lower for p in ("timed out", "timeout")): - category = "timeout" - elif any(p in lower for p in ("refused", "connection reset")): - category = "connect" - else: - category = "other" - self._log(f"Metadata probe failed (elapsed={elapsed_ms}ms category={category}): {msg}") - raise Exception( - "Not running on GCP: metadata service unavailable (single attempt to 169.254.169.254 failed): " - + f"{msg}" + category = classify(msg) + over_timeout = elapsed_ms > (per_attempt_timeout * 1000 + 300) + diag = ( + "Not running on GCP: metadata probe failed; " + f"elapsed_ms={elapsed_ms} category={category} timeout_s={per_attempt_timeout} " + f"over_timeout_margin={over_timeout} env=[{env_hint} {cred_hint}] " + f"exception_type={type(e).__name__} detail={msg}" ) - - raise Exception( - "Not running on GCP: no environment variable, metadata service, or default credentials detected" - ) + self._log(diag) + raise Exception(diag) async def fast_detect(self) -> None: """Fast detection: env/file only, no network.""" diff --git a/python/tests/test_cloud_validation.py b/python/tests/test_cloud_validation.py index d1ea4b9..c8ab30e 100644 --- a/python/tests/test_cloud_validation.py +++ b/python/tests/test_cloud_validation.py @@ -10,18 +10,13 @@ import s2iam from s2iam import CloudProviderType, JWTType -from .test_server_utils import GoTestServerManager +from .test_server_utils import get_shared_server from .testhelp import expect_cloud_provider_detected, require_cloud_role, validate_identity_and_jwt @pytest.fixture(scope="session") def test_server(): - server = GoTestServerManager(timeout_minutes=5) - try: - server.start() - yield server - finally: - server.stop() + return get_shared_server() @pytest.mark.asyncio diff --git a/python/tests/test_fastpath.py b/python/tests/test_fastpath.py index 1d15253..6b079c1 100644 --- a/python/tests/test_fastpath.py +++ b/python/tests/test_fastpath.py @@ -13,7 +13,7 @@ import s2iam from s2iam import CloudProviderType -from .test_server_utils import GoTestServerManager +from .test_server_utils import get_shared_server from .testhelp import expect_cloud_provider_detected, validate_identity_and_jwt @@ -114,26 +114,22 @@ async def _test_equivalent_functionality(self, normal_provider, fastpath_provide assert normal_identity.region == fastpath_identity.region, "Both providers should extract same region" # End-result validation using shared helper (mirrors Go shared happy-path code) - server = GoTestServerManager(timeout_minutes=1) - try: - server.start() - # Use helper with fast-path provider - provider_type = normal_provider.get_type() - audience = "https://authsvc.singlestore.com" if provider_type == CloudProviderType.GCP else None - _, fast_identity, claims = await validate_identity_and_jwt( - fastpath_provider, - workspace_group_id="test-workspace", - server_url=f"{server.server_url}/auth/iam/database", - audience=audience, - ) - # Cross-check that fast-path identity matches normal detection identity on critical fields - assert fast_identity.identifier == normal_identity.identifier, "Identifier mismatch" - assert fast_identity.provider == normal_identity.provider, "Provider type mismatch" - assert fast_identity.account_id == normal_identity.account_id, "Account ID mismatch" - assert fast_identity.region == normal_identity.region, "Region mismatch" - print("✓ Fast-path validation: identity and JWT claims consistent with normal detection") - finally: - server.stop() + server = get_shared_server() + # Use helper with fast-path provider + provider_type = normal_provider.get_type() + audience = "https://authsvc.singlestore.com" if provider_type == CloudProviderType.GCP else None + _, fast_identity, claims = await validate_identity_and_jwt( + fastpath_provider, + workspace_group_id="test-workspace", + server_url=f"{server.server_url}/auth/iam/database", + audience=audience, + ) + # Cross-check that fast-path identity matches normal detection identity on critical fields + assert fast_identity.identifier == normal_identity.identifier, "Identifier mismatch" + assert fast_identity.provider == normal_identity.provider, "Provider type mismatch" + assert fast_identity.account_id == normal_identity.account_id, "Account ID mismatch" + assert fast_identity.region == normal_identity.region, "Region mismatch" + print("✓ Fast-path validation: identity and JWT claims consistent with normal detection") print("✓ Fast-path and normal detection produced equivalent results") diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index 5fe150b..a81621d 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -11,17 +11,13 @@ import s2iam from s2iam import CloudProviderType, JWTType -from .test_server_utils import GoTestServerManager +from .test_server_utils import get_shared_server from .testhelp import TEST_DETECT_TIMEOUT, expect_cloud_provider_detected, require_cloud_role @pytest.fixture(scope="session") def test_server(): - """Fixture to manage the test server lifecycle.""" - server = GoTestServerManager(timeout_minutes=5) # Auto-shutdown after 5 minutes, random port - server.start() - yield server - server.stop() + return get_shared_server() @pytest.mark.asyncio diff --git a/python/tests/test_server_utils.py b/python/tests/test_server_utils.py index 04b1dfa..9102a59 100644 --- a/python/tests/test_server_utils.py +++ b/python/tests/test_server_utils.py @@ -30,6 +30,9 @@ def __init__( self.go_dir = self._get_go_directory(go_dir) self.actual_port: Optional[int] = None + # NOTE: Manager is intentionally not thread-safe because pytest runs tests sequentially. + # The server lives for the whole test session (or until explicit stop); process exit handles cleanup. + def _get_go_directory(self, go_dir: Optional[str]) -> str: """Get the Go directory location for the two known test scenarios.""" if go_dir: @@ -104,6 +107,8 @@ def start(self) -> None: self.info_file = os.path.join(self.go_dir, "s2iam_test_server_info.json") # Prepare server command (request random port with 0 and info-file) + # Use -shutdown-on-stdin-close and keep stdin open for lifetime of process; closing stdin (via process exit) + # triggers graceful shutdown. We never explicitly terminate in tests. server_cmd = [ "./s2iam_test_server", "-port", @@ -169,32 +174,7 @@ def start(self) -> None: logger.debug("Test server started successfully on port %s", self.actual_port) - # Removed _read_server_port; info-file polling replaces stdout parsing - - def stop(self) -> None: - """Stop the Go test server.""" - if self.process: - self.process.terminate() - try: - self.process.wait(timeout=5) - except subprocess.TimeoutExpired: - self.process.kill() - self.process.wait() - self.process = None - - # Ensure stderr file is closed - if hasattr(self, "_stderr_file") and self._stderr_file and not self._stderr_file.closed: - try: - self._stderr_file.flush() - except Exception: - pass - try: - self._stderr_file.close() - except Exception: - pass - - # Show debug log contents if available - self.show_debug_log() + # Info-file polling supplies the dynamically chosen port; no stdout parsing helper needed. def show_debug_log(self) -> None: """Display the contents of the Go server debug log.""" @@ -222,6 +202,18 @@ def __enter__(self): self.start() return self - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit.""" - self.stop() + def __exit__(self, exc_type, exc_val, exc_tb): # No explicit shutdown; allow exceptions to propagate + return False + + +_shared_server: Optional[GoTestServerManager] = None + + +def get_shared_server() -> GoTestServerManager: + """Return a singleton shared test server (starts on first use).""" + global _shared_server + if _shared_server is None: + mgr = GoTestServerManager(timeout_minutes=5) + mgr.start() + _shared_server = mgr + return _shared_server diff --git a/python/tests/testhelp.py b/python/tests/testhelp.py index 7ef43dd..42ea603 100644 --- a/python/tests/testhelp.py +++ b/python/tests/testhelp.py @@ -8,6 +8,7 @@ import base64 import json import os +import time from typing import Any, Dict, Optional, Tuple import pytest @@ -35,11 +36,29 @@ async def expect_cloud_provider_detected(timeout: float = TEST_DETECT_TIMEOUT) - ): pytest.skip("cloud provider required") + start = time.monotonic() try: provider = await s2iam.detect_provider(timeout=timeout) return provider - except s2iam.CloudProviderNotFound: - pytest.fail("Cloud provider detection failed - expected to detect provider in test environment") + except s2iam.CloudProviderNotFound as first_err: + first_elapsed_ms = int((time.monotonic() - start) * 1000) + retry_timeout = max(1.0, timeout * 0.5) + retry_start = time.monotonic() + try: + retry_provider = await s2iam.detect_provider(timeout=retry_timeout) + retry_elapsed_ms = int((time.monotonic() - retry_start) * 1000) + pytest.fail( + "Cloud provider detection failed first attempt but second immediate attempt succeeded; " + f"first_elapsed_ms={first_elapsed_ms} retry_elapsed_ms={retry_elapsed_ms} " + f"primary_error={first_err} retry_timeout_s={retry_timeout} provider={retry_provider.get_type().value}" + ) + except s2iam.CloudProviderNotFound as second_err: + retry_elapsed_ms = int((time.monotonic() - retry_start) * 1000) + pytest.fail( + "Cloud provider detection failed twice; " + f"first_elapsed_ms={first_elapsed_ms} second_elapsed_ms={retry_elapsed_ms} " + f"first_error={first_err} second_error={second_err} retry_timeout_s={retry_timeout}" + ) except s2iam.ProviderIdentityUnavailable: pytest.skip("cloud provider detected no identity")