diff --git a/Project.toml b/Project.toml
index c688ab9..000fa49 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,9 +6,15 @@ version = "0.1.0"
[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
+CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
+Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
+Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -17,17 +23,10 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
-julia = "1.7"
-BSON = "0.3.4"
-Bijections = "0.1.3"
-CommonRLInterface = "0.3.1"
-DataStructures = "0.18.11"
-Distributions = "0.25.48"
-ProgressMeter = "1.7.1"
-Scratch = "1.1.0"
-Suppressor = "0.2.0"
+julia = "1.10"
[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
diff --git a/README.md b/README.md
index 9b9053d..4267986 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
![AdaStress](docs/logo.svg)
-AdaStress is a software package that implements the Adaptive Stress Testing (AST) framework, which determines the likeliest failures for a system under test.
+AdaStress is a software package that implements and extends the Adaptive Stress Testing (AST) framework, which determines the likeliest failures for a system under test.
AdaStress provides three primary services:
- Interfaces between user simulations and the AST framework
@@ -33,8 +33,7 @@ AdaStress provides two basic simulation interfaces, **black-box** and **gray-box
Your simulation must inherit from the `BlackBox` or `GrayBox` type and implement the methods found in `src/interface/BlackBox.jl` or `src/interface/GrayBox.jl`.
## Further information
-For more detailed instructions on using AdaStress, see the [complete documentation](./docs/main.md). Example notebooks can be found in the `examples` directory. For background on original AST formulation, see
-> Lee, Ritchie, Ole J. Mengshoel, Anshu Saksena, Ryan W. Gardner, Daniel Genin, Joshua Silbermann, Michael Owen, and Mykel J. Kochenderfer. "Adaptive stress testing: Finding likely failure events with reinforcement learning." Journal of Artificial Intelligence Research 69 (2020): 1165-1201.
+For more detailed instructions on using AdaStress, see the [complete documentation](./docs/main.md). Example notebooks can be found in the `examples` directory. For background on the AST formulation, see the [original paper](https://doi.org/10.1613/jair.1.12190).
## License
AdaStress has been released under the NASA Open Source Agreement version 1.3, as detailed [here](docs/LICENSE.pdf).
diff --git a/docs/main.md b/docs/main.md
index 7cb4865..97ba99e 100644
--- a/docs/main.md
+++ b/docs/main.md
@@ -4,17 +4,16 @@
---
-- [Maintainers](#maintainers)
-- [Description](#description)
-- [Prerequisites](#prerequisites)
-- [Architecture](#architecture)
-- [Problem setup](#problem-setup)
-- [Interface](#interface)
-- [Serialization interface](#serialization-interface)
-- [Submodule management](#submodule-management)
-- [Solvers](#solvers)
-- [Analysis](#analysis)
-- [Acknowledgements](#acknowledgments)
+[Maintainers](#maintainers)\
+[Description](#description)\
+[Prerequisites](#prerequisites)\
+[Architecture](#architecture)\
+[Problem setup](#problem-setup)\
+[Interface](#interface)\
+[Serialization interface](#serialization-interface)\
+[Solvers](#solvers)\
+[Analysis](#analysis)\
+[Acknowledgements](#acknowledgments)
---
@@ -24,7 +23,7 @@
## Description
-AdaStress is a software package that implements the [adaptive stress testing (AST) framework](https://doi.org/10.1613/jair.1.12190), which determines the likeliest failures for a system under test.
+AdaStress is a software package that implements and extends the [adaptive stress testing (AST) framework](https://doi.org/10.1613/jair.1.12190), which determines the likeliest failures for a system under test.
AdaStress provides three primary services:
- Interfaces between user simulations and the AST framework
@@ -192,49 +191,6 @@ The serialization capabilities also make it easier to interact with other progra
An `ASTServer` and `ASTClient` can be created separately and configured to exchange a minimal amount of information to enable stress-testing. This exchange can be further encrypted in various ways, in order to obscure the system under test from the stress-testing agent. For an example of serialized stress-testing, see the notebooks in `examples/pedestrian`.
-## Submodule management
-
-The submodule manager allows optional and experimental features with heavy dependencies to be made available without increasing the loading time of the base package. The user can selectively enable and disable these submodules as needed. In the background, the submodule manager maintains an internal project environment with a minimal set of necessary dependencies, avoiding the need to load unused packages.
-
-This system is made necessary by certain limitations of the language, which does not currently support optional dependencies. A common solution involves creating multiple separate packages to extend a base package; however, we consider this approach somewhat of an anti-pattern, and have chosen not to employ it here. In future versions of AdaStress, the submodule system may be removed if a suitable alternative is possible.
-
-### Using submodules
-
-Submodules are managed through the following API:
-
-> - **`AdaStress.submodules()`**
-> List all available submodules.
-> - **`AdaStress.enabled()`**
-> List enabled submodules.
-> - **`AdaStress.enable(submodule)`**
-> Enable submodule(s). Accepts string or vector of strings. With zero arguments defaults to all associated submodules. Takes effect immediately.
-> - **`AdaStress.disable(submodule)`**
-> Disable submodule(s). Accepts string or vector of strings. With zero arguments defaults to all enabled submodules. Takes effect after Julia restart.
-> - **`AdaStress.load()`**
-> Load enabled submodules (necessary after Julia restart). Takes effect immediately.
-> - **`AdaStress.clean()`**
-> Forcibly remove temporary environment, purging all enabled submodules. Only necessary if submodule manager is corrupted and `disable` cannot restore functionality. Takes effect after Julia restart.
-
-Enabling a submodule can take several seconds, particularly the first time. Due to current limitations of the language, previously enabled submodules cannot be automatically loaded when a new Julia session is launched. The user should use the `load` command for this, as in the following example. In the first session, it is necessary to run
-
-> ```
-> julia> using AdaStress
-> julia> AdaStress.enable("SoftActorCritic")
-> ```
-while in later sessions, the user may simply run
-> ```
-> julia> using AdaStress
-> julia> AdaStress.load()
-> ```
-
-### Multiprocessing
-
-Due to current bugs in the language, many processes related to code loading and environment management are not truly atomic. This can lead to problems when submodules are used in multiprocessing, as occurs with policy-value verification analysis. In such cases, care should be taken when invoking the submodule manager API asynchronously. For an example of loading submodules on multiple processes, see the notebook `examples/pvv`.
-
-### Creating submodules
-
-Custom submodules are essentially regular Julia packages that reside within the AdaStress directory tree, complete with a UUID and `Project.toml` file. Submodules are associated with AdaStress via the `exclude` command, similarly to how source files are associated via `include`.
-
## Solvers
A solver object is a standalone entity representing an algorithm and its parameters. A solver can be applied to an `ASTMDP` or a function that generates an `ASTMDP`, producing a `Result` object, as in
@@ -276,9 +232,25 @@ For an example of a problem solved with MCTS, see the notebook `examples/walk1d`
Global solvers aim to produce an adversarial policy mapping from simulator state to environment instance. The output of the solver is a function that takes as input an observation of the system and returns an action. In this way, failure trajectories can be produced from any given initialization. This opens the door to a richer analysis of the system's weaknesses.
-#### Soft actor-critic
+#### Q-learning
->This feature is contained in a submodule, and must be explicitly enabled.
+Q-learning is a classic reinforcement learning algorithm that uses a table-based policy to map states to optimal actions. Exploration is driven by an epsilon-greedy action selection approach. The simplicity of the QL algorithm makes it a useful baseline for more advanced methods.
+
+| Parameter | Type | Default | Description |
+| - | - | - | - |
+| `state_mins` | `Vector{Float64}` | `[0.0]` | Minimum values of state vector |
+| `state_maxs` | `Vector{Float64}` | `[1.0]` | Maximum values of state vector |
+| `state_divs` | `Vector{Int64}` | `[10]` | State space grid size |
+| `act_mins` | `Vector{Float64}` | `[-3.0]` | Minimum values of actions (normalized) |
+| `act_maxs` | `Vector{Float64}` | `[3.0]` | Maximum values of actions (normalized) |
+| `act_divs` | `Vector{Int64}` | `[10]` | Action space grid size |
+| `num_episodes` | `Int64` | `1000` | Number of episodes |
+| `alpha` | `Float64` | `0.1` | Learning rate |
+| `gamma` | `Float64` | `1.0` | Discount factor |
+| `eps` | `Float64` | `0.25` | Exploration parameter |
+| `reverse_update` | `Bool` | `true` | Update table in time-reverse order |
+
+#### Soft actor-critic
Soft actor-critic (SAC) is a deep reinforcement learning algorithm that simultaneously learns a value function and a policy for the `ASTMDP`. Both take the form of neural networks, which can be used to generate failures online in real-time or analyze system properties offline. SAC offers the following tunable parameters:
@@ -286,8 +258,8 @@ Soft actor-critic (SAC) is a deep reinforcement learning algorithm that simultan
| - | - | - | - |
| `obs_dim` | `Int64` | none | Dimension of observation space |
| `act_dim` | `Int64` | none | Dimension of action space |
-| `act_mins` | `Vector{Float64}` | none | Minimum values of actions |
-| `act_maxs` | `Vector{Float64}` | none | Maximum values of actions |
+| `act_mins` | `Vector{Float64}` | none | Minimum values of actions (normalized) |
+| `act_maxs` | `Vector{Float64}` | none | Maximum values of actions (normalized) |
| `gamma` | `Float64` | `0.999` | Discount factor |
| `max_buffer_size` | `Int64` | `100000` | Maximum number of timesteps in buffer |
| `hidden_sizes` | `Vector{Int}` | `[100,100,100]` | Dimensions of hidden layers |
@@ -327,8 +299,6 @@ The analysis module provide methods to further analyze results.
### Policy-value verification
->This feature is contained in a submodule, and must be explicitly enabled.
-
Policy-value verification (PVV) is an experimental method of analyzing the output of a global solver. It assembles the policy network and value network (or ensemble of value networks) into a single value function over the state space. Then, given a set condition on the value function, the algorithm uses an adaptive refinement process to classify regions of state space that provably satisfy the condition, violate the condition, or are unprovable at the given tolerance.
As a matter of ongoing research, requirements concerning the safety of the system can be linked to conditions on the value function. For instance, a requirement that the possibility of failure not exceed $10^{-9}$ from a set of initial states (given some modeled environmental stochasticity) translates to a constraint on the value function. The validity and practicality of this analysis is largely dependent on the learning process and is still uncertain. Nonetheless, the approach can currently generate *approximate* artifacts that may be useful for casual and nonrigorous analysis of system performance.
@@ -341,4 +311,4 @@ For an example of a problem analyzed with PVV, see the notebook `examples/pvv`.
The adaptive stress testing framework was proposed and developed by Ritchie Lee during his PhD under the supervision of Prof. Mykel Kochenderfer (Stanford University). Ritchie directed the creation of AdaStress and was instrumental in shaping our particular approach to this problem.
-Some of the basic nomenclature in AdaStress is borrowed from the package `POMDPStressTesting.jl`, namely the `GrayBox` and `BlackBox` terminology. Note that the usage and interpretation of these terms differs between the packages. Code that is compatible with one package cannot immediately be used with the other without modification.
\ No newline at end of file
+Some of the basic nomenclature in AdaStress is borrowed from the package `POMDPStressTesting.jl`, namely the `GrayBox` and `BlackBox` terminology. Note that the usage and interpretation of these terms differs between the packages. Code that is compatible with one package cannot immediately be used with the other without modification.
diff --git a/examples/cartpole/Project.toml b/examples/cartpole/Project.toml
index 7d9a931..e7a6566 100644
--- a/examples/cartpole/Project.toml
+++ b/examples/cartpole/Project.toml
@@ -1,11 +1,7 @@
[deps]
AdaStress = "f8632b6a-8763-4da0-bfaf-5f7707adef25"
-BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
-Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
-ReinforcementLearning = "158674fc-8238-5cab-b5ba-03dfc80d1318"
-StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
+ReinforcementLearningExperiments = "6bd458e5-1694-412f-b601-3a888375c491"
diff --git a/examples/cartpole/cartpole.ipynb b/examples/cartpole/cartpole.ipynb
index b386163..6ed3014 100644
--- a/examples/cartpole/cartpole.ipynb
+++ b/examples/cartpole/cartpole.ipynb
@@ -14,21 +14,10 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "loose-program",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\cartpole`\n",
- "\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\cartpole\\Project.toml`\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\cartpole\\Manifest.toml`\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"using Pkg\n",
"Pkg.activate(\".\")\n",
@@ -36,117 +25,6 @@
"Pkg.instantiate()"
]
},
- {
- "cell_type": "markdown",
- "id": "municipal-teach",
- "metadata": {},
- "source": [
- "# Loading cart-pole controller (SUT)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "extreme-carbon",
- "metadata": {},
- "outputs": [],
- "source": [
- "# to learn\n",
- "using ReinforcementLearning\n",
- "\n",
- "# to load\n",
- "using BSON\n",
- "using Flux\n",
- "using NNlib\n",
- "using Random\n",
- "using StableRNGs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "ready-revelation",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "learn_policy (generic function with 1 method)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# This function can be invoked to re-train the DQN policy if desired.\n",
- "function learn_policy()\n",
- " seed = 123\n",
- " rng = StableRNG(seed)\n",
- " env = CartPoleEnv(; T = Float32, rng = rng)\n",
- " ns, na = length(state(env)), length(action_space(env))\n",
- "\n",
- " policy = Agent(\n",
- " policy = QBasedPolicy(\n",
- " learner = BasicDQNLearner(\n",
- " approximator = NeuralNetworkApproximator(\n",
- " model = Chain(\n",
- " Dense(ns, 128, relu; init = glorot_uniform(rng)),\n",
- " Dense(128, 128, relu; init = glorot_uniform(rng)),\n",
- " Dense(128, na; init = glorot_uniform(rng)),\n",
- " ) |> cpu,\n",
- " optimizer = ADAM(),\n",
- " ),\n",
- " batch_size = 32,\n",
- " min_replay_history = 100,\n",
- " loss_func = Flux.huber_loss,\n",
- " rng = rng,\n",
- " ),\n",
- " explorer = EpsilonGreedyExplorer(\n",
- " kind = :exp,\n",
- " ϵ_stable = 0.01,\n",
- " decay_steps = 500,\n",
- " rng = rng,\n",
- " ),\n",
- " ),\n",
- " trajectory = CircularArraySARTTrajectory(\n",
- " capacity = 1000,\n",
- " state = Vector{Float32} => (ns,),\n",
- " ),\n",
- " )\n",
- " stop_condition = StopAfterStep(100_000)\n",
- " hook = ComposedHook(TotalRewardPerEpisode(), TimePerStep())\n",
- " \n",
- " ex = Experiment(policy, env, stop_condition, hook, \"\")\n",
- " run(ex)\n",
- " policy = policy.policy\n",
- " policy.explorer.is_training = false\n",
- " BSON.@save \"dqn_policy.bson\" policy\n",
- "end"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "complimentary-renaissance",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "load_policy (generic function with 1 method)"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "load_policy() = BSON.load(\"dqn_policy.bson\")[:policy]"
- ]
- },
{
"cell_type": "markdown",
"id": "postal-contract",
@@ -157,21 +35,10 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "introductory-thompson",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "false"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"using Random\n",
"using AdaStress\n",
@@ -179,31 +46,32 @@
"using Distributions\n",
"using Plots\n",
"using ProgressMeter\n",
+ "using ReinforcementLearningExperiments\n",
"ProgressMeter.ijulia_behavior(:clear)"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
+ "id": "f2ec6694-4ae0-45c9-98be-238f60159c3a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ex = run(E`JuliaRL_DQN_CartPole`)\n",
+ "ex.policy.policy.explorer.ϵ_stable = 0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"id": "appropriate-tender",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "CartPoleSim"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"Base.@kwdef mutable struct CartPoleSim <: Interface.GrayBox\n",
- " env::ReinforcementLearning.AbstractEnv = CartPoleEnv()\n",
+ " env::AbstractEnv = ex.env\n",
" tmax::Float64 = 50.0\n",
- " pi::ReinforcementLearning.AbstractPolicy = load_policy()\n",
+ " pi::AbstractPolicy = ex.policy.policy\n",
" x_dist::Interface.Environment = Interface.Environment(:wind => Normal(0.0, 0.01))\n",
" log::Dict{Symbol, Any} = Dict{Symbol, Any}()\n",
" logging::Bool = false\n",
@@ -212,24 +80,13 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "copyrighted-possible",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "mdp_env (generic function with 1 method)"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"function Interface.reset!(sim::CartPoleSim)\n",
- " ReinforcementLearning.reset!(sim.env)\n",
+ " reset!(sim.env)\n",
" if sim.logging\n",
" sim.log[:s] = []\n",
" sim.log[:x] = []\n",
@@ -240,13 +97,13 @@
"\n",
"Interface.environment(sim::CartPoleSim) = sim.x_dist\n",
"\n",
- "Interface.observe(sim::CartPoleSim) = push!(copy(ReinforcementLearning.state(sim.env)), sim.env.t / sim.tmax)\n",
+ "Interface.observe(sim::CartPoleSim) = push!(copy(state(sim.env)), sim.env.t / sim.tmax)\n",
"\n",
"function Interface.step!(sim::CartPoleSim, x::Interface.EnvironmentValue)\n",
- " a = sim.pi(sim.env)\n",
- " s = ReinforcementLearning.state(sim.env)\n",
+ " a = plan!(sim.pi, sim.env)\n",
+ " s = state(sim.env)\n",
" s[2] += x[:wind]\n",
- " sim.env(a)\n",
+ " act!(sim.env, a)\n",
" if sim.logging\n",
" push!(sim.log[:s], Interface.observe(sim))\n",
" push!(sim.log[:x], x[:wind])\n",
@@ -266,11 +123,11 @@
"interval_dist(x::Real, l::Real, u::Real) = (l < x < u) ? min(x - l, u - x) : zero(x)\n",
"\n",
"function Interface.distance(sim::CartPoleSim)\n",
- " s = ReinforcementLearning.state(sim.env)\n",
+ " s = state(sim.env)\n",
" dx = interval_dist(s[1], sim.env.params.xthreshold)\n",
" dθ = interval_dist(s[3], sim.env.params.thetathreshold)\n",
" d = sqrt(dx^2 + dθ^2)\n",
- " return d\n",
+ " return Float64(d)\n",
"end\n",
"\n",
"function mdp_env(σ::Float64)\n",
@@ -291,1122 +148,10 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "civic-knowing",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"p = plot(title=\"\\\\theta\")\n",
@@ -1435,25 +180,10 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "dangerous-clearing",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:05\u001b[39m\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "14 naturally-occurring failures found out of 10000 episodes.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"c = 0\n",
@@ -1476,70 +206,37 @@
},
{
"cell_type": "code",
- "execution_count": 10,
- "id": "reduced-retreat",
+ "execution_count": null,
+ "id": "ddcd2347-5dfa-4799-b24c-f0066700439d",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "┌ Info: Enabled submodule SoftActorCritic.\n",
- "└ @ AdaStress C:\\Users\\rlipkis\\.julia\\dev\\adastress\\src\\utils.jl:93\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "AdaStress.enable(\"SoftActorCritic\")\n",
"using AdaStress.SoftActorCritic"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "crucial-citizenship",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\r",
- "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:03:42\u001b[39m\r\n",
- "\u001b[34m epoch: 150\u001b[39m\r\n",
- "\u001b[34m score: 75.326614\u001b[39m\r\n",
- "\u001b[34m stdev: 0.87628365\u001b[39m\r\n",
- "\u001b[34m fails: 0.96\u001b[39m\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "(MLPActorCritic(SoftActorCritic.SquashedGaussianMLPActor(Chain(Dense(5, 30, relu), Dense(30, 30, relu), Dense(30, 30, relu)), Dense(30, 1), Dense(30, 1), Float32[-3.0], Float32[3.0], Random._GLOBAL_RNG(), nothing, false), SoftActorCritic.MLPQFunction[SoftActorCritic.MLPQFunction(Chain(Dense(6, 30, relu), Dense(30, 30, relu), Dense(30, 30, relu), Dense(30, 1))), SoftActorCritic.MLPQFunction(Chain(Dense(6, 30, relu), Dense(30, 30, relu), Dense(30, 30, relu), Dense(30, 1))), SoftActorCritic.MLPQFunction(Chain(Dense(6, 30, relu), Dense(30, 30, relu), Dense(30, 30, relu), Dense(30, 1)))]), Dict{String, Any}(\"score\" => Any[-1.2982365f0, -0.45816767f0, -0.26920554f0, -0.24219736f0, -0.27620187f0, -0.18340729f0, -0.17002267f0, -1.0217903f0, -0.46302056f0, -0.3155263f0 … 75.00554f0, 74.15143f0, 73.943085f0, 75.13819f0, 75.35046f0, 74.6657f0, 75.3588f0, 75.16508f0, 74.918846f0, 75.326614f0], \"stdev\" => Any[0.09985117f0, 0.07230078f0, 0.07011744f0, 0.07710096f0, 0.07252703f0, 0.06111533f0, 0.060430225f0, 0.06532489f0, 0.061455257f0, 0.054803085f0 … 0.9653635f0, 0.9764626f0, 1.0118048f0, 0.9269699f0, 0.92986727f0, 0.95796734f0, 0.9261049f0, 0.909985f0, 0.8943766f0, 0.87628365f0], \"fails\" => Any[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.98, 0.95, 0.95, 0.98, 0.99, 0.99, 0.99, 1.0, 0.98, 0.96]))"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"\n",
- "sac = AdaStress.SoftActorCritic.SAC(;\n",
+ "sac = SAC(;\n",
" obs_dim=5, \n",
" act_dim=1,\n",
" act_mins=-3.0*ones(1),\n",
" act_maxs=3.0*ones(1),\n",
- " hidden_sizes=[30,30,30],\n",
- " q_optimizer=AdaBelief(1e-4),\n",
- " pi_optimizer=AdaBelief(1e-4),\n",
- " alpha_optimizer=AdaBelief(1e-4),\n",
+ " hidden_sizes=[32,64,32],\n",
+ " q_optimizer=SoftActorCritic.AdaBelief(1e-4),\n",
+ " pi_optimizer=SoftActorCritic.AdaBelief(1e-4),\n",
+ " alpha_optimizer=SoftActorCritic.AdaBelief(1e-4),\n",
" gamma=1.0,\n",
" num_q=3,\n",
" max_buffer_size=100_000,\n",
" batch_size=1024,\n",
- " epochs=150,\n",
+ " epochs=500,\n",
" steps_per_epoch=1_000,\n",
" start_steps=10_000,\n",
" max_ep_len=100,\n",
@@ -1556,127 +253,10 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"id": "early-raise",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"plot(info[\"score\"]; label=:none)"
]
@@ -1691,1952 +271,15 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"id": "accessible-cleaner",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"pθ = plot(; title=\"\\\\theta\", ga=0.5, minorgrid=true, minorgridalpha=0.25)\n",
"px = plot(; title=\"x\", ga=0.5, minorgrid=true, minorgridalpha=0.25)\n",
- "mdp = mdp_env(0.1)\n",
+ "mdp = mdp_env(0.025)\n",
"mdp.sim.logging = true\n",
"n_eps = 100\n",
"\n",
@@ -3660,380 +303,10 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"id": "nasty-bunny",
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "metadata": {},
+ "outputs": [],
"source": [
"plot(mdp.sim.log[:x]; title=\"Sample disturbance\", label=:none, lc=:black, ga=0.5, minorgrid=true, minorgridalpha=0.25)"
]
@@ -4048,21 +321,10 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"id": "dominican-vitamin",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Rectangle (generic function with 1 method)"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"function Circle(xc::Real, yc::Real, r::Real; n::Int64=100)\n",
" θ = range(0, 2π; length=n+1)\n",
@@ -4080,23 +342,12 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"id": "instructional-pricing",
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "movie (generic function with 2 methods)"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"function movie(sim::CartPoleSim, filename::String=\"animation\")\n",
" cart_length = 0.5\n",
@@ -4171,35 +422,10 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"id": "mature-symphony",
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "┌ Info: Saved animation to \n",
- "│ fn = C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\cartpole\\animation.mp4\n",
- "└ @ Plots C:\\Users\\rlipkis\\.julia\\packages\\Plots\\YAlrZ\\src\\animation.jl:114\n"
- ]
- },
- {
- "data": {
- "text/html": [
- ""
- ],
- "text/plain": [
- "Plots.AnimatedGif(\"C:\\\\Users\\\\rlipkis\\\\.julia\\\\dev\\\\adastress\\\\examples\\\\cartpole\\\\animation.mp4\")"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "metadata": {},
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"mdp = mdp_env(0.1)\n",
@@ -4220,15 +446,15 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Julia 1.7.0",
+ "display_name": "Julia 1.10.2",
"language": "julia",
- "name": "julia-1.7"
+ "name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
- "version": "1.7.0"
+ "version": "1.10.2"
}
},
"nbformat": 4,
diff --git a/examples/cartpole/dqn_policy.bson b/examples/cartpole/dqn_policy.bson
deleted file mode 100644
index 91e9b79..0000000
Binary files a/examples/cartpole/dqn_policy.bson and /dev/null differ
diff --git a/examples/cas/cas.ipynb b/examples/cas/cas.ipynb
index 10e4589..26f0daf 100644
--- a/examples/cas/cas.ipynb
+++ b/examples/cas/cas.ipynb
@@ -19,23 +19,12 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "limited-circuit",
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\cas`\n",
- "\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\cas\\Project.toml`\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\cas\\Manifest.toml`\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"using Pkg\n",
"Pkg.activate(\".\")\n",
@@ -45,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "corrected-rescue",
"metadata": {},
"outputs": [],
@@ -59,21 +48,10 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "dated-revelation",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "mdp_env (generic function with 1 method)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"function mdp_env(; kwargs...)\n",
" sim = SimpleACAS.Simulator(; n=2, seed=0, randomize=false, kwargs...)\n",
@@ -93,1491 +71,10 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "indoor-lender",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mdp = mdp_env(; logging=true)\n",
"Random.seed!(0)\n",
@@ -1595,25 +92,10 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "caring-timer",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07\u001b[39m\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "2 naturally-occurring failures found out of 10000 episodes (0.02%).\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"c = 0\n",
@@ -1637,38 +119,10 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"id": "working-picking",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:10\u001b[39m\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "DataStructures.PriorityQueue{Any, Any, Base.Order.ForwardOrdering} with 10 entries:\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12370.7\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12372.2\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12373.7\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12375.2\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12382.2\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[8.64674, -5.52582, 0.87… => 12393.9\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12396.5\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12407.3\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[-3.6445, -0.483991, -0.… => 12415.6\n",
- " MCTSResult(Dict{Symbol, Any}[Dict(:cmd_1=>[3.32626, -4.53711, -0.7… => 12422.9"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mcts = AdaStress.Solvers.MCTS(num_iterations=100_000)\n",
"sol = mcts(mdp_env)"
@@ -1684,14265 +138,10 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "equivalent-render",
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "metadata": {},
+ "outputs": [],
"source": [
"mdp = mdp_env(; logging=true)\n",
"\n",
@@ -15966,15 +165,15 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Julia 1.7.0",
+ "display_name": "Julia 1.10.2",
"language": "julia",
- "name": "julia-1.7"
+ "name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
- "version": "1.7.0"
+ "version": "1.10.2"
}
},
"nbformat": 4,
diff --git a/examples/fms/fms.ipynb b/examples/fms/fms.ipynb
index 8226b9f..f4f427e 100644
--- a/examples/fms/fms.ipynb
+++ b/examples/fms/fms.ipynb
@@ -18,21 +18,10 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "dressed-reaction",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\fms`\n",
- "\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\fms\\Project.toml`\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\fms\\Manifest.toml`\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"using Pkg\n",
"Pkg.activate(\".\")\n",
@@ -42,19 +31,10 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "raising-middle",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]\n",
- "└ @ Base loading.jl:1423\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"using AdaStress\n",
"using Distributions\n",
@@ -75,21 +55,10 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "continued-thriller",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Plan"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"\"\"\"\n",
"A waypoint consists of cartesian coordinates and time of arrival.\n",
@@ -117,21 +86,10 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "editorial-peeing",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "evaluate"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"struct CollisionModule\n",
" d_crit::Float64 # critical separation threshold [miles]\n",
@@ -254,21 +212,10 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "novel-livestock",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "FMSim"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"\"\"\"\n",
"Initial position of aircrafts, determined by Gaussian spread.\n",
@@ -302,7 +249,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"id": "flying-projector",
"metadata": {},
"outputs": [],
@@ -354,21 +301,10 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "cultural-tourism",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "100.0"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mdp = Interface.ASTMDP(FMSim(; num_aircraft=2); episodic=true)\n",
"mdp.reward.event_bonus = 100.0"
@@ -376,66 +312,20 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "returning-wrist",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "AdaStress.Solvers.MonteCarloTreeSearch.MCTS(100000, 10, 1.0, 0.7, 1.0, nothing)"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mcts = AdaStress.Solvers.MCTS(num_iterations=100_000)"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "light-karaoke",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:09\u001b[39m39mm39m\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " 10.501736 seconds (139.77 M allocations: 6.454 GiB, 10.73% gc time, 21.19% compilation time)\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "DataStructures.PriorityQueue{Any, Any, Base.Order.ForwardOrdering} with 10 entries:\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.892\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.8928\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.9045\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.9092\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.9327\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.9383\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.9467\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.962\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 73.9684\n",
- " MCTSResult(UInt32[0xaa01c11f, 0x75ac8a97, 0x90db374b, 0x77809412, … => 74.0108"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"@time sol = mcts(() -> mdp)"
@@ -451,21 +341,10 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"id": "popular-garage",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "visualize (generic function with 1 method)"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"function visualize(sim::FMSim)\n",
" # all plans\n",
@@ -491,148 +370,10 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "genuine-fifth",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Plan solved in 9.7e-6 seconds.\n",
- "Collision detected.\n",
- "CDS separation threshold: 528.0 ft\n",
- "Sep. at closest approach: 226.3 ft\n"
- ]
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"best_path = last(collect(keys(sol)))\n",
"AdaStress.Solvers.replay!(mdp, best_path)\n",
@@ -650,15 +391,15 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Julia 1.7.0",
+ "display_name": "Julia 1.10.2",
"language": "julia",
- "name": "julia-1.7"
+ "name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
- "version": "1.7.0"
+ "version": "1.10.2"
}
},
"nbformat": 4,
diff --git a/examples/pedestrian/pedestrian-client.ipynb b/examples/pedestrian/pedestrian-client.ipynb
index 557a2e9..dd06c42 100644
--- a/examples/pedestrian/pedestrian-client.ipynb
+++ b/examples/pedestrian/pedestrian-client.ipynb
@@ -18,23 +18,12 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "celtic-clerk",
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\pedestrian`\n",
- "\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\pedestrian\\Project.toml`\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\pedestrian\\Manifest.toml`\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"using Pkg\n",
"Pkg.activate(\".\")\n",
@@ -44,7 +33,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "ongoing-regulation",
"metadata": {},
"outputs": [],
@@ -56,99 +45,40 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "metropolitan-charm",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "AdaStress.Interface.ASTClient(ip\"156.68.48.136\", 2000, nothing, false, false, false, Any[])"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"client = Interface.ASTClient(; ip=Interface.getipaddr(), port=2000)"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "challenging-switch",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "┌ Info: ASTServer responded in 1666 milliseconds.\n",
- "└ @ AdaStress.Interface C:\\Users\\rlipkis\\.julia\\dev\\adastress\\src\\interface\\remote\\client.jl:131\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"Interface.connect!(client)"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "forced-poker",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "MCTS(10000, 10, 1.0, 0.85, 1.0, nothing)"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mcts = MCTS(num_iterations=10_000, α=0.85)"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"id": "brutal-brooklyn",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:51\u001b[39m\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "DataStructures.PriorityQueue{Any, Any, Base.Order.ForwardOrdering} with 10 entries:\n",
- " MCTSResult(UInt32[0xe67d225b, 0x5f9be755, 0x0aed2eeb, 0x26840a96, … => 1087.52\n",
- " MCTSResult(UInt32[0xef05d725, 0xde2fae7c, 0x2cf7afe5, 0x512ae58c, … => 1087.8\n",
- " MCTSResult(UInt32[0xe67d225b, 0xaccaa267, 0x9944dbf8, 0x20727eff, … => 1087.81\n",
- " MCTSResult(UInt32[0xe67d225b, 0x34e1b7a5, 0x27d8ee5b, 0x4d041c41, … => 1087.82\n",
- " MCTSResult(UInt32[0xe67d225b, 0xd21aebe7, 0xd339f33c, 0x6847add5, … => 1087.86\n",
- " MCTSResult(UInt32[0xe67d225b, 0x34e1b7a5, 0x900a8fe4, 0xc1741547, … => 1087.87\n",
- " MCTSResult(UInt32[0xe67d225b, 0x34e1b7a5, 0xe3ebdc90, 0x0daa5991, … => 1088.13\n",
- " MCTSResult(UInt32[0xe67d225b, 0x34e1b7a5, 0x900a8fe4, 0xab3981e6, … => 1088.23\n",
- " MCTSResult(UInt32[0xe67d225b, 0x34e1b7a5, 0x900a8fe4, 0x3a475166, … => 1088.29\n",
- " MCTSResult(UInt32[0xe67d225b, 0x17803b35, 0x78726223, 0x7ea848d4, … => 1088.46"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"Random.seed!(0)\n",
"sol = mcts(() -> Interface.generate_mdp(client))"
@@ -156,19 +86,10 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "angry-speed",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "MonteCarloTreeSearch.total_size(mcts.tree) = 10000\n",
- "MonteCarloTreeSearch.max_depth(mcts.tree) = 8\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"@show MonteCarloTreeSearch.total_size(mcts.tree);\n",
"@show MonteCarloTreeSearch.max_depth(mcts.tree);"
@@ -176,7 +97,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "competitive-apollo",
"metadata": {},
"outputs": [],
@@ -186,7 +107,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "gentle-strengthening",
"metadata": {},
"outputs": [],
@@ -205,15 +126,15 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Julia 1.7.0",
+ "display_name": "Julia 1.10.2",
"language": "julia",
- "name": "julia-1.7"
+ "name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
- "version": "1.7.0"
+ "version": "1.10.2"
}
},
"nbformat": 4,
diff --git a/examples/pedestrian/pedestrian-server.ipynb b/examples/pedestrian/pedestrian-server.ipynb
index 3875cf6..362dcc2 100644
--- a/examples/pedestrian/pedestrian-server.ipynb
+++ b/examples/pedestrian/pedestrian-server.ipynb
@@ -22,24 +22,10 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "horizontal-queens",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\pedestrian`\n",
- "\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\pedestrian\\Project.toml`\n",
- "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `C:\\Users\\rlipkis\\.julia\\dev\\adastress\\examples\\pedestrian\\Manifest.toml`\n",
- "\u001b[32m\u001b[1mPrecompiling\u001b[22m\u001b[39m project...\n",
- "\u001b[32m ✓ \u001b[39mAdaStress\n",
- " 1 dependency successfully precompiled in 3 seconds (143 already precompiled)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"using Pkg\n",
"Pkg.activate(\".\")\n",
@@ -57,7 +43,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "painful-cincinnati",
"metadata": {},
"outputs": [],
@@ -69,21 +55,10 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "interior-plenty",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "pedestrian_avoidance (generic function with 1 method)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"abstract type Actor end\n",
"\n",
@@ -126,21 +101,10 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "round-quest",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "update! (generic function with 3 methods)"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"distance(car::SelfDrivingCar, ped::Pedestrian) = sqrt(car.state[1]^2 + ped.state[1]^2)\n",
"\n",
@@ -206,7 +170,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "nonprofit-miracle",
"metadata": {},
"outputs": [],
@@ -228,21 +192,10 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"id": "angry-speed",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "1000.0"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mdp = AdaStress.ASTMDP(DriveSim())\n",
"mdp.reward.event_bonus = 1000.0"
@@ -250,56 +203,29 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "japanese-mason",
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "AdaStress.Interface.ASTServer(ip\"0.0.0.0\", 2000, nothing, AdaStress.Interface.ASTMDP{AdaStress.Interface.UnobservableState, AdaStress.Interface.SeedAction}(DriveSim(0.0, 1.0, 15.0, SelfDrivingCar([-100.0, 10.0], [-10.0, 0.5], 0.5, 3.0, 3.0), Pedestrian([-10.0, 0.5], Normal{Float64}(μ=0.0, σ=0.25)), 2.0, Dict{Any, Any}(:ped => MVector{2, Float64}[[-10.0, 0.5]], :car => MVector{2, Float64}[[-100.0, 10.0]], :t => [0.0])), AdaStress.Interface.Reward(true, AdaStress.Interface.GradientHeuristic(), 1000.0, AdaStress.Interface.WeightedObjective(1.0, 1.0, 1.0)), false, 0, Dict{Symbol, AdaStress.Interface.VariableInfo}(), Random.TaskLocalRNG()), nothing, false, false, true)"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"server = AdaStress.ASTServer(mdp; ip=Interface.IPv4(0), port=2000)"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "amazing-limitation",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Enter password: ········\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "┌ Info: Private token set.\n",
- "└ @ AdaStress.Interface C:\\Users\\rlipkis\\.julia\\dev\\adastress\\src\\interface\\remote\\server.jl:50\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"Interface.set_password(server)"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "earlier-optics",
"metadata": {},
"outputs": [],
@@ -317,21 +243,10 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"id": "personalized-japanese",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "┌ Info: Connected to AST client.\n",
- "│ conn = Sockets.TCPSocket(Base.Libc.WindowsRawSocket(0x00000000000004b4) open, 0 bytes waiting)\n",
- "└ @ AdaStress.Interface C:\\Users\\rlipkis\\.julia\\dev\\adastress\\src\\interface\\remote\\server.jl:128\n",
- "WARNING: using DataStructures.update! in module Main conflicts with an existing identifier.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"using BSON\n",
"using DataStructures\n",
@@ -340,7 +255,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "medieval-seminar",
"metadata": {},
"outputs": [],
@@ -350,7 +265,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"id": "inclusive-ecuador",
"metadata": {},
"outputs": [],
@@ -363,148 +278,10 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"id": "judicial-space",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"best_path = collect(keys(sol))[end]\n",
"AdaStress.Solvers.replay!(mdp, best_path)\n",
@@ -518,152 +295,10 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"id": "imposed-float",
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "metadata": {},
+ "outputs": [],
"source": [
"acc_paths = []\n",
"\n",
@@ -691,15 +326,15 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Julia 1.7.0",
+ "display_name": "Julia 1.10.2",
"language": "julia",
- "name": "julia-1.7"
+ "name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
- "version": "1.7.0"
+ "version": "1.10.2"
}
},
"nbformat": 4,
diff --git a/examples/pvv/pvv.ipynb b/examples/pvv/pvv.ipynb
index a22845c..5b68d09 100644
--- a/examples/pvv/pvv.ipynb
+++ b/examples/pvv/pvv.ipynb
@@ -79,8 +79,8 @@
"metadata": {},
"outputs": [],
"source": [
- "using AdaStress.PolicyValueVerification\n",
- "using AdaStress.SoftActorCritic"
+ "using .PolicyValueVerification\n",
+ "using .SoftActorCritic"
]
},
{
@@ -119,795 +119,7 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
+ "image/svg+xml": "\n\n"
},
"execution_count": 7,
"metadata": {},
@@ -1007,9132 +219,7 @@
},
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
+ "image/svg+xml": "\n\n"
},
"execution_count": 11,
"metadata": {},
@@ -10151,12009 +238,7 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "