diff --git a/fao_models/beam_pipelines/__init__.py b/fao_models/beam_pipelines/__init__.py
new file mode 100644
index 0000000..9e57ea5
--- /dev/null
+++ b/fao_models/beam_pipelines/__init__.py
@@ -0,0 +1,8 @@
+import sys
+import os
+
+# Get the parent directory path
+parent_dir = os.path.dirname(os.path.abspath(__file__))
+
+# Add the parent directory to the Python module search path
+sys.path.append(parent_dir)
diff --git a/fao_models/beam_pipelines/test_initialPCollection.py b/fao_models/beam_pipelines/test_initialPCollection.py
new file mode 100644
index 0000000..c3e94b8
--- /dev/null
+++ b/fao_models/beam_pipelines/test_initialPCollection.py
@@ -0,0 +1,43 @@
+#%%
+import apache_beam as beam
+from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
+import unittest
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+import logging
+import geopandas
+logger = logging.getLogger(__name__)
+
+# load gdf and compute centroid from geometry
+gdf = geopandas.read_file('C:\\Users\\kyle\\Downloads\\ALL_centroids_completed_v1_\\ALL_centroids_completed_v1_.shp')
+gdf.loc[:,'centroid'] = gdf.geometry.centroid
+print(gdf.head())
+#%%
+# convert centroid (a GeoSeries geometry), to a native python list of lat,lon)
+gdf.loc[:,'latlon'] = gdf.centroid.apply(lambda x: [x.y, x.x])
+print(gdf.dtypes)
+print(gdf.head())
+#%%
+# construct list of global_id, latlon tuples for the pipeline
+features = gdf[['global_id', 'latlon']].values.tolist()
+print(features[:5])
+
+
+#%%
+# https://beam.apache.org/documentation/pipelines/test-your-pipeline/#testing-transforms
+expected_output = features[:5]
+def test_pipe(argv=None, save_main_session=True):
+ """Main entry point;"""
+ # read in a gdf and construct begnning PCollection from gdf in-memory
+
+ # do we need to convert each record in gdf to a list or dict?
+ with TestPipeline(runner=beam.runners.DirectRunner()) as p:
+
+ pipe_features = p | beam.Create(features[:5]) # if you change this to features[:6] the test will raise AssertionError
+ assert_that(pipe_features,equal_to(expected_output), label='check features')
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ test_pipe()
diff --git a/fao_models/beam_pipelines/test_pyshp.py b/fao_models/beam_pipelines/test_pyshp.py
new file mode 100644
index 0000000..0d3bc79
--- /dev/null
+++ b/fao_models/beam_pipelines/test_pyshp.py
@@ -0,0 +1,7 @@
+# works but i don't think we'll be able ot make centroid lat lon with this package easily
+import shapefile
+
+input_file = 'C:\\Users\\kyle\\Downloads\\ALL_centroids_completed_v1_\\ALL_centroids_completed_v1_.shp'
+sf = shapefile.Reader(input_file)
+print(sf.fields)
+print(sf.records()[0:10])
\ No newline at end of file
diff --git a/fao_models/test_inference_pipeline_steps.ipynb b/fao_models/test_inference_pipeline_steps.ipynb
new file mode 100644
index 0000000..be13fbc
--- /dev/null
+++ b/fao_models/test_inference_pipeline_steps.ipynb
@@ -0,0 +1,890 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import ee\n",
+ "import io\n",
+ "from google.api_core import exceptions, retry\n",
+ "import google.auth\n",
+ "from models import get_model\n",
+ "\n",
+ "PROJECT = \"pc530-fao-fra-rss\" # change to your cloud project name\n",
+ "# ee.Initialize(project=PROJECT)\n",
+ "\n",
+ "## INIT WITH HIGH VOLUME ENDPOINT\n",
+ "credentials, _ = google.auth.default()\n",
+ "ee.Initialize(\n",
+ "credentials,\n",
+ "project=PROJECT,\n",
+ "opt_url=\"https://earthengine-highvolume.googleapis.com\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def get_ee_img(coords):\n",
+ " \"\"\"retrieve s2 image composite from ee at given coordinates. coords is a tuple of (lon, lat) in degrees.\"\"\"\n",
+ " ## MAKE S2 COMPOSITE IN HEXAGONS ##########################################\n",
+ " # Using Cloud Score + for cloud/cloud-shadow masking\n",
+ " # Harmonized Sentinel-2 Level 2A collection.\n",
+ " s2 = ee.ImageCollection(\"COPERNICUS/S2_SR_HARMONIZED\")\n",
+ "\n",
+ " # Cloud Score+ image collection. Note Cloud Score+ is produced from Sentinel-2\n",
+ " # Level 1C data and can be applied to either L1C or L2A collections.\n",
+ " csPlus = ee.ImageCollection(\"GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED\")\n",
+ "\n",
+ " # Use 'cs' or 'cs_cdf', depending on your use case; see docs for guidance.\n",
+ " QA_BAND = \"cs_cdf\"\n",
+ "\n",
+ " # The threshold for masking; values between 0.50 and 0.65 generally work well.\n",
+ " # Higher values will remove thin clouds, haze & cirrus shadows.\n",
+ " CLEAR_THRESHOLD = 0.50\n",
+ "\n",
+ " # Make a clear median composite.\n",
+ " sampleImage = (\n",
+ " s2.filterDate(\"2023-01-01\", \"2023-12-31\")\n",
+ " .filterBounds(ee.Geometry.Point(coords[0], coords[1]).buffer(64*10)) # only images touching 64 pixel centroid buffer\n",
+ " .linkCollection(csPlus, [QA_BAND])\n",
+ " .map(lambda img: img.updateMask(img.select(QA_BAND).gte(CLEAR_THRESHOLD)))\n",
+ " .median()\n",
+ " .select([\"B4\", \"B3\", \"B2\", \"B8\"], [\"R\", \"G\", \"B\", \"N\"])\n",
+ " )\n",
+ " return sampleImage\n",
+ "\n",
+ "@retry.Retry()\n",
+ "def get_patch(coords, image, format=\"NPY\"):\n",
+ " \"\"\"Uses ee.data.ComputePixels() to get a 32x32 patch centered on the coordinates, as a numpy array.\"\"\"\n",
+ " \n",
+ " # Output resolution in meters.\n",
+ " SCALE = 10\n",
+ "\n",
+ " # Pre-compute a geographic coordinate system.\n",
+ " proj = ee.Projection(\"EPSG:4326\").atScale(SCALE).getInfo()\n",
+ "\n",
+ " # Get scales in degrees out of the transform.\n",
+ " SCALE_X = proj[\"transform\"][0]\n",
+ " SCALE_Y = -proj[\"transform\"][4]\n",
+ "\n",
+ " # Patch size in pixels.\n",
+ " PATCH_SIZE = 32\n",
+ "\n",
+ " # Offset to the upper left corner.\n",
+ " OFFSET_X = -SCALE_X * PATCH_SIZE / 2\n",
+ " OFFSET_Y = -SCALE_Y * PATCH_SIZE / 2\n",
+ " \n",
+ " REQUEST = {\n",
+ " \"fileFormat\": \"NPY\",\n",
+ " \"grid\": {\n",
+ " \"dimensions\": {\"width\": PATCH_SIZE, \"height\": PATCH_SIZE},\n",
+ " \"affineTransform\": {\n",
+ " \"scaleX\": SCALE_X,\n",
+ " \"shearX\": 0,\n",
+ " \"shearY\": 0,\n",
+ " \"scaleY\": SCALE_Y,\n",
+ " },\n",
+ " \"crsCode\": proj[\"crs\"],\n",
+ " },\n",
+ " }\n",
+ " \n",
+ " request = dict(REQUEST)\n",
+ " request[\"fileFormat\"] = format\n",
+ " request[\"expression\"] = image\n",
+ " request[\"grid\"][\"affineTransform\"][\"translateX\"] = coords[0] + OFFSET_X\n",
+ " request[\"grid\"][\"affineTransform\"][\"translateY\"] = coords[1] + OFFSET_Y\n",
+ " return np.load(io.BytesIO(ee.data.computePixels(request)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[102.19, -1.54]\n"
+ ]
+ }
+ ],
+ "source": [
+ "id, latlon = 1233804841, [102.19,-1.54]#[-257.82, -1.54]#[-172.3490007781034,-13.523357265222518] #[-60.25204,3.86655]#\n",
+ "print(latlon)\n",
+ "image = get_ee_img(latlon)\n",
+ "patch = get_patch(latlon, image)\n",
+ "# print(patch)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5dc9a303a9e14bb7adf30cd4e5f9f896",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map(center=[-1.54, 102.19000000000001], controls=(WidgetControl(options=['position', 'transparent_bg'], widget…"
+ ]
+ },
+ "execution_count": 45,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import geemap\n",
+ "Map = geemap.Map()\n",
+ "Map.addLayer(image, {\"bands\": [\"R\", \"G\", \"B\"], \"min\": 0, \"max\": 2000}, \"S2\")\n",
+ "Map.addLayer(ee.Geometry.Point(latlon), {\"color\": \"red\"}, \"Centroid\")\n",
+ "Map.centerObject(ee.Geometry.Point(latlon), 18)\n",
+ "Map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[[ 644. 962. 828. 3541. ]\n",
+ " [ 615. 932. 781. 3420. ]\n",
+ " [ 614. 865. 816. 3196. ]\n",
+ " ...\n",
+ " [ 603.5 805.5 795. 3013.5]\n",
+ " [ 582. 807. 770.5 3060.5]\n",
+ " [ 574.5 798. 781. 3071. ]]\n",
+ "\n",
+ " [[ 617. 927. 808. 3608. ]\n",
+ " [ 590. 904. 800. 3384. ]\n",
+ " [ 600.5 870.5 823. 3148. ]\n",
+ " ...\n",
+ " [ 625. 819. 834.5 2952.5]\n",
+ " [ 598. 817. 773. 2983.5]\n",
+ " [ 596.5 816. 782. 2982. ]]\n",
+ "\n",
+ " [[ 636. 920. 824. 3344. ]\n",
+ " [ 622. 912. 808. 3340. ]\n",
+ " [ 626. 878.5 823. 3180. ]\n",
+ " ...\n",
+ " [ 591.5 785. 790. 2993. ]\n",
+ " [ 571. 793. 780. 3107. ]\n",
+ " [ 598. 850. 854. 2998. ]]\n",
+ "\n",
+ " ...\n",
+ "\n",
+ " [[ 788. 1060. 904. 3174. ]\n",
+ " [ 779. 1061. 904. 3190. ]\n",
+ " [ 764. 1126. 932. 3516. ]\n",
+ " ...\n",
+ " [ 580. 754. 669. 3180. ]\n",
+ " [ 594. 748. 672. 3233. ]\n",
+ " [ 546. 754. 722. 3280. ]]\n",
+ "\n",
+ " [[ 807. 1100. 948. 3096. ]\n",
+ " [ 799.5 1106. 949. 3285. ]\n",
+ " [ 757. 1050. 881.5 3664. ]\n",
+ " ...\n",
+ " [ 682. 825.5 639. 2925. ]\n",
+ " [ 616.5 778. 609. 3011.5]\n",
+ " [ 563.5 739. 580.5 3170. ]]\n",
+ "\n",
+ " [[ 860. 1152. 973. 3366. ]\n",
+ " [ 818. 1164. 954. 3552. ]\n",
+ " [ 810. 1058. 913. 3420. ]\n",
+ " ...\n",
+ " [ 698.5 834.5 658.5 3043. ]\n",
+ " [ 654. 813.5 673.5 2957. ]\n",
+ " [ 642. 798. 724. 3111. ]]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import tensorflow as tf\n",
+ "from numpy.lib.recfunctions import structured_to_unstructured\n",
+ "unstruct = structured_to_unstructured(patch)\n",
+ "print(unstruct)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[[0.0644 0.0962 0.0828 0.3541 ]\n",
+ " [0.0615 0.0932 0.0781 0.342 ]\n",
+ " [0.0614 0.0865 0.0816 0.3196 ]\n",
+ " ...\n",
+ " [0.06035 0.08055 0.0795 0.30135]\n",
+ " [0.0582 0.0807 0.07705 0.30605]\n",
+ " [0.05745 0.0798 0.0781 0.3071 ]]\n",
+ "\n",
+ " [[0.0617 0.0927 0.0808 0.3608 ]\n",
+ " [0.059 0.0904 0.08 0.3384 ]\n",
+ " [0.06005 0.08705 0.0823 0.3148 ]\n",
+ " ...\n",
+ " [0.0625 0.0819 0.08345 0.29525]\n",
+ " [0.0598 0.0817 0.0773 0.29835]\n",
+ " [0.05965 0.0816 0.0782 0.2982 ]]\n",
+ "\n",
+ " [[0.0636 0.092 0.0824 0.3344 ]\n",
+ " [0.0622 0.0912 0.0808 0.334 ]\n",
+ " [0.0626 0.08785 0.0823 0.318 ]\n",
+ " ...\n",
+ " [0.05915 0.0785 0.079 0.2993 ]\n",
+ " [0.0571 0.0793 0.078 0.3107 ]\n",
+ " [0.0598 0.085 0.0854 0.2998 ]]\n",
+ "\n",
+ " ...\n",
+ "\n",
+ " [[0.0788 0.106 0.0904 0.3174 ]\n",
+ " [0.0779 0.1061 0.0904 0.319 ]\n",
+ " [0.0764 0.1126 0.0932 0.3516 ]\n",
+ " ...\n",
+ " [0.058 0.0754 0.0669 0.318 ]\n",
+ " [0.0594 0.0748 0.0672 0.3233 ]\n",
+ " [0.0546 0.0754 0.0722 0.328 ]]\n",
+ "\n",
+ " [[0.0807 0.11 0.0948 0.3096 ]\n",
+ " [0.07995 0.1106 0.0949 0.3285 ]\n",
+ " [0.0757 0.105 0.08815 0.3664 ]\n",
+ " ...\n",
+ " [0.0682 0.08255 0.0639 0.2925 ]\n",
+ " [0.06165 0.0778 0.0609 0.30115]\n",
+ " [0.05635 0.0739 0.05805 0.317 ]]\n",
+ "\n",
+ " [[0.086 0.1152 0.0973 0.3366 ]\n",
+ " [0.0818 0.1164 0.0954 0.3552 ]\n",
+ " [0.081 0.1058 0.0913 0.342 ]\n",
+ " ...\n",
+ " [0.06985 0.08345 0.06585 0.3043 ]\n",
+ " [0.0654 0.08135 0.06735 0.2957 ]\n",
+ " [0.0642 0.0798 0.0724 0.3111 ]]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "rescaled = unstruct.astype(np.float64) / 10000\n",
+ "print(rescaled)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[[0.0644 0.0617 0.0636 ... 0.0788 0.0807 0.086 ]\n",
+ " [0.0962 0.0927 0.092 ... 0.106 0.11 0.1152 ]\n",
+ " [0.0828 0.0808 0.0824 ... 0.0904 0.0948 0.0973 ]\n",
+ " [0.3541 0.3608 0.3344 ... 0.3174 0.3096 0.3366 ]]\n",
+ "\n",
+ " [[0.0615 0.059 0.0622 ... 0.0779 0.07995 0.0818 ]\n",
+ " [0.0932 0.0904 0.0912 ... 0.1061 0.1106 0.1164 ]\n",
+ " [0.0781 0.08 0.0808 ... 0.0904 0.0949 0.0954 ]\n",
+ " [0.342 0.3384 0.334 ... 0.319 0.3285 0.3552 ]]\n",
+ "\n",
+ " [[0.0614 0.06005 0.0626 ... 0.0764 0.0757 0.081 ]\n",
+ " [0.0865 0.08705 0.08785 ... 0.1126 0.105 0.1058 ]\n",
+ " [0.0816 0.0823 0.0823 ... 0.0932 0.08815 0.0913 ]\n",
+ " [0.3196 0.3148 0.318 ... 0.3516 0.3664 0.342 ]]\n",
+ "\n",
+ " ...\n",
+ "\n",
+ " [[0.06035 0.0625 0.05915 ... 0.058 0.0682 0.06985]\n",
+ " [0.08055 0.0819 0.0785 ... 0.0754 0.08255 0.08345]\n",
+ " [0.0795 0.08345 0.079 ... 0.0669 0.0639 0.06585]\n",
+ " [0.30135 0.29525 0.2993 ... 0.318 0.2925 0.3043 ]]\n",
+ "\n",
+ " [[0.0582 0.0598 0.0571 ... 0.0594 0.06165 0.0654 ]\n",
+ " [0.0807 0.0817 0.0793 ... 0.0748 0.0778 0.08135]\n",
+ " [0.07705 0.0773 0.078 ... 0.0672 0.0609 0.06735]\n",
+ " [0.30605 0.29835 0.3107 ... 0.3233 0.30115 0.2957 ]]\n",
+ "\n",
+ " [[0.05745 0.05965 0.0598 ... 0.0546 0.05635 0.0642 ]\n",
+ " [0.0798 0.0816 0.085 ... 0.0754 0.0739 0.0798 ]\n",
+ " [0.0781 0.0782 0.0854 ... 0.0722 0.05805 0.0724 ]\n",
+ " [0.3071 0.2982 0.2998 ... 0.328 0.317 0.3111 ]]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "transposed = np.transpose(rescaled, (1, 2, 0))\n",
+ "print(transposed)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# and then not sure if we need this step, converting to a tf.tensor\n",
+ "# image_raw = tf.io.serialize_tensor(\n",
+ "# transposed.astype(np.float32)\n",
+ "# )\n",
+ "# print(image_raw)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[[[0.0644 0.0617 0.0636 0.0658 ]\n",
+ " [0.0651 0.0646 0.0643 0.06795]\n",
+ " [0.0716 0.07045 0.0649 0.0644 ]\n",
+ " ...\n",
+ " [0.3304 0.32775 0.3296 0.334 ]\n",
+ " [0.3416 0.3672 0.384 0.3524 ]\n",
+ " [0.326 0.3174 0.3096 0.3366 ]]\n",
+ "\n",
+ " [[0.0615 0.059 0.0622 0.0643 ]\n",
+ " [0.06375 0.0645 0.0642 0.0644 ]\n",
+ " [0.06665 0.0704 0.0644 0.0637 ]\n",
+ " ...\n",
+ " [0.322 0.33095 0.3356 0.3268 ]\n",
+ " [0.3417 0.3648 0.3728 0.3463 ]\n",
+ " [0.3348 0.319 0.3285 0.3552 ]]\n",
+ "\n",
+ " [[0.0614 0.06005 0.0626 0.06285]\n",
+ " [0.064 0.0626 0.0619 0.0626 ]\n",
+ " [0.0655 0.0696 0.06555 0.0647 ]\n",
+ " ...\n",
+ " [0.321 0.3272 0.3308 0.3284 ]\n",
+ " [0.3413 0.3564 0.3516 0.3428 ]\n",
+ " [0.3376 0.3516 0.3664 0.342 ]]\n",
+ "\n",
+ " ...\n",
+ "\n",
+ " [[0.06035 0.0625 0.05915 0.05825]\n",
+ " [0.0599 0.0612 0.0619 0.0632 ]\n",
+ " [0.0617 0.06035 0.0589 0.0582 ]\n",
+ " ...\n",
+ " [0.2951 0.3009 0.31485 0.3148 ]\n",
+ " [0.3068 0.3204 0.3357 0.3202 ]\n",
+ " [0.3277 0.318 0.2925 0.3043 ]]\n",
+ "\n",
+ " [[0.0582 0.0598 0.0571 0.0574 ]\n",
+ " [0.0606 0.0612 0.0621 0.0624 ]\n",
+ " [0.061 0.05905 0.05725 0.0606 ]\n",
+ " ...\n",
+ " [0.3031 0.3086 0.3176 0.3122 ]\n",
+ " [0.3204 0.3372 0.35 0.3344 ]\n",
+ " [0.3377 0.3233 0.30115 0.2957 ]]\n",
+ "\n",
+ " [[0.05745 0.05965 0.0598 0.0582 ]\n",
+ " [0.0596 0.0616 0.0624 0.0613 ]\n",
+ " [0.0602 0.0582 0.057 0.0606 ]\n",
+ " ...\n",
+ " [0.31725 0.3078 0.3176 0.3304 ]\n",
+ " [0.3344 0.3404 0.3416 0.3596 ]\n",
+ " [0.3548 0.328 0.317 0.3111 ]]]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# how to get numpy array of 4 bands into correct shape for model prediction\n",
+ "reshaped = np.reshape(transposed, (1, 32, 32, 4))\n",
+ "print(reshaped)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model found: resnet\n",
+ "tf.Tensor([[0.42960128]], shape=(1, 1), dtype=float32)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 30-epoch model with 86% binary accuracy\n",
+ "model_name = \"resnet\"\n",
+ "optimizer = \"adam\"\n",
+ "loss_function = \"binary_crossentropy\"\n",
+ "checkpoint = \"C:\\\\fao-models\\\\saved_models\\\\resnet-epochs30-batch64-lr001\\\\best_model.h5\"\n",
+ "\n",
+ "# load several model versions into memory..\n",
+ "model = get_model(model_name, optimizer=optimizer, loss_fn=loss_function)\n",
+ "model.load_weights(checkpoint)\n",
+ "# when mode.trainable was set to False, was getting error at model.load_weights(checkpoint): ValueError: axes don't match array\n",
+ "# https://stackoverflow.com/questions/51944836/keras-load-model-valueerror-axes-dont-match-array\n",
+ "model.trainable = True \n",
+ "\n",
+ "# print(model.summary())\n",
+ "prediction = model(reshaped)\n",
+ "print(prediction)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model found: resnet\n",
+ "tf.Tensor([[0.47269127]], shape=(1, 1), dtype=float32)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 15 epoch model with 98% binary accuracy \n",
+ "model_name = \"resnet\"\n",
+ "optimizer = \"adam\"\n",
+ "loss_function = \"binary_crossentropy\"\n",
+ "checkpoint = \"C:\\\\fao-models\\\\saved_models\\\\resnet-epochs5-batch64-lr001-seed5-lrdecay5\\\\best_model.h5\"\n",
+ "# load several model versions into memory..\n",
+ "model = get_model(model_name, optimizer=optimizer, loss_fn=loss_function)\n",
+ "model.load_weights(checkpoint)\n",
+ "# when mode.trainable was set to False, was getting error at model.load_weights(checkpoint): ValueError: axes don't match array\n",
+ "# https://stackoverflow.com/questions/51944836/keras-load-model-valueerror-axes-dont-match-array\n",
+ "model.trainable = True \n",
+ "\n",
+ "# print(model.summary())\n",
+ "prediction = model(reshaped)\n",
+ "print(prediction)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "gee",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}