Skip to content

Commit

Permalink
Implemented parallel evaluation of the organisms for double pole bala…
Browse files Browse the repository at this point in the history
…ncing experiment
  • Loading branch information
yaricom committed Jul 16, 2024
1 parent 98a2d1c commit 0907603
Show file tree
Hide file tree
Showing 8 changed files with 452 additions and 285 deletions.
15 changes: 0 additions & 15 deletions examples/pole/common.go

This file was deleted.

109 changes: 109 additions & 0 deletions examples/pole2/cart2pole.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package pole2

import (
"context"
"fmt"
"github.com/yaricom/goNEAT/v4/experiment"
"github.com/yaricom/goNEAT/v4/experiment/utils"
"github.com/yaricom/goNEAT/v4/neat"
"github.com/yaricom/goNEAT/v4/neat/genetics"
)

type cartDoublePoleGenerationEvaluator struct {
// The output path to store execution results
OutputPath string
// The flag to indicate whether to apply Markov evaluation variant
Markov bool

// The flag to indicate whether to use continuous activation or discrete
ActionType ActionType
}

// NewCartDoublePoleGenerationEvaluator is the generations evaluator for double-pole balancing experiment: both Markov and non-Markov versions
func NewCartDoublePoleGenerationEvaluator(outDir string, markov bool, actionType ActionType) experiment.GenerationEvaluator {
return &cartDoublePoleGenerationEvaluator{
OutputPath: outDir,
Markov: markov,
ActionType: actionType,
}
}

// GenerationEvaluate Perform evaluation of one epoch on double pole balancing
func (e *cartDoublePoleGenerationEvaluator) GenerationEvaluate(ctx context.Context, pop *genetics.Population, epoch *experiment.Generation) error {
options, ok := neat.FromContext(ctx)
if !ok {
return neat.ErrNEATOptionsNotFound
}
cartPole := NewCartPole(e.Markov)

cartPole.nonMarkovLong = false
cartPole.generalizationTest = false

// Evaluate each organism on a test
for _, org := range pop.Organisms {
winner, err := OrganismEvaluate(org, cartPole, e.ActionType)
if err != nil {
return err
}

if winner && (epoch.Champion == nil || org.Fitness > epoch.Champion.Fitness) {
// This will be winner in Markov case
epoch.Solved = true
epoch.WinnerNodes = len(org.Genotype.Nodes)
epoch.WinnerGenes = org.Genotype.Extrons()
epoch.WinnerEvals = options.PopSize*epoch.Id + org.Genotype.Id
epoch.Champion = org
org.IsWinner = true
}
}

// Check for winner in Non-Markov case
if !e.Markov {
epoch.Solved = false
// evaluate generalization tests
if champion, err := EvaluateOrganismGeneralization(pop.Species, cartPole, e.ActionType); err != nil {
return err
} else if champion.IsWinner {
epoch.Solved = true
epoch.WinnerNodes = len(champion.Genotype.Nodes)
epoch.WinnerGenes = champion.Genotype.Extrons()
epoch.WinnerEvals = options.PopSize*epoch.Id + champion.Genotype.Id
epoch.Champion = champion
}
}

// Fill statistics about current epoch
epoch.FillPopulationStatistics(pop)

// Only print to file every print_every generation
if epoch.Solved || epoch.Id%options.PrintEvery == 0 {
if _, err := utils.WritePopulationPlain(e.OutputPath, pop, epoch); err != nil {
neat.ErrorLog(fmt.Sprintf("Failed to dump population, reason: %s\n", err))
return err
}
}

if epoch.Solved {
// print winner organism's statistics
org := epoch.Champion
utils.PrintActivationDepth(org, true)

genomeFile := "pole2_winner_genome"
// Prints the winner organism to file!
if orgPath, err := utils.WriteGenomePlain(genomeFile, e.OutputPath, org, epoch); err != nil {
neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism's genome, reason: %s\n", err))
} else {
neat.InfoLog(fmt.Sprintf("Generation #%d winner's genome dumped to: %s\n", epoch.Id, orgPath))
}

// Prints the winner organism's phenotype to the Cytoscape JSON file!
if orgPath, err := utils.WriteGenomeCytoscapeJSON(genomeFile, e.OutputPath, org, epoch); err != nil {
neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism's phenome Cytoscape JSON graph, reason: %s\n", err))
} else {
neat.InfoLog(fmt.Sprintf("Generation #%d winner's phenome Cytoscape JSON graph dumped to: %s\n",
epoch.Id, orgPath))
}
}

return nil
}
138 changes: 138 additions & 0 deletions examples/pole2/cart2pole_parallel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package pole2

import (
"context"
"fmt"
"github.com/yaricom/goNEAT/v4/experiment"
"github.com/yaricom/goNEAT/v4/experiment/utils"
"github.com/yaricom/goNEAT/v4/neat"
"github.com/yaricom/goNEAT/v4/neat/genetics"
"sync"
)

type cartDoublePoleParallelGenerationEvaluator struct {
cartDoublePoleGenerationEvaluator
}

type parallelEvaluationResult struct {
genomeId int
fitness float64
error float64
winner bool
err error
}

// NewCartDoublePoleParallelGenerationEvaluator is the generations evaluator for double-pole balancing experiment: both Markov and non-Markov versions
func NewCartDoublePoleParallelGenerationEvaluator(outDir string, markov bool, actionType ActionType) experiment.GenerationEvaluator {
return &cartDoublePoleParallelGenerationEvaluator{
cartDoublePoleGenerationEvaluator{
OutputPath: outDir,
Markov: markov,
ActionType: actionType,
},
}
}

func (e *cartDoublePoleParallelGenerationEvaluator) GenerationEvaluate(ctx context.Context, pop *genetics.Population, epoch *experiment.Generation) error {
options, ok := neat.FromContext(ctx)
if !ok {
return neat.ErrNEATOptionsNotFound
}

organismMapping := make(map[int]*genetics.Organism)

popSize := len(pop.Organisms)
resChan := make(chan parallelEvaluationResult, popSize)
// The wait group to wait for all GO routines
var wg sync.WaitGroup

// Evaluate each organism in generation
for _, org := range pop.Organisms {
if _, ok = organismMapping[org.Genotype.Id]; ok {
return fmt.Errorf("organism with %d already exists in mapping", org.Genotype.Id)
}
organismMapping[org.Genotype.Id] = org
wg.Add(1)

// run in separate GO thread
go func(organism *genetics.Organism, actionType ActionType, resChan chan<- parallelEvaluationResult, wg *sync.WaitGroup) {
defer wg.Done()

// create simulator and evaluate
cartPole := NewCartPole(e.Markov)
cartPole.nonMarkovLong = false
cartPole.generalizationTest = false

winner, err := OrganismEvaluate(organism, cartPole, actionType)
if err != nil {
resChan <- parallelEvaluationResult{err: err}
return
}

// create result
result := parallelEvaluationResult{
genomeId: organism.Genotype.Id,
fitness: organism.Fitness,
error: organism.Error,
winner: winner,
}
resChan <- result

}(org, e.ActionType, resChan, &wg)
}

// wait for evaluation results
wg.Wait()
close(resChan)

for result := range resChan {
if result.err != nil {
return result.err
}
// find and update original organism
org, ok := organismMapping[result.genomeId]
if ok {
org.Fitness = result.fitness
org.Error = result.error
} else {
return fmt.Errorf("organism not found in mapping for id: %d", result.genomeId)
}

if result.winner && (epoch.Champion == nil || org.Fitness > epoch.Champion.Fitness) {
// This will be winner in Markov case
epoch.Solved = true
epoch.WinnerNodes = len(org.Genotype.Nodes)
epoch.WinnerGenes = org.Genotype.Extrons()
epoch.WinnerEvals = options.PopSize*epoch.Id + org.Genotype.Id
epoch.Champion = org
org.IsWinner = true
}
}

// Fill statistics about current epoch
epoch.FillPopulationStatistics(pop)

if epoch.Solved {
// print winner organism's statistics
org := epoch.Champion
utils.PrintActivationDepth(org, true)

genomeFile := "pole2_parallel_winner_genome"
// Prints the winner organism to file!
if orgPath, err := utils.WriteGenomePlain(genomeFile, e.OutputPath, org, epoch); err != nil {
neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism's genome, reason: %s\n", err))
} else {
neat.InfoLog(fmt.Sprintf("Generation #%d winner's genome dumped to: %s\n", epoch.Id, orgPath))
}

// Prints the winner organism's phenotype to the Cytoscape JSON file!
if orgPath, err := utils.WriteGenomeCytoscapeJSON(genomeFile, e.OutputPath, org, epoch); err != nil {
neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism's phenome Cytoscape JSON graph, reason: %s\n", err))
} else {
neat.InfoLog(fmt.Sprintf("Generation #%d winner's phenome Cytoscape JSON graph dumped to: %s\n",
epoch.Id, orgPath))
}
}

return nil
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pole
package pole2

import (
"fmt"
Expand Down
Loading

0 comments on commit 0907603

Please sign in to comment.