Skip to content

Commit

Permalink
tabu/local search wip
Browse files Browse the repository at this point in the history
  • Loading branch information
johanneslenfers committed Oct 30, 2024
1 parent 62eb515 commit a08ebd3
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package elevate.heuristic_search.heuristics

import elevate.heuristic_search
import elevate.heuristic_search.util.Solution
import elevate.heuristic_search.{Heuristic, HeuristicPanel}

class EvolutionaryAlgorithm[P] extends Heuristic[P] {


override def start(panel: HeuristicPanel[P], solution: Solution[P], depth: Int, samples: Int): heuristic_search.ExplorationResult[P] = {

// Step One: Generate the initial population of individuals randomly. (First generation)
// Step Two: Repeat the following regenerational steps until termination:
// Evaluate the fitness of each individual in the population (time limit, sufficient fitness achieved, etc.)
// Select the fittest individuals for reproduction. (Parents)
// Breed new individuals through crossover and mutation operations to give birth to offspring.
// Replace the least-fit individuals of the population with new individuals.


// similar to local search?
// get mutations

// get Neighborhood of starting element as initial population

// Evaluate Neighborhood

// select fittest individuals

// get neighborhoods for each neighbor

// repeat


null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ import elevate.heuristic_search.util.Solution

class LocalSearch[P] extends Heuristic[P] {

// do not terminate here if budget is left over
// start from beginning, but do not make similar rewrites
// maybe save history to access performance instead of execution?


def start(panel: HeuristicPanel[P], initialSolution: Solution[P], depth: Int, samples: Int): ExplorationResult[P] = {
var solution: Solution[P] = initialSolution
var solutionValue: Option[Double] = panel.f(solution)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package elevate.heuristic_search.heuristics

import elevate.core.Strategy
import elevate.heuristic_search._
import elevate.heuristic_search.util.Solution

import scala.util.Random

class LocalSearch2[P] extends Heuristic[P] {

// check: do not terminate here if budget is left over
// check: start from beginning, but do not make similar rewrites
// check: maybe save history to access performance instead of execution?
// todo: test

import scala.collection.mutable

def start(panel: HeuristicPanel[P], initialSolution: Solution[P], depth: Int, samples: Int): ExplorationResult[P] = {
var solution: Solution[P] = initialSolution
val initialSolutionValue: Option[Double] = panel.f(solution)
var solutionValue: Option[Double] = initialSolutionValue
var counter = 1 // Start with 1 to account for the initial solution evaluation
// cache that identifies a rewrite by its rewrite sequence (Strategy and Location)
val cache = mutable.Map[Seq[(Strategy[P], Int)], Option[Double]](solution.rewrite_sequence() -> initialSolutionValue)
val tabu = mutable.Set.empty[Seq[(Strategy[P], Int)]]

do {
// Get neighborhood and ensure we don't exceed the sample limit
val Ns = Random.shuffle(panel.N(solution).filterNot(elem => tabu.contains(elem.rewrite_sequence()))).take(samples - counter)

// if no neighbor is left add this to tabu list and repeat from initial solution
if (Ns.isEmpty || Ns.forall(ns => tabu.contains(ns.rewrite_sequence()))) {

// add solution to tabu list
tabu.add(solution.rewrite_sequence())

// start from beginning
solution = initialSolution
solutionValue = initialSolutionValue
} else {


// Evaluate the neighborhood using the cache to avoid re-evaluations
val betterNeighbor = Ns
.flatMap { ns =>
if (counter < samples) {
cache.getOrElseUpdate(ns.rewrite_sequence(), {
// If ns is not in cache, evaluate and cache it
val value = panel.f(ns)
counter += 1
value
}).map(fns => (ns, fns))
} else {
// If the sample limit is reached, skip further evaluations
None
}
}
.minByOption(_._2)

// Check if a better solution was found
betterNeighbor match {
// If so, choose this as the new solution
case Some((bestNeighbor, bestValue)) if bestValue < solutionValue.getOrElse(Double.MaxValue) =>
solution = bestNeighbor
solutionValue = Some(bestValue)
case _ =>
// If not, restart from the initial solution if budget is left

// do not sample terminal expression again
tabu.add(solution.rewrite_sequence())

if (counter < samples) {
solution = initialSolution
solutionValue = initialSolutionValue
} else {
// Otherwise, terminate
return ExplorationResult(solution, solutionValue, None)
}
}
}

} while (counter < samples)

ExplorationResult(
solution,
solutionValue,
None
)
}

}
123 changes: 123 additions & 0 deletions src/main/scala/elevate/heuristic_search/heuristics/MCTS2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//package elevate.heuristic_search.heuristics
//
//trait GameState {
// def getLegalActions: List[GameAction]
//
// def move(action: GameAction): GameState
//
// def isTerminal: Boolean
//
// def getReward: Double
//}
//
//case class GameAction(index: Int)
//
//import scala.util.Random
//
//case class Node(state: GameState, parent: Option[Node], action: Option[GameAction]) {
// var wins: Double = 0
// var visits: Int = 0
// val children: collection.mutable.ListBuffer[Node] = collection.mutable.ListBuffer()
//
// def uctValue(explorationParameter: Double = Math.sqrt(2)): Double = {
// if (visits == 0) Double.MaxValue
// else wins.toDouble / visits + explorationParameter * Math.sqrt(Math.log(parent.map(_.visits.toDouble).getOrElse(1.0)) / visits)
// }
//
// def addChild(childState: GameState, action: GameAction): Node = {
// val childNode = Node(childState, Some(this), Some(action))
// children += childNode
// childNode
// }
//
// def isFullyExpanded: Boolean = children.size == state.getLegalActions.size
//
// def bestChild: Node = children.maxBy((child: Node) => child.wins.toDouble / child.visits)
//
//
// def selectPromisingNode: Node = if (isFullyExpanded) children.maxBy(_.uctValue()) else this
//}
//
//class MCTS(iterationLimit: Int) {
// def search(initialState: GameState): GameState = {
// val rootNode = Node(initialState, None, None)
//
// for (_ <- 1 to iterationLimit) {
// val promisingNode = selectPromisingNode(rootNode)
// val expandedNode = expandNode(promisingNode)
// val simulationResult = simulateRandomPlayout(expandedNode)
// backPropagate(expandedNode, simulationResult)
// }
//
// printTree(rootNode, 0)
//
// rootNode.bestChild.state
// }
//
// private def selectPromisingNode(node: Node): Node = {
// var currentNode = node
// while (!currentNode.state.isTerminal && currentNode.isFullyExpanded) {
// currentNode = currentNode.children.maxBy(_.uctValue())
// }
// currentNode
// }
//
// private def expandNode(node: Node): Node = {
// val untriedActions = node.state.getLegalActions.filterNot(action => node.children.exists(_.action.contains(action)))
// if (untriedActions.nonEmpty) {
// val action = untriedActions.head
// val newState = node.state.move(action)
// node.addChild(newState, action)
// } else {
// node
// }
// }
//
// private def simulateRandomPlayout(node: Node): Double = {
// var tempState = node.state
// while (!tempState.isTerminal) {
// val legalActions = tempState.getLegalActions
// val randomAction = legalActions(Random.nextInt(legalActions.size))
// tempState = tempState.move(randomAction)
// }
// tempState.getReward
// }
//
// private def backPropagate(node: Node, reward: Double): Unit = {
// var tempNode: Option[Node] = Some(node)
// while (tempNode.isDefined) {
// tempNode.get.visits += 1
// tempNode.get.wins += reward
// tempNode = tempNode.get.parent
// }
// }
//
// def printTree(node: Node, depth: Int = 0): Unit = {
// // Print the current node with indentation based on its depth in the tree
// println(" " * depth * 2 + node.state + " " + node.wins + " " + node.visits)
//
// // Recursively print all the children
// for (child <- node.children) {
// printTree(child, depth + 1)
// }
// }
//}
//
//
//case class DummyGameState(currentPlayer: Int, movesLeft: Int) extends GameState {
// override def getLegalActions: List[GameAction] = (1 to movesLeft).map(GameAction).toList
//
// override def move(action: GameAction): GameState = DummyGameState(-currentPlayer, movesLeft - 1)
//
// override def isTerminal: Boolean = movesLeft == 0
//
// override def getReward: Double = if (movesLeft % 2 == 0) currentPlayer else -currentPlayer
//}
//
//object Main extends App {
// val initialState = DummyGameState(1, 3)
// val mcts = new MCTS(iterationLimit = 10)
// val finalState = mcts.search(initialState)
//
// println(s"Best move leads to state: $finalState")
//}
4 changes: 4 additions & 0 deletions src/main/scala/elevate/heuristic_search/util/Solution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ case class Solution[P](
)
}

def rewrite_sequence(): Seq[(Strategy[P], Int)] = {
solutionSteps.map(step => (step.strategy, step.location))
}

def parent(): Solution[P] = {
Solution[P](solutionSteps.dropRight(1))
}
Expand Down

0 comments on commit a08ebd3

Please sign in to comment.