Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and enhance pipeline with batch processing #607

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions cmd/tls/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"strconv"
"time"

"github.com/jmpsec/osctrl/pkg/metrics"
"github.com/jmpsec/osctrl/pkg/carves"
"github.com/jmpsec/osctrl/pkg/environments"
"github.com/jmpsec/osctrl/pkg/logging"
"github.com/jmpsec/osctrl/pkg/metrics"
"github.com/jmpsec/osctrl/pkg/nodes"
"github.com/jmpsec/osctrl/pkg/queries"
"github.com/jmpsec/osctrl/pkg/settings"
Expand Down Expand Up @@ -90,16 +90,17 @@ var validPlatform = map[string]bool{

// HandlersTLS to keep all handlers for TLS
type HandlersTLS struct {
Envs *environments.Environment
EnvsMap *environments.MapEnvironments
Nodes *nodes.NodeManager
Tags *tags.TagManager
Queries *queries.Queries
Carves *carves.Carves
Settings *settings.Settings
SettingsMap *settings.MapSettings
Metrics *metrics.Metrics
Logs *logging.LoggerTLS
Envs *environments.Environment
EnvsMap *environments.MapEnvironments
Nodes *nodes.NodeManager
Tags *tags.TagManager
Queries *queries.Queries
Carves *carves.Carves
Settings *settings.Settings
SettingsMap *settings.MapSettings
Metrics *metrics.Metrics
Logs *logging.LoggerTLS
WriteHandler *batchWriter
}

// TLSResponse to be returned to requests
Expand Down Expand Up @@ -186,6 +187,9 @@ func CreateHandlersTLS(opts ...Option) *HandlersTLS {
for _, opt := range opts {
opt(h)
}
// All these opt function need be refactored to reduce unnecessary complexity
// For now, we hardcode the values for testing
h.WriteHandler = newBatchWriter(50, time.Minute, *h.Nodes)
return h
}

Expand Down
64 changes: 20 additions & 44 deletions cmd/tls/handlers/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,12 @@ func (h *HandlersTLS) ConfigHandler(w http.ResponseWriter, r *http.Request) {
utils.HTTPResponse(w, "", http.StatusInternalServerError, []byte(""))
return
}
// Check if provided node_key is valid and if so, update node
// We need to update the node info in another go routine
if node, err := h.Nodes.GetByKey(t.NodeKey); err == nil {
ip := utils.GetIP(r)
if err := h.Nodes.RecordIPAddress(ip, node); err != nil {
h.Inc(metricConfigErr)
log.Err(err).Msg("error recording IP address")
}
// Refresh last config for node
if err := h.Nodes.ConfigRefresh(node, ip, len(body)); err != nil {
h.Inc(metricConfigErr)
log.Err(err).Msg("error refreshing last config")
}
h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip})
log.Debug().Msgf("node-uuid: %s with nodeid %d added to batch writer for config update", node.UUID, node.ID)

// Record ingested data
requestSize.WithLabelValues(string(env.UUID), "ConfigHandler").Observe(float64(len(body)))
log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for ConfigHandler endpoint", node.UUID, env.Name, len(body))
Expand Down Expand Up @@ -328,22 +322,17 @@ func (h *HandlersTLS) QueryReadHandler(w http.ResponseWriter, r *http.Request) {
// Record ingested data
requestSize.WithLabelValues(string(env.UUID), "QueryRead").Observe(float64(len(body)))
log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for QueryReadHandler endpoint", node.UUID, env.Name, len(body))
ip := utils.GetIP(r)
if err := h.Nodes.RecordIPAddress(ip, node); err != nil {
h.Inc(metricReadErr)
log.Err(err).Msg("error recording IP address")
}

nodeInvalid = false
qs, accelerate, err = h.Queries.NodeQueries(node)
if err != nil {
h.Inc(metricReadErr)
log.Err(err).Msg("error getting queries from db")
}
// Refresh last query read request
if err := h.Nodes.QueryReadRefresh(node, ip, len(body)); err != nil {
h.Inc(metricReadErr)
log.Err(err).Msg("error refreshing last query read")
}
// Refresh node last seen
ip := utils.GetIP(r)
Copy link
Preview

Copilot AI Mar 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] In the QueryReadHandler, utils.GetIP(r) is called a second time even though it was already obtained earlier. Consider capturing the IP address once and reusing it to avoid potential inconsistencies.

Suggested change
ip := utils.GetIP(r)

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip})
log.Debug().Msgf("node-uuid: %s with nodeid %d added to batch writer for query read update", node.UUID, node.ID)
} else {
log.Err(err).Msg("GetByKey")
nodeInvalid = true
Expand Down Expand Up @@ -413,11 +402,7 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request)
// Record ingested data
requestSize.WithLabelValues(string(env.UUID), "QueryWrite").Observe(float64(len(body)))
log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for QueryWriteHandler endpoint", node.UUID, env.Name, len(body))
ip := utils.GetIP(r)
if err := h.Nodes.RecordIPAddress(ip, node); err != nil {
h.Inc(metricWriteErr)
log.Err(err).Msg("error recording IP address")
}

nodeInvalid = false
for name, c := range t.Queries {
var carves []types.QueryCarveScheduled
Expand All @@ -432,10 +417,9 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request)
}
}
}
if err := h.Nodes.QueryWriteRefresh(node, ip, len(body)); err != nil {
h.Inc(metricWriteErr)
log.Err(err).Msg("error refreshing last query write")
}
// Refresh node last seen
ip := utils.GetIP(r)
h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip})
// Process submitted results and mark query as processed
go h.Logs.ProcessLogQueryResult(t, env.ID, (*h.EnvsMap)[env.Name].DebugHTTP)
} else {
Expand Down Expand Up @@ -668,11 +652,7 @@ func (h *HandlersTLS) CarveInitHandler(w http.ResponseWriter, r *http.Request) {
// Record ingested data
requestSize.WithLabelValues(string(env.UUID), "CarveInit").Observe(float64(len(body)))
log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for CarveInitHandler endpoint", node.UUID, env.Name, len(body))
ip := utils.GetIP(r)
if err := h.Nodes.RecordIPAddress(ip, node); err != nil {
h.Inc(metricInitErr)
log.Err(err).Msg("error recording IP address")
}

initCarve = true
carveSessionID = generateCarveSessionID()
// Process carve init
Expand All @@ -681,11 +661,9 @@ func (h *HandlersTLS) CarveInitHandler(w http.ResponseWriter, r *http.Request) {
log.Err(err).Msg("error procesing carve init")
initCarve = false
}
// Refresh last carve request
if err := h.Nodes.CarveRefresh(node, ip, len(body)); err != nil {
h.Inc(metricInitErr)
log.Err(err).Msg("error refreshing last carve init")
}
// Refresh last seen
ip := utils.GetIP(r)
h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip})
}
// Prepare response
response := types.CarveInitResponse{Success: initCarve, SessionID: carveSessionID}
Expand Down Expand Up @@ -748,11 +726,9 @@ func (h *HandlersTLS) CarveBlockHandler(w http.ResponseWriter, r *http.Request)
blockCarve = true
// Process received block
go h.ProcessCarveBlock(t, env.Name, carve.UUID, env.ID)
// Refresh last carve request
if err := h.Nodes.CarveRefreshByUUID(carve.UUID, utils.GetIP(r), len(body)); err != nil {
h.Inc(metricBlockErr)
log.Err(err).Msg("error refreshing last carve init")
}
// Refresh last seen
ip := utils.GetIP(r)
h.WriteHandler.addEvent(writeEvent{NodeID: carve.NodeID, IP: ip})
}
// Prepare response
response := types.CarveBlockResponse{Success: blockCarve}
Expand Down
93 changes: 93 additions & 0 deletions cmd/tls/handlers/writers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package handlers

import (
"time"

"github.com/jmpsec/osctrl/pkg/nodes"
"github.com/rs/zerolog/log"
)

// writeEvent represents a single update request.
type writeEvent struct {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either we have a typed event for each update, or something generic enough that can be used for any updates

NodeID uint
IP string
}

// batchWriter encapsulates the batching logic.
type batchWriter struct {
events chan writeEvent
batchSize int // minimum number of events before flushing
timeout time.Duration // maximum wait time before flushing
nodesRepo nodes.NodeManager
}

// newBatchWriter creates and starts a new batch writer.
func newBatchWriter(batchSize int, timeout time.Duration, repo nodes.NodeManager) *batchWriter {
bw := &batchWriter{
events: make(chan writeEvent, 2000), // buffer size as needed
batchSize: batchSize,
timeout: timeout,
nodesRepo: repo,
}
go bw.run()
return bw
}

// addEvent sends a new write event to the batch writer.
func (bw *batchWriter) addEvent(ev writeEvent) {
bw.events <- ev
}

// run is the background worker that collects and flushes events.
func (bw *batchWriter) run() {
batch := make(map[uint]writeEvent)
timer := time.NewTimer(bw.timeout)
defer timer.Stop()
for {
select {
case ev, ok := <-bw.events:
if !ok {
// Channel closed: flush any remaining events.
if len(batch) > 0 {
bw.flush(batch)
}
return
}
// Overwrite any existing event for the same NodeID.
batch[ev.NodeID] = ev

// Flush if we have reached the batch size threshold.
if len(batch) >= bw.batchSize {
if !timer.Stop() {
<-timer.C // drain the timer channel if necessary
}
bw.flush(batch)
batch = make(map[uint]writeEvent)
timer.Reset(bw.timeout)
}
case <-timer.C:
if len(batch) > 0 {
bw.flush(batch)
batch = make(map[uint]writeEvent)
}
timer.Reset(bw.timeout)
}
}
}

// flush performs the bulk update for a batch of events.
func (bw *batchWriter) flush(batch map[uint]writeEvent) {

nodeIDs := make([]uint, 0, len(batch))
for _, ev := range batch {
nodeIDs = append(nodeIDs, ev.NodeID)

// TODO: Implement the actual update logic.
// Update the node's IP address.
// Since the IP address changes infrequently, no need to update in bulk.
}
log.Info().Int("count", len(batch)).Msg("flushed batch")
if err := bw.nodesRepo.RefreshLastSeenBatch(nodeIDs); err != nil {
log.Err(err).Msg("refreshing last seen batch failed")
}
}
17 changes: 0 additions & 17 deletions pkg/logging/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,6 @@ func (l *LoggerTLS) DispatchLogs(data []byte, uuid, logType, environment string,
log.Debug().Msgf("dispatching logs to %s", l.Logging)
}
l.Log(logType, data, environment, uuid, debug)
// Refresh last logging request
if logType == types.StatusLog {
// Update metadata for node
if err := l.Nodes.RefreshLastStatus(uuid); err != nil {
log.Err(err).Msg("error refreshing last status")
}
}
if logType == types.ResultLog {
// Update metadata for node
if err := l.Nodes.RefreshLastResult(uuid); err != nil {
log.Err(err).Msg("error refreshing last result")
}
}
}

// DispatchQueries - Helper to dispatch queries
Expand All @@ -42,10 +29,6 @@ func (l *LoggerTLS) DispatchQueries(queryData types.QueryWriteData, node nodes.O
if err != nil {
log.Err(err).Msg("error preparing data")
}
// Refresh last query write request
if err := l.Nodes.RefreshLastQueryWrite(node.UUID); err != nil {
log.Err(err).Msg("error refreshing last query write")
}
// Send data to storage
// FIXME allow multiple types of logging
if debug {
Expand Down
Loading
Loading