Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache grid pos #4

Open
Chan-Dong-Jun opened this issue Aug 19, 2024 · 3 comments
Open

Cache grid pos #4

Chan-Dong-Jun opened this issue Aug 19, 2024 · 3 comments
Assignees

Comments

@Chan-Dong-Jun
Copy link
Owner

What's the problem this feature will solve?
There should be a method to cache the positions of agents at each step of the simulation so that we can display the agents on the grid visualization.

Describe the solution you'd like

  1. Cache the positions of the agents
  2. Modify the visualization module. Currently, the visualization module requires the model object to get the grid attribute for displaying visualization. The module should be modified such that there is no need for the model object, and the grid is directly constructed from the cached file.
@Chan-Dong-Jun Chan-Dong-Jun self-assigned this Aug 19, 2024
@Chan-Dong-Jun
Copy link
Owner Author

Main edit

Caching methods

    def get_grid_dataframe(self, cache_file_path: str = None):
        grid_state = {
            'width': self.model.grid.width,
            'height': self.model.grid.height,
            'agents': []
        }
        for x in range(grid_state['width']):
            for y in range(grid_state['height']):
                cell_contents = self.model.grid._grid[x][y]
                if cell_contents:
                    if not hasattr(cell_contents, "__iter__"):
                        cell_contents = [cell_contents]
                    for agent in cell_contents:
                        agent_state = {
                            'pos_x': agent.pos[0],
                            'pos_y': agent.pos[1],
                            'unique_id': agent.unique_id,
                            'wealth': agent.wealth,
                            # **agent.__dict__
                        }
                        grid_state['agents'].append(agent_state)
        padding = len(str(self._total_steps)) - 1
        filename = f"{self.cache_file_path}/grid_data_{(self.model._steps):0{padding}}.parquet"

        # Convert to DataFrame
        df = pd.DataFrame(grid_state['agents'])

        # Save DataFrame to Parquet
        df.to_parquet(filename)

    @staticmethod
    def reconstruct_grid(filename, *attributes_list):
        # Load the DataFrame from Parquet
        df = pd.read_parquet(filename)

        # Create a new Grid instance
        width = df['pos_x'].max() + 1  # Assuming positions start from 0
        height = df['pos_y'].max() + 1  # Assuming positions start from 0
        grid = Grid(width, height, False)

        # Add agents to the grid
        for _, row in df.iterrows():
            agent = Agent(row['unique_id'], Model(100, 10, 10))
            agent.wealth = row["wealth"]
            grid.place_agent(agent, (row['pos_x'], row['pos_y']))

        return grid

get_grid_dataframe caches the position of the agents and writes them to a parquet file. reconstruct_grid takes in the parquet file and returns a grid object with the Agents populated.

Visualisation module

@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None):
    space_fig = Figure()
    space_ax = space_fig.subplots()

    space = CacheableModel.reconstruct_grid(f'output_dir/grid_data_{(model._steps + 1):0{3}}.parquet')
    print(f'output_dir/grid_data_{(model._steps + 1):0{3}}.parquet')
    if space is None:
        # Sometimes the space is defined as model.space instead of model.grid
        space = model.space
    if isinstance(space, mesa.space.NetworkGrid):
        _draw_network_grid(space, space_ax, agent_portrayal)
    elif isinstance(space, mesa.space.ContinuousSpace):
        _draw_continuous_space(space, space_ax, agent_portrayal)
    else:
        _draw_grid(space, space_ax, agent_portrayal)
    solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

The SpaceMatplotlib takes a parquet file and reconstructs the grid object. This is in contrast to directly reading from grid attribute of the model object.
This is still a work in progress but currently the grid visualisation can read from the cached files directly.
image

@Chan-Dong-Jun
Copy link
Owner Author

Main edit

Caching methods

    class TestModel(mesa.Model):
    """A model with some number of agents."""

    def __init__(self, N, width, height):
        super().__init__()
        # self.num_agents = N
        # self.grid = mesa.space.MultiGrid(1,1, True)
        # self.schedule = mesa.time.RandomActivation(self)

        # Create agents
        # for i in range(self.num_agents):
            # a = MoneyAgent(i, self)
            # self.schedule.add(a)
            # Add the agent to a random grid cell
            # x = self.random.randrange(self.grid.width)
            # y = self.random.randrange(self.grid.height)
            # self.grid.place_agent(a, (x, y))

        # self.datacollector = mesa.DataCollector()

    def step(self):
        # self.datacollector.collect(self)
        self._steps += 1

The TestModel acts as a dummy model to be fed into SolaraViz. This allows the visualizer to step while the data displayed is taken from the cached data.

Limitations:

  • the visualizer will break after it has no more cached data to read. Must stop the visualiser before the last cached file is read.
  • the visualizer's step can only increase. Possible improvements include adding a playback button to decrement the step number.

@Chan-Dong-Jun
Copy link
Owner Author

Main edit

Caching methods

   @solara.component
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None):
    fig = Figure()
    ax = fig.subplots()

    # TODO: Check
    model_files = glob.glob(f"output_dir/model_data_*.parquet")
    model_dfs = []
    for model_file in model_files:
        table = pq.read_table(model_file)
        df = table.to_pandas()
        model_dfs.append(df)
    df = pd.concat(model_dfs, ignore_index=True)[:model._steps+1]

    if isinstance(measure, str):
        ax.plot(df.loc[:, measure])
        ax.set_ylabel(measure)
    elif isinstance(measure, dict):
        for m, color in measure.items():
            ax.plot(df.loc[:, m], label=m, color=color)
        fig.legend()
    elif isinstance(measure, list | tuple):
        for m in measure:
            ax.plot(df.loc[:, m], label=m)
        fig.legend()
    # Set integer x axis
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    solara.FigureMatplotlib(fig, dependencies=dependencies)

The PlotMatplotlib will now plot the matplotlib graph from cached data.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

1 participant