-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
62eb515
commit a08ebd3
Showing
5 changed files
with
259 additions
and
0 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
src/main/scala/elevate/heuristic_search/heuristics/EvolutionaryAlgorithm.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package elevate.heuristic_search.heuristics | ||
|
||
import elevate.heuristic_search | ||
import elevate.heuristic_search.util.Solution | ||
import elevate.heuristic_search.{Heuristic, HeuristicPanel} | ||
|
||
class EvolutionaryAlgorithm[P] extends Heuristic[P] { | ||
|
||
|
||
override def start(panel: HeuristicPanel[P], solution: Solution[P], depth: Int, samples: Int): heuristic_search.ExplorationResult[P] = { | ||
|
||
// Step One: Generate the initial population of individuals randomly. (First generation) | ||
// Step Two: Repeat the following regenerational steps until termination: | ||
// Evaluate the fitness of each individual in the population (time limit, sufficient fitness achieved, etc.) | ||
// Select the fittest individuals for reproduction. (Parents) | ||
// Breed new individuals through crossover and mutation operations to give birth to offspring. | ||
// Replace the least-fit individuals of the population with new individuals. | ||
|
||
|
||
// similar to local search? | ||
// get mutations | ||
|
||
// get Neighborhood of starting element as initial population | ||
|
||
// Evaluate Neighborhood | ||
|
||
// select fittest individuals | ||
|
||
// get neighborhoods for each neighbor | ||
|
||
// repeat | ||
|
||
|
||
null | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
src/main/scala/elevate/heuristic_search/heuristics/LocalSearch2.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package elevate.heuristic_search.heuristics | ||
|
||
import elevate.core.Strategy | ||
import elevate.heuristic_search._ | ||
import elevate.heuristic_search.util.Solution | ||
|
||
import scala.util.Random | ||
|
||
class LocalSearch2[P] extends Heuristic[P] { | ||
|
||
// check: do not terminate here if budget is left over | ||
// check: start from beginning, but do not make similar rewrites | ||
// check: maybe save history to access performance instead of execution? | ||
// todo: test | ||
|
||
import scala.collection.mutable | ||
|
||
def start(panel: HeuristicPanel[P], initialSolution: Solution[P], depth: Int, samples: Int): ExplorationResult[P] = { | ||
var solution: Solution[P] = initialSolution | ||
val initialSolutionValue: Option[Double] = panel.f(solution) | ||
var solutionValue: Option[Double] = initialSolutionValue | ||
var counter = 1 // Start with 1 to account for the initial solution evaluation | ||
// cache that identifies a rewrite by its rewrite sequence (Strategy and Location) | ||
val cache = mutable.Map[Seq[(Strategy[P], Int)], Option[Double]](solution.rewrite_sequence() -> initialSolutionValue) | ||
val tabu = mutable.Set.empty[Seq[(Strategy[P], Int)]] | ||
|
||
do { | ||
// Get neighborhood and ensure we don't exceed the sample limit | ||
val Ns = Random.shuffle(panel.N(solution).filterNot(elem => tabu.contains(elem.rewrite_sequence()))).take(samples - counter) | ||
|
||
// if no neighbor is left add this to tabu list and repeat from initial solution | ||
if (Ns.isEmpty || Ns.forall(ns => tabu.contains(ns.rewrite_sequence()))) { | ||
|
||
// add solution to tabu list | ||
tabu.add(solution.rewrite_sequence()) | ||
|
||
// start from beginning | ||
solution = initialSolution | ||
solutionValue = initialSolutionValue | ||
} else { | ||
|
||
|
||
// Evaluate the neighborhood using the cache to avoid re-evaluations | ||
val betterNeighbor = Ns | ||
.flatMap { ns => | ||
if (counter < samples) { | ||
cache.getOrElseUpdate(ns.rewrite_sequence(), { | ||
// If ns is not in cache, evaluate and cache it | ||
val value = panel.f(ns) | ||
counter += 1 | ||
value | ||
}).map(fns => (ns, fns)) | ||
} else { | ||
// If the sample limit is reached, skip further evaluations | ||
None | ||
} | ||
} | ||
.minByOption(_._2) | ||
|
||
// Check if a better solution was found | ||
betterNeighbor match { | ||
// If so, choose this as the new solution | ||
case Some((bestNeighbor, bestValue)) if bestValue < solutionValue.getOrElse(Double.MaxValue) => | ||
solution = bestNeighbor | ||
solutionValue = Some(bestValue) | ||
case _ => | ||
// If not, restart from the initial solution if budget is left | ||
|
||
// do not sample terminal expression again | ||
tabu.add(solution.rewrite_sequence()) | ||
|
||
if (counter < samples) { | ||
solution = initialSolution | ||
solutionValue = initialSolutionValue | ||
} else { | ||
// Otherwise, terminate | ||
return ExplorationResult(solution, solutionValue, None) | ||
} | ||
} | ||
} | ||
|
||
} while (counter < samples) | ||
|
||
ExplorationResult( | ||
solution, | ||
solutionValue, | ||
None | ||
) | ||
} | ||
|
||
} |
123 changes: 123 additions & 0 deletions
123
src/main/scala/elevate/heuristic_search/heuristics/MCTS2.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
//package elevate.heuristic_search.heuristics | ||
// | ||
//trait GameState { | ||
// def getLegalActions: List[GameAction] | ||
// | ||
// def move(action: GameAction): GameState | ||
// | ||
// def isTerminal: Boolean | ||
// | ||
// def getReward: Double | ||
//} | ||
// | ||
//case class GameAction(index: Int) | ||
// | ||
//import scala.util.Random | ||
// | ||
//case class Node(state: GameState, parent: Option[Node], action: Option[GameAction]) { | ||
// var wins: Double = 0 | ||
// var visits: Int = 0 | ||
// val children: collection.mutable.ListBuffer[Node] = collection.mutable.ListBuffer() | ||
// | ||
// def uctValue(explorationParameter: Double = Math.sqrt(2)): Double = { | ||
// if (visits == 0) Double.MaxValue | ||
// else wins.toDouble / visits + explorationParameter * Math.sqrt(Math.log(parent.map(_.visits.toDouble).getOrElse(1.0)) / visits) | ||
// } | ||
// | ||
// def addChild(childState: GameState, action: GameAction): Node = { | ||
// val childNode = Node(childState, Some(this), Some(action)) | ||
// children += childNode | ||
// childNode | ||
// } | ||
// | ||
// def isFullyExpanded: Boolean = children.size == state.getLegalActions.size | ||
// | ||
// def bestChild: Node = children.maxBy((child: Node) => child.wins.toDouble / child.visits) | ||
// | ||
// | ||
// def selectPromisingNode: Node = if (isFullyExpanded) children.maxBy(_.uctValue()) else this | ||
//} | ||
// | ||
//class MCTS(iterationLimit: Int) { | ||
// def search(initialState: GameState): GameState = { | ||
// val rootNode = Node(initialState, None, None) | ||
// | ||
// for (_ <- 1 to iterationLimit) { | ||
// val promisingNode = selectPromisingNode(rootNode) | ||
// val expandedNode = expandNode(promisingNode) | ||
// val simulationResult = simulateRandomPlayout(expandedNode) | ||
// backPropagate(expandedNode, simulationResult) | ||
// } | ||
// | ||
// printTree(rootNode, 0) | ||
// | ||
// rootNode.bestChild.state | ||
// } | ||
// | ||
// private def selectPromisingNode(node: Node): Node = { | ||
// var currentNode = node | ||
// while (!currentNode.state.isTerminal && currentNode.isFullyExpanded) { | ||
// currentNode = currentNode.children.maxBy(_.uctValue()) | ||
// } | ||
// currentNode | ||
// } | ||
// | ||
// private def expandNode(node: Node): Node = { | ||
// val untriedActions = node.state.getLegalActions.filterNot(action => node.children.exists(_.action.contains(action))) | ||
// if (untriedActions.nonEmpty) { | ||
// val action = untriedActions.head | ||
// val newState = node.state.move(action) | ||
// node.addChild(newState, action) | ||
// } else { | ||
// node | ||
// } | ||
// } | ||
// | ||
// private def simulateRandomPlayout(node: Node): Double = { | ||
// var tempState = node.state | ||
// while (!tempState.isTerminal) { | ||
// val legalActions = tempState.getLegalActions | ||
// val randomAction = legalActions(Random.nextInt(legalActions.size)) | ||
// tempState = tempState.move(randomAction) | ||
// } | ||
// tempState.getReward | ||
// } | ||
// | ||
// private def backPropagate(node: Node, reward: Double): Unit = { | ||
// var tempNode: Option[Node] = Some(node) | ||
// while (tempNode.isDefined) { | ||
// tempNode.get.visits += 1 | ||
// tempNode.get.wins += reward | ||
// tempNode = tempNode.get.parent | ||
// } | ||
// } | ||
// | ||
// def printTree(node: Node, depth: Int = 0): Unit = { | ||
// // Print the current node with indentation based on its depth in the tree | ||
// println(" " * depth * 2 + node.state + " " + node.wins + " " + node.visits) | ||
// | ||
// // Recursively print all the children | ||
// for (child <- node.children) { | ||
// printTree(child, depth + 1) | ||
// } | ||
// } | ||
//} | ||
// | ||
// | ||
//case class DummyGameState(currentPlayer: Int, movesLeft: Int) extends GameState { | ||
// override def getLegalActions: List[GameAction] = (1 to movesLeft).map(GameAction).toList | ||
// | ||
// override def move(action: GameAction): GameState = DummyGameState(-currentPlayer, movesLeft - 1) | ||
// | ||
// override def isTerminal: Boolean = movesLeft == 0 | ||
// | ||
// override def getReward: Double = if (movesLeft % 2 == 0) currentPlayer else -currentPlayer | ||
//} | ||
// | ||
//object Main extends App { | ||
// val initialState = DummyGameState(1, 3) | ||
// val mcts = new MCTS(iterationLimit = 10) | ||
// val finalState = mcts.search(initialState) | ||
// | ||
// println(s"Best move leads to state: $finalState") | ||
//} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters