Skip to content

Commit

Permalink
Merge pull request #81 from theislab/add/geodesic_costs
Browse files Browse the repository at this point in the history
example on cost geodesic on current example for custom costs
  • Loading branch information
ArinaDanilina authored Oct 8, 2024
2 parents 5b9d4e0 + b687d4d commit 2107658
Showing 1 changed file with 126 additions and 10 deletions.
136 changes: 126 additions & 10 deletions examples/problems/200_custom_cost_matrices.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
{
"data": {
"text/plain": [
"FGWProblem[('1', '2'), ('0', '1')]"
"FGWProblem[('0', '1'), ('1', '2')]"
]
},
"execution_count": 3,
Expand All @@ -131,13 +131,129 @@
"fgw"
]
},
{
"cell_type": "markdown",
"id": "b17bc6c3",
"metadata": {},
"source": [
"## Setting cost functions individually for each term\n",
"\n",
"The cost functions can be set manually for each term by passing a dictionary to the {attr}`~moscot.problems.generic.FGWProblem.prepare.cost` argument in {meth}`~moscot.problems.generic.FGWProblem.prepare`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2696ddf8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n",
"\u001b[34mINFO \u001b[0m Computing pca with `\u001b[33mn_comps\u001b[0m=\u001b[1;36m30\u001b[0m` for `xy` using `adata.X` \n"
]
},
{
"data": {
"text/plain": [
"FGWProblem[('0', '1'), ('1', '2')]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fgw = fgw.prepare(\n",
" key=\"batch\",\n",
" x_attr=\"X_pca\",\n",
" y_attr=\"X_pca\",\n",
" cost={\"xy\": \"geodesic\", \"x\": \"euclidean\", \"y\": \"sq_euclidean\"},\n",
")\n",
"fgw"
]
},
{
"cell_type": "markdown",
"id": "c8d71b53",
"metadata": {},
"source": [
"We can check which cost functions have been set:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7690fd3b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<ott.geometry.costs.Euclidean at 0x20be3f91bd0>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fgw[(\"0\", \"1\")].x.cost"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f9219522",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<ott.geometry.costs.SqEuclidean at 0x20bac2088d0>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fgw[(\"0\", \"1\")].y.cost"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8cc8678f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'geodesic'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fgw[(\"0\", \"1\")].xy.cost"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "625be056",
"metadata": {},
"source": [
"## Passing cost matrices one by one\n",
"## Passing cost matrices manually\n",
"\n",
"We can pass the custom cost matrices by accessing the {class}`~moscot.base.problems.OTProblem`.\n",
"The method {meth}`~moscot.base.problems.OTProblem.set_xy` allows to pass a custom cost matrix\n",
Expand All @@ -154,7 +270,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"id": "ca81a426",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -185,7 +301,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "2195c41a",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -227,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"id": "e21b574c",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -284,7 +400,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"id": "4b99031c",
"metadata": {},
"outputs": [],
Expand All @@ -304,7 +420,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"id": "7c6f8eb6",
"metadata": {},
"outputs": [
Expand All @@ -314,7 +430,7 @@
"OTProblem[stage='prepared', shape=(20, 20)]"
]
},
"execution_count": 8,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -347,7 +463,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 13,
"id": "cbb8b363",
"metadata": {},
"outputs": [
Expand All @@ -357,7 +473,7 @@
"OTProblem[stage='prepared', shape=(20, 20)]"
]
},
"execution_count": 9,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand Down

0 comments on commit 2107658

Please sign in to comment.