Skip to content

Commit

Permalink
Merge pull request #66 from theislab/fix/gw_examples
Browse files Browse the repository at this point in the history
Fix examples using gw
  • Loading branch information
ArinaDanilina authored Mar 14, 2024
2 parents 47989fb + f979112 commit b89576b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 63 deletions.
47 changes: 17 additions & 30 deletions examples/problems/100_tagged_arrays.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"warnings.simplefilter(\"ignore\", FutureWarning)\n",
"\n",
"from moscot import datasets\n",
"from moscot.problems.generic import GWProblem\n",
"from moscot.problems.generic import FGWProblem\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -95,31 +95,18 @@
"source": [
"## Prepare the problem\n",
"\n",
"We instantiate and prepare a {class}`~moscot.problems.generic.GWProblem` to demonstrate the role of the {class}`~moscot.utils.tagged_array.TaggedArray`."
"We instantiate and prepare a {class}`~moscot.problems.generic.FGWProblem` to demonstrate the role of the {class}`~moscot.utils.tagged_array.TaggedArray`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "16f0c3a9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n"
]
}
],
"outputs": [],
"source": [
"gwp = GWProblem(adata)\n",
"gwp = gwp.prepare(key=\"batch\", x_attr=\"X_pca\", y_attr=\"X_pca\", joint_attr=\"X_pca\")"
"fgw = FGWProblem(adata)\n",
"fgw = fgw.prepare(key=\"batch\", x_attr=\"X_pca\", y_attr=\"X_pca\", joint_attr=\"X_pca\")"
]
},
{
Expand All @@ -134,7 +121,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "e4ca2aea",
"id": "820b9261",
"metadata": {},
"outputs": [
{
Expand All @@ -154,7 +141,7 @@
" [ 2.024, -1.597, -0.591, ..., -0.114, 0.29 , -0.332],\n",
" [-0.938, 2.426, -0.128, ..., 0.194, 0.829, -0.438],\n",
" [ 2.709, -2.885, -0.925, ..., 0.315, 0.334, 0.078]],\n",
" dtype=float32), tag=<Tag.POINT_CLOUD: 'point_cloud'>, cost=<ott.geometry.costs.SqEuclidean object at 0x000002837922CC10>)"
" dtype=float32), tag=<Tag.POINT_CLOUD: 'point_cloud'>, cost=<ott.geometry.costs.SqEuclidean object at 0x000002095006ED50>)"
]
},
"execution_count": 4,
Expand All @@ -163,7 +150,7 @@
}
],
"source": [
"gwp[\"0\", \"1\"].xy"
"fgw[(\"0\", \"1\")].xy"
]
},
{
Expand All @@ -187,7 +174,7 @@
"data": {
"text/plain": [
"(<Tag.POINT_CLOUD: 'point_cloud'>,\n",
" <ott.geometry.costs.SqEuclidean at 0x2837922cc10>)"
" <ott.geometry.costs.SqEuclidean at 0x2095006ed50>)"
]
},
"execution_count": 5,
Expand All @@ -196,7 +183,7 @@
}
],
"source": [
"gwp[\"0\", \"1\"].xy.tag, gwp[\"0\", \"1\"].xy.cost"
"fgw[\"0\", \"1\"].xy.tag, fgw[\"0\", \"1\"].xy.cost"
]
},
{
Expand Down Expand Up @@ -247,7 +234,7 @@
}
],
"source": [
"gwp[\"0\", \"1\"].xy.data_src, gwp[\"0\", \"1\"].xy.data_tgt"
"fgw[\"0\", \"1\"].xy.data_src, fgw[\"0\", \"1\"].xy.data_tgt"
]
},
{
Expand Down Expand Up @@ -287,7 +274,7 @@
}
],
"source": [
"gwp[\"0\", \"1\"].x.data_src, gwp[\"0\", \"1\"].x.data_tgt"
"fgw[\"0\", \"1\"].x.data_src, fgw[\"0\", \"1\"].x.data_tgt"
]
},
{
Expand Down Expand Up @@ -335,7 +322,7 @@
}
],
"source": [
"gwp[\"0\", \"1\"].xy.tag, gwp[\"0\", \"1\"].xy.data_tgt"
"fgw[\"0\", \"1\"].xy.tag, fgw[\"0\", \"1\"].xy.data_tgt"
]
},
{
Expand Down Expand Up @@ -366,14 +353,14 @@
],
"source": [
"rng = np.random.default_rng(seed=42)\n",
"obs_names_0 = gwp[\"0\", \"1\"].adata_src.obs_names\n",
"obs_names_1 = gwp[\"0\", \"1\"].adata_tgt.obs_names\n",
"obs_names_0 = fgw[\"0\", \"1\"].adata_src.obs_names\n",
"obs_names_1 = fgw[\"0\", \"1\"].adata_tgt.obs_names\n",
"\n",
"cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))\n",
"cm_linear = pd.DataFrame(data=cost_linear_01, index=obs_names_0, columns=obs_names_1)\n",
"\n",
"gwp[\"0\", \"1\"].set_xy(cm_linear, tag=\"cost_matrix\")\n",
"gwp[\"0\", \"1\"].xy.tag, gwp[\"0\", \"1\"].xy.data_tgt"
"fgw[\"0\", \"1\"].set_xy(cm_linear, tag=\"cost_matrix\")\n",
"fgw[\"0\", \"1\"].xy.tag, fgw[\"0\", \"1\"].xy.data_tgt"
]
}
],
Expand Down
21 changes: 4 additions & 17 deletions examples/solvers/300_quad_problems_basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"source": [
"# Quadratic problems\n",
"\n",
"This example shows how to solve quadratic problems, e.g., the {class}`~moscot.problems.time.LineageProblem`, the {class}`~moscot.problems.spatiotemporal.SpatioTemporalProblem`, the {class}`~moscot.problems.space.MappingProblem`, the {class}`~moscot.problems.space.AlignmentProblem`, and the {class}`~moscot.problems.generic.GWProblem`.\n",
"This example shows how to solve quadratic problems, e.g., the {class}`~moscot.problems.time.LineageProblem`, the {class}`~moscot.problems.spatiotemporal.SpatioTemporalProblem`, the {class}`~moscot.problems.space.MappingProblem`, the {class}`~moscot.problems.space.AlignmentProblem`, the {class}`~moscot.problems.generic.GWProblem`, and the {class}`~moscot.problems.generic.FGWProblem`.\n",
"\n",
":::{seealso}\n",
"- See {doc}`400_quad_problems_advanced` for an advanced example on how to solve quadratic problems.\n",
Expand Down Expand Up @@ -38,7 +38,7 @@
"warnings.simplefilter(\"ignore\", FutureWarning)\n",
"\n",
"from moscot import datasets\n",
"from moscot.problems.generic import GWProblem\n",
"from moscot.problems.generic import FGWProblem, GWProblem\n",
"\n",
"import numpy as np\n",
"\n",
Expand Down Expand Up @@ -100,13 +100,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n",
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n",
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n"
]
Expand All @@ -122,12 +115,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n",
"\u001b[34mINFO \u001b[0m Solving `\u001b[1;36m1\u001b[0m` problems \n",
"\u001b[34mINFO \u001b[0m Solving problem OTProblem\u001b[1m[\u001b[0m\u001b[33mstage\u001b[0m=\u001b[32m'prepared'\u001b[0m, \u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m. \n",
"max difference: 0.021854\n"
Expand All @@ -141,9 +128,9 @@
" x_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n",
" y_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n",
")\n",
"gwp = gwp.solve(alpha=1.0, epsilon=1e-1)\n",
"gwp = gwp.solve(epsilon=1e-1)\n",
"\n",
"fgwp = GWProblem(adata)\n",
"fgwp = FGWProblem(adata)\n",
"fgwp = fgwp.prepare(\n",
" key=\"batch\",\n",
" x_attr={\"attr\": \"obsm\", \"key\": \"spatial\"},\n",
Expand Down
18 changes: 2 additions & 16 deletions examples/solvers/400_quad_problems_advanced.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"source": [
"# Quadratic problems (advanced)\n",
"\n",
"This example shows an advanced quadratic problems usage, e.g., the {class}`~moscot.problems.time.LineageProblem`, the {class}`~moscot.problems.spatiotemporal.SpatioTemporalProblem`, the {class}`~moscot.problems.space.MappingProblem`, the {class}`~moscot.problems.space.AlignmentProblem`, and the {class}`~moscot.problems.generic.GWProblem`.\n",
"This example shows an advanced quadratic problems usage, e.g., the {class}`~moscot.problems.time.LineageProblem`, the {class}`~moscot.problems.spatiotemporal.SpatioTemporalProblem`, the {class}`~moscot.problems.space.MappingProblem`, the {class}`~moscot.problems.space.AlignmentProblem`, the {class}`~moscot.problems.generic.GWProblem`, and the {class}`~moscot.problems.generic.FGWProblem`.\n",
"\n",
":::{seealso}\n",
"- See {doc}`300_quad_problems_basic` for an introduction on how to solve quadratic problems.\n",
Expand Down Expand Up @@ -84,19 +84,6 @@
"id": "bcd55475",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Ordering \u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'0-0'\u001b[0m, \u001b[32m'1-0'\u001b[0m, \u001b[32m'2-0'\u001b[0m, \u001b[32m'3-0'\u001b[0m, \u001b[32m'4-0'\u001b[0m, \u001b[32m'5-0'\u001b[0m, \u001b[32m'6-0'\u001b[0m, \u001b[32m'7-0'\u001b[0m, \u001b[32m'8-0'\u001b[0m, \u001b[32m'9-0'\u001b[0m, \n",
" \u001b[32m'10-0'\u001b[0m, \u001b[32m'11-0'\u001b[0m, \u001b[32m'12-0'\u001b[0m, \u001b[32m'13-0'\u001b[0m, \u001b[32m'14-0'\u001b[0m, \u001b[32m'15-0'\u001b[0m, \u001b[32m'16-0'\u001b[0m, \u001b[32m'17-0'\u001b[0m, \u001b[32m'18-0'\u001b[0m, \n",
" \u001b[32m'19-0'\u001b[0m, \u001b[32m'0-1'\u001b[0m, \u001b[32m'1-1'\u001b[0m, \u001b[32m'2-1'\u001b[0m, \u001b[32m'3-1'\u001b[0m, \u001b[32m'4-1'\u001b[0m, \u001b[32m'5-1'\u001b[0m, \u001b[32m'6-1'\u001b[0m, \u001b[32m'7-1'\u001b[0m, \u001b[32m'8-1'\u001b[0m, \n",
" \u001b[32m'9-1'\u001b[0m, \u001b[32m'10-1'\u001b[0m, \u001b[32m'11-1'\u001b[0m, \u001b[32m'12-1'\u001b[0m, \u001b[32m'13-1'\u001b[0m, \u001b[32m'14-1'\u001b[0m, \u001b[32m'15-1'\u001b[0m, \u001b[32m'16-1'\u001b[0m, \u001b[32m'17-1'\u001b[0m, \n",
" \u001b[32m'18-1'\u001b[0m, \u001b[32m'19-1'\u001b[0m\u001b[1m]\u001b[0m, \n",
" \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m in ascending order. \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -189,7 +176,7 @@
}
],
"source": [
"gwp = gwp.solve(alpha=0.5, epsilon=1e-1, min_iterations=0, max_iterations=1)"
"gwp = gwp.solve(epsilon=1e-1, min_iterations=0, max_iterations=1)"
]
},
{
Expand Down Expand Up @@ -227,7 +214,6 @@
"source": [
"ls_kwargs = {\"min_iterations\": 10, \"max_iterations\": 1000, \"threshold\": 0.01}\n",
"gwp = gwp.solve(\n",
" alpha=0.5,\n",
" epsilon=1e-1,\n",
" threshold=0.1,\n",
" min_iterations=2,\n",
Expand Down

0 comments on commit b89576b

Please sign in to comment.