Skip to content

Commit

Permalink
adding a notebook to save an SBI model
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Jan 26, 2024
1 parent a00c620 commit 64b8cdf
Showing 1 changed file with 153 additions and 0 deletions.
153 changes: 153 additions & 0 deletions notebooks/train_SBI.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d467063e-00c2-48e1-a214-434767a4bc37",
"metadata": {},
"source": [
"# Quick train SBI\n",
"Then save the posterior as a pkl using the evaluate module."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "486dda47-bf7b-45ea-88fe-55960d81c4bb",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'sbi'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msbi\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# from sbi import inference\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msbi\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minference\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SNPE\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'sbi'"
]
}
],
"source": [
"import sbi\n",
"from sbi.inference import SNPE\n",
"from sbi.inference.base import infer\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f64b72b1-3c46-45af-932e-59512b2adbc8",
"metadata": {},
"outputs": [],
"source": [
"# this is necessary to import modules from this repo\n",
"import sys\n",
"sys.path.append('..')\n",
"from src.scripts import evaluate"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cd5034fb-94da-4b3d-b5ca-89f16ed98ff9",
"metadata": {},
"outputs": [],
"source": [
"def simulator(thetas):#, percent_errors):\n",
" # just plop the pendulum within here\n",
" m, b = thetas\n",
" x = np.linspace(0, 100, 101)\n",
" rs = np.random.RandomState()#2147483648)# \n",
" # I'm thinking sigma could actually be a function of x\n",
" # if we want to get fancy down the road\n",
" sigma = 10\n",
" ε = rs.normal(loc=0, scale=sigma, size = len(x)) \n",
" return m * x + b + ε"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9fe446e0-e80e-4c6a-a67e-8a8bd19d2787",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'torch' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m num_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 3\u001b[0m low_bounds \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m10\u001b[39m])\n\u001b[1;32m 4\u001b[0m high_bounds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m10\u001b[39m])\n\u001b[1;32m 6\u001b[0m prior \u001b[38;5;241m=\u001b[39m sbi\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mBoxUniform(low \u001b[38;5;241m=\u001b[39m low_bounds, high \u001b[38;5;241m=\u001b[39m high_bounds)\n",
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
]
}
],
"source": [
"num_dim = 2\n",
"\n",
"low_bounds = torch.tensor([0, -10])\n",
"high_bounds = torch.tensor([10, 10])\n",
"\n",
"prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b4d1c9af-cedc-483b-92ba-8bb1d195bccf",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'infer' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m posterior \u001b[38;5;241m=\u001b[39m \u001b[43minfer\u001b[49m(simulator, prior, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSNPE\u001b[39m\u001b[38;5;124m\"\u001b[39m, num_simulations\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10000\u001b[39m)\n",
"\u001b[0;31mNameError\u001b[0m: name 'infer' is not defined"
]
}
],
"source": [
"posterior = infer(simulator, prior, \"SNPE\", num_simulations=10000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07042d00-c9b1-494b-8870-d93ad18e2e11",
"metadata": {},
"outputs": [],
"source": [
"inference_model = evaluate.InferenceModel()\n",
"path = \"saved_models/sbi/\"\n",
"model_name = \"sbi_linear\"\n",
"inference_model.save_model_pkl(self, path, model_name, posterior)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 64b8cdf

Please sign in to comment.