diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index b532117c..5d068701 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -131,6 +131,10 @@ properties: enum: - GET - POST + http_headers: + type: object + additionalProperties: + type: string additionalProperties: false required: - name diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index fbafe7b9..f1250499 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -242,6 +242,7 @@ pub struct EndpointDetails { pub path: Option, #[serde(rename = "http_method")] pub method: Option, + pub http_headers: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9782698e..98d230a5 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -317,27 +317,35 @@ impl StreamContext { }; let http_method = endpoint.method.unwrap_or_default().to_string(); - let mut headers = vec![ + let mut headers: HashMap<_, _> = [ (ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()), (":method", &http_method), (":path", &path), (":authority", endpoint.name.as_str()), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), - ]; + ] + .into_iter() + .collect(); if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()); } if self.traceparent.is_some() { - headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); + headers.insert(TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()); + } + + // override http headers that are set in the prompt target + let http_headers = endpoint.http_headers.unwrap_or_default(); + for (key, value) in http_headers.iter() { + headers.insert(key.as_str(), value.as_str()); } let call_args = CallArgs::new( ARCH_INTERNAL_CLUSTER_NAME, &path, - headers, + headers.into_iter().collect(), Some(tool_params_json_str.as_bytes()), vec![], Duration::from_secs(5), diff --git a/demos/currency_exchange/arch_config.yaml b/demos/currency_exchange/arch_config.yaml index f8776c48..c27e1ff2 100644 --- a/demos/currency_exchange/arch_config.yaml +++ b/demos/currency_exchange/arch_config.yaml @@ -8,7 +8,7 @@ listener: llm_providers: - name: gpt-4o - access_key: $OPENAI_API_KEY + access_key: $OPENAI_API_KEY1 provider_interface: openai model: gpt-4o @@ -33,6 +33,8 @@ prompt_targets: endpoint: name: frankfurther_api path: /v1/latest?base=USD&symbols={currency_symbol} + http_headers: + Authorization: "Bearer $FRANKFURT_API_KEY" system_prompt: | You are a helpful assistant. Show me the currency symbol you want to convert from USD. @@ -41,11 +43,12 @@ prompt_targets: endpoint: name: frankfurther_api path: /v1/currencies + http_headers: + Authorization: "Bearer $FRANKFURT_API_KEY" endpoints: frankfurther_api: - endpoint: api.frankfurter.dev:443 - protocol: https + endpoint: host.docker.internal:1122 tracing: random_sampling: 100