Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: data source internal identifier and query sanitizer #9

Merged
merged 3 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.21
require (
github.com/auth0/go-jwt-middleware/v2 v2.2.1
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/secure v1.1.0
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.20.0
github.com/google/generative-ai-go v0.8.0
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQ
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
github.com/gin-contrib/secure v1.1.0 h1:wy/psCWbgUBDCLH13KgB/m06NHXb1jczSTRp+H2hK7E=
github.com/gin-contrib/secure v1.1.0/go.mod h1:LtEfyy326NRwgkUq8ac6npf845L0L9B8yfEaLcxMHIc=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
Expand Down
63 changes: 25 additions & 38 deletions internal/controllers/core.controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package controllers
import (
"fmt"
"net/http"
"strings"

"github.com/gin-gonic/gin"
_ "github.com/lib/pq"
Expand All @@ -16,14 +15,14 @@ import (
type CoreController struct {
CoreService services.CoreService
DataSourceService services.DataSourceService
TrinoService services.EngineService
EngineService services.EngineService
}

func NewCoreController(coreService services.CoreService, dataSourceService services.DataSourceService, trinoService services.EngineService) CoreController {
func NewCoreController(coreService services.CoreService, dataSourceService services.DataSourceService, engineService services.EngineService) CoreController {
return CoreController{
CoreService: coreService,
DataSourceService: dataSourceService,
TrinoService: trinoService,
EngineService: engineService,
}
}

Expand All @@ -44,7 +43,6 @@ func NewCoreController(coreService services.CoreService, dataSourceService servi
// @Router /prompts/generate [post]
func (self *CoreController) GenerateQuery(c *gin.Context) {
var generateQueryBody models.GenerateQueryBody
var queryResult models.QueryResult

_ownerId, _ := c.Get("owner_id")
ownerId, ok := _ownerId.(string)
Expand All @@ -65,7 +63,7 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

if err := c.ShouldBindJSON(&generateQueryBody); err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Failed to read request body.",
Expand All @@ -74,9 +72,9 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

// Get Data source info and secret
ds, err := self.DataSourceService.GetByName(generateQueryBody.DataSourceName, ownerId)
ds, err := self.DataSourceService.GetByName(generateQueryBody.DataSourceName, ownerId, false)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Data source provided is invalid.",
Expand All @@ -85,10 +83,9 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

// Extract schemas info from Data source
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.Name, ds.OrganizationId)
logger.Info.Println(schemas)
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.ID)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Error extracting data source schema",
Expand All @@ -97,31 +94,27 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

schemasYaml, _ := yaml.Marshal(&schemas)
catalogName := self.TrinoService.GetCatalogName(ds.Name, ds.OrganizationId)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named %s with the following database schema:\n\n%s\n\nGive me an SQL Trino Query that provides the following information: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", catalogName, string(schemasYaml), generateQueryBody.Text)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named data_source_%s with the following database schema:\n\n%s\n\nGive me an SQL Trino Query that provides the following information: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", ds.ID.Hex(), string(schemasYaml), generateQueryBody.Text)
logger.Info.Println(mergedPrompt)

// Generate query
queryResult, err = self.CoreService.PromptGemini(mergedPrompt)
query, err := self.CoreService.PromptGemini(mergedPrompt)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "internal_server_error",
Description: "Error processing prompt",
})
return
}

query := strings.ReplaceAll(queryResult.QueryMarkdown, "```", "")
query = strings.ReplaceAll(query, "sql", "")

var results []map[string]interface{}

if generateQueryBody.Execute {
// Get data from Data source using generated query
results, err = self.TrinoService.GetRawData(query)
results, err = self.EngineService.GetRawData(query)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to get data from data source:: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to get data from data source: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "bad_request",
Description: "Failed to get data from data source.",
Expand All @@ -133,7 +126,7 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
// save activity
activity := models.GenerateQueryActivity{
Prompt: generateQueryBody.Text,
Query: queryResult.QueryMarkdown,
Query: query,
Data: results,
UserId: sub,
MergedPrompt: mergedPrompt,
Expand Down Expand Up @@ -162,7 +155,6 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
// @Router /prompts/improve [post]
func (self *CoreController) ImproveQuery(c *gin.Context) {
var improveQueryBody models.ImproveQueryBody
var queryResult models.QueryResult

_ownerId, _ := c.Get("owner_id")
ownerId, ok := _ownerId.(string)
Expand All @@ -183,7 +175,7 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

if err := c.ShouldBindJSON(&improveQueryBody); err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Failed to read request body.",
Expand All @@ -192,9 +184,9 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

// Get Data source info and secret
ds, err := self.DataSourceService.GetByName(improveQueryBody.DataSourceName, ownerId)
ds, err := self.DataSourceService.GetByName(improveQueryBody.DataSourceName, ownerId, false)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Data source provided is invalid.",
Expand All @@ -203,10 +195,9 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

// Extract schemas info from Data source
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.Name, ds.OrganizationId)
logger.Info.Println(schemas)
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.ID)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Error extracting data source schema",
Expand All @@ -215,31 +206,27 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

schemasYaml, _ := yaml.Marshal(&schemas)
catalogName := self.TrinoService.GetCatalogName(ds.Name, ds.OrganizationId)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named %s with the following database schema:\n\n%s\n\nEnhance this SQL Trino query for improved readability and performance: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", catalogName, string(schemasYaml), improveQueryBody.Query)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named data_source_%s with the following database schema:\n\n%s\n\nEnhance this SQL Trino query for improved readability and performance: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", ds.ID.Hex(), string(schemasYaml), improveQueryBody.Query)
logger.Info.Println(mergedPrompt)

// Generate query
queryResult, err = self.CoreService.PromptGemini(mergedPrompt)
query, err := self.CoreService.PromptGemini(mergedPrompt)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "internal_server_error",
Description: "Error processing prompt",
})
return
}

query := strings.ReplaceAll(queryResult.QueryMarkdown, "```", "")
query = strings.ReplaceAll(query, "sql", "")

var results []map[string]interface{}

if improveQueryBody.Execute {
// Get data from Data source using generated query
results, err = self.TrinoService.GetRawData(query)
results, err = self.EngineService.GetRawData(query)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to get data from data source:: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to get data from data source: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "bad_request",
Description: "Failed to get data from data source.",
Expand All @@ -251,7 +238,7 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
// save activity
activity := models.ImproveQueryActivity{
OriginalQuery: improveQueryBody.Query,
ImprovedQuery: queryResult.QueryMarkdown,
ImprovedQuery: query,
Data: results,
UserId: sub,
MergedPrompt: mergedPrompt,
Expand Down
23 changes: 11 additions & 12 deletions internal/controllers/ds.controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (self *DataSourceController) GetDataSourceByName(c *gin.Context) {
})
}

dataSource, err := self.DataSourceService.GetByName(name, ownerId)
dataSource, err := self.DataSourceService.GetByName(name, ownerId, false)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] failed to get data source by name: %v\n", ownerId, sub, err))

Expand Down Expand Up @@ -180,13 +180,10 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {

newDataSource.CreatedBy = sub
newDataSource.OrganizationId = ownerId
// newDataSource.UserInfo.Name = userData.Name
// newDataSource.UserInfo.Picture = userData.Picture
newDataSource.CreatedAt = time.Now()
newDataSource.UpdatedAt = time.Now()

if err := self.validate.Struct(newDataSource); err != nil {
logger.Info.Println(newDataSource)
validationErr := err.(validatorv10.ValidationErrors)
logger.Error.Println(fmt.Printf("[%s][%s] One or more data source fields are invalid: %s\n", ownerId, sub, validationErr))
c.JSON(http.StatusBadRequest, models.HTTPError{
Expand All @@ -196,7 +193,9 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {
return
}

if err := self.DataSourceService.Create(newDataSource); err != nil {
createdDataSource, err := self.DataSourceService.Create(newDataSource)

if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to create data source: %v\n", ownerId, sub, err))
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "internal_server_error",
Expand All @@ -205,7 +204,7 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {
return
}

c.JSON(http.StatusOK, models.HTTPSuccess{Message: "success"})
c.JSON(http.StatusOK, createdDataSource)
}

// @BasePath /
Expand All @@ -219,7 +218,7 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {
// @Produce json
// @Param name path string true "Data Source Name" string
// @Param Data_Source body models.UpdateRequestDataSourceBody true "Data Source"
// @Success 200 {object} models.HTTPSuccess
// @Success 200 {object} models.DataSource
// @Failure 400 {object} models.HTTPError
// @Failure 401 {object} models.HTTPError
// @Router /datasources/{name} [put]
Expand Down Expand Up @@ -257,7 +256,7 @@ func (self *DataSourceController) UpdateDataSourceByName(c *gin.Context) {

updateDataSource.UpdatedAt = time.Now()

err := self.DataSourceService.Update(name, ownerId, updateDataSource)
updatedDataSource, err := self.DataSourceService.Update(name, ownerId, updateDataSource)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] failed to update data source: %v\n", ownerId, sub, err))
c.JSON(http.StatusBadRequest, models.HTTPError{
Expand All @@ -267,7 +266,7 @@ func (self *DataSourceController) UpdateDataSourceByName(c *gin.Context) {
return
}

c.JSON(http.StatusOK, models.HTTPSuccess{Message: "success"})
c.JSON(http.StatusOK, updatedDataSource)
}

// @BasePath /
Expand All @@ -280,7 +279,7 @@ func (self *DataSourceController) UpdateDataSourceByName(c *gin.Context) {
// @Accept json
// @Produce json
// @Param name path string true "Data Source Name" string
// @Success 200 {object} models.HTTPSuccess
// @Success 200 {object} models.DataSource
// @Failure 400 {object} models.HTTPError
// @Failure 401 {object} models.HTTPError
// @Failure 500 {object} models.HTTPError
Expand All @@ -306,7 +305,7 @@ func (self *DataSourceController) DeleteDataSourceByName(c *gin.Context) {

name := c.Param("name")

err := self.DataSourceService.Delete(name, ownerId)
deletedDataSource, err := self.DataSourceService.Delete(name, ownerId)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to delete data source: %v\n", ownerId, sub, err))
c.JSON(http.StatusBadRequest, models.HTTPError{
Expand All @@ -316,7 +315,7 @@ func (self *DataSourceController) DeleteDataSourceByName(c *gin.Context) {
return
}

c.JSON(http.StatusOK, models.HTTPSuccess{Message: "success"})
c.JSON(http.StatusOK, deletedDataSource)
}

func (self *DataSourceController) RegisterDataSourceRoutes(rg *gin.RouterGroup) {
Expand Down
9 changes: 6 additions & 3 deletions internal/middlewares/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middlewares

import (
"os"
"strings"

jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/validator"
Expand All @@ -19,12 +20,14 @@ func SetVarsToContext() gin.HandlerFunc {
claims := c.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims)
customClaims := claims.CustomClaims.(*CustomClaims)

sub := strings.Split(claims.RegisteredClaims.Subject, "|")[1]

if customClaims.OrganizationId != "" {
c.Set("owner_id", customClaims.OrganizationId)
c.Set("sub", claims.RegisteredClaims.Subject)
c.Set("sub", sub)
} else {
c.Set("owner_id", claims.RegisteredClaims.Subject)
c.Set("sub", claims.RegisteredClaims.Subject)
c.Set("owner_id", sub)
c.Set("sub", sub)
}
} else {
c.Set("owner_id", "poligono")
Expand Down
4 changes: 0 additions & 4 deletions internal/models/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ type ImproveQueryBody struct {
Execute bool `json:"execute"`
}

type QueryResult struct {
QueryMarkdown string `json:"query"`
}

type GenerateQueryActivity struct {
ID primitive.ObjectID `json:"-" bson:"_id,omitempty"`
Prompt string `json:"prompt" bson:"prompt"`
Expand Down
2 changes: 1 addition & 1 deletion internal/models/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
)

type DataSource struct {
ID primitive.ObjectID `json:"-" bson:"omitempty,_id" swaggerignore:"true"`
ID primitive.ObjectID `json:"-" bson:"_id" swaggerignore:"true"`
Name string `json:"name" bson:"name" validate:"required"`
OrganizationId string `json:"organization_id" bson:"organization_id" validate:"required" swaggerignore:"true"`
CreatedBy string `json:"-" bson:"created_by" validate:"required"`
Expand Down
22 changes: 8 additions & 14 deletions internal/models/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@ type SQLSchema struct {
Name string `json:"name"`
}

type Schemas struct {
Schemas []Schema
}

type Schema struct {
ID primitive.ObjectID `json:"-" bson:"omitempty,_id" yaml:"-"`
Name string `json:"name" bson:"name"`
OrganizationId string `json:"organization_id" bson:"organization_id" yaml:"-"`
DataSourceName string `json:"data_source_name" bson:"data_source_name" yaml:"-"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
ID primitive.ObjectID `json:"-" bson:"_id" yaml:"-"`
Name string `json:"name" bson:"name"`
DataSourceId primitive.ObjectID `json:"data_source_id" bson:"data_source_id" yaml:"-"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
}

type UpdateSchema struct {
Name string `json:"name" bson:"name"`
DataSourceName string `json:"data_source_name" bson:"data_source_name" yaml:"-"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
Name string `json:"name" bson:"name"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
}

type SQLTable struct {
Expand Down
Loading
Loading