From 00846cd7d8edac8979adfd4193289530ee50b909 Mon Sep 17 00:00:00 2001 From: zhuoyuan-liu <117184318+zhuoyuan-liu@users.noreply.github.com> Date: Fri, 28 Feb 2025 10:44:32 +0100 Subject: [PATCH] Refactor and enhance pipeline with batch processing (#15) * Implement batch processing for node updates and refactor IP address handling * Remove history IP table * Remove unused function * remove history name tables * Remove unused GetMetadata function from NodeMetadata * Keep remove unused function --- cmd/tls/handlers/handlers.go | 26 +++-- cmd/tls/handlers/post.go | 64 ++++------- cmd/tls/handlers/writers.go | 93 ++++++++++++++++ pkg/logging/dispatch.go | 17 --- pkg/nodes/ipaddress.go | 117 -------------------- pkg/nodes/metadata.go | 17 --- pkg/nodes/names.go | 205 ----------------------------------- pkg/nodes/nodes.go | 204 +--------------------------------- 8 files changed, 131 insertions(+), 612 deletions(-) create mode 100644 cmd/tls/handlers/writers.go delete mode 100644 pkg/nodes/ipaddress.go delete mode 100644 pkg/nodes/names.go diff --git a/cmd/tls/handlers/handlers.go b/cmd/tls/handlers/handlers.go index 303ba087..6b6ea517 100644 --- a/cmd/tls/handlers/handlers.go +++ b/cmd/tls/handlers/handlers.go @@ -5,10 +5,10 @@ import ( "strconv" "time" - "github.com/jmpsec/osctrl/pkg/metrics" "github.com/jmpsec/osctrl/pkg/carves" "github.com/jmpsec/osctrl/pkg/environments" "github.com/jmpsec/osctrl/pkg/logging" + "github.com/jmpsec/osctrl/pkg/metrics" "github.com/jmpsec/osctrl/pkg/nodes" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" @@ -90,16 +90,17 @@ var validPlatform = map[string]bool{ // HandlersTLS to keep all handlers for TLS type HandlersTLS struct { - Envs *environments.Environment - EnvsMap *environments.MapEnvironments - Nodes *nodes.NodeManager - Tags *tags.TagManager - Queries *queries.Queries - Carves *carves.Carves - Settings *settings.Settings - SettingsMap *settings.MapSettings - Metrics *metrics.Metrics - Logs *logging.LoggerTLS + Envs *environments.Environment + EnvsMap *environments.MapEnvironments + Nodes *nodes.NodeManager + Tags *tags.TagManager + Queries *queries.Queries + Carves *carves.Carves + Settings *settings.Settings + SettingsMap *settings.MapSettings + Metrics *metrics.Metrics + Logs *logging.LoggerTLS + WriteHandler *batchWriter } // TLSResponse to be returned to requests @@ -186,6 +187,9 @@ func CreateHandlersTLS(opts ...Option) *HandlersTLS { for _, opt := range opts { opt(h) } + // All these opt function need be refactored to reduce unnecessary complexity + // For now, we hardcode the values for testing + h.WriteHandler = newBatchWriter(50, time.Minute, *h.Nodes) return h } diff --git a/cmd/tls/handlers/post.go b/cmd/tls/handlers/post.go index b4787fa5..6c65fb16 100644 --- a/cmd/tls/handlers/post.go +++ b/cmd/tls/handlers/post.go @@ -157,18 +157,12 @@ func (h *HandlersTLS) ConfigHandler(w http.ResponseWriter, r *http.Request) { utils.HTTPResponse(w, "", http.StatusInternalServerError, []byte("")) return } - // Check if provided node_key is valid and if so, update node + // We need to update the node info in another go routine if node, err := h.Nodes.GetByKey(t.NodeKey); err == nil { ip := utils.GetIP(r) - if err := h.Nodes.RecordIPAddress(ip, node); err != nil { - h.Inc(metricConfigErr) - log.Err(err).Msg("error recording IP address") - } - // Refresh last config for node - if err := h.Nodes.ConfigRefresh(node, ip, len(body)); err != nil { - h.Inc(metricConfigErr) - log.Err(err).Msg("error refreshing last config") - } + h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip}) + log.Debug().Msgf("node-uuid: %s with nodeid %d added to batch writer for config update", node.UUID, node.ID) + // Record ingested data requestSize.WithLabelValues(string(env.UUID), "ConfigHandler").Observe(float64(len(body))) log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for ConfigHandler endpoint", node.UUID, env.Name, len(body)) @@ -328,22 +322,17 @@ func (h *HandlersTLS) QueryReadHandler(w http.ResponseWriter, r *http.Request) { // Record ingested data requestSize.WithLabelValues(string(env.UUID), "QueryRead").Observe(float64(len(body))) log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for QueryReadHandler endpoint", node.UUID, env.Name, len(body)) - ip := utils.GetIP(r) - if err := h.Nodes.RecordIPAddress(ip, node); err != nil { - h.Inc(metricReadErr) - log.Err(err).Msg("error recording IP address") - } + nodeInvalid = false qs, accelerate, err = h.Queries.NodeQueries(node) if err != nil { h.Inc(metricReadErr) log.Err(err).Msg("error getting queries from db") } - // Refresh last query read request - if err := h.Nodes.QueryReadRefresh(node, ip, len(body)); err != nil { - h.Inc(metricReadErr) - log.Err(err).Msg("error refreshing last query read") - } + // Refresh node last seen + ip := utils.GetIP(r) + h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip}) + log.Debug().Msgf("node-uuid: %s with nodeid %d added to batch writer for query read update", node.UUID, node.ID) } else { log.Err(err).Msg("GetByKey") nodeInvalid = true @@ -413,11 +402,7 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request) // Record ingested data requestSize.WithLabelValues(string(env.UUID), "QueryWrite").Observe(float64(len(body))) log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for QueryWriteHandler endpoint", node.UUID, env.Name, len(body)) - ip := utils.GetIP(r) - if err := h.Nodes.RecordIPAddress(ip, node); err != nil { - h.Inc(metricWriteErr) - log.Err(err).Msg("error recording IP address") - } + nodeInvalid = false for name, c := range t.Queries { var carves []types.QueryCarveScheduled @@ -432,10 +417,9 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request) } } } - if err := h.Nodes.QueryWriteRefresh(node, ip, len(body)); err != nil { - h.Inc(metricWriteErr) - log.Err(err).Msg("error refreshing last query write") - } + // Refresh node last seen + ip := utils.GetIP(r) + h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip}) // Process submitted results and mark query as processed go h.Logs.ProcessLogQueryResult(t, env.ID, (*h.EnvsMap)[env.Name].DebugHTTP) } else { @@ -668,11 +652,7 @@ func (h *HandlersTLS) CarveInitHandler(w http.ResponseWriter, r *http.Request) { // Record ingested data requestSize.WithLabelValues(string(env.UUID), "CarveInit").Observe(float64(len(body))) log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for CarveInitHandler endpoint", node.UUID, env.Name, len(body)) - ip := utils.GetIP(r) - if err := h.Nodes.RecordIPAddress(ip, node); err != nil { - h.Inc(metricInitErr) - log.Err(err).Msg("error recording IP address") - } + initCarve = true carveSessionID = generateCarveSessionID() // Process carve init @@ -681,11 +661,9 @@ func (h *HandlersTLS) CarveInitHandler(w http.ResponseWriter, r *http.Request) { log.Err(err).Msg("error procesing carve init") initCarve = false } - // Refresh last carve request - if err := h.Nodes.CarveRefresh(node, ip, len(body)); err != nil { - h.Inc(metricInitErr) - log.Err(err).Msg("error refreshing last carve init") - } + // Refresh last seen + ip := utils.GetIP(r) + h.WriteHandler.addEvent(writeEvent{NodeID: node.ID, IP: ip}) } // Prepare response response := types.CarveInitResponse{Success: initCarve, SessionID: carveSessionID} @@ -748,11 +726,9 @@ func (h *HandlersTLS) CarveBlockHandler(w http.ResponseWriter, r *http.Request) blockCarve = true // Process received block go h.ProcessCarveBlock(t, env.Name, carve.UUID, env.ID) - // Refresh last carve request - if err := h.Nodes.CarveRefreshByUUID(carve.UUID, utils.GetIP(r), len(body)); err != nil { - h.Inc(metricBlockErr) - log.Err(err).Msg("error refreshing last carve init") - } + // Refresh last seen + ip := utils.GetIP(r) + h.WriteHandler.addEvent(writeEvent{NodeID: carve.NodeID, IP: ip}) } // Prepare response response := types.CarveBlockResponse{Success: blockCarve} diff --git a/cmd/tls/handlers/writers.go b/cmd/tls/handlers/writers.go new file mode 100644 index 00000000..8243e3de --- /dev/null +++ b/cmd/tls/handlers/writers.go @@ -0,0 +1,93 @@ +package handlers + +import ( + "time" + + "github.com/jmpsec/osctrl/pkg/nodes" + "github.com/rs/zerolog/log" +) + +// writeEvent represents a single update request. +type writeEvent struct { + NodeID uint + IP string +} + +// batchWriter encapsulates the batching logic. +type batchWriter struct { + events chan writeEvent + batchSize int // minimum number of events before flushing + timeout time.Duration // maximum wait time before flushing + nodesRepo nodes.NodeManager +} + +// newBatchWriter creates and starts a new batch writer. +func newBatchWriter(batchSize int, timeout time.Duration, repo nodes.NodeManager) *batchWriter { + bw := &batchWriter{ + events: make(chan writeEvent, 2000), // buffer size as needed + batchSize: batchSize, + timeout: timeout, + nodesRepo: repo, + } + go bw.run() + return bw +} + +// addEvent sends a new write event to the batch writer. +func (bw *batchWriter) addEvent(ev writeEvent) { + bw.events <- ev +} + +// run is the background worker that collects and flushes events. +func (bw *batchWriter) run() { + batch := make(map[uint]writeEvent) + timer := time.NewTimer(bw.timeout) + defer timer.Stop() + for { + select { + case ev, ok := <-bw.events: + if !ok { + // Channel closed: flush any remaining events. + if len(batch) > 0 { + bw.flush(batch) + } + return + } + // Overwrite any existing event for the same NodeID. + batch[ev.NodeID] = ev + + // Flush if we have reached the batch size threshold. + if len(batch) >= bw.batchSize { + if !timer.Stop() { + <-timer.C // drain the timer channel if necessary + } + bw.flush(batch) + batch = make(map[uint]writeEvent) + timer.Reset(bw.timeout) + } + case <-timer.C: + if len(batch) > 0 { + bw.flush(batch) + batch = make(map[uint]writeEvent) + } + timer.Reset(bw.timeout) + } + } +} + +// flush performs the bulk update for a batch of events. +func (bw *batchWriter) flush(batch map[uint]writeEvent) { + + nodeIDs := make([]uint, 0, len(batch)) + for _, ev := range batch { + nodeIDs = append(nodeIDs, ev.NodeID) + + // TODO: Implement the actual update logic. + // Update the node's IP address. + // Since the IP address changes infrequently, no need to update in bulk. + } + log.Info().Int("count", len(batch)).Msg("flushed batch") + if err := bw.nodesRepo.RefreshLastSeenBatch(nodeIDs); err != nil { + log.Err(err).Msg("refreshing last seen batch failed") + } +} diff --git a/pkg/logging/dispatch.go b/pkg/logging/dispatch.go index 72b0a6e3..ad79df00 100644 --- a/pkg/logging/dispatch.go +++ b/pkg/logging/dispatch.go @@ -20,19 +20,6 @@ func (l *LoggerTLS) DispatchLogs(data []byte, uuid, logType, environment string, log.Debug().Msgf("dispatching logs to %s", l.Logging) } l.Log(logType, data, environment, uuid, debug) - // Refresh last logging request - if logType == types.StatusLog { - // Update metadata for node - if err := l.Nodes.RefreshLastStatus(uuid); err != nil { - log.Err(err).Msg("error refreshing last status") - } - } - if logType == types.ResultLog { - // Update metadata for node - if err := l.Nodes.RefreshLastResult(uuid); err != nil { - log.Err(err).Msg("error refreshing last result") - } - } } // DispatchQueries - Helper to dispatch queries @@ -42,10 +29,6 @@ func (l *LoggerTLS) DispatchQueries(queryData types.QueryWriteData, node nodes.O if err != nil { log.Err(err).Msg("error preparing data") } - // Refresh last query write request - if err := l.Nodes.RefreshLastQueryWrite(node.UUID); err != nil { - log.Err(err).Msg("error refreshing last query write") - } // Send data to storage // FIXME allow multiple types of logging if debug { diff --git a/pkg/nodes/ipaddress.go b/pkg/nodes/ipaddress.go deleted file mode 100644 index 4aa2be9f..00000000 --- a/pkg/nodes/ipaddress.go +++ /dev/null @@ -1,117 +0,0 @@ -package nodes - -import ( - "fmt" - - "gorm.io/gorm" -) - -// NodeHistoryIPAddress to keep track of all IP Addresses for nodes -type NodeHistoryIPAddress struct { - gorm.Model - UUID string `gorm:"index"` - IPAddress string - Count int -} - -// UpdateIPAddress to update the node IP Address -func (n *NodeManager) UpdateIPAddress(ipaddress string, node OsqueryNode) error { - data := OsqueryNode{ - IPAddress: "", - } - if (ipaddress != "") && (ipaddress != node.IPAddress) { - data.IPAddress = ipaddress - e := NodeHistoryIPAddress{ - UUID: node.UUID, - IPAddress: ipaddress, - Count: 1, - } - if err := n.NewHistoryIPAddress(e); err != nil { - return fmt.Errorf("newNodeHistoryIPAddress %v", err) - } - if err := n.DB.Model(&node).Updates(data).Error; err != nil { - return fmt.Errorf("Updates %v", err) - } - } else { - if err := n.IncHistoryIPAddress(node.UUID, ipaddress); err != nil { - return fmt.Errorf("incNodeHistoryIPAddress %v", err) - } - } - return nil -} - -// RecordIPAddress to update and archive the node IP Address -func (n *NodeManager) RecordIPAddress(ipaddress string, node OsqueryNode) error { - if ipaddress == "" { - return nil - } - if !n.SeenIPAddress(node.UUID, ipaddress) { - e := NodeHistoryIPAddress{ - UUID: node.UUID, - IPAddress: ipaddress, - Count: 1, - } - if err := n.NewHistoryIPAddress(e); err != nil { - return fmt.Errorf("newNodeHistoryIPAddress %v", err) - } - } else { - if err := n.IncHistoryIPAddress(node.UUID, ipaddress); err != nil { - return fmt.Errorf("newNodeHistoryIPAddress %v", err) - } - } - return nil -} - -// UpdateIPAddressByUUID to update node IP Address by UUID -func (n *NodeManager) UpdateIPAddressByUUID(ipaddress, uuid string) error { - node, err := n.GetByUUID(uuid) - if err != nil { - return fmt.Errorf("getNodeByUUID %v", err) - } - return n.UpdateIPAddress(ipaddress, node) -} - -// UpdateIPAddressByKey to update node IP Address by node_key -func (n *NodeManager) UpdateIPAddressByKey(ipaddress, nodekey string) error { - node, err := n.GetByKey(nodekey) - if err != nil { - return fmt.Errorf("getNodeByKey %v", err) - } - return n.UpdateIPAddress(ipaddress, node) -} - -// NewHistoryIPAddress to insert new entry for the history of IP Addresses -func (n *NodeManager) NewHistoryIPAddress(entry NodeHistoryIPAddress) error { - if err := n.DB.Create(&entry).Error; err != nil { - return err - } - return nil -} - -// GetHistoryIPAddress to retrieve the History IP Address record by UUID and the IP Address -func (n *NodeManager) GetHistoryIPAddress(uuid, ipaddress string) (NodeHistoryIPAddress, error) { - var nodeip NodeHistoryIPAddress - if err := n.DB.Where("uuid = ? AND ip_address = ?", uuid, ipaddress).Order("updated_at").First(&nodeip).Error; err != nil { - return nodeip, err - } - return nodeip, nil -} - -// IncHistoryIPAddress to increase the count for this IP Address -func (n *NodeManager) IncHistoryIPAddress(uuid, ipaddress string) error { - nodeip, err := n.GetHistoryIPAddress(uuid, ipaddress) - if err != nil { - return fmt.Errorf("getNodeHistoryIPAddress %v", err) - } - if err := n.DB.Model(&nodeip).Update("count", nodeip.Count+1).Error; err != nil { - return fmt.Errorf("Update %v", err) - } - return nil -} - -// SeenIPAddress to check if an IP Address has been seen per node by UUID -func (n *NodeManager) SeenIPAddress(uuid, ipaddress string) bool { - var results int64 - n.DB.Model(&NodeHistoryIPAddress{}).Where("uuid = ? AND ip_address = ?", uuid, ipaddress).Count(&results) - return (results > 0) -} diff --git a/pkg/nodes/metadata.go b/pkg/nodes/metadata.go index 3d27b75d..77d55400 100644 --- a/pkg/nodes/metadata.go +++ b/pkg/nodes/metadata.go @@ -14,20 +14,3 @@ type NodeMetadata struct { PlatformVersion string BytesReceived int } - -// GetMetadata to extract the metadata struct from a node -func (n *NodeManager) GetMetadata(node OsqueryNode) NodeMetadata { - return NodeMetadata{ - IPAddress: node.IPAddress, - Username: node.Username, - OsqueryUser: node.OsqueryUser, - Hostname: node.Hostname, - Localname: node.Localname, - ConfigHash: node.ConfigHash, - DaemonHash: node.DaemonHash, - OsqueryVersion: node.OsqueryVersion, - Platform: node.Platform, - PlatformVersion: node.PlatformVersion, - BytesReceived: node.BytesReceived, - } -} diff --git a/pkg/nodes/names.go b/pkg/nodes/names.go deleted file mode 100644 index d124c4c0..00000000 --- a/pkg/nodes/names.go +++ /dev/null @@ -1,205 +0,0 @@ -package nodes - -import ( - "fmt" - - "gorm.io/gorm" -) - -// NodeHistoryHostname to keep track of all IP Addresses for nodes -type NodeHistoryHostname struct { - gorm.Model - UUID string `gorm:"index"` - Hostname string - Count int -} - -// NodeHistoryLocalname to keep track of all IP Addresses for nodes -type NodeHistoryLocalname struct { - gorm.Model - UUID string `gorm:"index"` - Localname string - Count int -} - -// NodeHistoryUsername to keep track of all usernames for nodes -type NodeHistoryUsername struct { - gorm.Model - UUID string `gorm:"index"` - Username string - Count int -} - -// NewHistoryHostname to insert new entry for the history of Hostnames -func (n *NodeManager) NewHistoryHostname(entry NodeHistoryHostname) error { - if err := n.DB.Create(&entry).Error; err != nil { - return fmt.Errorf("Create newNodeHistoryHostname %v", err) - } - return nil -} - -// NewHistoryLocalname to insert new entry for the history of Localnames -func (n *NodeManager) NewHistoryLocalname(entry NodeHistoryLocalname) error { - if err := n.DB.Create(&entry).Error; err != nil { - return fmt.Errorf("Create newNodeHistoryLocalname %v", err) - } - return nil -} - -// NewHistoryUsername to insert new entry for the history of Usernames -func (n *NodeManager) NewHistoryUsername(entry NodeHistoryUsername) error { - if err := n.DB.Create(&entry).Error; err != nil { - return fmt.Errorf("Create newNodeHistoryUsername %v", err) - } - return nil -} - -// SeenUsername to check if an username has been seen per node by UUID -func (n *NodeManager) SeenUsername(uuid, username string) bool { - var results int64 - n.DB.Model(&NodeHistoryUsername{}).Where("uuid = ? AND username = ?", uuid, username).Count(&results) - return (results > 0) -} - -// SeenHostname to check if an hostname has been seen per node by UUID -func (n *NodeManager) SeenHostname(uuid, hostname string) bool { - var results int64 - n.DB.Model(&NodeHistoryHostname{}).Where("uuid = ? AND hostname = ?", uuid, hostname).Count(&results) - return (results > 0) -} - -// SeenLocalname to check if an localname has been seen per node by UUID -func (n *NodeManager) SeenLocalname(uuid, localname string) bool { - var results int64 - n.DB.Model(&NodeHistoryLocalname{}).Where("uuid = ? AND localname = ?", uuid, localname).Count(&results) - return (results > 0) -} - -// RecordLocalname to update and archive the node localname -func (n *NodeManager) RecordLocalname(localname string, node OsqueryNode) error { - if localname == "" { - return nil - } - if !n.SeenLocalname(node.UUID, localname) { - e := NodeHistoryLocalname{ - UUID: node.UUID, - Localname: localname, - Count: 1, - } - if err := n.NewHistoryLocalname(e); err != nil { - return fmt.Errorf("newNodeHistoryLocalname %v", err) - } - } else { - if err := n.IncHistoryLocalname(node.UUID, localname); err != nil { - return fmt.Errorf("newNodeHistoryLocalname %v", err) - } - } - return nil -} - -// RecordHostname to update and archive the node hostname -func (n *NodeManager) RecordHostname(hostname string, node OsqueryNode) error { - if hostname == "" { - return nil - } - if !n.SeenLocalname(node.UUID, hostname) { - e := NodeHistoryHostname{ - UUID: node.UUID, - Hostname: hostname, - Count: 1, - } - if err := n.NewHistoryHostname(e); err != nil { - return fmt.Errorf("newNodeHistoryHostname %v", err) - } - } else { - if err := n.IncHistoryLocalname(node.UUID, hostname); err != nil { - return fmt.Errorf("newNodeHistoryHostname %v", err) - } - } - return nil -} - -// RecordUsername to update and archive the node username -func (n *NodeManager) RecordUsername(username string, node OsqueryNode) error { - if username == "" { - return nil - } - if !n.SeenUsername(node.UUID, username) { - e := NodeHistoryUsername{ - UUID: node.UUID, - Username: username, - Count: 1, - } - if err := n.NewHistoryUsername(e); err != nil { - return fmt.Errorf("newNodeHistoryUsername %v", err) - } - } else { - if err := n.IncHistoryUsername(node.UUID, username); err != nil { - return fmt.Errorf("newNodeHistoryUsername %v", err) - } - } - return nil -} - -// GetHistoryLocalname to retrieve the History localname record by UUID and the localname -func (n *NodeManager) GetHistoryLocalname(uuid, localname string) (NodeHistoryLocalname, error) { - var nodeLocalname NodeHistoryLocalname - if err := n.DB.Where("uuid = ? AND localname = ?", uuid, localname).Order("updated_at").First(&nodeLocalname).Error; err != nil { - return nodeLocalname, err - } - return nodeLocalname, nil -} - -// GetHistoryHostname to retrieve the History hostname record by UUID and the hostname -func (n *NodeManager) GetHistoryHostname(uuid, hostname string) (NodeHistoryHostname, error) { - var nodeHostname NodeHistoryHostname - if err := n.DB.Where("uuid = ? AND hostname = ?", uuid, hostname).Order("updated_at").First(&nodeHostname).Error; err != nil { - return nodeHostname, err - } - return nodeHostname, nil -} - -// GetHistoryUsername to retrieve the History username record by UUID and the username -func (n *NodeManager) GetHistoryUsername(uuid, username string) (NodeHistoryUsername, error) { - var nodeUsername NodeHistoryUsername - if err := n.DB.Where("uuid = ? AND username = ?", uuid, username).Order("updated_at").First(&nodeUsername).Error; err != nil { - return nodeUsername, err - } - return nodeUsername, nil -} - -// IncHistoryLocalname to increase the count for this localname -func (n *NodeManager) IncHistoryLocalname(uuid, localname string) error { - nodeLocalname, err := n.GetHistoryLocalname(uuid, localname) - if err != nil { - return fmt.Errorf("getNodeHistoryLocalname %v", err) - } - if err := n.DB.Model(&nodeLocalname).Update("count", nodeLocalname.Count+1).Error; err != nil { - return fmt.Errorf("Update %v", err) - } - return nil -} - -// IncHistoryUsername to increase the count for this username -func (n *NodeManager) IncHistoryUsername(uuid, username string) error { - nodeUsername, err := n.GetHistoryUsername(uuid, username) - if err != nil { - return fmt.Errorf("getNodeHistoryUsername %v", err) - } - if err := n.DB.Model(&nodeUsername).Update("count", nodeUsername.Count+1).Error; err != nil { - return fmt.Errorf("Update %v", err) - } - return nil -} - -// IncHistoryHostname to increase the count for this hostname -func (n *NodeManager) IncHistoryHostname(uuid, localname string) error { - nodeLocalname, err := n.GetHistoryHostname(uuid, localname) - if err != nil { - return fmt.Errorf("getNodeHistoryLocalname %v", err) - } - if err := n.DB.Model(&nodeLocalname).Update("count", nodeLocalname.Count+1).Error; err != nil { - return fmt.Errorf("Update %v", err) - } - return nil -} diff --git a/pkg/nodes/nodes.go b/pkg/nodes/nodes.go index a327a6ca..4c24abb4 100644 --- a/pkg/nodes/nodes.go +++ b/pkg/nodes/nodes.go @@ -105,22 +105,6 @@ func CreateNodes(backend *gorm.DB) *NodeManager { if err := backend.AutoMigrate(&ArchiveOsqueryNode{}); err != nil { log.Fatal().Msgf("Failed to AutoMigrate table (archive_osquery_nodes): %v", err) } - // table node_history_ipaddress - if err := backend.AutoMigrate(&NodeHistoryIPAddress{}); err != nil { - log.Fatal().Msgf("Failed to AutoMigrate table (node_history_ipaddress): %v", err) - } - // table node_history_hostname - if err := backend.AutoMigrate(&NodeHistoryHostname{}); err != nil { - log.Fatal().Msgf("Failed to AutoMigrate table (node_history_hostname): %v", err) - } - // table node_history_localname - if err := backend.AutoMigrate(&NodeHistoryLocalname{}); err != nil { - log.Fatal().Msgf("Failed to AutoMigrate table (node_history_localname): %v", err) - } - // table node_history_username - if err := backend.AutoMigrate(&NodeHistoryUsername{}); err != nil { - log.Fatal().Msgf("Failed to AutoMigrate table (node_history_username): %v", err) - } return n } @@ -343,30 +327,18 @@ func (n *NodeManager) UpdateMetadataByUUID(uuid string, metadata NodeMetadata) e "bytes_received": node.BytesReceived + metadata.BytesReceived, } // Record username - if err := n.RecordUsername(metadata.Username, node); err != nil { - return fmt.Errorf("RecordUsername %v", err) - } if metadata.Username != node.Username && metadata.Username != "" { updates["username"] = metadata.Username } // Record hostname - if err := n.RecordHostname(metadata.Hostname, node); err != nil { - return fmt.Errorf("RecordHostname %v", err) - } if metadata.Hostname != node.Hostname && metadata.Hostname != "" { updates["hostname"] = metadata.Hostname } // Record localname - if err := n.RecordLocalname(metadata.Localname, node); err != nil { - return fmt.Errorf("RecordLocalname %v", err) - } if metadata.Localname != node.Localname && metadata.Localname != "" { updates["localname"] = metadata.Localname } // Record IP address - if err := n.RecordIPAddress(metadata.IPAddress, node); err != nil { - return fmt.Errorf("RecordIPAddress %v", err) - } if metadata.IPAddress != node.IPAddress && metadata.IPAddress != "" { updates["ip_address"] = metadata.IPAddress } @@ -394,35 +366,6 @@ func (n *NodeManager) Create(node *OsqueryNode) error { if err := n.DB.Create(&node).Error; err != nil { return fmt.Errorf("Create %v", err) } - h := NodeHistoryHostname{ - UUID: node.UUID, - Hostname: node.Hostname, - } - if err := n.NewHistoryHostname(h); err != nil { - return fmt.Errorf("newNodeHistoryHostname %v", err) - } - l := NodeHistoryLocalname{ - UUID: node.UUID, - Localname: node.Localname, - } - if err := n.NewHistoryLocalname(l); err != nil { - return fmt.Errorf("newNodeHistoryLocalname %v", err) - } - i := NodeHistoryIPAddress{ - UUID: node.UUID, - IPAddress: node.IPAddress, - Count: 1, - } - if err := n.NewHistoryIPAddress(i); err != nil { - return fmt.Errorf("newNodeHistoryIPAddress %v", err) - } - u := NodeHistoryUsername{ - UUID: node.UUID, - Username: node.Username, - } - if err := n.NewHistoryUsername(u); err != nil { - return fmt.Errorf("newNodeHistoryUsername %v", err) - } return nil } @@ -475,57 +418,6 @@ func (n *NodeManager) ArchiveDeleteByUUID(uuid string) error { return nil } -// RefreshLastEventByUUID to refresh the last status log for this node -func (n *NodeManager) RefreshLastEventByUUID(uuid, event string) error { - node, err := n.GetByUUID(uuid) - if err != nil { - return fmt.Errorf("getNodeByUUID %v", err) - } - return n.RefreshLastEvent(node, event) -} - -// RefreshLastEventByKey to refresh the last status log for this node -func (n *NodeManager) RefreshLastEventByKey(nodeKey, event string) error { - node, err := n.GetByKey(nodeKey) - if err != nil { - return err - } - return n.RefreshLastEvent(node, event) -} - -// RefreshLastEvent to refresh the last status log for this node -func (n *NodeManager) RefreshLastEvent(node OsqueryNode, event string) error { - if err := n.DB.Model(&node).Update(event, time.Now()).Error; err != nil { - return fmt.Errorf("Update %v", err) - } - return nil -} - -// RefreshLastStatus to refresh the last status log for this node -func (n *NodeManager) RefreshLastStatus(uuid string) error { - return n.RefreshLastEventByUUID(uuid, "last_status") -} - -// RefreshLastResult to refresh the last result log for this node -func (n *NodeManager) RefreshLastResult(uuid string) error { - return n.RefreshLastEventByUUID(uuid, "last_result") -} - -// RefreshLastConfig to refresh the last configuration for this node -func (n *NodeManager) RefreshLastConfig(nodeKey string) error { - return n.RefreshLastEventByKey(nodeKey, "last_config") -} - -// RefreshLastQueryRead to refresh the last on-demand query read for this node -func (n *NodeManager) RefreshLastQueryRead(nodeKey string) error { - return n.RefreshLastEventByKey(nodeKey, "last_query_read") -} - -// RefreshLastQueryWrite to refresh the last on-demand query write for this node -func (n *NodeManager) RefreshLastQueryWrite(uuid string) error { - return n.RefreshLastEventByUUID(uuid, "last_query_write") -} - // Helper to convert an enrolled osquery node into an archived osquery node func nodeArchiveFromNode(node OsqueryNode, trigger string) ArchiveOsqueryNode { return ArchiveOsqueryNode{ @@ -559,24 +451,6 @@ func nodeArchiveFromNode(node OsqueryNode, trigger string) ArchiveOsqueryNode { } } -// IncreaseBytesByUUID to update received bytes by UUID -func (n *NodeManager) IncreaseBytesByUUID(uuid string, incBytes int) error { - node, err := n.GetByUUID(uuid) - if err != nil { - return fmt.Errorf("getNodeByUUID %v", err) - } - return n.IncreaseBytes(node, incBytes) -} - -// IncreaseBytesByKey to update received bytes by node_key -func (n *NodeManager) IncreaseBytesByKey(nodekey string, incBytes int) error { - node, err := n.GetByKey(nodekey) - if err != nil { - return fmt.Errorf("getNodeByKey %v", err) - } - return n.IncreaseBytes(node, incBytes) -} - // IncreaseBytes to update received bytes per node func (n *NodeManager) IncreaseBytes(node OsqueryNode, incBytes int) error { if err := n.DB.Model(&node).Update("bytes_received", node.BytesReceived+incBytes).Error; err != nil { @@ -585,19 +459,9 @@ func (n *NodeManager) IncreaseBytes(node OsqueryNode, incBytes int) error { return nil } -// ConfigRefresh to perform all needed update operations per node in a config request -func (n *NodeManager) ConfigRefresh(node OsqueryNode, lastIp string, incBytes int) error { - updates := map[string]interface{}{ - "last_config": time.Now(), - "bytes_received": node.BytesReceived + incBytes, - } - if lastIp != "" { - updates["ip_address"] = lastIp - } - if err := n.DB.Model(&node).Updates(updates).Error; err != nil { - return fmt.Errorf("Updates %v", err) - } - return nil +func (n *NodeManager) RefreshLastSeenBatch(nodeID []uint) error { + + return n.DB.Model(&OsqueryNode{}).Where("id IN ?", nodeID).UpdateColumn("last_config", time.Now()).Error } // MetadataRefresh to perform all needed update operations per node to keep metadata refreshed @@ -607,65 +471,3 @@ func (n *NodeManager) MetadataRefresh(node OsqueryNode, updates map[string]inter } return nil } - -// QueryReadRefresh to perform all needed update operations per node in a query read request -func (n *NodeManager) QueryReadRefresh(node OsqueryNode, lastIp string, incBytes int) error { - updates := map[string]interface{}{ - "last_query_read": time.Now(), - "bytes_received": node.BytesReceived + incBytes, - } - if lastIp != "" { - updates["ip_address"] = lastIp - } - if err := n.DB.Model(&node).Updates(updates).Error; err != nil { - return fmt.Errorf("Updates %v", err) - } - return nil -} - -// QueryWriteRefresh to perform all needed update operations per node in a query write request -func (n *NodeManager) QueryWriteRefresh(node OsqueryNode, lastIp string, incBytes int) error { - updates := map[string]interface{}{ - "last_query_write": time.Now(), - "bytes_received": node.BytesReceived + incBytes, - } - if lastIp != "" { - updates["ip_address"] = lastIp - } - if err := n.DB.Model(&node).Updates(updates).Error; err != nil { - return fmt.Errorf("Updates %v", err) - } - return nil -} - -// CarveRefresh to perform all needed update operations per node in a carve request -func (n *NodeManager) CarveRefresh(node OsqueryNode, lastIp string, incBytes int) error { - updates := map[string]interface{}{ - "bytes_received": node.BytesReceived + incBytes, - } - if lastIp != "" { - updates["ip_address"] = lastIp - } - if err := n.DB.Model(&node).Updates(updates).Error; err != nil { - return fmt.Errorf("Updates %v", err) - } - return nil -} - -// CarveRefreshByUUID to perform all needed update operations per node in a carve request -func (n *NodeManager) CarveRefreshByUUID(uuid, lastIp string, incBytes int) error { - node, err := n.GetByUUID(uuid) - if err != nil { - return fmt.Errorf("getNodeByUUID %v", err) - } - updates := map[string]interface{}{ - "bytes_received": node.BytesReceived + incBytes, - } - if lastIp != "" { - updates["ip_address"] = lastIp - } - if err := n.DB.Model(&node).Updates(updates).Error; err != nil { - return fmt.Errorf("Updates %v", err) - } - return nil -}