diff --git a/cmd/forward.go b/cmd/forward.go index 9bc7ce24..e69825bf 100644 --- a/cmd/forward.go +++ b/cmd/forward.go @@ -31,6 +31,11 @@ var forwardCmd = &cobra.Command{ return fmt.Errorf("the forward target server host at position %d is identical to the source server host %s", i, sourceUrl.Host) } } + for _, isTrino := range forward.PrestoFlagsArray.IsTrino { + if isTrino { + return fmt.Errorf("forward command does not support Trino yet") + } + } return nil }, Short: "Watch incoming query workloads from the first Presto cluster (cluster 0) and forward them to the rest clusters.", @@ -39,6 +44,7 @@ var forwardCmd = &cobra.Command{ func init() { RootCmd.AddCommand(forwardCmd) forward.PrestoFlagsArray.Install(forwardCmd) + _ = forwardCmd.Flags().MarkHidden("trino") wd, _ := os.Getwd() forwardCmd.Flags().StringVarP(&forward.OutputPath, "output-path", "o", wd, "Output directory path") forwardCmd.Flags().StringVarP(&forward.RunName, "name", "n", fmt.Sprintf("forward_%s", time.Now().Format(utils.DirectoryNameTimeFormat)), `Assign a name to this run. (default: "forward_")`) diff --git a/cmd/forward/main.go b/cmd/forward/main.go index 296c6b96..9d58639d 100644 --- a/cmd/forward/main.go +++ b/cmd/forward/main.go @@ -2,12 +2,17 @@ package forward import ( "context" - "fmt" "github.com/spf13/cobra" + "net/http" + "os" + "os/signal" + "path/filepath" "pbench/log" "pbench/presto" "pbench/utils" "sync" + "sync/atomic" + "syscall" "time" ) @@ -17,49 +22,141 @@ var ( RunName string PollInterval time.Duration - runningTasks sync.WaitGroup + runningTasks sync.WaitGroup + failedToForward atomic.Uint32 + forwarded atomic.Uint32 ) -type QueryHistory struct { - QueryId string `presto:"query_id"` - Query string `presto:"query"` - Created *time.Time `presto:"created"` -} - func Run(_ *cobra.Command, _ []string) { - //OutputPath = filepath.Join(OutputPath, RunName) - //utils.PrepareOutputDirectory(OutputPath) - // - //// also start to write logs to the output directory from this point on. - //logPath := filepath.Join(OutputPath, "forward.log") - //flushLog := utils.InitLogFile(logPath) - //defer flushLog() + OutputPath = filepath.Join(OutputPath, RunName) + utils.PrepareOutputDirectory(OutputPath) - prestoClusters := PrestoFlagsArray.Assemble() + // also start to write logs to the output directory from this point on. + logPath := filepath.Join(OutputPath, "forward.log") + flushLog := utils.InitLogFile(logPath) + defer flushLog() + + ctx, cancel := context.WithCancel(context.Background()) + timeToExit := make(chan os.Signal, 1) + signal.Notify(timeToExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + // Handle SIGINT, SIGTERM, and SIGQUIT. When ctx is canceled, in-progress MySQL transactions and InfluxDB operations will roll back. + go func() { + sig := <-timeToExit + if sig != nil { + log.Info().Msg("abort forwarding") + cancel() + } + }() + + prestoClusters := PrestoFlagsArray.Pivot() // The design here is to forward the traffic from cluster 0 to the rest. sourceClusterSize := 0 clients := make([]*presto.Client, 0, len(prestoClusters)) for i, cluster := range prestoClusters { clients = append(clients, cluster.NewPrestoClient()) - if stats, _, err := clients[i].GetClusterInfo(context.Background()); err != nil { - log.Fatal().Err(err).Msgf("cannot connect to cluster at position %d", i) + // Check if we can connect to the cluster. + if stats, _, err := clients[i].GetClusterInfo(ctx); err != nil { + log.Fatal().Err(err).Msgf("cannot connect to cluster at position %d: %s", i, cluster.ServerUrl) } else if i == 0 { sourceClusterSize = stats.ActiveWorkers } else if stats.ActiveWorkers != sourceClusterSize { - log.Warn().Msgf("source cluster size does not match target cluster %d size (%d != %d)", i, stats.ActiveWorkers, sourceClusterSize) + log.Warn().Msgf("the source cluster and target cluster %d do not match in size (%d != %d)", i, sourceClusterSize, stats.ActiveWorkers) } } sourceClient := clients[0] trueValue := true - states, _, err := sourceClient.GetQueryState(context.Background(), &presto.GetQueryStatsOptions{ - IncludeAllQueries: &trueValue, - IncludeAllQueryProgressStats: nil, - ExcludeResourceGroupPathInfo: nil, - QueryTextSizeLimit: nil, - }) - if err != nil { - log.Fatal().Err(err).Msgf("cannot get query states") + // lastQueryStateCheckCutoffTime is the query create time of the most recent query in the previous batch. + // We only look at queries created later than this timestamp in the following batch. + lastQueryStateCheckCutoffTime := time.Time{} + firstBatch := true + // Keep running until the source cluster becomes unavailable or the user interrupts or quits using Ctrl + C or Ctrl + D. + for ctx.Err() == nil { + states, _, err := sourceClient.GetQueryState(ctx, &presto.GetQueryStatsOptions{IncludeAllQueries: &trueValue}) + if err != nil { + log.Error().Err(err).Msgf("failed to get query states") + break + } + newCutoffTime := time.Time{} + for _, state := range states { + if !state.CreateTime.After(lastQueryStateCheckCutoffTime) { + // We looked at this query in the previous batch. + continue + } + if newCutoffTime.Before(state.CreateTime) { + newCutoffTime = state.CreateTime + } + if !firstBatch { + runningTasks.Add(1) + go forwardQuery(ctx, &state, clients) + } + } + firstBatch = false + if newCutoffTime.After(lastQueryStateCheckCutoffTime) { + lastQueryStateCheckCutoffTime = newCutoffTime + } + timer := time.NewTimer(PollInterval) + select { + case <-ctx.Done(): + case <-timer.C: + } + } + runningTasks.Wait() + // This causes the signal handler to exit. + close(timeToExit) + log.Info().Uint32("forwarded", forwarded.Load()).Uint32("failed_to_forward", failedToForward.Load()). + Msgf("finished forwarding queries") +} + +func forwardQuery(ctx context.Context, queryState *presto.QueryStateInfo, clients []*presto.Client) { + defer runningTasks.Done() + queryInfo, _, queryInfoErr := clients[0].GetQueryInfo(ctx, queryState.QueryId, false, nil) + if queryInfoErr != nil { + log.Error().Str("query_id", queryState.QueryId).Err(queryInfoErr).Msg("failed to get query info for forwarding") + failedToForward.Add(1) + return + } + SessionPropertyHeader := clients[0].GenerateSessionParamsHeaderValue(queryInfo.Session.CollectSessionProperties()) + successful, failed := atomic.Uint32{}, atomic.Uint32{} + forwardedQueries := sync.WaitGroup{} + for i := 1; i < len(clients); i++ { + forwardedQueries.Add(1) + go func(client *presto.Client) { + defer forwardedQueries.Done() + clientResult, _, queryErr := client.Query(ctx, queryInfo.Query, func(req *http.Request) { + if queryInfo.Session.Catalog != nil { + req.Header.Set(presto.CatalogHeader, *queryInfo.Session.Catalog) + } + if queryInfo.Session.Schema != nil { + req.Header.Set(presto.SchemaHeader, *queryInfo.Session.Schema) + } + req.Header.Set(presto.SessionHeader, SessionPropertyHeader) + req.Header.Set(presto.SourceHeader, queryInfo.QueryId) + }) + if queryErr != nil { + log.Error().Str("source_query_id", queryInfo.QueryId). + Str("target_host", client.GetHost()).Err(queryErr).Msg("failed to execute query") + failed.Add(1) + return + } + rowCount := 0 + drainErr := clientResult.Drain(ctx, func(qr *presto.QueryResults) error { + rowCount += len(qr.Data) + return nil + }) + if drainErr != nil { + log.Error().Str("source_query_id", queryInfo.QueryId). + Str("target_host", client.GetHost()).Err(drainErr).Msg("failed to fetch query result") + failed.Add(1) + return + } + successful.Add(1) + log.Info().Str("source_query_id", queryInfo.QueryId). + Str("target_host", client.GetHost()).Int("row_count", rowCount).Msg("query executed successfully") + }(clients[i]) } - fmt.Printf("%#v", states) + forwardedQueries.Wait() + log.Info().Str("source_query_id", queryInfo.QueryId).Uint32("successful", successful.Load()). + Uint32("failed", failed.Load()).Msg("query forwarding finished") + forwarded.Add(1) } diff --git a/cmd/replay.go b/cmd/replay.go index 9595514b..aa956b21 100644 --- a/cmd/replay.go +++ b/cmd/replay.go @@ -24,6 +24,9 @@ We also expect the queries in this CSV file are sorted by "create_time" in ascen } utils.ExpandHomeDirectory(&replay.OutputPath) utils.ExpandHomeDirectory(&args[0]) + if replay.PrestoFlags.IsTrino { + return fmt.Errorf("replay command does not support Trino yet") + } return nil }, Run: replay.Run, @@ -33,6 +36,7 @@ func init() { RootCmd.AddCommand(replayCmd) wd, _ := os.Getwd() replay.PrestoFlags.Install(replayCmd) + _ = replayCmd.Flags().MarkHidden("trino") replayCmd.Flags().StringVarP(&replay.OutputPath, "output-path", "o", wd, "Output directory path") replayCmd.Flags().StringVarP(&replay.RunName, "name", "n", fmt.Sprintf("replay_%s", time.Now().Format(utils.DirectoryNameTimeFormat)), `Assign a name to this run. (default: "replay_")`) } diff --git a/presto/client.go b/presto/client.go index d54af8a6..84c4b9e2 100644 --- a/presto/client.go +++ b/presto/client.go @@ -91,6 +91,10 @@ func GenerateHttpQueryParameter(v any) string { return queryBuilder.String() } +func (c *Client) GetHost() string { + return c.serverUrl.Host +} + func (c *Client) setHeader(key, value string) { if c.isTrino { key = strings.Replace(key, "X-Presto", "X-Trino", 1) diff --git a/presto/client_test.go b/presto/client_test.go index 141e84c6..3e080e97 100644 --- a/presto/client_test.go +++ b/presto/client_test.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/stretchr/testify/assert" "pbench/presto" + "pbench/presto/query_json" "strings" "syscall" "testing" @@ -42,9 +43,14 @@ func TestQuery(t *testing.T) { assert.Equal(t, 150000, rowCount) buf := &strings.Builder{} - _, err = client.GetQueryInfo(context.Background(), qr.Id, false, buf) + var queryInfo *query_json.QueryInfo + queryInfo, _, err = client.GetQueryInfo(context.Background(), qr.Id, false, buf) assert.Nil(t, err) + assert.Nil(t, queryInfo) assert.Greater(t, buf.Len(), 0) + queryInfo, _, err = client.GetQueryInfo(context.Background(), qr.Id, true, nil) + assert.Nil(t, err) + assert.Equal(t, qr.Id, queryInfo.QueryId) } func TestGenerateQueryParameter(t *testing.T) { diff --git a/presto/query.go b/presto/query.go index 2ff7fc4e..55ee6439 100644 --- a/presto/query.go +++ b/presto/query.go @@ -4,6 +4,7 @@ import ( "context" "io" "net/http" + "pbench/presto/query_json" ) func (c *Client) requestQueryResults(ctx context.Context, req *http.Request) (*QueryResults, *http.Response, error) { @@ -49,7 +50,9 @@ func (c *Client) CancelQuery(ctx context.Context, nextUri string, opts ...Reques return c.requestQueryResults(ctx, req) } -func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, writer io.Writer, opts ...RequestOption) (*http.Response, error) { +// GetQueryInfo retrieves the query JSON for the given query ID. +// If writer is nil, we return deserialized QueryInfo. Otherwise, we just return the raw buffer. +func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, writer io.Writer, opts ...RequestOption) (*query_json.QueryInfo, *http.Response, error) { urlStr := "v1/query/" + queryId if pretty { urlStr += "?pretty" @@ -57,11 +60,20 @@ func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, req, err := c.NewRequest("GET", urlStr, nil, opts...) if err != nil { - return nil, err + return nil, nil, err + } + var ( + resp *http.Response + queryInfo *query_json.QueryInfo + ) + if writer != nil { + resp, err = c.Do(ctx, req, writer) + } else { + queryInfo = new(query_json.QueryInfo) + resp, err = c.Do(ctx, req, queryInfo) } - resp, err := c.Do(ctx, req, writer) if err != nil { - return resp, err + return nil, resp, err } - return resp, nil + return queryInfo, resp, nil } diff --git a/presto/query_json/session.go b/presto/query_json/session.go index 1d043f8f..38d07ce2 100644 --- a/presto/query_json/session.go +++ b/presto/query_json/session.go @@ -51,3 +51,19 @@ func (s *Session) PrepareForInsert() { s.SessionPropertiesJson = string(jsonBytes[:len(jsonBytes)-1]) } } + +func (s *Session) CollectSessionProperties() map[string]any { + sessionParams := make(map[string]any) + if s == nil { + return sessionParams + } + for k, v := range s.SystemProperties { + sessionParams[k] = v + } + for catalog, catalogProps := range s.CatalogProperties { + for k, v := range catalogProps { + sessionParams[catalog+"."+k] = v + } + } + return sessionParams +} diff --git a/presto/query_state.go b/presto/query_state.go index 8c575c66..dd4fa258 100644 --- a/presto/query_state.go +++ b/presto/query_state.go @@ -11,16 +11,16 @@ import ( // https://github.com/prestodb/presto/blob/master/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfo.java // Unused fields are commented out for now. type QueryStateInfo struct { - QueryId string `json:"queryId"` - QueryState string `json:"queryState"` + QueryId string `json:"queryId"` + //QueryState string `json:"queryState"` //ResourceGroupId []string `json:"resourceGroupId"` - Query string `json:"query"` - QueryTruncated bool `json:"queryTruncated"` - CreateTime time.Time `json:"createTime"` - User string `json:"user"` - Authenticated bool `json:"authenticated"` - Source string `json:"source"` - Catalog string `json:"catalog"` + //Query string `json:"query"` + //QueryTruncated bool `json:"queryTruncated"` + CreateTime time.Time `json:"createTime"` + //User string `json:"user"` + //Authenticated bool `json:"authenticated"` + //Source string `json:"source,omitempty"` + //Catalog string `json:"catalog"` //Progress struct { // ElapsedTimeMillis int `json:"elapsedTimeMillis"` // QueuedTimeMillis int `json:"queuedTimeMillis"` @@ -61,7 +61,7 @@ func (c *Client) GetQueryState(ctx context.Context, reqOpt *GetQueryStatsOptions } infoArray := make([]QueryStateInfo, 0, 16) - resp, err := c.Do(ctx, req, infoArray) + resp, err := c.Do(ctx, req, &infoArray) if err != nil { return nil, resp, err } diff --git a/stage/stage_utils.go b/stage/stage_utils.go index 9530930d..db479835 100644 --- a/stage/stage_utils.go +++ b/stage/stage_utils.go @@ -270,7 +270,7 @@ func (s *Stage) saveQueryJsonFile(result *QueryResult) { checkErr(err) if err == nil { // We need to save the query json file even if the stage context is canceled. - _, err = s.Client.GetQueryInfo(utils.GetCtxWithTimeout(time.Second*5), result.QueryId, false, queryJsonFile) + _, _, err = s.Client.GetQueryInfo(utils.GetCtxWithTimeout(time.Second*5), result.QueryId, false, queryJsonFile) checkErr(err) checkErr(queryJsonFile.Close()) } diff --git a/utils/presto_flags.go b/utils/presto_flags.go index fee56d62..91020d67 100644 --- a/utils/presto_flags.go +++ b/utils/presto_flags.go @@ -56,7 +56,8 @@ func (a *PrestoFlagsArray) Install(cmd *cobra.Command) { cmd.Flags().StringArrayVarP(&a.Password, "password", "p", []string{""}, "Presto user password (optional)") } -func (a *PrestoFlagsArray) Assemble() []PrestoFlags { +// Pivot generates PrestoFlags array that is suitable for creating Presto clients conveniently. +func (a *PrestoFlagsArray) Pivot() []PrestoFlags { ret := make([]PrestoFlags, 0, len(a.ServerUrl)) for _, url := range a.ServerUrl { ret = append(ret, PrestoFlags{