Skip to content

Commit

Permalink
refactor: modified struct
Browse files Browse the repository at this point in the history
  • Loading branch information
parmesant committed Dec 26, 2024
1 parent e60a884 commit 54d76b9
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 63 deletions.
33 changes: 9 additions & 24 deletions src/correlation/correlation_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,24 @@
*
*/

use datafusion::common::tree_node::TreeNode;
use itertools::Itertools;

use crate::{
query::{TableScanVisitor, QUERY_SESSION},
rbac::{
map::SessionKey,
role::{Action, Permission},
Users,
},
use crate::rbac::{
map::SessionKey,
role::{Action, Permission},
Users,
};

use super::CorrelationError;

async fn get_tables_from_query(query: &str) -> Result<TableScanVisitor, CorrelationError> {
let session_state = QUERY_SESSION.state();
let raw_logical_plan = session_state
.create_logical_plan(query)
.await
.map_err(|err| CorrelationError::AnyhowError(err.into()))?;

let mut visitor = TableScanVisitor::default();
let _ = raw_logical_plan.visit(&mut visitor);
Ok(visitor)
}
use super::{CorrelationError, TableConfig};

pub async fn user_auth_for_query(
session_key: &SessionKey,
query: &str,
table_configs: &[TableConfig],
) -> Result<(), CorrelationError> {
let tables = get_tables_from_query(query).await?;
let tables = table_configs.iter().map(|t| &t.table_name).collect_vec();
let permissions = Users.get_permissions(session_key);

for table_name in tables.into_inner().iter() {
for table_name in tables {
let mut authorized = false;

// in permission check if user can run query on the stream.
Expand Down
39 changes: 15 additions & 24 deletions src/correlation/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use relative_path::RelativePathBuf;
use crate::{
option::CONFIG,
storage::{CORRELATION_DIRECTORY, PARSEABLE_ROOT_DIRECTORY},
utils::{actix::extract_session_key_from_req, uid::Uid},
utils::actix::extract_session_key_from_req,
};

use super::{
Expand Down Expand Up @@ -53,7 +53,7 @@ pub async fn get(req: HttpRequest) -> Result<impl Responder, CorrelationError> {

let correlation = CORRELATIONS.get_correlation_by_id(correlation_id).await?;

if user_auth_for_query(&session_key, &correlation.query)
if user_auth_for_query(&session_key, &correlation.table_configs)
.await
.is_ok()
{
Expand All @@ -68,10 +68,10 @@ pub async fn post(req: HttpRequest, body: Bytes) -> Result<impl Responder, Corre
.map_err(|err| CorrelationError::AnyhowError(anyhow::Error::msg(err.to_string())))?;

let correlation_request: CorrelationRequest = serde_json::from_slice(&body)?;
let correlation: CorrelationConfig = correlation_request.into();

// validate user's query auth
user_auth_for_query(&session_key, &correlation.query).await?;
correlation_request.validate(&session_key).await?;

let correlation: CorrelationConfig = correlation_request.into();

// Save to disk
let store = CONFIG.storage().get_object_store();
Expand All @@ -80,10 +80,7 @@ pub async fn post(req: HttpRequest, body: Bytes) -> Result<impl Responder, Corre
// Save to memory
CORRELATIONS.update(&correlation).await?;

Ok(format!(
"Saved correlation with ID- {}",
correlation.id.to_string()
))
Ok(format!("Saved correlation with ID- {}", correlation.id))
}

pub async fn modify(req: HttpRequest, body: Bytes) -> Result<impl Responder, CorrelationError> {
Expand All @@ -95,17 +92,14 @@ pub async fn modify(req: HttpRequest, body: Bytes) -> Result<impl Responder, Cor
.get("correlation_id")
.ok_or(CorrelationError::Metadata("No correlation ID Provided"))?;

let correlation_request: CorrelationRequest = serde_json::from_slice(&body)?;
// validate whether user has access to this correlation object or not
let correlation = CORRELATIONS.get_correlation_by_id(correlation_id).await?;
user_auth_for_query(&session_key, &correlation.table_configs).await?;

// validate user's query auth
user_auth_for_query(&session_key, &correlation_request.query).await?;
let correlation_request: CorrelationRequest = serde_json::from_slice(&body)?;
correlation_request.validate(&session_key).await?;

let correlation: CorrelationConfig = CorrelationConfig {
version: correlation_request.version,
id: Uid::from_string(correlation_id)
.map_err(|err| CorrelationError::AnyhowError(anyhow::Error::msg(err.to_string())))?,
query: correlation_request.query,
};
let correlation = correlation_request.generate_correlation_config(correlation_id.to_owned());

// Save to disk
let store = CONFIG.storage().get_object_store();
Expand All @@ -114,10 +108,7 @@ pub async fn modify(req: HttpRequest, body: Bytes) -> Result<impl Responder, Cor
// Save to memory
CORRELATIONS.update(&correlation).await?;

Ok(format!(
"Modified correlation with ID- {}",
correlation.id.to_string()
))
Ok(format!("Modified correlation with ID- {}", correlation.id))
}

pub async fn delete(req: HttpRequest) -> Result<impl Responder, CorrelationError> {
Expand All @@ -132,14 +123,14 @@ pub async fn delete(req: HttpRequest) -> Result<impl Responder, CorrelationError
let correlation = CORRELATIONS.get_correlation_by_id(correlation_id).await?;

// validate user's query auth
user_auth_for_query(&session_key, &correlation.query).await?;
user_auth_for_query(&session_key, &correlation.table_configs).await?;

// Delete from disk
let store = CONFIG.storage().get_object_store();
let path = RelativePathBuf::from_iter([
PARSEABLE_ROOT_DIRECTORY,
CORRELATION_DIRECTORY,
&format!("{}", correlation.id),
&correlation.id.to_string(),
]);
store.delete_object(&path).await?;

Expand Down
127 changes: 114 additions & 13 deletions src/correlation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
*
*/

use std::collections::HashSet;

use actix_web::http::header::ContentType;
use chrono::Utc;
use correlation_utils::user_auth_for_query;
use datafusion::error::DataFusionError;
use http::StatusCode;
use itertools::Itertools;
use once_cell::sync::Lazy;
Expand All @@ -27,8 +31,8 @@ use tokio::sync::RwLock;
use tracing::{trace, warn};

use crate::{
handlers::http::rbac::RBACError, option::CONFIG, rbac::map::SessionKey,
storage::ObjectStorageError, utils::uid::Uid,
handlers::http::rbac::RBACError, option::CONFIG, query::QUERY_SESSION, rbac::map::SessionKey,
storage::ObjectStorageError, users::filters::FilterQuery, utils::get_hash,
};

pub mod correlation_utils;
Expand Down Expand Up @@ -69,7 +73,10 @@ impl Correlation {

let mut user_correlations = vec![];
for c in correlations {
if user_auth_for_query(session_key, &c.query).await.is_ok() {
if user_auth_for_query(session_key, &c.table_configs)
.await
.is_ok()
{
user_correlations.push(c);
}
}
Expand All @@ -81,10 +88,7 @@ impl Correlation {
correlation_id: &str,
) -> Result<CorrelationConfig, CorrelationError> {
let read = self.0.read().await;
let correlation = read
.iter()
.find(|c| c.id.to_string() == correlation_id)
.cloned();
let correlation = read.iter().find(|c| c.id == correlation_id).cloned();

if let Some(c) = correlation {
Ok(c)
Expand All @@ -110,7 +114,7 @@ impl Correlation {
let index = read_access
.iter()
.enumerate()
.find(|(_, c)| c.id.to_string() == correlation_id)
.find(|(_, c)| c.id == correlation_id)
.to_owned();

if let Some((index, _)) = index {
Expand All @@ -126,6 +130,7 @@ impl Correlation {
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum CorrelationVersion {
V1,
}
Expand All @@ -134,8 +139,12 @@ pub enum CorrelationVersion {
#[serde(rename_all = "camelCase")]
pub struct CorrelationConfig {
pub version: CorrelationVersion,
pub id: Uid,
pub query: String,
pub id: String,
pub table_configs: Vec<TableConfig>,
pub join_config: JoinConfig,
pub filter: Option<FilterQuery>,
pub start_time: Option<String>,
pub end_time: Option<String>,
}

impl CorrelationConfig {}
Expand All @@ -144,16 +153,89 @@ impl CorrelationConfig {}
#[serde(rename_all = "camelCase")]
pub struct CorrelationRequest {
pub version: CorrelationVersion,
pub query: String,
pub table_configs: Vec<TableConfig>,
pub join_config: JoinConfig,
pub filter: Option<FilterQuery>,
pub start_time: Option<String>,
pub end_time: Option<String>,
}

impl From<CorrelationRequest> for CorrelationConfig {
fn from(val: CorrelationRequest) -> Self {
Self {
version: val.version,
id: crate::utils::uid::gen(),
query: val.query,
id: get_hash(Utc::now().timestamp_micros().to_string().as_str()),
table_configs: val.table_configs,
join_config: val.join_config,
filter: val.filter,
start_time: val.start_time,
end_time: val.end_time,
}
}
}

impl CorrelationRequest {
pub fn generate_correlation_config(self, id: String) -> CorrelationConfig {
CorrelationConfig {
version: self.version,
id,
table_configs: self.table_configs,
join_config: self.join_config,
filter: self.filter,
start_time: self.start_time,
end_time: self.end_time,
}
}

/// This function will validate the TableConfigs, JoinConfig, and user auth
pub async fn validate(&self, session_key: &SessionKey) -> Result<(), CorrelationError> {
let ctx = &QUERY_SESSION;

let h1: HashSet<&String> = self.table_configs.iter().map(|t| &t.table_name).collect();
let h2 = HashSet::from([&self.join_config.table_one, &self.join_config.table_two]);

// check if table config tables are the same
if h1.len() != 2 {
return Err(CorrelationError::Metadata(
"Must provide config for two unique tables",
));
}

// check that the tables mentioned in join config are
// the same as those in table config
if h1 != h2 {
return Err(CorrelationError::Metadata(
"Must provide same tables for join config and table config",
));
}

// check if user has access to table
user_auth_for_query(session_key, &self.table_configs).await?;

// to validate table config, we need to check whether the mentioned fields
// are present in the table or not
for table_config in self.table_configs.iter() {
// table config check
let df = ctx.table(&table_config.table_name).await?;

let mut selected_fields = table_config
.selected_fields
.iter()
.map(|c| c.as_str())
.collect_vec();
let join_field = if table_config.table_name == self.join_config.table_one {
&self.join_config.field_one
} else {
&self.join_config.field_two
};

selected_fields.push(join_field.as_str());

// if this errors out then the table config is incorrect or join config is incorrect
df.select_columns(selected_fields.as_slice())?;
}

Ok(())
}
}

Expand All @@ -171,6 +253,8 @@ pub enum CorrelationError {
AnyhowError(#[from] anyhow::Error),
#[error("Unauthorized")]
Unauthorized,
#[error("DataFusion Error: {0}")]
DataFusion(#[from] DataFusionError),
}

impl actix_web::ResponseError for CorrelationError {
Expand All @@ -182,6 +266,7 @@ impl actix_web::ResponseError for CorrelationError {
Self::UserDoesNotExist(_) => StatusCode::NOT_FOUND,
Self::AnyhowError(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Unauthorized => StatusCode::BAD_REQUEST,
Self::DataFusion(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}

Expand All @@ -191,3 +276,19 @@ impl actix_web::ResponseError for CorrelationError {
.body(self.to_string())
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TableConfig {
pub selected_fields: Vec<String>,
pub table_name: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct JoinConfig {
pub table_one: String,
pub field_one: String,
pub table_two: String,
pub field_two: String,
}
2 changes: 1 addition & 1 deletion src/handlers/http/modal/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl Server {
),
)
.service(
web::resource("/correlation/{correlation_id}")
web::resource("/{correlation_id}")
.route(
web::get()
.to(correlation::http_handlers::get)
Expand Down
2 changes: 1 addition & 1 deletion src/storage/object_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ pub trait ObjectStorage: Send + Sync + 'static {
let path = RelativePathBuf::from_iter([
PARSEABLE_ROOT_DIRECTORY,
CORRELATION_DIRECTORY,
&format!("{}", correlation.id),
&format!("{}.json", correlation.id),
]);
self.put_object(&path, to_bytes(correlation)).await?;
Ok(())
Expand Down

0 comments on commit 54d76b9

Please sign in to comment.