diff --git a/attestation-agent/app/build.rs b/attestation-agent/app/build.rs index a070ee979..f4daf0488 100644 --- a/attestation-agent/app/build.rs +++ b/attestation-agent/app/build.rs @@ -9,9 +9,17 @@ use ttrpc_codegen::{Codegen, Customize, ProtobufCustomize}; fn main() -> std::io::Result<()> { #[cfg(feature = "grpc")] { - tonic_build::compile_protos("../protos/keyprovider.proto")?; - tonic_build::compile_protos("../protos/getresource.proto")?; - tonic_build::compile_protos("../protos/attestation-agent.proto")?; + tonic_build::configure() + .build_server(true) + .protoc_arg("--experimental_allow_proto3_optional") + .compile( + &[ + "../protos/keyprovider.proto", + "../protos/getresource.proto", + "../protos/attestation-agent.proto", + ], + &["../protos"], + )?; } #[cfg(feature = "ttrpc")] diff --git a/attestation-agent/app/src/rpc/attestation/mod.rs b/attestation-agent/app/src/rpc/attestation/mod.rs index b2fbf27ef..2e7e8e2a0 100644 --- a/attestation-agent/app/src/rpc/attestation/mod.rs +++ b/attestation-agent/app/src/rpc/attestation/mod.rs @@ -20,7 +20,10 @@ pub mod grpc { use attestation::attestation_agent_service_server::{ AttestationAgentService, AttestationAgentServiceServer, }; - use attestation::{GetEvidenceRequest, GetEvidenceResponse, GetTokenRequest, GetTokenResponse}; + use attestation::{ + ExtendRuntimeMeasurementRequest, ExtendRuntimeMeasurementResponse, GetEvidenceRequest, + GetEvidenceResponse, GetTokenRequest, GetTokenResponse, + }; use std::net::SocketAddr; use tonic::{transport::Server, Request, Response, Status}; @@ -84,6 +87,35 @@ pub mod grpc { Result::Ok(Response::new(reply)) } + + async fn extend_runtime_measurement( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let attestation_agent_mutex_clone = Arc::clone(&ASYNC_ATTESTATION_AGENT); + let mut attestation_agent = attestation_agent_mutex_clone.lock().await; + + debug!("Call AA to extend runtime measurement ..."); + + attestation_agent + .extend_runtime_measurement(request.events, request.register_index) + .await + .map_err(|e| { + error!("Call AA to extend runtime measurement failed: {}", e); + Status::internal(format!( + "[ERROR:{}] AA extend runtime measurement failed: {}", + AGENT_NAME, e + )) + })?; + + debug!("Extend runtime measurement successfully!"); + + let reply = ExtendRuntimeMeasurementResponse {}; + + Result::Ok(Response::new(reply)) + } } pub async fn start_grpc_service(socket: SocketAddr) -> Result<()> { @@ -176,6 +208,34 @@ pub mod ttrpc { ::ttrpc::Result::Ok(reply) } + + async fn extend_runtime_measurement( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + req: attestation_agent::ExtendRuntimeMeasurementRequest, + ) -> ::ttrpc::Result { + debug!("Call AA to extend runtime measurement ..."); + + let attestation_agent_mutex_clone = ASYNC_ATTESTATION_AGENT.clone(); + let mut attestation_agent = attestation_agent_mutex_clone.lock().await; + + attestation_agent + .extend_runtime_measurement(req.Events, req.RegisterIndex) + .await + .map_err(|e| { + error!("Call AA to extend runtime measurement failed: {}", e); + let mut error_status = ::ttrpc::proto::Status::new(); + error_status.set_code(Code::INTERNAL); + error_status.set_message(format!( + "[ERROR:{}] AA extend runtime measurement failed: {}", + AGENT_NAME, e + )); + ::ttrpc::Error::RpcStatus(error_status) + })?; + + let reply = attestation_agent::ExtendRuntimeMeasurementResponse::new(); + ::ttrpc::Result::Ok(reply) + } } pub fn start_ttrpc_service() -> Result> { diff --git a/attestation-agent/attester/src/lib.rs b/attestation-agent/attester/src/lib.rs index b1ade1fb1..acb8f1c1a 100644 --- a/attestation-agent/attester/src/lib.rs +++ b/attestation-agent/attester/src/lib.rs @@ -59,6 +59,16 @@ pub trait Attester { /// The parameter `report_data` will be used as the user input of the /// evidence to avoid reply attack. async fn get_evidence(&self, report_data: Vec) -> Result; + + /// Extend TEE specific dynamic measurement register + /// to enable dynamic measurement capabilities for input data at runtime. + async fn extend_runtime_measurement( + &self, + _events: Vec>, + _register_index: Option, + ) -> Result<()> { + bail!("Unimplemented") + } } // Detect which TEE platform the KBC running environment is. diff --git a/attestation-agent/attester/src/tdx/mod.rs b/attestation-agent/attester/src/tdx/mod.rs index 63db5e991..b5b47e777 100644 --- a/attestation-agent/attester/src/tdx/mod.rs +++ b/attestation-agent/attester/src/tdx/mod.rs @@ -65,6 +65,28 @@ impl Attester for TdxAttester { serde_json::to_string(&evidence) .map_err(|e| anyhow!("Serialize TDX evidence failed: {:?}", e)) } + + async fn extend_runtime_measurement( + &self, + events: Vec>, + _register_index: Option, + ) -> Result<()> { + for event in events { + match tdx_attest_rs::tdx_att_extend(&event) { + tdx_attest_rs::tdx_attest_error_t::TDX_ATTEST_SUCCESS => { + log::debug!("TDX extend runtime measurement succeeded.") + } + error_code => { + bail!( + "TDX Attester: Failed to extend RTMR. Error code: {:?}", + error_code + ); + } + } + } + + Ok(()) + } } #[cfg(test)] diff --git a/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent.rs b/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent.rs index 1d27b2c7d..e256e663b 100644 --- a/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent.rs +++ b/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent.rs @@ -513,16 +513,265 @@ impl ::protobuf::reflect::ProtobufValue for GetTokenResponse { type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; } +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:attestation_agent.ExtendRuntimeMeasurementRequest) +pub struct ExtendRuntimeMeasurementRequest { + // message fields + // @@protoc_insertion_point(field:attestation_agent.ExtendRuntimeMeasurementRequest.Events) + pub Events: ::std::vec::Vec<::std::vec::Vec>, + // @@protoc_insertion_point(field:attestation_agent.ExtendRuntimeMeasurementRequest.RegisterIndex) + pub RegisterIndex: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:attestation_agent.ExtendRuntimeMeasurementRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ExtendRuntimeMeasurementRequest { + fn default() -> &'a ExtendRuntimeMeasurementRequest { + ::default_instance() + } +} + +impl ExtendRuntimeMeasurementRequest { + pub fn new() -> ExtendRuntimeMeasurementRequest { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>( + "Events", + |m: &ExtendRuntimeMeasurementRequest| { &m.Events }, + |m: &mut ExtendRuntimeMeasurementRequest| { &mut m.Events }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "RegisterIndex", + |m: &ExtendRuntimeMeasurementRequest| { &m.RegisterIndex }, + |m: &mut ExtendRuntimeMeasurementRequest| { &mut m.RegisterIndex }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ExtendRuntimeMeasurementRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ExtendRuntimeMeasurementRequest { + const NAME: &'static str = "ExtendRuntimeMeasurementRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.Events.push(is.read_bytes()?); + }, + 16 => { + self.RegisterIndex = ::std::option::Option::Some(is.read_uint64()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + for value in &self.Events { + my_size += ::protobuf::rt::bytes_size(1, &value); + }; + if let Some(v) = self.RegisterIndex { + my_size += ::protobuf::rt::uint64_size(2, v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + for v in &self.Events { + os.write_bytes(1, &v)?; + }; + if let Some(v) = self.RegisterIndex { + os.write_uint64(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ExtendRuntimeMeasurementRequest { + ExtendRuntimeMeasurementRequest::new() + } + + fn clear(&mut self) { + self.Events.clear(); + self.RegisterIndex = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ExtendRuntimeMeasurementRequest { + static instance: ExtendRuntimeMeasurementRequest = ExtendRuntimeMeasurementRequest { + Events: ::std::vec::Vec::new(), + RegisterIndex: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ExtendRuntimeMeasurementRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ExtendRuntimeMeasurementRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ExtendRuntimeMeasurementRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ExtendRuntimeMeasurementRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:attestation_agent.ExtendRuntimeMeasurementResponse) +pub struct ExtendRuntimeMeasurementResponse { + // special fields + // @@protoc_insertion_point(special_field:attestation_agent.ExtendRuntimeMeasurementResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ExtendRuntimeMeasurementResponse { + fn default() -> &'a ExtendRuntimeMeasurementResponse { + ::default_instance() + } +} + +impl ExtendRuntimeMeasurementResponse { + pub fn new() -> ExtendRuntimeMeasurementResponse { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ExtendRuntimeMeasurementResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ExtendRuntimeMeasurementResponse { + const NAME: &'static str = "ExtendRuntimeMeasurementResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ExtendRuntimeMeasurementResponse { + ExtendRuntimeMeasurementResponse::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ExtendRuntimeMeasurementResponse { + static instance: ExtendRuntimeMeasurementResponse = ExtendRuntimeMeasurementResponse { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ExtendRuntimeMeasurementResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ExtendRuntimeMeasurementResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ExtendRuntimeMeasurementResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ExtendRuntimeMeasurementResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + static file_descriptor_proto_data: &'static [u8] = b"\ \n\x17attestation-agent.proto\x12\x11attestation_agent\"6\n\x12GetEviden\ ceRequest\x12\x20\n\x0bRuntimeData\x18\x01\x20\x01(\x0cR\x0bRuntimeData\ \"1\n\x13GetEvidenceResponse\x12\x1a\n\x08Evidence\x18\x01\x20\x01(\x0cR\ \x08Evidence\"/\n\x0fGetTokenRequest\x12\x1c\n\tTokenType\x18\x01\x20\ \x01(\tR\tTokenType\"(\n\x10GetTokenResponse\x12\x14\n\x05Token\x18\x01\ - \x20\x01(\x0cR\x05Token2\xcc\x01\n\x17AttestationAgentService\x12\\\n\ - \x0bGetEvidence\x12%.attestation_agent.GetEvidenceRequest\x1a&.attestati\ - on_agent.GetEvidenceResponse\x12S\n\x08GetToken\x12\".attestation_agent.\ - GetTokenRequest\x1a#.attestation_agent.GetTokenResponseb\x06proto3\ + \x20\x01(\x0cR\x05Token\"v\n\x1fExtendRuntimeMeasurementRequest\x12\x16\ + \n\x06Events\x18\x01\x20\x03(\x0cR\x06Events\x12)\n\rRegisterIndex\x18\ + \x02\x20\x01(\x04H\0R\rRegisterIndex\x88\x01\x01B\x10\n\x0e_RegisterInde\ + x\"\"\n\x20ExtendRuntimeMeasurementResponse2\xd2\x02\n\x17AttestationAge\ + ntService\x12\\\n\x0bGetEvidence\x12%.attestation_agent.GetEvidenceReque\ + st\x1a&.attestation_agent.GetEvidenceResponse\x12S\n\x08GetToken\x12\".a\ + ttestation_agent.GetTokenRequest\x1a#.attestation_agent.GetTokenResponse\ + \x12\x83\x01\n\x18ExtendRuntimeMeasurement\x122.attestation_agent.Extend\ + RuntimeMeasurementRequest\x1a3.attestation_agent.ExtendRuntimeMeasuremen\ + tResponseb\x06proto3\ "; /// `FileDescriptorProto` object which was a source for this generated file @@ -540,11 +789,13 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor { file_descriptor.get(|| { let generated_file_descriptor = generated_file_descriptor_lazy.get(|| { let mut deps = ::std::vec::Vec::with_capacity(0); - let mut messages = ::std::vec::Vec::with_capacity(4); + let mut messages = ::std::vec::Vec::with_capacity(6); messages.push(GetEvidenceRequest::generated_message_descriptor_data()); messages.push(GetEvidenceResponse::generated_message_descriptor_data()); messages.push(GetTokenRequest::generated_message_descriptor_data()); messages.push(GetTokenResponse::generated_message_descriptor_data()); + messages.push(ExtendRuntimeMeasurementRequest::generated_message_descriptor_data()); + messages.push(ExtendRuntimeMeasurementResponse::generated_message_descriptor_data()); let mut enums = ::std::vec::Vec::with_capacity(0); ::protobuf::reflect::GeneratedFileDescriptor::new_generated( file_descriptor_proto(), diff --git a/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent_ttrpc.rs b/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent_ttrpc.rs index 96a8c394e..2e4d29cae 100644 --- a/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent_ttrpc.rs +++ b/attestation-agent/kbs_protocol/src/token_provider/aa/attestation_agent_ttrpc.rs @@ -43,6 +43,11 @@ impl AttestationAgentServiceClient { let mut cres = super::attestation_agent::GetTokenResponse::new(); ::ttrpc::async_client_request!(self, ctx, req, "attestation_agent.AttestationAgentService", "GetToken", cres); } + + pub async fn extend_runtime_measurement(&self, ctx: ttrpc::context::Context, req: &super::attestation_agent::ExtendRuntimeMeasurementRequest) -> ::ttrpc::Result { + let mut cres = super::attestation_agent::ExtendRuntimeMeasurementResponse::new(); + ::ttrpc::async_client_request!(self, ctx, req, "attestation_agent.AttestationAgentService", "ExtendRuntimeMeasurement", cres); + } } struct GetEvidenceMethod { @@ -67,6 +72,17 @@ impl ::ttrpc::r#async::MethodHandler for GetTokenMethod { } } +struct ExtendRuntimeMeasurementMethod { + service: Arc>, +} + +#[async_trait] +impl ::ttrpc::r#async::MethodHandler for ExtendRuntimeMeasurementMethod { + async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<::ttrpc::Response> { + ::ttrpc::async_request_handler!(self, ctx, req, attestation_agent, ExtendRuntimeMeasurementRequest, extend_runtime_measurement); + } +} + #[async_trait] pub trait AttestationAgentService: Sync { async fn get_evidence(&self, _ctx: &::ttrpc::r#async::TtrpcContext, _: super::attestation_agent::GetEvidenceRequest) -> ::ttrpc::Result { @@ -75,6 +91,9 @@ pub trait AttestationAgentService: Sync { async fn get_token(&self, _ctx: &::ttrpc::r#async::TtrpcContext, _: super::attestation_agent::GetTokenRequest) -> ::ttrpc::Result { Err(::ttrpc::Error::RpcStatus(::ttrpc::get_status(::ttrpc::Code::NOT_FOUND, "/attestation_agent.AttestationAgentService/GetToken is not supported".to_string()))) } + async fn extend_runtime_measurement(&self, _ctx: &::ttrpc::r#async::TtrpcContext, _: super::attestation_agent::ExtendRuntimeMeasurementRequest) -> ::ttrpc::Result { + Err(::ttrpc::Error::RpcStatus(::ttrpc::get_status(::ttrpc::Code::NOT_FOUND, "/attestation_agent.AttestationAgentService/ExtendRuntimeMeasurement is not supported".to_string()))) + } } pub fn create_attestation_agent_service(service: Arc>) -> HashMap { @@ -88,6 +107,9 @@ pub fn create_attestation_agent_service(service: Arc); + methods.insert("ExtendRuntimeMeasurement".to_string(), + Box::new(ExtendRuntimeMeasurementMethod{service: service.clone()}) as Box); + ret.insert("attestation_agent.AttestationAgentService".to_string(), ::ttrpc::r#async::Service{ methods, streams }); ret } diff --git a/attestation-agent/lib/src/lib.rs b/attestation-agent/lib/src/lib.rs index cec73aec0..ff0eda218 100644 --- a/attestation-agent/lib/src/lib.rs +++ b/attestation-agent/lib/src/lib.rs @@ -79,6 +79,13 @@ pub trait AttestationAPIs { /// Get TEE hardware signed evidence that includes the runtime data. async fn get_evidence(&mut self, runtime_data: &[u8]) -> Result>; + + /// Extend runtime measurement register + async fn extend_runtime_measurement( + &mut self, + events: Vec>, + register_index: Option, + ) -> Result<()>; } /// Attestation agent to provide attestation service. @@ -193,4 +200,18 @@ impl AttestationAPIs for AttestationAgent { let evidence = attester.get_evidence(runtime_data.to_vec()).await?; Ok(evidence.into_bytes()) } + + /// Extend runtime measurement register + async fn extend_runtime_measurement( + &mut self, + events: Vec>, + register_index: Option, + ) -> Result<()> { + let tee_type = detect_tee_type().ok_or(anyhow!("no supported tee type found!"))?; + let attester = TryInto::::try_into(tee_type)?; + attester + .extend_runtime_measurement(events, register_index) + .await?; + Ok(()) + } } diff --git a/attestation-agent/protos/attestation-agent.proto b/attestation-agent/protos/attestation-agent.proto index ec233db14..a6a455464 100644 --- a/attestation-agent/protos/attestation-agent.proto +++ b/attestation-agent/protos/attestation-agent.proto @@ -18,7 +18,15 @@ message GetTokenResponse { bytes Token = 1; } +message ExtendRuntimeMeasurementRequest { + repeated bytes Events = 1; + optional uint64 RegisterIndex = 2; +} + +message ExtendRuntimeMeasurementResponse {} + service AttestationAgentService { rpc GetEvidence(GetEvidenceRequest) returns (GetEvidenceResponse) {}; rpc GetToken(GetTokenRequest) returns (GetTokenResponse) {}; + rpc ExtendRuntimeMeasurement(ExtendRuntimeMeasurementRequest) returns (ExtendRuntimeMeasurementResponse) {}; }