diff --git a/fao_models/beam_pipelines/__init__.py b/fao_models/beam_pipelines/__init__.py new file mode 100644 index 0000000..9e57ea5 --- /dev/null +++ b/fao_models/beam_pipelines/__init__.py @@ -0,0 +1,8 @@ +import sys +import os + +# Get the parent directory path +parent_dir = os.path.dirname(os.path.abspath(__file__)) + +# Add the parent directory to the Python module search path +sys.path.append(parent_dir) diff --git a/fao_models/beam_pipelines/test_initialPCollection.py b/fao_models/beam_pipelines/test_initialPCollection.py new file mode 100644 index 0000000..c3e94b8 --- /dev/null +++ b/fao_models/beam_pipelines/test_initialPCollection.py @@ -0,0 +1,43 @@ +#%% +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions +import unittest +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +import logging +import geopandas +logger = logging.getLogger(__name__) + +# load gdf and compute centroid from geometry +gdf = geopandas.read_file('C:\\Users\\kyle\\Downloads\\ALL_centroids_completed_v1_\\ALL_centroids_completed_v1_.shp') +gdf.loc[:,'centroid'] = gdf.geometry.centroid +print(gdf.head()) +#%% +# convert centroid (a GeoSeries geometry), to a native python list of lat,lon) +gdf.loc[:,'latlon'] = gdf.centroid.apply(lambda x: [x.y, x.x]) +print(gdf.dtypes) +print(gdf.head()) +#%% +# construct list of global_id, latlon tuples for the pipeline +features = gdf[['global_id', 'latlon']].values.tolist() +print(features[:5]) + + +#%% +# https://beam.apache.org/documentation/pipelines/test-your-pipeline/#testing-transforms +expected_output = features[:5] +def test_pipe(argv=None, save_main_session=True): + """Main entry point;""" + # read in a gdf and construct begnning PCollection from gdf in-memory + + # do we need to convert each record in gdf to a list or dict? + with TestPipeline(runner=beam.runners.DirectRunner()) as p: + + pipe_features = p | beam.Create(features[:5]) # if you change this to features[:6] the test will raise AssertionError + assert_that(pipe_features,equal_to(expected_output), label='check features') + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + test_pipe() diff --git a/fao_models/beam_pipelines/test_pyshp.py b/fao_models/beam_pipelines/test_pyshp.py new file mode 100644 index 0000000..0d3bc79 --- /dev/null +++ b/fao_models/beam_pipelines/test_pyshp.py @@ -0,0 +1,7 @@ +# works but i don't think we'll be able ot make centroid lat lon with this package easily +import shapefile + +input_file = 'C:\\Users\\kyle\\Downloads\\ALL_centroids_completed_v1_\\ALL_centroids_completed_v1_.shp' +sf = shapefile.Reader(input_file) +print(sf.fields) +print(sf.records()[0:10]) \ No newline at end of file diff --git a/fao_models/test_inference_pipeline_steps.ipynb b/fao_models/test_inference_pipeline_steps.ipynb new file mode 100644 index 0000000..be13fbc --- /dev/null +++ b/fao_models/test_inference_pipeline_steps.ipynb @@ -0,0 +1,890 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "import ee\n", + "import io\n", + "from google.api_core import exceptions, retry\n", + "import google.auth\n", + "from models import get_model\n", + "\n", + "PROJECT = \"pc530-fao-fra-rss\" # change to your cloud project name\n", + "# ee.Initialize(project=PROJECT)\n", + "\n", + "## INIT WITH HIGH VOLUME ENDPOINT\n", + "credentials, _ = google.auth.default()\n", + "ee.Initialize(\n", + "credentials,\n", + "project=PROJECT,\n", + "opt_url=\"https://earthengine-highvolume.googleapis.com\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_ee_img(coords):\n", + " \"\"\"retrieve s2 image composite from ee at given coordinates. coords is a tuple of (lon, lat) in degrees.\"\"\"\n", + " ## MAKE S2 COMPOSITE IN HEXAGONS ##########################################\n", + " # Using Cloud Score + for cloud/cloud-shadow masking\n", + " # Harmonized Sentinel-2 Level 2A collection.\n", + " s2 = ee.ImageCollection(\"COPERNICUS/S2_SR_HARMONIZED\")\n", + "\n", + " # Cloud Score+ image collection. Note Cloud Score+ is produced from Sentinel-2\n", + " # Level 1C data and can be applied to either L1C or L2A collections.\n", + " csPlus = ee.ImageCollection(\"GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED\")\n", + "\n", + " # Use 'cs' or 'cs_cdf', depending on your use case; see docs for guidance.\n", + " QA_BAND = \"cs_cdf\"\n", + "\n", + " # The threshold for masking; values between 0.50 and 0.65 generally work well.\n", + " # Higher values will remove thin clouds, haze & cirrus shadows.\n", + " CLEAR_THRESHOLD = 0.50\n", + "\n", + " # Make a clear median composite.\n", + " sampleImage = (\n", + " s2.filterDate(\"2023-01-01\", \"2023-12-31\")\n", + " .filterBounds(ee.Geometry.Point(coords[0], coords[1]).buffer(64*10)) # only images touching 64 pixel centroid buffer\n", + " .linkCollection(csPlus, [QA_BAND])\n", + " .map(lambda img: img.updateMask(img.select(QA_BAND).gte(CLEAR_THRESHOLD)))\n", + " .median()\n", + " .select([\"B4\", \"B3\", \"B2\", \"B8\"], [\"R\", \"G\", \"B\", \"N\"])\n", + " )\n", + " return sampleImage\n", + "\n", + "@retry.Retry()\n", + "def get_patch(coords, image, format=\"NPY\"):\n", + " \"\"\"Uses ee.data.ComputePixels() to get a 32x32 patch centered on the coordinates, as a numpy array.\"\"\"\n", + " \n", + " # Output resolution in meters.\n", + " SCALE = 10\n", + "\n", + " # Pre-compute a geographic coordinate system.\n", + " proj = ee.Projection(\"EPSG:4326\").atScale(SCALE).getInfo()\n", + "\n", + " # Get scales in degrees out of the transform.\n", + " SCALE_X = proj[\"transform\"][0]\n", + " SCALE_Y = -proj[\"transform\"][4]\n", + "\n", + " # Patch size in pixels.\n", + " PATCH_SIZE = 32\n", + "\n", + " # Offset to the upper left corner.\n", + " OFFSET_X = -SCALE_X * PATCH_SIZE / 2\n", + " OFFSET_Y = -SCALE_Y * PATCH_SIZE / 2\n", + " \n", + " REQUEST = {\n", + " \"fileFormat\": \"NPY\",\n", + " \"grid\": {\n", + " \"dimensions\": {\"width\": PATCH_SIZE, \"height\": PATCH_SIZE},\n", + " \"affineTransform\": {\n", + " \"scaleX\": SCALE_X,\n", + " \"shearX\": 0,\n", + " \"shearY\": 0,\n", + " \"scaleY\": SCALE_Y,\n", + " },\n", + " \"crsCode\": proj[\"crs\"],\n", + " },\n", + " }\n", + " \n", + " request = dict(REQUEST)\n", + " request[\"fileFormat\"] = format\n", + " request[\"expression\"] = image\n", + " request[\"grid\"][\"affineTransform\"][\"translateX\"] = coords[0] + OFFSET_X\n", + " request[\"grid\"][\"affineTransform\"][\"translateY\"] = coords[1] + OFFSET_Y\n", + " return np.load(io.BytesIO(ee.data.computePixels(request)))" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[102.19, -1.54]\n" + ] + } + ], + "source": [ + "id, latlon = 1233804841, [102.19,-1.54]#[-257.82, -1.54]#[-172.3490007781034,-13.523357265222518] #[-60.25204,3.86655]#\n", + "print(latlon)\n", + "image = get_ee_img(latlon)\n", + "patch = get_patch(latlon, image)\n", + "# print(patch)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5dc9a303a9e14bb7adf30cd4e5f9f896", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map(center=[-1.54, 102.19000000000001], controls=(WidgetControl(options=['position', 'transparent_bg'], widget…" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import geemap\n", + "Map = geemap.Map()\n", + "Map.addLayer(image, {\"bands\": [\"R\", \"G\", \"B\"], \"min\": 0, \"max\": 2000}, \"S2\")\n", + "Map.addLayer(ee.Geometry.Point(latlon), {\"color\": \"red\"}, \"Centroid\")\n", + "Map.centerObject(ee.Geometry.Point(latlon), 18)\n", + "Map" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[ 644. 962. 828. 3541. ]\n", + " [ 615. 932. 781. 3420. ]\n", + " [ 614. 865. 816. 3196. ]\n", + " ...\n", + " [ 603.5 805.5 795. 3013.5]\n", + " [ 582. 807. 770.5 3060.5]\n", + " [ 574.5 798. 781. 3071. ]]\n", + "\n", + " [[ 617. 927. 808. 3608. ]\n", + " [ 590. 904. 800. 3384. ]\n", + " [ 600.5 870.5 823. 3148. ]\n", + " ...\n", + " [ 625. 819. 834.5 2952.5]\n", + " [ 598. 817. 773. 2983.5]\n", + " [ 596.5 816. 782. 2982. ]]\n", + "\n", + " [[ 636. 920. 824. 3344. ]\n", + " [ 622. 912. 808. 3340. ]\n", + " [ 626. 878.5 823. 3180. ]\n", + " ...\n", + " [ 591.5 785. 790. 2993. ]\n", + " [ 571. 793. 780. 3107. ]\n", + " [ 598. 850. 854. 2998. ]]\n", + "\n", + " ...\n", + "\n", + " [[ 788. 1060. 904. 3174. ]\n", + " [ 779. 1061. 904. 3190. ]\n", + " [ 764. 1126. 932. 3516. ]\n", + " ...\n", + " [ 580. 754. 669. 3180. ]\n", + " [ 594. 748. 672. 3233. ]\n", + " [ 546. 754. 722. 3280. ]]\n", + "\n", + " [[ 807. 1100. 948. 3096. ]\n", + " [ 799.5 1106. 949. 3285. ]\n", + " [ 757. 1050. 881.5 3664. ]\n", + " ...\n", + " [ 682. 825.5 639. 2925. ]\n", + " [ 616.5 778. 609. 3011.5]\n", + " [ 563.5 739. 580.5 3170. ]]\n", + "\n", + " [[ 860. 1152. 973. 3366. ]\n", + " [ 818. 1164. 954. 3552. ]\n", + " [ 810. 1058. 913. 3420. ]\n", + " ...\n", + " [ 698.5 834.5 658.5 3043. ]\n", + " [ 654. 813.5 673.5 2957. ]\n", + " [ 642. 798. 724. 3111. ]]]\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "from numpy.lib.recfunctions import structured_to_unstructured\n", + "unstruct = structured_to_unstructured(patch)\n", + "print(unstruct)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[0.0644 0.0962 0.0828 0.3541 ]\n", + " [0.0615 0.0932 0.0781 0.342 ]\n", + " [0.0614 0.0865 0.0816 0.3196 ]\n", + " ...\n", + " [0.06035 0.08055 0.0795 0.30135]\n", + " [0.0582 0.0807 0.07705 0.30605]\n", + " [0.05745 0.0798 0.0781 0.3071 ]]\n", + "\n", + " [[0.0617 0.0927 0.0808 0.3608 ]\n", + " [0.059 0.0904 0.08 0.3384 ]\n", + " [0.06005 0.08705 0.0823 0.3148 ]\n", + " ...\n", + " [0.0625 0.0819 0.08345 0.29525]\n", + " [0.0598 0.0817 0.0773 0.29835]\n", + " [0.05965 0.0816 0.0782 0.2982 ]]\n", + "\n", + " [[0.0636 0.092 0.0824 0.3344 ]\n", + " [0.0622 0.0912 0.0808 0.334 ]\n", + " [0.0626 0.08785 0.0823 0.318 ]\n", + " ...\n", + " [0.05915 0.0785 0.079 0.2993 ]\n", + " [0.0571 0.0793 0.078 0.3107 ]\n", + " [0.0598 0.085 0.0854 0.2998 ]]\n", + "\n", + " ...\n", + "\n", + " [[0.0788 0.106 0.0904 0.3174 ]\n", + " [0.0779 0.1061 0.0904 0.319 ]\n", + " [0.0764 0.1126 0.0932 0.3516 ]\n", + " ...\n", + " [0.058 0.0754 0.0669 0.318 ]\n", + " [0.0594 0.0748 0.0672 0.3233 ]\n", + " [0.0546 0.0754 0.0722 0.328 ]]\n", + "\n", + " [[0.0807 0.11 0.0948 0.3096 ]\n", + " [0.07995 0.1106 0.0949 0.3285 ]\n", + " [0.0757 0.105 0.08815 0.3664 ]\n", + " ...\n", + " [0.0682 0.08255 0.0639 0.2925 ]\n", + " [0.06165 0.0778 0.0609 0.30115]\n", + " [0.05635 0.0739 0.05805 0.317 ]]\n", + "\n", + " [[0.086 0.1152 0.0973 0.3366 ]\n", + " [0.0818 0.1164 0.0954 0.3552 ]\n", + " [0.081 0.1058 0.0913 0.342 ]\n", + " ...\n", + " [0.06985 0.08345 0.06585 0.3043 ]\n", + " [0.0654 0.08135 0.06735 0.2957 ]\n", + " [0.0642 0.0798 0.0724 0.3111 ]]]\n" + ] + } + ], + "source": [ + "rescaled = unstruct.astype(np.float64) / 10000\n", + "print(rescaled)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[0.0644 0.0617 0.0636 ... 0.0788 0.0807 0.086 ]\n", + " [0.0962 0.0927 0.092 ... 0.106 0.11 0.1152 ]\n", + " [0.0828 0.0808 0.0824 ... 0.0904 0.0948 0.0973 ]\n", + " [0.3541 0.3608 0.3344 ... 0.3174 0.3096 0.3366 ]]\n", + "\n", + " [[0.0615 0.059 0.0622 ... 0.0779 0.07995 0.0818 ]\n", + " [0.0932 0.0904 0.0912 ... 0.1061 0.1106 0.1164 ]\n", + " [0.0781 0.08 0.0808 ... 0.0904 0.0949 0.0954 ]\n", + " [0.342 0.3384 0.334 ... 0.319 0.3285 0.3552 ]]\n", + "\n", + " [[0.0614 0.06005 0.0626 ... 0.0764 0.0757 0.081 ]\n", + " [0.0865 0.08705 0.08785 ... 0.1126 0.105 0.1058 ]\n", + " [0.0816 0.0823 0.0823 ... 0.0932 0.08815 0.0913 ]\n", + " [0.3196 0.3148 0.318 ... 0.3516 0.3664 0.342 ]]\n", + "\n", + " ...\n", + "\n", + " [[0.06035 0.0625 0.05915 ... 0.058 0.0682 0.06985]\n", + " [0.08055 0.0819 0.0785 ... 0.0754 0.08255 0.08345]\n", + " [0.0795 0.08345 0.079 ... 0.0669 0.0639 0.06585]\n", + " [0.30135 0.29525 0.2993 ... 0.318 0.2925 0.3043 ]]\n", + "\n", + " [[0.0582 0.0598 0.0571 ... 0.0594 0.06165 0.0654 ]\n", + " [0.0807 0.0817 0.0793 ... 0.0748 0.0778 0.08135]\n", + " [0.07705 0.0773 0.078 ... 0.0672 0.0609 0.06735]\n", + " [0.30605 0.29835 0.3107 ... 0.3233 0.30115 0.2957 ]]\n", + "\n", + " [[0.05745 0.05965 0.0598 ... 0.0546 0.05635 0.0642 ]\n", + " [0.0798 0.0816 0.085 ... 0.0754 0.0739 0.0798 ]\n", + " [0.0781 0.0782 0.0854 ... 0.0722 0.05805 0.0724 ]\n", + " [0.3071 0.2982 0.2998 ... 0.328 0.317 0.3111 ]]]\n" + ] + } + ], + "source": [ + "transposed = np.transpose(rescaled, (1, 2, 0))\n", + "print(transposed)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# and then not sure if we need this step, converting to a tf.tensor\n", + "# image_raw = tf.io.serialize_tensor(\n", + "# transposed.astype(np.float32)\n", + "# )\n", + "# print(image_raw)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[[0.0644 0.0617 0.0636 0.0658 ]\n", + " [0.0651 0.0646 0.0643 0.06795]\n", + " [0.0716 0.07045 0.0649 0.0644 ]\n", + " ...\n", + " [0.3304 0.32775 0.3296 0.334 ]\n", + " [0.3416 0.3672 0.384 0.3524 ]\n", + " [0.326 0.3174 0.3096 0.3366 ]]\n", + "\n", + " [[0.0615 0.059 0.0622 0.0643 ]\n", + " [0.06375 0.0645 0.0642 0.0644 ]\n", + " [0.06665 0.0704 0.0644 0.0637 ]\n", + " ...\n", + " [0.322 0.33095 0.3356 0.3268 ]\n", + " [0.3417 0.3648 0.3728 0.3463 ]\n", + " [0.3348 0.319 0.3285 0.3552 ]]\n", + "\n", + " [[0.0614 0.06005 0.0626 0.06285]\n", + " [0.064 0.0626 0.0619 0.0626 ]\n", + " [0.0655 0.0696 0.06555 0.0647 ]\n", + " ...\n", + " [0.321 0.3272 0.3308 0.3284 ]\n", + " [0.3413 0.3564 0.3516 0.3428 ]\n", + " [0.3376 0.3516 0.3664 0.342 ]]\n", + "\n", + " ...\n", + "\n", + " [[0.06035 0.0625 0.05915 0.05825]\n", + " [0.0599 0.0612 0.0619 0.0632 ]\n", + " [0.0617 0.06035 0.0589 0.0582 ]\n", + " ...\n", + " [0.2951 0.3009 0.31485 0.3148 ]\n", + " [0.3068 0.3204 0.3357 0.3202 ]\n", + " [0.3277 0.318 0.2925 0.3043 ]]\n", + "\n", + " [[0.0582 0.0598 0.0571 0.0574 ]\n", + " [0.0606 0.0612 0.0621 0.0624 ]\n", + " [0.061 0.05905 0.05725 0.0606 ]\n", + " ...\n", + " [0.3031 0.3086 0.3176 0.3122 ]\n", + " [0.3204 0.3372 0.35 0.3344 ]\n", + " [0.3377 0.3233 0.30115 0.2957 ]]\n", + "\n", + " [[0.05745 0.05965 0.0598 0.0582 ]\n", + " [0.0596 0.0616 0.0624 0.0613 ]\n", + " [0.0602 0.0582 0.057 0.0606 ]\n", + " ...\n", + " [0.31725 0.3078 0.3176 0.3304 ]\n", + " [0.3344 0.3404 0.3416 0.3596 ]\n", + " [0.3548 0.328 0.317 0.3111 ]]]]\n" + ] + } + ], + "source": [ + "# how to get numpy array of 4 bands into correct shape for model prediction\n", + "reshaped = np.reshape(transposed, (1, 32, 32, 4))\n", + "print(reshaped)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model found: resnet\n", + "tf.Tensor([[0.42960128]], shape=(1, 1), dtype=float32)\n" + ] + } + ], + "source": [ + "# 30-epoch model with 86% binary accuracy\n", + "model_name = \"resnet\"\n", + "optimizer = \"adam\"\n", + "loss_function = \"binary_crossentropy\"\n", + "checkpoint = \"C:\\\\fao-models\\\\saved_models\\\\resnet-epochs30-batch64-lr001\\\\best_model.h5\"\n", + "\n", + "# load several model versions into memory..\n", + "model = get_model(model_name, optimizer=optimizer, loss_fn=loss_function)\n", + "model.load_weights(checkpoint)\n", + "# when mode.trainable was set to False, was getting error at model.load_weights(checkpoint): ValueError: axes don't match array\n", + "# https://stackoverflow.com/questions/51944836/keras-load-model-valueerror-axes-dont-match-array\n", + "model.trainable = True \n", + "\n", + "# print(model.summary())\n", + "prediction = model(reshaped)\n", + "print(prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model found: resnet\n", + "tf.Tensor([[0.47269127]], shape=(1, 1), dtype=float32)\n" + ] + } + ], + "source": [ + "# 15 epoch model with 98% binary accuracy \n", + "model_name = \"resnet\"\n", + "optimizer = \"adam\"\n", + "loss_function = \"binary_crossentropy\"\n", + "checkpoint = \"C:\\\\fao-models\\\\saved_models\\\\resnet-epochs5-batch64-lr001-seed5-lrdecay5\\\\best_model.h5\"\n", + "# load several model versions into memory..\n", + "model = get_model(model_name, optimizer=optimizer, loss_fn=loss_function)\n", + "model.load_weights(checkpoint)\n", + "# when mode.trainable was set to False, was getting error at model.load_weights(checkpoint): ValueError: axes don't match array\n", + "# https://stackoverflow.com/questions/51944836/keras-load-model-valueerror-axes-dont-match-array\n", + "model.trainable = True \n", + "\n", + "# print(model.summary())\n", + "prediction = model(reshaped)\n", + "print(prediction)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "gee", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}