Skip to content

Commit

Permalink
fixed PartiallyObservableGameState.getVectorObservation() (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
DennisSoemers authored May 24, 2024
1 parent 2082f65 commit 5dc4403
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/rts/GameState.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class GameState {
// 4: unit type
// 5: current unit action
// 6: wall
public static final int numVectorObservationFeatureMaps = 6;
private static final int NUM_VECTOR_OBSERVATION_FEATURE_MAPS = 6;

/**
* Initializes the GameState with a PhysicalGameState and a UnitTypeTable
Expand Down Expand Up @@ -921,7 +921,7 @@ public static GameState fromJSON(String JSON, UnitTypeTable utt) {
*/
public int [][][] getVectorObservation(final int player){
if (vectorObservation == null) {
vectorObservation = new int[2][numVectorObservationFeatureMaps][pgs.height][pgs.width];
vectorObservation = new int[2][NUM_VECTOR_OBSERVATION_FEATURE_MAPS][pgs.height][pgs.width];
}
// hitpointsMatrix is vectorObservation[player][0]
// resourcesMatrix is vectorObservation[player][1]
Expand Down
146 changes: 130 additions & 16 deletions src/rts/PartiallyObservableGameState.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package rts;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import rts.units.Unit;

/**
Expand All @@ -10,10 +13,19 @@
* @author santi
*/
public class PartiallyObservableGameState extends GameState {
/**
*
*/
protected int player; // the observer player

protected int observer; // the observer player

// Feature maps:
// 1: hit points
// 2: resources
// 3: player
// 4: unit type
// 5: current unit action
// 6: walls
// 7: which cells can I see?
// 8: for which cells do I know that my opponent can see them?
public static final int NUM_VECTOR_OBSERVATION_FEATURE_MAPS_PARTIAL_OBS = 8;

/**
* Creates a partially observable game state, from the point of view of 'player':
Expand All @@ -25,31 +37,32 @@ public PartiallyObservableGameState(GameState gs, int a_player) {
unitCancelationCounter = gs.unitCancelationCounter;
time = gs.time;

player = a_player;
observer = a_player;

unitActions.putAll(gs.unitActions);

List<Unit> toDelete = new LinkedList<>();
for (Unit u : pgs.getUnits()) {
if (u.getPlayer() != player) {
final List<Unit> toDelete = new LinkedList<>();
for (final Unit u : pgs.getUnits()) {
if (u.getPlayer() != observer) {
if (!observable(u.getX(), u.getY())) {
toDelete.add(u);
}
}
}
for (Unit u : toDelete)
for (final Unit u : toDelete)
removeUnit(u);
}

/**
* Returns whether the position is within view of the player
* @see rts.GameState#observable(int, int)
*/
public boolean observable(int x, int y) {
for (Unit u : pgs.getUnits()) {
if (u.getPlayer() == player) {
double d = Math.sqrt((u.getX() - x) * (u.getX() - x) + (u.getY() - y) * (u.getY() - y));
if (d <= u.getType().sightRadius)
@Override
public boolean observable(final int x, final int y) {
for (final Unit u : pgs.getUnits()) {
if (u.getPlayer() == observer) {
final int dSquared = (u.getX() - x) * (u.getX() - x) + (u.getY() - y) * (u.getY() - y);
if (dSquared <= u.getType().sightRadius * u.getType().sightRadius)
return true;
}
}
Expand All @@ -60,7 +73,108 @@ public boolean observable(int x, int y) {
/* (non-Javadoc)
* @see rts.GameState#clone()
*/
public PartiallyObservableGameState clone() {
return new PartiallyObservableGameState(super.clone(), player);
@Override
public PartiallyObservableGameState clone() {
return new PartiallyObservableGameState(super.clone(), observer);
}

@Override
public int [][][] getVectorObservation(final int player){
if (vectorObservation == null) {
vectorObservation = new int[2][NUM_VECTOR_OBSERVATION_FEATURE_MAPS_PARTIAL_OBS][pgs.height][pgs.width];
}

List<int[]> friendlyUnits = new ArrayList<>();
List<int[]> enemyUnits = new ArrayList<>();

// hitpointsMatrix is vectorObservation[player][0]
// resourcesMatrix is vectorObservation[player][1]
// playersMatrix is vectorObservation[player][2]
// unitTypesMatrix is vectorObservation[player][3]
// unitActionMatrix is vectorObservation[player][4]
// wallMatrix is vectorObservation[player][5]
// myVisibilityMatrix is vectorObservation[player][6]
// opponentVisibilityMatrix is vectorObservation[player][7]

for (int i=0; i<vectorObservation[player][0].length; i++) {
Arrays.fill(vectorObservation[player][0][i], 0);
Arrays.fill(vectorObservation[player][1][i], 0);
Arrays.fill(vectorObservation[player][2][i], 0);
Arrays.fill(vectorObservation[player][3][i], 0);
Arrays.fill(vectorObservation[player][4][i], 0);
Arrays.fill(vectorObservation[player][5][i], 0);
Arrays.fill(vectorObservation[player][6][i], 0);
Arrays.fill(vectorObservation[player][7][i], 0);
}

for (int i = 0; i < pgs.units.size(); i++) {
Unit u = pgs.units.get(i);
UnitActionAssignment uaa = unitActions.get(u);

vectorObservation[player][0][u.getY()][u.getX()] = u.getHitPoints();
vectorObservation[player][1][u.getY()][u.getX()] = u.getResources();

final int owner = u.getPlayer();
if (owner >= 0) { // Owned by a player, not neutral
vectorObservation[player][2][u.getY()][u.getX()] = ((u.getPlayer() + player) % 2) + 1;

// Split units based on owner (used for last two layers of the observation)
if (owner == player)
friendlyUnits.add(new int[]{u.getX(), u.getY(), u.getType().sightRadius});
else
enemyUnits.add(new int[]{u.getX(), u.getY(), u.getType().sightRadius});
}

vectorObservation[player][3][u.getY()][u.getX()] = u.getType().ID + 1;

if (uaa != null) {
vectorObservation[player][4][u.getY()][u.getX()] = uaa.action.type;
} else {
// Commented line of code is unnecessary: already initialised to 0
//vectorObservation[player][4][u.getY()][u.getX()] = UnitAction.TYPE_NONE;
}
}

// Encode the presence of walls
final int[] terrain = pgs.terrain;
for (int y = 0; y < pgs.height; ++y) {
System.arraycopy(terrain, y * pgs.width, vectorObservation[player][5][y], 0, pgs.width);
}

// Encode visibility
final int[][] playerVisibility = calculateVisibility(friendlyUnits, pgs.width, pgs.height);
final int[][] opponentVisibility = calculateVisibility(enemyUnits, pgs.width, pgs.height);

for (int y = 0; y < pgs.height; y++) {
System.arraycopy(playerVisibility, y * pgs.width, vectorObservation[player][6][y], 0, pgs.width);
System.arraycopy(opponentVisibility, y * pgs.width, vectorObservation[player][7][y], 0, pgs.width);
}

return vectorObservation[player];
}

private static int[][] calculateVisibility(final List<int[]> units, final int width, final int height) {
final int[][] visibility = new int[height][width];
for (final int[] unit : units) {
final int ux = unit[0];
final int uy = unit[1];
final int sightRadius = unit[2];
final int sightRadiusSquared = sightRadius * sightRadius;

for (int dy = -sightRadius; dy <= sightRadius; dy++) {
for (int dx = -sightRadius; dx <= sightRadius; dx++) {
final int x = ux + dx;
final int y = uy + dy;

if (x >= 0 && x < width && y >= 0 && y < height) {
final int distanceSquared = dx * dx + dy * dy;
if (distanceSquared <= sightRadiusSquared) {
visibility[y][x] = 1;
}
}
}
}
}
return visibility;
}
}

0 comments on commit 5dc4403

Please sign in to comment.