Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assigner: implements task assignment for scans and dummy mock test #95

Merged
merged 7 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"github.com/shinobistack/gokakashi/internal/assigner"
"github.com/shinobistack/gokakashi/internal/db"
"log"
"os"
Expand Down Expand Up @@ -103,7 +104,9 @@ func handleConfigV1() {
// Populate the database
db.PopulateDatabase(configDB, cfg)

// log.Println("Shutting down goKakashi gracefully...")
// ToDo: To be go routine who independently and routinely checks and assigns scans in agentTasks table
go assigner.StartAssigner(cfg.Site.Host, cfg.Site.Port, cfg.Site.APIToken, 1*time.Minute)

}

func handleConfigV0() {
Expand Down
202 changes: 202 additions & 0 deletions internal/assigner/assigner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package assigner

import (
"bytes"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/shinobistack/gokakashi/internal/restapi/v1/agents"
"github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks"
"github.com/shinobistack/gokakashi/internal/restapi/v1/scans"
"log"
"net/http"
"net/url"
"strings"
"time"
)

func normalizeServer(server string) string {
if !strings.HasPrefix(server, "http://") && !strings.HasPrefix(server, "https://") {
server = "http://" + server // Default to HTTP
}
return server
}

func constructURL(server string, port int, path string) string {
base := normalizeServer(server)
u, err := url.Parse(base)
if err != nil {
log.Fatalf("Invalid server URL: %s", base)
}
if u.Port() == "" {
u.Host = fmt.Sprintf("%s:%d", u.Host, port)
}
u.Path = path
return u.String()
}

func StartAssigner(server string, port int, token string, interval time.Duration) {
log.Println("Starting the periodic task assigner...")
ticker := time.NewTicker(interval)
defer ticker.Stop()

for range ticker.C {
AssignTasks(server, port, token)
}

}

func AssignTasks(server string, port int, token string) {
log.Println("Assigner now begins assigning your scans")
// Step 1: Fetch scans needing assignment
pendingScans, err := fetchPendingScans(server, port, token, "scan_pending")
if err != nil {
log.Printf("Error fetching pending scans: %v", err)
return
}

if len(pendingScans) == 0 {
log.Println("No pending scans to assign.")
return
}

// Step 2: Fetch available agents
availableAgents, err := fetchAvailableAgents(server, port, token, "connected")
if err != nil {
log.Printf("Error fetching available agents: %v", err)
return
}

if len(availableAgents) == 0 {
log.Println("No agents available for assignment.")
return
}

// log.Printf("Agents are available: %v", availableAgents)

// Step 3: Assign scans to agents
// ToDo: to explore task assignment for better efficiency
for i, scan := range pendingScans {
// Check if scan is already assigned
if isScanAssigned(server, port, token, scan.ID) {
log.Printf("Scan ID %s is already assigned. Skipping.", scan.ID)
continue
}

// Select agent using round-robin
agent := availableAgents[i%len(availableAgents)]
if err := createAgentTask(server, port, token, agent.ID, scan.ID); err != nil {
log.Printf("Failed to assign scan %s to agent %d: %v", scan.ID, agent.ID, err)
} else {
log.Printf("Successfully assigned scan %s to agent %d", scan.ID, agent.ID)
}

}
}

func fetchPendingScans(server string, port int, token, status string) ([]scans.GetScanResponse, error) {
url := constructURL(server, port, "/api/v1/scans") + fmt.Sprintf("?status=%s", status)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request for pending scans: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server responded with status: %d", resp.StatusCode)
}

var scans []scans.GetScanResponse
if err := json.NewDecoder(resp.Body).Decode(&scans); err != nil {
return nil, fmt.Errorf("failed to decode scans response: %w", err)
}

return scans, nil
}

func fetchAvailableAgents(server string, port int, token, status string) ([]agents.GetAgentResponse, error) {
url := constructURL(server, port, "/api/v1/agents") + fmt.Sprintf("?status=%s", status)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server responded with status: %d", resp.StatusCode)
}

var agents []agents.GetAgentResponse
if err := json.NewDecoder(resp.Body).Decode(&agents); err != nil {
return nil, err
}

return agents, nil
}

func isScanAssigned(server string, port int, token string, scanID uuid.UUID) bool {
url := constructURL(server, port, "/api/v1/agents/tasks") + fmt.Sprintf("?scan_id=%s", scanID)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
log.Printf("Error checking scan assignment: %v", err)
return false
}

req.Header.Set("Authorization", "Bearer "+token)

resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Printf("Error checking scan assignment: %v", err)
return false
}
defer resp.Body.Close()

return resp.StatusCode == http.StatusOK
}

func createAgentTask(server string, port int, token string, agentID int, scanID uuid.UUID) error {
url := constructURL(server, port, fmt.Sprintf("/api/v1/agents/%d/tasks", agentID))

reqBody := agenttasks.CreateAgentTaskRequest{
AgentID: agentID,
ScanID: scanID,
Status: "pending",
CreatedAt: time.Now(),
}

reqBodyJSON, _ := json.Marshal(reqBody)

req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBodyJSON))
if err != nil {
return err
}

req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
return fmt.Errorf("server responded with status: %d", resp.StatusCode)
}

return nil
}
129 changes: 129 additions & 0 deletions internal/assigner/assigner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package assigner_test

import (
"encoding/json"
"github.com/google/uuid"
"github.com/shinobistack/gokakashi/internal/assigner"
"net/http"
"net/http/httptest"
"testing"
)

type MockScan struct {
ID uuid.UUID `json:"id"`
Status string `json:"status"`
}

type MockAgent struct {
ID int `json:"id"`
Status string `json:"status"`
}

func TestAssignTasks(t *testing.T) {
// Mock data
mockScans := []MockScan{
{ID: uuid.New(), Status: "scan_pending"},
{ID: uuid.New(), Status: "scan_pending"},
{ID: uuid.New(), Status: "scan_pending"},
{ID: uuid.New(), Status: "scan_pending"},
{ID: uuid.New(), Status: "scan_pending"},
{ID: uuid.New(), Status: "scan_pending"},
{ID: uuid.New(), Status: "scan_pending"},
{ID: uuid.New(), Status: "scan_pending"},
}
mockAgents := []MockAgent{
{ID: 1, Status: "connected"},
{ID: 2, Status: "connected"},
}

// Mock server
scanHandler := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/scans" && r.URL.Query().Get("status") == "scan_pending" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(mockScans); err != nil {
http.Error(w, "Failed to encode mock scans", http.StatusInternalServerError)
return
}
} else if r.URL.Path == "/api/v1/agents" && r.URL.Query().Get("status") == "connected" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(mockAgents); err != nil {
http.Error(w, "Failed to encode mock agents", http.StatusInternalServerError)
return
}
} else if r.URL.Path == "/api/v1/agents/tasks" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusNotFound)
}
}

mockServer := httptest.NewServer(http.HandlerFunc(scanHandler))
defer mockServer.Close()

// Run the assigner logic
assigner.AssignTasks(mockServer.URL, 0, "mock-token")

t.Log("Ensure tasks are assigned to agents in round-robin fashion.")
}

func TestAssignTasksWithNoAgents(t *testing.T) {
mockScans := []MockScan{
{ID: uuid.New(), Status: "scan_pending"},
}

scanHandler := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/scans" && r.URL.Query().Get("status") == "scan_pending" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(mockScans); err != nil {
http.Error(w, "Failed to encode mock scans", http.StatusInternalServerError)
return
}
} else if r.URL.Path == "/api/v1/agents" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode([]MockAgent{}); err != nil {
http.Error(w, "Failed to encode mock agents", http.StatusInternalServerError)
return
} // No agents available
} else {
w.WriteHeader(http.StatusNotFound)
}
}

mockServer := httptest.NewServer(http.HandlerFunc(scanHandler))
defer mockServer.Close()

assigner.AssignTasks(mockServer.URL, 0, "mock-token")

t.Log("Ensure no assignments are made when no agents are available.")
}

func TestAssignTasksWithNoScans(t *testing.T) {
mockAgents := []MockAgent{
{ID: 1, Status: "connected"},
}

scanHandler := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/scans" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode([]MockScan{}); err != nil {
http.Error(w, "Failed to encode mock scans", http.StatusInternalServerError)
return
} // No scans available
} else if r.URL.Path == "/api/v1/agents" && r.URL.Query().Get("status") == "connected" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(mockAgents); err != nil {
http.Error(w, "Failed to encode mock agents", http.StatusInternalServerError)
return
}
} else {
w.WriteHeader(http.StatusNotFound)
}
}

mockServer := httptest.NewServer(http.HandlerFunc(scanHandler))
defer mockServer.Close()

assigner.AssignTasks(mockServer.URL, 0, "mock-token")

t.Log("Ensure no assignments are made when no scans are pending.")
}
10 changes: 9 additions & 1 deletion internal/restapi/v1/agents/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"github.com/shinobistack/gokakashi/ent"
"github.com/shinobistack/gokakashi/ent/agents"
"github.com/swaggest/usecase/status"
"time"
)
Expand All @@ -24,6 +25,7 @@ type ListAgentsResponse struct {
}

type PollAgentsRequest struct {
Status string `query:"status"`
}

type PollAgentsResponse struct {
Expand Down Expand Up @@ -78,7 +80,13 @@ func GetAgent(client *ent.Client) func(ctx context.Context, req GetAgentRequest,

func PollAgents(client *ent.Client) func(ctx context.Context, req PollAgentsRequest, res *[]PollAgentsResponse) error {
return func(ctx context.Context, req PollAgentsRequest, res *[]PollAgentsResponse) error {
agentsList, err := client.Agents.Query().All(ctx)
query := client.Agents.Query()

if req.Status != "" {
query = query.Where(agents.Status(req.Status))
}

agentsList, err := query.All(ctx)
if err != nil {
return status.Wrap(err, status.Internal)
}
Expand Down
Loading
Loading