diff --git a/notebooks/21_dc-resistivity-inversion-w-beta-cooling.ipynb b/notebooks/21_dc-resistivity-inversion-w-beta-cooling.ipynb index 1e0bc8b..1396738 100644 --- a/notebooks/21_dc-resistivity-inversion-w-beta-cooling.ipynb +++ b/notebooks/21_dc-resistivity-inversion-w-beta-cooling.ipynb @@ -14,10 +14,11 @@ "id": "0888671c-f9b1-4b4e-8eb2-f8aca720920e", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:17.552484Z", - "iopub.status.busy": "2025-10-14T21:26:17.551940Z", - "iopub.status.idle": "2025-10-14T21:26:19.456066Z", - "shell.execute_reply": "2025-10-14T21:26:19.455476Z" + "iopub.execute_input": "2025-10-20T18:45:12.294642Z", + "iopub.status.busy": "2025-10-20T18:45:12.294242Z", + "iopub.status.idle": "2025-10-20T18:45:13.853287Z", + "shell.execute_reply": "2025-10-20T18:45:13.852711Z", + "shell.execute_reply.started": "2025-10-20T18:45:12.294613Z" } }, "outputs": [], @@ -49,10 +50,11 @@ "id": "0b8dce76-9a8e-42f2-8ddf-f3dcae0674a5", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:19.458206Z", - "iopub.status.busy": "2025-10-14T21:26:19.457897Z", - "iopub.status.idle": "2025-10-14T21:26:19.465731Z", - "shell.execute_reply": "2025-10-14T21:26:19.465157Z" + "iopub.execute_input": "2025-10-20T18:45:13.854039Z", + "iopub.status.busy": "2025-10-20T18:45:13.853737Z", + "iopub.status.idle": "2025-10-20T18:45:13.863482Z", + "shell.execute_reply": "2025-10-20T18:45:13.862795Z", + "shell.execute_reply.started": "2025-10-20T18:45:13.854020Z" } }, "outputs": [ @@ -82,10 +84,11 @@ "id": "8db572c3-2960-42ca-9d52-0547431c5a80", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:19.467442Z", - "iopub.status.busy": "2025-10-14T21:26:19.467251Z", - "iopub.status.idle": "2025-10-14T21:26:19.471743Z", - "shell.execute_reply": "2025-10-14T21:26:19.471300Z" + "iopub.execute_input": "2025-10-20T18:45:13.864747Z", + "iopub.status.busy": "2025-10-20T18:45:13.864403Z", + "iopub.status.idle": "2025-10-20T18:45:13.870833Z", + "shell.execute_reply": "2025-10-20T18:45:13.870067Z", + "shell.execute_reply.started": "2025-10-20T18:45:13.864714Z" } }, "outputs": [ @@ -113,10 +116,11 @@ "id": "14621398-5fe7-4fe2-843d-7d067e026e18", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:19.473522Z", - "iopub.status.busy": "2025-10-14T21:26:19.473265Z", - "iopub.status.idle": "2025-10-14T21:26:19.476447Z", - "shell.execute_reply": "2025-10-14T21:26:19.475979Z" + "iopub.execute_input": "2025-10-20T18:45:13.872139Z", + "iopub.status.busy": "2025-10-20T18:45:13.871750Z", + "iopub.status.idle": "2025-10-20T18:45:13.877033Z", + "shell.execute_reply": "2025-10-20T18:45:13.876471Z", + "shell.execute_reply.started": "2025-10-20T18:45:13.872103Z" } }, "outputs": [], @@ -131,10 +135,11 @@ "id": "d994d154-91bb-4fb2-82ce-e3dc7e9f66cf", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:19.478370Z", - "iopub.status.busy": "2025-10-14T21:26:19.478102Z", - "iopub.status.idle": "2025-10-14T21:26:19.488735Z", - "shell.execute_reply": "2025-10-14T21:26:19.488175Z" + "iopub.execute_input": "2025-10-20T18:45:13.878014Z", + "iopub.status.busy": "2025-10-20T18:45:13.877781Z", + "iopub.status.idle": "2025-10-20T18:45:13.893415Z", + "shell.execute_reply": "2025-10-20T18:45:13.892812Z", + "shell.execute_reply.started": "2025-10-20T18:45:13.877993Z" } }, "outputs": [], @@ -148,10 +153,11 @@ "id": "0dc7c54b-8df0-42ac-8ab9-01302ffb7b59", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:19.490570Z", - "iopub.status.busy": "2025-10-14T21:26:19.490309Z", - "iopub.status.idle": "2025-10-14T21:26:19.493370Z", - "shell.execute_reply": "2025-10-14T21:26:19.492675Z" + "iopub.execute_input": "2025-10-20T18:45:13.895342Z", + "iopub.status.busy": "2025-10-20T18:45:13.895122Z", + "iopub.status.idle": "2025-10-20T18:45:13.898898Z", + "shell.execute_reply": "2025-10-20T18:45:13.898218Z", + "shell.execute_reply.started": "2025-10-20T18:45:13.895323Z" } }, "outputs": [], @@ -166,10 +172,11 @@ "id": "79ca22c8-4258-400f-8bf6-01b4f0c63641", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:19.495762Z", - "iopub.status.busy": "2025-10-14T21:26:19.495219Z", - "iopub.status.idle": "2025-10-14T21:26:19.499836Z", - "shell.execute_reply": "2025-10-14T21:26:19.499211Z" + "iopub.execute_input": "2025-10-20T18:45:13.899792Z", + "iopub.status.busy": "2025-10-20T18:45:13.899594Z", + "iopub.status.idle": "2025-10-20T18:45:13.908810Z", + "shell.execute_reply": "2025-10-20T18:45:13.908125Z", + "shell.execute_reply.started": "2025-10-20T18:45:13.899774Z" } }, "outputs": [], @@ -188,10 +195,11 @@ "id": "d3abdcae-d68d-4c56-8633-e25dc6804dd2", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:19.502113Z", - "iopub.status.busy": "2025-10-14T21:26:19.501766Z", - "iopub.status.idle": "2025-10-14T21:26:20.326331Z", - "shell.execute_reply": "2025-10-14T21:26:20.325760Z" + "iopub.execute_input": "2025-10-20T18:45:13.909808Z", + "iopub.status.busy": "2025-10-20T18:45:13.909542Z", + "iopub.status.idle": "2025-10-20T18:45:14.725648Z", + "shell.execute_reply": "2025-10-20T18:45:14.725143Z", + "shell.execute_reply.started": "2025-10-20T18:45:13.909782Z" } }, "outputs": [ @@ -272,10 +280,11 @@ "id": "78fb6825-7a38-43c8-abf8-96177abe07d5", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.328134Z", - "iopub.status.busy": "2025-10-14T21:26:20.327886Z", - "iopub.status.idle": "2025-10-14T21:26:20.367713Z", - "shell.execute_reply": "2025-10-14T21:26:20.367039Z" + "iopub.execute_input": "2025-10-20T18:45:14.726329Z", + "iopub.status.busy": "2025-10-20T18:45:14.726144Z", + "iopub.status.idle": "2025-10-20T18:45:14.774665Z", + "shell.execute_reply": "2025-10-20T18:45:14.773909Z", + "shell.execute_reply.started": "2025-10-20T18:45:14.726312Z" } }, "outputs": [], @@ -318,10 +327,11 @@ "id": "ad513789-95df-41c1-a1c0-9b686fab6341", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.369779Z", - "iopub.status.busy": "2025-10-14T21:26:20.369574Z", - "iopub.status.idle": "2025-10-14T21:26:20.374230Z", - "shell.execute_reply": "2025-10-14T21:26:20.373690Z" + "iopub.execute_input": "2025-10-20T18:45:14.775658Z", + "iopub.status.busy": "2025-10-20T18:45:14.775409Z", + "iopub.status.idle": "2025-10-20T18:45:14.782176Z", + "shell.execute_reply": "2025-10-20T18:45:14.781258Z", + "shell.execute_reply.started": "2025-10-20T18:45:14.775635Z" } }, "outputs": [], @@ -339,10 +349,11 @@ "id": "f9bceaab-ece7-4d57-8b58-734586d56c25", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.376321Z", - "iopub.status.busy": "2025-10-14T21:26:20.376047Z", - "iopub.status.idle": "2025-10-14T21:26:20.383519Z", - "shell.execute_reply": "2025-10-14T21:26:20.382877Z" + "iopub.execute_input": "2025-10-20T18:45:14.783471Z", + "iopub.status.busy": "2025-10-20T18:45:14.783109Z", + "iopub.status.idle": "2025-10-20T18:45:14.794382Z", + "shell.execute_reply": "2025-10-20T18:45:14.793523Z", + "shell.execute_reply.started": "2025-10-20T18:45:14.783435Z" } }, "outputs": [], @@ -365,10 +376,11 @@ "id": "36c320ef-116c-4794-ab10-a64d1f15ccf6", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.385500Z", - "iopub.status.busy": "2025-10-14T21:26:20.385253Z", - "iopub.status.idle": "2025-10-14T21:26:20.390199Z", - "shell.execute_reply": "2025-10-14T21:26:20.389565Z" + "iopub.execute_input": "2025-10-20T18:45:14.795686Z", + "iopub.status.busy": "2025-10-20T18:45:14.795350Z", + "iopub.status.idle": "2025-10-20T18:45:14.802981Z", + "shell.execute_reply": "2025-10-20T18:45:14.801897Z", + "shell.execute_reply.started": "2025-10-20T18:45:14.795651Z" } }, "outputs": [], @@ -387,10 +399,11 @@ "id": "b381c61e-b219-47a2-9a7f-3f9b9b487179", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.392158Z", - "iopub.status.busy": "2025-10-14T21:26:20.391927Z", - "iopub.status.idle": "2025-10-14T21:26:20.395242Z", - "shell.execute_reply": "2025-10-14T21:26:20.394732Z" + "iopub.execute_input": "2025-10-20T18:45:14.804464Z", + "iopub.status.busy": "2025-10-20T18:45:14.804063Z", + "iopub.status.idle": "2025-10-20T18:45:14.811023Z", + "shell.execute_reply": "2025-10-20T18:45:14.810057Z", + "shell.execute_reply.started": "2025-10-20T18:45:14.804426Z" } }, "outputs": [], @@ -412,10 +425,11 @@ "id": "777cf480-633c-4a1f-88c3-ae027697aba7", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.397186Z", - "iopub.status.busy": "2025-10-14T21:26:20.396955Z", - "iopub.status.idle": "2025-10-14T21:26:20.808136Z", - "shell.execute_reply": "2025-10-14T21:26:20.807395Z" + "iopub.execute_input": "2025-10-20T18:45:14.812345Z", + "iopub.status.busy": "2025-10-20T18:45:14.811985Z", + "iopub.status.idle": "2025-10-20T18:45:15.195302Z", + "shell.execute_reply": "2025-10-20T18:45:15.194761Z", + "shell.execute_reply.started": "2025-10-20T18:45:14.812306Z" } }, "outputs": [], @@ -452,17 +466,18 @@ "id": "445f4b23-8877-49e6-8b4a-33049e731855", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.810962Z", - "iopub.status.busy": "2025-10-14T21:26:20.810462Z", - "iopub.status.idle": "2025-10-14T21:26:20.814803Z", - "shell.execute_reply": "2025-10-14T21:26:20.814324Z" + "iopub.execute_input": "2025-10-20T18:45:15.196091Z", + "iopub.status.busy": "2025-10-20T18:45:15.195899Z", + "iopub.status.idle": "2025-10-20T18:45:15.200182Z", + "shell.execute_reply": "2025-10-20T18:45:15.199588Z", + "shell.execute_reply.started": "2025-10-20T18:45:15.196074Z" } }, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 15, @@ -489,10 +504,11 @@ "id": "9320114b-435f-48f8-8ec1-d830678e2a3c", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.816699Z", - "iopub.status.busy": "2025-10-14T21:26:20.816450Z", - "iopub.status.idle": "2025-10-14T21:26:20.819442Z", - "shell.execute_reply": "2025-10-14T21:26:20.818876Z" + "iopub.execute_input": "2025-10-20T18:45:15.200939Z", + "iopub.status.busy": "2025-10-20T18:45:15.200743Z", + "iopub.status.idle": "2025-10-20T18:45:15.214842Z", + "shell.execute_reply": "2025-10-20T18:45:15.214219Z", + "shell.execute_reply.started": "2025-10-20T18:45:15.200923Z" } }, "outputs": [], @@ -517,10 +533,11 @@ "id": "9042f903-4d02-4abe-9de1-2dfacfb3b3a0", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:20.821210Z", - "iopub.status.busy": "2025-10-14T21:26:20.821014Z", - "iopub.status.idle": "2025-10-14T21:26:27.791405Z", - "shell.execute_reply": "2025-10-14T21:26:27.790627Z" + "iopub.execute_input": "2025-10-20T18:45:15.215813Z", + "iopub.status.busy": "2025-10-20T18:45:15.215583Z", + "iopub.status.idle": "2025-10-20T18:45:20.221145Z", + "shell.execute_reply": "2025-10-20T18:45:20.220472Z", + "shell.execute_reply.started": "2025-10-20T18:45:15.215793Z" } }, "outputs": [ @@ -531,13 +548,7 @@ "/home/santi/.miniforge3/envs/inversion_ideas/lib/python3.13/site-packages/simpeg/electromagnetics/static/resistivity/simulation_2d.py:768: RuntimeWarning: invalid value encountered in divide\n", " r_hat = r_vec / r[:, None]\n", "/home/santi/.miniforge3/envs/inversion_ideas/lib/python3.13/site-packages/simpeg/electromagnetics/static/resistivity/simulation_2d.py:795: RuntimeWarning: invalid value encountered in divide\n", - " alpha[not_top] = (ky * k1e(ky * r) / k0e(ky * r) * r_dot_n)[not_top]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + " alpha[not_top] = (ky * k1e(ky * r) / k0e(ky * r) * r_dot_n)[not_top]\n", "/home/santi/.miniforge3/envs/inversion_ideas/lib/python3.13/site-packages/pymatsolver/solvers.py:415: FutureWarning: In Future pymatsolver v0.4.0, passing a vector of shape (n, 1) to the solve method will return an array with shape (n, 1), instead of always returning a flattened array. This is to be consistent with numpy.linalg.solve broadcasting.\n", " return self.solve(val)\n" ] @@ -607,10 +618,11 @@ "id": "e81a3051-9874-40f6-95ca-d1f8559ec1bb", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:27.794617Z", - "iopub.status.busy": "2025-10-14T21:26:27.794273Z", - "iopub.status.idle": "2025-10-14T21:26:27.801569Z", - "shell.execute_reply": "2025-10-14T21:26:27.800716Z" + "iopub.execute_input": "2025-10-20T18:45:20.222058Z", + "iopub.status.busy": "2025-10-20T18:45:20.221827Z", + "iopub.status.idle": "2025-10-20T18:45:20.227899Z", + "shell.execute_reply": "2025-10-20T18:45:20.227284Z", + "shell.execute_reply.started": "2025-10-20T18:45:20.222038Z" } }, "outputs": [], @@ -667,17 +679,18 @@ "id": "bb78df59-73b9-4c45-aeca-feb9a9cdcfe1", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:26:27.804792Z", - "iopub.status.busy": "2025-10-14T21:26:27.804117Z", - "iopub.status.idle": "2025-10-14T21:28:03.030533Z", - "shell.execute_reply": "2025-10-14T21:28:03.029976Z" + "iopub.execute_input": "2025-10-20T18:45:20.228936Z", + "iopub.status.busy": "2025-10-20T18:45:20.228691Z", + "iopub.status.idle": "2025-10-20T18:47:17.008221Z", + "shell.execute_reply": "2025-10-20T18:47:17.007362Z", + "shell.execute_reply.started": "2025-10-20T18:45:20.228912Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "dc5ebc07d17443aeacf559e2064d7d3f", + "model_id": "d56819732741432ea1a8482223138ef2", "version_major": 2, "version_minor": 0 }, @@ -688,27 +701,6 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: ⚠️ Reached maximum number of Gauss-Newton iterations (2).\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: ⚠️ Reached maximum number of Gauss-Newton iterations (2).\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: ⚠️ Reached maximum number of Gauss-Newton iterations (2).\n" - ] - }, { "name": "stderr", "output_type": "stream", @@ -734,13 +726,164 @@ { "cell_type": "code", "execution_count": 20, + "id": "74fe6998-1343-4a93-a293-648ae8077567", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-20T18:47:17.009526Z", + "iopub.status.busy": "2025-10-20T18:47:17.009191Z", + "iopub.status.idle": "2025-10-20T18:47:17.036776Z", + "shell.execute_reply": "2025-10-20T18:47:17.035952Z", + "shell.execute_reply.started": "2025-10-20T18:47:17.009493Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━┓\n",
+       "┃ iteration  model                        objective_value  conj_grad_iters  line_search_iters  step_norm    ┃\n",
+       "┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━┩\n",
+       "│ 0         │ [-4.58151669 -4.58151669    │ 4.45e+04        │ 0               │ 0                 │ 0         │   │\n",
+       "│           │ -4.58151669 ... -4.58151669 │                 │                 │                   │           │   │\n",
+       "│           │ -4.58151669                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58151669]               │                 │                 │                   │           │   │\n",
+       "│ 1         │ [-4.5815244  -4.58152295    │ 4.45e+04        │ 0               │ 1                 │ 1.42e+01  │ 0 │\n",
+       "│           │ -4.58153199 ... -4.5815263  │                 │                 │                   │           │   │\n",
+       "│           │ -4.58154021                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58185369]               │                 │                 │                   │           │   │\n",
+       "│ 2         │ [-4.58152553 -4.58152388    │ 1.27e+04        │ 0               │ 1                 │ 8.19e+00  │ 0 │\n",
+       "│           │ -4.58153433 ... -4.58152755 │                 │                 │                   │           │   │\n",
+       "│           │ -4.58154342                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58189799]               │                 │                 │                   │           │   │\n",
+       "└───────────┴─────────────────────────────┴─────────────────┴─────────────────┴───────────────────┴───────────┴───┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1miteration\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodel \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mobjective_value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mconj_grad_iters\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mline_search_iters\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mstep_norm\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━┩\n", + "│ 0 │ [-4.58151669 -4.58151669 │ 4.45e+04 │ 0 │ 0 │ 0 │ │\n", + "│ │ -4.58151669 ... -4.58151669 │ │ │ │ │ │\n", + "│ │ -4.58151669 │ │ │ │ │ │\n", + "│ │ -4.58151669] │ │ │ │ │ │\n", + "│ 1 │ [-4.5815244 -4.58152295 │ 4.45e+04 │ 0 │ 1 │ 1.42e+01 │ 0 │\n", + "│ │ -4.58153199 ... -4.5815263 │ │ │ │ │ │\n", + "│ │ -4.58154021 │ │ │ │ │ │\n", + "│ │ -4.58185369] │ │ │ │ │ │\n", + "│ 2 │ [-4.58152553 -4.58152388 │ 1.27e+04 │ 0 │ 1 │ 8.19e+00 │ 0 │\n", + "│ │ -4.58153433 ... -4.58152755 │ │ │ │ │ │\n", + "│ │ -4.58154342 │ │ │ │ │ │\n", + "│ │ -4.58189799] │ │ │ │ │ │\n", + "└───────────┴─────────────────────────────┴─────────────────┴─────────────────┴───────────────────┴───────────┴───┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━┓\n",
+       "┃ iteration  model                        objective_value  conj_grad_iters  line_search_iters  step_norm    ┃\n",
+       "┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━┩\n",
+       "│ 0         │ [-4.58152553 -4.58152388    │ 3.86e+03        │ 0               │ 0                 │ 0         │   │\n",
+       "│           │ -4.58153433 ... -4.58152755 │                 │                 │                   │           │   │\n",
+       "│           │ -4.58154342                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58189799]               │                 │                 │                   │           │   │\n",
+       "│ 1         │ [-4.58152557 -4.5815239     │ 3.86e+03        │ 0               │ 1                 │ 5.70e+00  │ 0 │\n",
+       "│           │ -4.58153457 ... -4.5815272  │                 │                 │                   │           │   │\n",
+       "│           │ -4.58154266                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58188797]               │                 │                 │                   │           │   │\n",
+       "│ 2         │ [-4.58152614 -4.58152436    │ 1.32e+03        │ 0               │ 1                 │ 3.53e+00  │ 0 │\n",
+       "│           │ -4.5815355  ... -4.58152848 │                 │                 │                   │           │   │\n",
+       "│           │ -4.58154566                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58192929]               │                 │                 │                   │           │   │\n",
+       "└───────────┴─────────────────────────────┴─────────────────┴─────────────────┴───────────────────┴───────────┴───┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1miteration\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodel \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mobjective_value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mconj_grad_iters\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mline_search_iters\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mstep_norm\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━┩\n", + "│ 0 │ [-4.58152553 -4.58152388 │ 3.86e+03 │ 0 │ 0 │ 0 │ │\n", + "│ │ -4.58153433 ... -4.58152755 │ │ │ │ │ │\n", + "│ │ -4.58154342 │ │ │ │ │ │\n", + "│ │ -4.58189799] │ │ │ │ │ │\n", + "│ 1 │ [-4.58152557 -4.5815239 │ 3.86e+03 │ 0 │ 1 │ 5.70e+00 │ 0 │\n", + "│ │ -4.58153457 ... -4.5815272 │ │ │ │ │ │\n", + "│ │ -4.58154266 │ │ │ │ │ │\n", + "│ │ -4.58188797] │ │ │ │ │ │\n", + "│ 2 │ [-4.58152614 -4.58152436 │ 1.32e+03 │ 0 │ 1 │ 3.53e+00 │ 0 │\n", + "│ │ -4.5815355 ... -4.58152848 │ │ │ │ │ │\n", + "│ │ -4.58154566 │ │ │ │ │ │\n", + "│ │ -4.58192929] │ │ │ │ │ │\n", + "└───────────┴─────────────────────────────┴─────────────────┴─────────────────┴───────────────────┴───────────┴───┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━┓\n",
+       "┃ iteration  model                        objective_value  conj_grad_iters  line_search_iters  step_norm    ┃\n",
+       "┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━┩\n",
+       "│ 0         │ [-4.58152614 -4.58152436    │ 4.99e+02        │ 0               │ 0                 │ 0         │   │\n",
+       "│           │ -4.5815355  ... -4.58152848 │                 │                 │                   │           │   │\n",
+       "│           │ -4.58154566                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58192929]               │                 │                 │                   │           │   │\n",
+       "│ 1         │ [-4.58152756 -4.58152543    │ 4.99e+02        │ 0               │ 1                 │ 2.89e+00  │ 0 │\n",
+       "│           │ -4.58153819 ... -4.58153061 │                 │                 │                   │           │   │\n",
+       "│           │ -4.58155102                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58200435]               │                 │                 │                   │           │   │\n",
+       "│ 2         │ [-4.58152851 -4.58152616    │ 2.90e+02        │ 0               │ 1                 │ 1.54e+00  │ 0 │\n",
+       "│           │ -4.58153993 ... -4.58153216 │                 │                 │                   │           │   │\n",
+       "│           │ -4.58155483                 │                 │                 │                   │           │   │\n",
+       "│           │  -4.58205792]               │                 │                 │                   │           │   │\n",
+       "└───────────┴─────────────────────────────┴─────────────────┴─────────────────┴───────────────────┴───────────┴───┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1miteration\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodel \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mobjective_value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mconj_grad_iters\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mline_search_iters\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mstep_norm\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━┩\n", + "│ 0 │ [-4.58152614 -4.58152436 │ 4.99e+02 │ 0 │ 0 │ 0 │ │\n", + "│ │ -4.5815355 ... -4.58152848 │ │ │ │ │ │\n", + "│ │ -4.58154566 │ │ │ │ │ │\n", + "│ │ -4.58192929] │ │ │ │ │ │\n", + "│ 1 │ [-4.58152756 -4.58152543 │ 4.99e+02 │ 0 │ 1 │ 2.89e+00 │ 0 │\n", + "│ │ -4.58153819 ... -4.58153061 │ │ │ │ │ │\n", + "│ │ -4.58155102 │ │ │ │ │ │\n", + "│ │ -4.58200435] │ │ │ │ │ │\n", + "│ 2 │ [-4.58152851 -4.58152616 │ 2.90e+02 │ 0 │ 1 │ 1.54e+00 │ 0 │\n", + "│ │ -4.58153993 ... -4.58153216 │ │ │ │ │ │\n", + "│ │ -4.58155483 │ │ │ │ │ │\n", + "│ │ -4.58205792] │ │ │ │ │ │\n", + "└───────────┴─────────────────────────────┴─────────────────┴─────────────────┴───────────────────┴───────────┴───┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for minimizer_log in inversion.log.minimizer_logs:\n", + " if minimizer_log is not None:\n", + " minimizer_log.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, "id": "6af9c936-5cf8-43bd-bfcc-e857715de2bd", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:28:03.644079Z", - "iopub.status.busy": "2025-10-14T21:28:03.643889Z", - "iopub.status.idle": "2025-10-14T21:28:03.646642Z", - "shell.execute_reply": "2025-10-14T21:28:03.646217Z" + "iopub.execute_input": "2025-10-20T18:47:17.041026Z", + "iopub.status.busy": "2025-10-20T18:47:17.040686Z", + "iopub.status.idle": "2025-10-20T18:47:17.045062Z", + "shell.execute_reply": "2025-10-20T18:47:17.044220Z", + "shell.execute_reply.started": "2025-10-20T18:47:17.040991Z" } }, "outputs": [], @@ -750,14 +893,15 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "id": "e7c77d48-d7ce-42fc-8792-106f55a2accf", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:28:03.648768Z", - "iopub.status.busy": "2025-10-14T21:28:03.648367Z", - "iopub.status.idle": "2025-10-14T21:28:04.024786Z", - "shell.execute_reply": "2025-10-14T21:28:04.024265Z" + "iopub.execute_input": "2025-10-20T18:47:17.046442Z", + "iopub.status.busy": "2025-10-20T18:47:17.046060Z", + "iopub.status.idle": "2025-10-20T18:47:17.495182Z", + "shell.execute_reply": "2025-10-20T18:47:17.494658Z", + "shell.execute_reply.started": "2025-10-20T18:47:17.046406Z" } }, "outputs": [ @@ -804,14 +948,15 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "id": "024667e4-4f2c-4dff-866d-2edda36810a6", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:28:04.026395Z", - "iopub.status.busy": "2025-10-14T21:28:04.026227Z", - "iopub.status.idle": "2025-10-14T21:28:05.759093Z", - "shell.execute_reply": "2025-10-14T21:28:05.758511Z" + "iopub.execute_input": "2025-10-20T18:47:17.495934Z", + "iopub.status.busy": "2025-10-20T18:47:17.495728Z", + "iopub.status.idle": "2025-10-20T18:47:19.324077Z", + "shell.execute_reply": "2025-10-20T18:47:19.323492Z", + "shell.execute_reply.started": "2025-10-20T18:47:17.495917Z" } }, "outputs": [ @@ -878,14 +1023,15 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "id": "d5ce0fdc-a42c-41bb-86d9-859eb690f35f", "metadata": { "execution": { - "iopub.execute_input": "2025-10-14T21:28:05.760784Z", - "iopub.status.busy": "2025-10-14T21:28:05.760589Z", - "iopub.status.idle": "2025-10-14T21:28:05.763550Z", - "shell.execute_reply": "2025-10-14T21:28:05.763057Z" + "iopub.execute_input": "2025-10-20T18:47:19.324929Z", + "iopub.status.busy": "2025-10-20T18:47:19.324675Z", + "iopub.status.idle": "2025-10-20T18:47:19.328043Z", + "shell.execute_reply": "2025-10-20T18:47:19.327426Z", + "shell.execute_reply.started": "2025-10-20T18:47:19.324910Z" } }, "outputs": [], diff --git a/src/inversion_ideas/base/__init__.py b/src/inversion_ideas/base/__init__.py index 3d2ae0c..8fc6aad 100644 --- a/src/inversion_ideas/base/__init__.py +++ b/src/inversion_ideas/base/__init__.py @@ -4,7 +4,7 @@ from .conditions import Condition from .directive import Directive -from .minimizer import Minimizer +from .minimizer import Minimizer, MinimizerResult from .objective_function import Combo, Objective, Scaled from .simulation import Simulation @@ -13,6 +13,7 @@ "Condition", "Directive", "Minimizer", + "MinimizerResult", "Objective", "Scaled", "Simulation", diff --git a/src/inversion_ideas/base/minimizer.py b/src/inversion_ideas/base/minimizer.py index 887d64d..756762c 100644 --- a/src/inversion_ideas/base/minimizer.py +++ b/src/inversion_ideas/base/minimizer.py @@ -3,19 +3,50 @@ """ from abc import ABC, abstractmethod -from collections.abc import Generator +from collections.abc import Callable, Generator from ..typing import Model from .objective_function import Objective +class MinimizerResult(dict): + """ + Dictionary to store results of a single minimization iteration. + + This class is a child of ``dict``, but allows to access the values through + attributes. + + Notes + ----- + Inspired in the :class:`scipy.optimize.OptimizeResult`. + """ + + def __getattr__(self, name): + try: + return self[name] + except KeyError as e: + raise AttributeError(name) from e + + __setattr__ = dict.__setitem__ # type: ignore[assignment] + __delattr__ = dict.__delitem__ # type: ignore[assignment] + + def __dir__(self): + return list(self.keys()) + + class Minimizer(ABC): """ Base class to represent minimizers as generators. """ @abstractmethod - def __call__(self, objective: Objective, initial_model: Model) -> Generator[Model]: + def __call__( + self, + objective: Objective, + initial_model: Model, + *, + callback: Callable[[MinimizerResult], None] | None = None, + ) -> Generator[Model]: """ Minimize objective function. @@ -25,6 +56,9 @@ def __call__(self, objective: Objective, initial_model: Model) -> Generator[Mode Objective function to be minimized. initial_model : (n_params) array Initial model used to start the minimization. + callback : callable, optional + Callable that gets called after each iteration. + Takes a :class:`inversion_ideas.base.MinimizerResult` as argument. Returns ------- diff --git a/src/inversion_ideas/inversion.py b/src/inversion_ideas/inversion.py index 5a3a8a8..27f0177 100644 --- a/src/inversion_ideas/inversion.py +++ b/src/inversion_ideas/inversion.py @@ -9,9 +9,14 @@ import typing from collections.abc import Callable +from rich.console import Group, RenderableType +from rich.live import Live +from rich.spinner import Spinner +from rich.tree import Tree + from .base import Condition, Directive, Minimizer, Objective -from .inversion_log import InversionLog, InversionLogRich -from .typing import Model +from .inversion_log import InversionLog, InversionLogRich, MinimizerLog +from .typing import Log, Model from .utils import get_logger @@ -40,11 +45,15 @@ class Inversion: no limit on the total amount of iterations. cache_models : bool, optional Whether to cache each model after each iteration. - log : InversionLog or bool, optional - Instance of :class:`InversionLog` to store information about the inversion. + log : Log or bool, optional + Instance of :class:`InversionLog` to store information about the inversion, + or any object that follows the :class:`inversion_ideas.typing.Log` protocol. If `True`, a default :class:`InversionLog` is going to be used. If `False`, no log will be assigned to the inversion, and :attr:`Inversion.log` will be ``None``. + log_minimizers : bool, optional + Whether to log the minimizers or not. Logging minimizers is only possible when + the ``minimizer`` is an instance of :class:`inversion_ideas.base.Minimizer``. minimizer_kwargs : dict, optional Extra arguments that will be passed to the ``minimizer`` when called. """ @@ -59,7 +68,8 @@ def __init__( stopping_criteria: Condition | Callable[[Model], bool], max_iterations: int | None = None, cache_models=False, - log: "InversionLog | bool" = True, + log: Log | InversionLog | bool = True, + log_minimizers: bool = True, minimizer_kwargs: dict | None = None, ): self.objective_function = objective_function @@ -72,6 +82,7 @@ def __init__( if minimizer_kwargs is None: minimizer_kwargs = {} self.minimizer_kwargs = minimizer_kwargs + self._log_minimizers = log_minimizers # Assign log if log is False: @@ -86,6 +97,11 @@ def __init__( # Assign model as a copy of the initial model self.model = initial_model.copy() + # TODO: Support for handling custom callbacks for the minimizer + if log is not None and "callback" in self.minimizer_kwargs: + msg = "Passing a custom callback for the minimizer is not yet supported." + raise NotImplementedError(msg) + def __next__(self): """ Run next iteration in the inversion. @@ -137,10 +153,20 @@ def __next__(self): directive(self.model, self.counter) # Minimize objective function + # --------------------------- if isinstance(self.minimizer, Minimizer): - # Keep only the last model of the minimizer iterator + # Generate a new minimizer log for this iteration + minimizer_kwargs = self.minimizer_kwargs.copy() + if self.log is not None and self.log_minimizers: + minimizer_log = MinimizerLog() + self.minimizer_logs.append(minimizer_log) + minimizer_kwargs["callback"] = minimizer_log.update + + # Unpack the generator and keep only the last model *_, model = self.minimizer( - self.objective_function, self.model, **self.minimizer_kwargs + self.objective_function, + self.model, + **minimizer_kwargs, ) else: model = self.minimizer( @@ -185,6 +211,22 @@ def models(self) -> list: self._models = [self.initial_model] return self._models + @property + def log_minimizers(self) -> bool: + """Whether if minimizers will be logged or not.""" + return self._log_minimizers and isinstance(self.minimizer, Minimizer) + + @property + def minimizer_logs(self) -> list[None | MinimizerLog] | None: + """ + Logs of minimizers. + """ + if not self.log_minimizers: + return None + if not hasattr(self, "_minimizer_logs"): + self._minimizer_logs = [None] + return self._minimizer_logs + def run(self, show_log=True) -> Model: """ Run the inversion. @@ -195,11 +237,29 @@ def run(self, show_log=True) -> Model: Whether to show the ``log`` (if it's defined) during the inversion. """ if show_log and self.log is not None: - if not hasattr(self.log, "live"): + if not isinstance(self.log, RenderableType): raise NotImplementedError() - with self.log.live() as live: + + spinner = Spinner( + name="dots", text="Starting inversion...", style="green", speed=1 + ) + log = Tree(self.log) if self.log_minimizers else self.log + group = Group(log, spinner) + + with Live(group, refresh_per_second=10) as live: for _ in self: - live.refresh() + if self.log_minimizers: + minimizer_log = self.minimizer_logs[self.counter] + if minimizer_log is not None: + renderable = minimizer_log.__rich__() + renderable.title = ( + f"Minimizer log for iteration {self.counter}" + ) + log.add(renderable) + spinner.text = f"Running iteration {self.counter + 1}..." + group.renderables.pop(-1) + live.refresh() + else: for _ in self: pass diff --git a/src/inversion_ideas/inversion_log.py b/src/inversion_ideas/inversion_log.py index b679acf..7833997 100644 --- a/src/inversion_ideas/inversion_log.py +++ b/src/inversion_ideas/inversion_log.py @@ -4,12 +4,15 @@ import numbers import typing +import warnings from collections.abc import Callable, Iterable -from rich.console import Console +from rich.console import Console, RenderableType from rich.live import Live from rich.table import Table +from .base import MinimizerResult + try: import pandas # noqa: ICN001 except ImportError: @@ -19,6 +22,29 @@ from .typing import Model +def _get_fmt(value): + """ + Guess fmt of object based on its value. + + Parameters + ---------- + value : Any + + Returns + ------- + fmt : str + """ + if isinstance(value, bool): + fmt = "" + elif isinstance(value, numbers.Integral): + fmt = "d" + elif isinstance(value, numbers.Real): + fmt = ".2e" + else: + fmt = "" + return fmt + + class Column(typing.NamedTuple): """ Column for the ``InversionLog``. @@ -44,11 +70,19 @@ class InversionLog: """ def __init__( - self, columns: typing.Mapping[str, Column | Callable[[int, Model], typing.Any]] + self, + columns: typing.Mapping[str, Column | Callable[[int, Model], typing.Any]], ): for name, column in columns.items(): self.add_column(name, column) + def update(self, iteration: int, model: Model): + """ + Update the log. + """ + for name, column in self.columns.items(): + self.log[name].append(column.callable(iteration, model)) + @property def has_records(self) -> bool: """ @@ -111,13 +145,6 @@ def log(self) -> dict[str, list]: self._log: dict[str, list] = {col: [] for col in self.columns} return self._log - def update(self, iteration: int, model: Model): - """ - Update the log. - """ - for name, column in self.columns.items(): - self.log[name].append(column.callable(iteration, model)) - def to_pandas(self, index_col=0): """ Generate a ``pandas.DataFrame`` out of the log. @@ -129,7 +156,7 @@ def to_pandas(self, index_col=0): return pandas.DataFrame(self.log).set_index(index) @classmethod - def create_from(cls, objective_function: Combo) -> typing.Self: + def create_from(cls, objective_function: Combo, **kwargs) -> typing.Self: r""" Create the standard log for a classic inversion. @@ -138,6 +165,8 @@ def create_from(cls, objective_function: Combo) -> typing.Self: objective_function : Combo Combo objective function with two elements: the data misfit and the regularization (including a trade-off parameter). + kwargs : dict + Keyword arguments passed to the constructor of the class. Returns ------- @@ -193,7 +222,7 @@ def create_from(cls, objective_function: Combo) -> typing.Self: fmt=".2e", ), } - return cls(columns) + return cls(columns, **kwargs) class InversionLogRich(InversionLog): @@ -216,6 +245,26 @@ def __init__(self, columns: dict[str, Callable | Column], **kwargs): super().__init__(columns) self.kwargs = kwargs + def __rich__(self) -> RenderableType: + """ + Return the log as a Rich renderable. + """ + return self.table + + # def update_group(self, iteration: int): + # self.update_table() + # + # # Create a tree to add the minimizer log + # if self.minimizer_logs is not None: + # minimizer_log = self.minimizer_logs[iteration] + # if minimizer_log is not None: + # tree = Tree(self.table) + # tree.add(Panel(minimizer_log, title="Minimizer log")) + # self.group.renderables.append(tree) + # return + # + # self.group.renderables.append(panel) + @property def table(self) -> Table: """ @@ -229,17 +278,25 @@ def table(self) -> Table: def show(self): """ - Show table. + Show log through a Rich console. """ console = Console() - console.print(self.table) + console.print(self) def live(self, **kwargs): """ Context manager for live update of the table. """ + warnings.warn("live will be removed", FutureWarning, stacklevel=2) return Live(self.table, **kwargs) + def update(self, iteration: int, model: Model): + """ + Update the log. + """ + super().update(iteration, model) + self.update_table() + def update_table(self): """ Add row to the table given the latest inverted model. @@ -248,29 +305,71 @@ def update_table(self): ---------- model : (n_params) array """ + # TODO: Check that each entry in the log has the same amount of elements row = [] for name, column in self.columns.items(): value = self.log[name][-1] # last element in the log - fmt = column.fmt if column.fmt is not None else self._get_fmt(value) + fmt = column.fmt if column.fmt is not None else _get_fmt(value) row.append(f"{value:{fmt}}") self.table.add_row(*row) - def _get_fmt(self, value): - if isinstance(value, bool): - fmt = "" - elif isinstance(value, numbers.Integral): - fmt = "d" - elif isinstance(value, numbers.Real): - fmt = ".2e" - else: - fmt = "" - return fmt - def update(self, iteration: int, model: Model): +class MinimizerLog: + """Class to store results of a minimizer in the form of a log.""" + + def update(self, minimizer_result: MinimizerResult): """ - Update the log. + Use as callback for :class:`inversion_ideas.base.Minimizer`. + """ + for field, value in minimizer_result.items(): + if field not in self.log: + self.log[field] = [] + self.log[field].append(value) + self._update_table() - Update the table as well. + @property + def log(self) -> dict[str, list]: + """Returns the log.""" + if not hasattr(self, "_log"): + self._log: dict[str, list] = {} + return self._log + + def __rich__(self) -> Table: """ - super().update(iteration, model) - self.update_table() + Return the log as a Rich renderable. + """ + return self.table + + @property + def table(self) -> Table: + """ + Table for the inversion log. + """ + if not hasattr(self, "_table"): + self._table = Table() + if not self._table.columns: + for column_name in self.log: + self._table.add_column(column_name) + return self._table + + def _update_table(self): + """ + Add last row in the log to the Rich table. + + Parameters + ---------- + model : (n_params) array + """ + row = [] + for values in self.log.values(): + value = values[-1] # last element in the log + fmt = _get_fmt(value) + row.append(f"{value:{fmt}}") + self.table.add_row(*row) + + def show(self): + """ + Show log through a Rich console. + """ + console = Console() + console.print(self) diff --git a/src/inversion_ideas/minimize/__init__.py b/src/inversion_ideas/minimize/__init__.py index dd2bea9..07c49e9 100644 --- a/src/inversion_ideas/minimize/__init__.py +++ b/src/inversion_ideas/minimize/__init__.py @@ -2,7 +2,8 @@ Minimizer functions and classes. """ +from ..base import MinimizerResult from ._functions import conjugate_gradient from ._minimizers import GaussNewtonConjugateGradient -__all__ = ["GaussNewtonConjugateGradient", "conjugate_gradient"] +__all__ = ["GaussNewtonConjugateGradient", "MinimizerResult", "conjugate_gradient"] diff --git a/src/inversion_ideas/minimize/_minimizers.py b/src/inversion_ideas/minimize/_minimizers.py index ba852cb..0103fbb 100644 --- a/src/inversion_ideas/minimize/_minimizers.py +++ b/src/inversion_ideas/minimize/_minimizers.py @@ -9,7 +9,7 @@ import numpy as np from scipy.sparse.linalg import cg -from ..base import Condition, Minimizer, Objective +from ..base import Condition, Minimizer, MinimizerResult, Objective from ..errors import ConvergenceWarning from ..typing import Model, Preconditioner from ..utils import get_logger @@ -62,16 +62,33 @@ def __call__( self, objective: Objective, initial_model: Model, + *, preconditioner: Preconditioner | Callable[[Model], Preconditioner] | None = None, + callback: Callable[[MinimizerResult], None] | None = None, ) -> Generator[Model]: """ Create iterator over Gauss-Newton minimization. + + Parameters + ---------- + objective : Objective + Objective function that will get minimized. + initial_model : (n_params) array + Initial model to start the minimization. + preconditioner : (n_params, n_params) array, sparray or LinearOperator or Callable, optional + Matrix used as preconditioner in the conjugant gradient algorithm. + If None, no preconditioner will be used. + A callable can be passed to build the preconditioner dynamically: such + callable should take a single ``initial_model`` argument and return an + array, `sparray` or a `LinearOperator`. + callback : callable, optional + Callable that gets called after each iteration. """ - # Define a static preconditioner for all Gauss-Newton iterations cg_kwargs = self.cg_kwargs.copy() + # Define a static preconditioner for all Gauss-Newton iterations if preconditioner is not None: if "M" in self.cg_kwargs: msg = "Cannot simultanously pass `preconditioner` and `M`." @@ -84,10 +101,23 @@ def __call__( cg_kwargs["M"] = preconditioner # Perform Gauss-Newton iterations + # ------------------------------- iteration = 0 phi_prev_value = np.inf # value of the objective function on previous model model = initial_model.copy() + # Run callback before first yield + if callback is not None: + minimizer_result = MinimizerResult( + iteration=iteration, + model=model, + objective_value=objective(model), + conj_grad_code=0, + line_search_iters=0, + step_norm=0, + ) + callback(minimizer_result) + # Yield initial model, so the generator is never empty yield model @@ -95,7 +125,7 @@ def __call__( while True: # Stop if reached max number of iterations if iteration >= self.maxiter: - get_logger().info( + get_logger().debug( "⚠️ Reached maximum number of Gauss-Newton iterations " f"({self.maxiter})." ) @@ -115,17 +145,17 @@ def __call__( # Apply Conjugate Gradient to get search direction gradient, hessian = objective.gradient(model), objective.hessian(model) - search_direction, info = cg(hessian, -gradient, **cg_kwargs) - if info != 0: + search_direction, cg_iters = cg(hessian, -gradient, **cg_kwargs) + if cg_iters != 0: warnings.warn( "Conjugate gradient convergence to tolerance not achieved after " - f"{info} number of iterations.", + f"{cg_iters} number of iterations.", ConvergenceWarning, stacklevel=2, ) # Perform line search - alpha, n_ls_iters = backtracking_line_search( + alpha, line_search_iters = backtracking_line_search( objective, model, search_direction, @@ -136,16 +166,29 @@ def __call__( if alpha is None: msg = ( "Couldn't find a valid alpha, obtained None. " - f"Ran {n_ls_iters} iterations." + f"Ran {line_search_iters} iterations." ) raise RuntimeError(msg) # Perform model step - model += alpha * search_direction + step = alpha * search_direction + model += step # Update cached values and iteration counter phi_prev_value = phi_value iteration += 1 - # Yield inverted model for the current Gauss-Newon iteration + # Run callback before next yield + if callback is not None: + minimizer_result = MinimizerResult( + iteration=iteration, + model=model, + objective_value=phi_value, + conj_grad_code=cg_iters, + line_search_iters=line_search_iters, + step_norm=float(np.linalg.norm(step)), + ) + callback(minimizer_result) + + # Yield inverted model for the current Gauss-Newton iteration yield model diff --git a/src/inversion_ideas/typing.py b/src/inversion_ideas/typing.py index f907550..98ddfbf 100644 --- a/src/inversion_ideas/typing.py +++ b/src/inversion_ideas/typing.py @@ -2,13 +2,17 @@ Custom types used for type hints. """ -from typing import Protocol, TypeAlias +from collections.abc import Callable +from typing import TYPE_CHECKING, Protocol, TypeAlias import numpy as np import numpy.typing as npt from scipy.sparse import sparray from scipy.sparse.linalg import LinearOperator +if TYPE_CHECKING: + from .base import MinimizerResult + Model: TypeAlias = npt.NDArray[np.float64] """ Type alias to represent models in the inversion framework as 1D arrays. @@ -35,3 +39,15 @@ def update_irls(self, model: Model) -> None: def activate_irls(self, model_previous: Model) -> None: raise NotImplementedError + + +class Log(Protocol): + """ + Protocol to define inversion and minimizer logs. + """ + + def update(self, iteration: int, model: Model) -> None: + raise NotImplementedError + + def get_minimizer_callback(self) -> Callable[["MinimizerResult"], None]: + raise NotImplementedError