Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor updateStopRules and add tests #2361

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 32 additions & 106 deletions cmd/metricscollector/v1beta1/file-metricscollector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -141,21 +140,15 @@ func printMetricsFile(mFile string) {
}

func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, fileFormat commonv1beta1.FileFormat) {
// First metric is objective in metricNames array.
objMetric := strings.Split(*metricNames, ";")[0]
objType := commonv1beta1.ObjectiveType(*objectiveType)

// metricStartStep is the dict where key = metric name, value = start step.
// We should apply early stopping rule only if metric is reported at least "start_step" times.
metricStartStep := make(map[string]int)
for _, stopRule := range stopRules {
if stopRule.StartStep != 0 {
metricStartStep[stopRule.Name] = stopRule.StartStep
}
rules, err := filemc.NewRuleSet(objMetric, objType, stopRules)
if err != nil {
klog.Fatalf("NewRuleSet failed: %v", err)
}

// For objective metric we calculate best optimal value from the recorded metrics.
// This is workaround for Median Stop algorithm.
// TODO (andreyvelich): Think about it, maybe define latest, max or min strategy type in stop-rule as well ?
var optimalObjValue *float64

// Check that metric file exists.
checkMetricFile(mFile)

Expand All @@ -171,6 +164,11 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f
klog.Fatalf("Failed to create new Process from pid %v, error: %v", mainProcPid, err)
}

// Get list of regural expressions from filters.
metricRegList := filemc.GetFilterRegexpList(filters)

liveRuleMetrics := rules.LiveMetrics()

// Start watch log lines.
t, _ := tail.TailFile(mFile, tail.Config{Follow: true})
for line := range t.Lines {
Expand All @@ -180,14 +178,10 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f

switch fileFormat {
case commonv1beta1.TextFormat:
// Get list of regural expressions from filters.
var metricRegList []*regexp.Regexp
metricRegList = filemc.GetFilterRegexpList(filters)

// Check if log line contains metric from stop rules.
isRuleLine := false
for _, rule := range stopRules {
if strings.Contains(logText, rule.Name) {
for _, name := range liveRuleMetrics {
if strings.Contains(logText, name) {
isRuleLine = true
break
}
Expand All @@ -211,53 +205,46 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f
klog.Fatalf("Unable to parse value %v to float for metric %v", metricValue, metricName)
}

// stopRules contains array of EarlyStoppingRules that has not been reached yet.
// After rule is reached we delete appropriate element from the array.
for idx, rule := range stopRules {
if metricName != rule.Name {
// liveRuleMetrics contains array of EarlyStoppingRules Name that has not been reached yet.
for _, name := range liveRuleMetrics {
if metricName != name {
continue
}
stopRules, optimalObjValue = updateStopRules(stopRules, optimalObjValue, metricValue, metricStartStep, rule, idx)
err = rules.UpdateMetric(name, metricValue)
if err != nil {
klog.Fatalf("Unable to UpdateMetric %s %v", name, err)
}
}
}
}
case commonv1beta1.JsonFormat:
var logJsonObj map[string]interface{}
var logJsonObj map[string]any
if err = json.Unmarshal([]byte(logText), &logJsonObj); err != nil {
klog.Fatalf("Failed to unmarshal logs in %v format, log: %s, error: %v", commonv1beta1.JsonFormat, logText, err)
}
// Check if log line contains metric from stop rules.
isRuleLine := false
for _, rule := range stopRules {
if _, exist := logJsonObj[rule.Name]; exist {
isRuleLine = true
break
}
}
// If log line doesn't contain appropriate metric, continue track file.
if !isRuleLine {
continue
}

// stopRules contains array of EarlyStoppingRules that has not been reached yet.
// After rule is reached we delete appropriate element from the array.
for idx, rule := range stopRules {
value, exist := logJsonObj[rule.Name].(string)
// liveRuleMetrics contains array of EarlyStoppingRules Name that has not been reached yet.
for _, name := range liveRuleMetrics {
value, exist := logJsonObj[name].(string)
if !exist {
continue
}
metricValue, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
if err != nil {
klog.Fatalf("Unable to parse value %v to float for metric %v", metricValue, rule.Name)
klog.Fatalf("Unable to parse value %v to float for metric %v", metricValue, name)
}
err = rules.UpdateMetric(name, metricValue)
if err != nil {
klog.Fatalf("Unable to UpdateMetric %s %v", name, err)
}
stopRules, optimalObjValue = updateStopRules(stopRules, optimalObjValue, metricValue, metricStartStep, rule, idx)
}
default:
klog.Fatalf("Format must be set to %v or %v", commonv1beta1.TextFormat, commonv1beta1.JsonFormat)
}

// If stopRules array is empty, Trial is early stopped.
if len(stopRules) == 0 {
liveRuleMetrics = rules.LiveMetrics()
// If liveRuleMetrics array is empty, Trial is early stopped.
if len(liveRuleMetrics) == 0 {
klog.Info("Training container is early stopped")
isEarlyStopped = true

Expand Down Expand Up @@ -329,67 +316,6 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f
}
}

func updateStopRules(
stopRules []commonv1beta1.EarlyStoppingRule,
optimalObjValue *float64,
metricValue float64,
metricStartStep map[string]int,
rule commonv1beta1.EarlyStoppingRule,
ruleIdx int,
) ([]commonv1beta1.EarlyStoppingRule, *float64) {

// First metric is objective in metricNames array.
objMetric := strings.Split(*metricNames, ";")[0]
objType := commonv1beta1.ObjectiveType(*objectiveType)

// Calculate optimalObjValue.
if rule.Name == objMetric {
if optimalObjValue == nil {
optimalObjValue = &metricValue
} else if objType == commonv1beta1.ObjectiveTypeMaximize && metricValue > *optimalObjValue {
optimalObjValue = &metricValue
} else if objType == commonv1beta1.ObjectiveTypeMinimize && metricValue < *optimalObjValue {
optimalObjValue = &metricValue
}
// Assign best optimal value to metric value.
metricValue = *optimalObjValue
}

// Reduce steps if appropriate metric is reported.
// Once rest steps are empty we apply early stopping rule.
if _, ok := metricStartStep[rule.Name]; ok {
metricStartStep[rule.Name]--
if metricStartStep[rule.Name] != 0 {
return stopRules, optimalObjValue
}
}

ruleValue, err := strconv.ParseFloat(rule.Value, 64)
if err != nil {
klog.Fatalf("Unable to parse value %v to float for rule metric %v", rule.Value, rule.Name)
}

// Metric value can be equal, less or greater than stop rule.
// Deleting suitable stop rule from the array.
if rule.Comparison == commonv1beta1.ComparisonTypeEqual && metricValue == ruleValue {
return deleteStopRule(stopRules, ruleIdx), optimalObjValue
} else if rule.Comparison == commonv1beta1.ComparisonTypeLess && metricValue < ruleValue {
return deleteStopRule(stopRules, ruleIdx), optimalObjValue
} else if rule.Comparison == commonv1beta1.ComparisonTypeGreater && metricValue > ruleValue {
return deleteStopRule(stopRules, ruleIdx), optimalObjValue
}
return stopRules, optimalObjValue
}

func deleteStopRule(stopRules []commonv1beta1.EarlyStoppingRule, idx int) []commonv1beta1.EarlyStoppingRule {
if idx >= len(stopRules) {
klog.Fatalf("Index %v out of range stopRules: %v", idx, stopRules)
}
stopRules[idx] = stopRules[len(stopRules)-1]
stopRules[len(stopRules)-1] = commonv1beta1.EarlyStoppingRule{}
return stopRules[:len(stopRules)-1]
}

func main() {
flag.Var(&stopRules, "stop-rule", "The list of early stopping stop rules")
flag.Parse()
Expand Down
147 changes: 147 additions & 0 deletions pkg/metricscollector/v1beta1/file-metricscollector/rules.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package sidecarmetricscollector

import (
"fmt"
"math"
"strconv"

commonv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/common/v1beta1"
)

type RuleSet struct {
spec []commonv1beta1.EarlyStoppingRule
status []struct {
pruner earlyStoppingPruner
reach bool
}
}

func NewRuleSet(
objMetric string,
objType commonv1beta1.ObjectiveType,
spec []commonv1beta1.EarlyStoppingRule,
) (*RuleSet, error) {
s := &RuleSet{
spec: spec,
status: make([]struct {
pruner earlyStoppingPruner
reach bool
}, len(spec)),
}

for i, rule := range spec {
pruner, err := defaultFactory(rule)
if err != nil {
return nil, err
}
if objMetric == rule.Name {
pruner = &objPruner{
objType: objType,
optimalObjValue: math.NaN(),
sub: pruner,
}
}
s.status[i].pruner = pruner
}

return s, nil
}

func (s *RuleSet) LiveMetrics() []string {
ls := make([]string, 0, len(s.spec))
for i, rule := range s.spec {
if !s.status[i].reach {
ls = append(ls, rule.Name)
}
}
return ls
}

func (s *RuleSet) UpdateMetric(name string, metricValue float64) error {
for i := range s.spec {
rule := &s.spec[i]
status := &s.status[i]
if rule.Name != name || status.reach {
continue
}

reach, err := status.pruner.Pruner(metricValue)
if err != nil {
return err
}
if reach {
status.reach = true
}
}
return nil
}

type earlyStoppingPruner interface {
Pruner(metricValue float64) (bool, error)
}

func defaultFactory(rule commonv1beta1.EarlyStoppingRule) (earlyStoppingPruner, error) {
r := rule
switch rule.Comparison {
case commonv1beta1.ComparisonTypeGreater, commonv1beta1.ComparisonTypeLess, commonv1beta1.ComparisonTypeEqual:
value, err := strconv.ParseFloat(r.Value, 64)
if err != nil {
return nil, fmt.Errorf("unable to parse value to float for rule metric %s: %w", r.Name, err)
}
return &basicPruner{
target: value,
startStep: r.StartStep,
cmp: r.Comparison,
}, nil
default:
return nil, fmt.Errorf("unknown rule comparison: %s", r.Comparison)
}
}

type basicPruner struct {
target float64
step int
startStep int
cmp commonv1beta1.ComparisonType
}

func (p *basicPruner) Pruner(metricValue float64) (bool, error) {
p.step++
if p.startStep > 0 && p.step < p.startStep {
return false, nil
}
switch p.cmp {
case commonv1beta1.ComparisonTypeLess:
return metricValue < p.target, nil
case commonv1beta1.ComparisonTypeGreater:
return metricValue > p.target, nil
case commonv1beta1.ComparisonTypeEqual:
return metricValue == p.target, nil
default:
return false, fmt.Errorf("unknown rule comparison: %s", p.cmp)
}
}

type objPruner struct {
objType commonv1beta1.ObjectiveType
optimalObjValue float64
sub earlyStoppingPruner
}

func (p *objPruner) Pruner(metricValue float64) (bool, error) {
// For objective metric we calculate best optimal value from the recorded metrics.
// This is workaround for Median Stop algorithm.
// TODO (andreyvelich): Think about it, maybe define latest, max or min strategy type in stop-rule as well ?

if math.IsNaN(p.optimalObjValue) {
p.optimalObjValue = metricValue
} else if p.objType == commonv1beta1.ObjectiveTypeMaximize && metricValue > p.optimalObjValue {
p.optimalObjValue = metricValue
} else if p.objType == commonv1beta1.ObjectiveTypeMinimize && metricValue < p.optimalObjValue {
p.optimalObjValue = metricValue
}
// Assign best optimal value to metric value.
metricValue = p.optimalObjValue

return p.sub.Pruner(metricValue)
}
Loading
Loading