Skip to content

Commit

Permalink
Merge pull request #9 from poligonoio/feature/data-source-identifier-ref
Browse files Browse the repository at this point in the history
refactor: data source internal identifier and query sanitizer
  • Loading branch information
eddydecena authored Aug 31, 2024
2 parents b5e0213 + 45fa06e commit 878d2a6
Show file tree
Hide file tree
Showing 23 changed files with 255 additions and 302 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.21
require (
github.com/auth0/go-jwt-middleware/v2 v2.2.1
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/secure v1.1.0
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.20.0
github.com/google/generative-ai-go v0.8.0
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQ
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
github.com/gin-contrib/secure v1.1.0 h1:wy/psCWbgUBDCLH13KgB/m06NHXb1jczSTRp+H2hK7E=
github.com/gin-contrib/secure v1.1.0/go.mod h1:LtEfyy326NRwgkUq8ac6npf845L0L9B8yfEaLcxMHIc=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
Expand Down
63 changes: 25 additions & 38 deletions internal/controllers/core.controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package controllers
import (
"fmt"
"net/http"
"strings"

"github.com/gin-gonic/gin"
_ "github.com/lib/pq"
Expand All @@ -16,14 +15,14 @@ import (
type CoreController struct {
CoreService services.CoreService
DataSourceService services.DataSourceService
TrinoService services.EngineService
EngineService services.EngineService
}

func NewCoreController(coreService services.CoreService, dataSourceService services.DataSourceService, trinoService services.EngineService) CoreController {
func NewCoreController(coreService services.CoreService, dataSourceService services.DataSourceService, engineService services.EngineService) CoreController {
return CoreController{
CoreService: coreService,
DataSourceService: dataSourceService,
TrinoService: trinoService,
EngineService: engineService,
}
}

Expand All @@ -44,7 +43,6 @@ func NewCoreController(coreService services.CoreService, dataSourceService servi
// @Router /prompts/generate [post]
func (self *CoreController) GenerateQuery(c *gin.Context) {
var generateQueryBody models.GenerateQueryBody
var queryResult models.QueryResult

_ownerId, _ := c.Get("owner_id")
ownerId, ok := _ownerId.(string)
Expand All @@ -65,7 +63,7 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

if err := c.ShouldBindJSON(&generateQueryBody); err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Failed to read request body.",
Expand All @@ -74,9 +72,9 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

// Get Data source info and secret
ds, err := self.DataSourceService.GetByName(generateQueryBody.DataSourceName, ownerId)
ds, err := self.DataSourceService.GetByName(generateQueryBody.DataSourceName, ownerId, false)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Data source provided is invalid.",
Expand All @@ -85,10 +83,9 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

// Extract schemas info from Data source
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.Name, ds.OrganizationId)
logger.Info.Println(schemas)
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.ID)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Error extracting data source schema",
Expand All @@ -97,31 +94,27 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
}

schemasYaml, _ := yaml.Marshal(&schemas)
catalogName := self.TrinoService.GetCatalogName(ds.Name, ds.OrganizationId)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named %s with the following database schema:\n\n%s\n\nGive me an SQL Trino Query that provides the following information: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", catalogName, string(schemasYaml), generateQueryBody.Text)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named data_source_%s with the following database schema:\n\n%s\n\nGive me an SQL Trino Query that provides the following information: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", ds.ID.Hex(), string(schemasYaml), generateQueryBody.Text)
logger.Info.Println(mergedPrompt)

// Generate query
queryResult, err = self.CoreService.PromptGemini(mergedPrompt)
query, err := self.CoreService.PromptGemini(mergedPrompt)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "internal_server_error",
Description: "Error processing prompt",
})
return
}

query := strings.ReplaceAll(queryResult.QueryMarkdown, "```", "")
query = strings.ReplaceAll(query, "sql", "")

var results []map[string]interface{}

if generateQueryBody.Execute {
// Get data from Data source using generated query
results, err = self.TrinoService.GetRawData(query)
results, err = self.EngineService.GetRawData(query)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to get data from data source:: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to get data from data source: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "bad_request",
Description: "Failed to get data from data source.",
Expand All @@ -133,7 +126,7 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
// save activity
activity := models.GenerateQueryActivity{
Prompt: generateQueryBody.Text,
Query: queryResult.QueryMarkdown,
Query: query,
Data: results,
UserId: sub,
MergedPrompt: mergedPrompt,
Expand Down Expand Up @@ -162,7 +155,6 @@ func (self *CoreController) GenerateQuery(c *gin.Context) {
// @Router /prompts/improve [post]
func (self *CoreController) ImproveQuery(c *gin.Context) {
var improveQueryBody models.ImproveQueryBody
var queryResult models.QueryResult

_ownerId, _ := c.Get("owner_id")
ownerId, ok := _ownerId.(string)
Expand All @@ -183,7 +175,7 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

if err := c.ShouldBindJSON(&improveQueryBody); err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to read request body: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Failed to read request body.",
Expand All @@ -192,9 +184,9 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

// Get Data source info and secret
ds, err := self.DataSourceService.GetByName(improveQueryBody.DataSourceName, ownerId)
ds, err := self.DataSourceService.GetByName(improveQueryBody.DataSourceName, ownerId, false)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Data source provided is invalid: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Data source provided is invalid.",
Expand All @@ -203,10 +195,9 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

// Extract schemas info from Data source
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.Name, ds.OrganizationId)
logger.Info.Println(schemas)
schemas, err := self.DataSourceService.GetDataSourceSchemas(ds.ID)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error extracting metadata: %v\n", ownerId, sub, err)
c.JSON(http.StatusBadRequest, models.HTTPError{
Error: "bad_request",
Description: "Error extracting data source schema",
Expand All @@ -215,31 +206,27 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
}

schemasYaml, _ := yaml.Marshal(&schemas)
catalogName := self.TrinoService.GetCatalogName(ds.Name, ds.OrganizationId)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named %s with the following database schema:\n\n%s\n\nEnhance this SQL Trino query for improved readability and performance: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", catalogName, string(schemasYaml), improveQueryBody.Query)
var mergedPrompt string = fmt.Sprintf("I have a PostgreSQL Catalog in Trino named data_source_%s with the following database schema:\n\n%s\n\nEnhance this SQL Trino query for improved readability and performance: %s\n\nTo accomplish the task correctly please consider including schema on the query and return only a query without additional text.", ds.ID.Hex(), string(schemasYaml), improveQueryBody.Query)
logger.Info.Println(mergedPrompt)

// Generate query
queryResult, err = self.CoreService.PromptGemini(mergedPrompt)
query, err := self.CoreService.PromptGemini(mergedPrompt)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Error processing prompt: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "internal_server_error",
Description: "Error processing prompt",
})
return
}

query := strings.ReplaceAll(queryResult.QueryMarkdown, "```", "")
query = strings.ReplaceAll(query, "sql", "")

var results []map[string]interface{}

if improveQueryBody.Execute {
// Get data from Data source using generated query
results, err = self.TrinoService.GetRawData(query)
results, err = self.EngineService.GetRawData(query)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to get data from data source:: %v\n", ownerId, sub, err))
logger.Error.Printf("[%s][%s] Failed to get data from data source: %v\n", ownerId, sub, err)
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "bad_request",
Description: "Failed to get data from data source.",
Expand All @@ -251,7 +238,7 @@ func (self *CoreController) ImproveQuery(c *gin.Context) {
// save activity
activity := models.ImproveQueryActivity{
OriginalQuery: improveQueryBody.Query,
ImprovedQuery: queryResult.QueryMarkdown,
ImprovedQuery: query,
Data: results,
UserId: sub,
MergedPrompt: mergedPrompt,
Expand Down
23 changes: 11 additions & 12 deletions internal/controllers/ds.controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (self *DataSourceController) GetDataSourceByName(c *gin.Context) {
})
}

dataSource, err := self.DataSourceService.GetByName(name, ownerId)
dataSource, err := self.DataSourceService.GetByName(name, ownerId, false)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] failed to get data source by name: %v\n", ownerId, sub, err))

Expand Down Expand Up @@ -180,13 +180,10 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {

newDataSource.CreatedBy = sub
newDataSource.OrganizationId = ownerId
// newDataSource.UserInfo.Name = userData.Name
// newDataSource.UserInfo.Picture = userData.Picture
newDataSource.CreatedAt = time.Now()
newDataSource.UpdatedAt = time.Now()

if err := self.validate.Struct(newDataSource); err != nil {
logger.Info.Println(newDataSource)
validationErr := err.(validatorv10.ValidationErrors)
logger.Error.Println(fmt.Printf("[%s][%s] One or more data source fields are invalid: %s\n", ownerId, sub, validationErr))
c.JSON(http.StatusBadRequest, models.HTTPError{
Expand All @@ -196,7 +193,9 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {
return
}

if err := self.DataSourceService.Create(newDataSource); err != nil {
createdDataSource, err := self.DataSourceService.Create(newDataSource)

if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to create data source: %v\n", ownerId, sub, err))
c.JSON(http.StatusInternalServerError, models.HTTPError{
Error: "internal_server_error",
Expand All @@ -205,7 +204,7 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {
return
}

c.JSON(http.StatusOK, models.HTTPSuccess{Message: "success"})
c.JSON(http.StatusOK, createdDataSource)
}

// @BasePath /
Expand All @@ -219,7 +218,7 @@ func (self *DataSourceController) CreateDataSource(c *gin.Context) {
// @Produce json
// @Param name path string true "Data Source Name" string
// @Param Data_Source body models.UpdateRequestDataSourceBody true "Data Source"
// @Success 200 {object} models.HTTPSuccess
// @Success 200 {object} models.DataSource
// @Failure 400 {object} models.HTTPError
// @Failure 401 {object} models.HTTPError
// @Router /datasources/{name} [put]
Expand Down Expand Up @@ -257,7 +256,7 @@ func (self *DataSourceController) UpdateDataSourceByName(c *gin.Context) {

updateDataSource.UpdatedAt = time.Now()

err := self.DataSourceService.Update(name, ownerId, updateDataSource)
updatedDataSource, err := self.DataSourceService.Update(name, ownerId, updateDataSource)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] failed to update data source: %v\n", ownerId, sub, err))
c.JSON(http.StatusBadRequest, models.HTTPError{
Expand All @@ -267,7 +266,7 @@ func (self *DataSourceController) UpdateDataSourceByName(c *gin.Context) {
return
}

c.JSON(http.StatusOK, models.HTTPSuccess{Message: "success"})
c.JSON(http.StatusOK, updatedDataSource)
}

// @BasePath /
Expand All @@ -280,7 +279,7 @@ func (self *DataSourceController) UpdateDataSourceByName(c *gin.Context) {
// @Accept json
// @Produce json
// @Param name path string true "Data Source Name" string
// @Success 200 {object} models.HTTPSuccess
// @Success 200 {object} models.DataSource
// @Failure 400 {object} models.HTTPError
// @Failure 401 {object} models.HTTPError
// @Failure 500 {object} models.HTTPError
Expand All @@ -306,7 +305,7 @@ func (self *DataSourceController) DeleteDataSourceByName(c *gin.Context) {

name := c.Param("name")

err := self.DataSourceService.Delete(name, ownerId)
deletedDataSource, err := self.DataSourceService.Delete(name, ownerId)
if err != nil {
logger.Error.Println(fmt.Printf("[%s][%s] Failed to delete data source: %v\n", ownerId, sub, err))
c.JSON(http.StatusBadRequest, models.HTTPError{
Expand All @@ -316,7 +315,7 @@ func (self *DataSourceController) DeleteDataSourceByName(c *gin.Context) {
return
}

c.JSON(http.StatusOK, models.HTTPSuccess{Message: "success"})
c.JSON(http.StatusOK, deletedDataSource)
}

func (self *DataSourceController) RegisterDataSourceRoutes(rg *gin.RouterGroup) {
Expand Down
9 changes: 6 additions & 3 deletions internal/middlewares/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middlewares

import (
"os"
"strings"

jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/validator"
Expand All @@ -19,12 +20,14 @@ func SetVarsToContext() gin.HandlerFunc {
claims := c.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims)
customClaims := claims.CustomClaims.(*CustomClaims)

sub := strings.Split(claims.RegisteredClaims.Subject, "|")[1]

if customClaims.OrganizationId != "" {
c.Set("owner_id", customClaims.OrganizationId)
c.Set("sub", claims.RegisteredClaims.Subject)
c.Set("sub", sub)
} else {
c.Set("owner_id", claims.RegisteredClaims.Subject)
c.Set("sub", claims.RegisteredClaims.Subject)
c.Set("owner_id", sub)
c.Set("sub", sub)
}
} else {
c.Set("owner_id", "poligono")
Expand Down
4 changes: 0 additions & 4 deletions internal/models/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ type ImproveQueryBody struct {
Execute bool `json:"execute"`
}

type QueryResult struct {
QueryMarkdown string `json:"query"`
}

type GenerateQueryActivity struct {
ID primitive.ObjectID `json:"-" bson:"_id,omitempty"`
Prompt string `json:"prompt" bson:"prompt"`
Expand Down
2 changes: 1 addition & 1 deletion internal/models/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
)

type DataSource struct {
ID primitive.ObjectID `json:"-" bson:"omitempty,_id" swaggerignore:"true"`
ID primitive.ObjectID `json:"-" bson:"_id" swaggerignore:"true"`
Name string `json:"name" bson:"name" validate:"required"`
OrganizationId string `json:"organization_id" bson:"organization_id" validate:"required" swaggerignore:"true"`
CreatedBy string `json:"-" bson:"created_by" validate:"required"`
Expand Down
22 changes: 8 additions & 14 deletions internal/models/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@ type SQLSchema struct {
Name string `json:"name"`
}

type Schemas struct {
Schemas []Schema
}

type Schema struct {
ID primitive.ObjectID `json:"-" bson:"omitempty,_id" yaml:"-"`
Name string `json:"name" bson:"name"`
OrganizationId string `json:"organization_id" bson:"organization_id" yaml:"-"`
DataSourceName string `json:"data_source_name" bson:"data_source_name" yaml:"-"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
ID primitive.ObjectID `json:"-" bson:"_id" yaml:"-"`
Name string `json:"name" bson:"name"`
DataSourceId primitive.ObjectID `json:"data_source_id" bson:"data_source_id" yaml:"-"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
}

type UpdateSchema struct {
Name string `json:"name" bson:"name"`
DataSourceName string `json:"data_source_name" bson:"data_source_name" yaml:"-"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
Name string `json:"name" bson:"name"`
Description string `json:"description" bson:"description" yaml:"description,omitempty"`
Tables []Table `json:"tables" bson:"tables"`
}

type SQLTable struct {
Expand Down
Loading

0 comments on commit 878d2a6

Please sign in to comment.