-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding a notebook to save an SBI model
- Loading branch information
1 parent
a00c620
commit 64b8cdf
Showing
1 changed file
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d467063e-00c2-48e1-a214-434767a4bc37", | ||
"metadata": {}, | ||
"source": [ | ||
"# Quick train SBI\n", | ||
"Then save the posterior as a pkl using the evaluate module." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "486dda47-bf7b-45ea-88fe-55960d81c4bb", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "ModuleNotFoundError", | ||
"evalue": "No module named 'sbi'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msbi\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# from sbi import inference\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msbi\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minference\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SNPE\n", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'sbi'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import sbi\n", | ||
"from sbi.inference import SNPE\n", | ||
"from sbi.inference.base import infer\n", | ||
"import torch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "f64b72b1-3c46-45af-932e-59512b2adbc8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# this is necessary to import modules from this repo\n", | ||
"import sys\n", | ||
"sys.path.append('..')\n", | ||
"from src.scripts import evaluate" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "cd5034fb-94da-4b3d-b5ca-89f16ed98ff9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def simulator(thetas):#, percent_errors):\n", | ||
" # just plop the pendulum within here\n", | ||
" m, b = thetas\n", | ||
" x = np.linspace(0, 100, 101)\n", | ||
" rs = np.random.RandomState()#2147483648)# \n", | ||
" # I'm thinking sigma could actually be a function of x\n", | ||
" # if we want to get fancy down the road\n", | ||
" sigma = 10\n", | ||
" ε = rs.normal(loc=0, scale=sigma, size = len(x)) \n", | ||
" return m * x + b + ε" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "9fe446e0-e80e-4c6a-a67e-8a8bd19d2787", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "NameError", | ||
"evalue": "name 'torch' is not defined", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[4], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m num_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 3\u001b[0m low_bounds \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m10\u001b[39m])\n\u001b[1;32m 4\u001b[0m high_bounds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m10\u001b[39m])\n\u001b[1;32m 6\u001b[0m prior \u001b[38;5;241m=\u001b[39m sbi\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mBoxUniform(low \u001b[38;5;241m=\u001b[39m low_bounds, high \u001b[38;5;241m=\u001b[39m high_bounds)\n", | ||
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"num_dim = 2\n", | ||
"\n", | ||
"low_bounds = torch.tensor([0, -10])\n", | ||
"high_bounds = torch.tensor([10, 10])\n", | ||
"\n", | ||
"prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "b4d1c9af-cedc-483b-92ba-8bb1d195bccf", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "NameError", | ||
"evalue": "name 'infer' is not defined", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m posterior \u001b[38;5;241m=\u001b[39m \u001b[43minfer\u001b[49m(simulator, prior, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSNPE\u001b[39m\u001b[38;5;124m\"\u001b[39m, num_simulations\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10000\u001b[39m)\n", | ||
"\u001b[0;31mNameError\u001b[0m: name 'infer' is not defined" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"posterior = infer(simulator, prior, \"SNPE\", num_simulations=10000)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "07042d00-c9b1-494b-8870-d93ad18e2e11", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"inference_model = evaluate.InferenceModel()\n", | ||
"path = \"saved_models/sbi/\"\n", | ||
"model_name = \"sbi_linear\"\n", | ||
"inference_model.save_model_pkl(self, path, model_name, posterior)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |