-
Notifications
You must be signed in to change notification settings - Fork 715
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[skip test] Add Example for VisionEncoderDecoder
- Loading branch information
Showing
1 changed file
with
265 additions
and
0 deletions.
There are no files selected for viewing
265 changes: 265 additions & 0 deletions
265
examples/python/annotation/image/VisionEncoderDecoderForImageCaptioning.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", | ||
"\n", | ||
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/image/VisionEncoderDecoderForImageCaptioning.ipynb)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## VisionEncoderDecoderForImageCaptioning Annotator" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebok we are going to generate captions for images using spark-nlp. It uses the vision transformer ViT to encode the images and then GPT2 to generate tokens. This model is rather heavy so make sure you have enough RAM and possible use an accelerator such as a GPU." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Downloading Images" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!wget -q https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/images/images.zip" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import shutil\n", | ||
"shutil.unpack_archive(\"images.zip\", \"images\", \"zip\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Start Spark Session" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sparknlp\n", | ||
"from sparknlp.base import *\n", | ||
"from sparknlp.annotator import *\n", | ||
"from pyspark.sql import SparkSession\n", | ||
"\n", | ||
"spark = sparknlp.start()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data_df = spark.read.format(\"image\").option(\"dropInvalid\", value = True).load(path=\"images/images/\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Pipeline with VisionEncoderDecoderForImageCaptioning" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"image_assembler = ImageAssembler() \\\n", | ||
" .setInputCol(\"image\") \\\n", | ||
" .setOutputCol(\"image_assembler\")\n", | ||
"\n", | ||
"image_captioning = VisionEncoderDecoderForImageCaptioning \\\n", | ||
" .pretrained() \\\n", | ||
" .setInputCols([\"image_assembler\"]) \\\n", | ||
" .setOutputCol(\"caption\")\n", | ||
"\n", | ||
"pipeline = Pipeline(stages=[\n", | ||
" image_assembler,\n", | ||
" image_captioning,\n", | ||
"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"+-----------------+---------------------------------------------------------+\n", | ||
"|image_name |result |\n", | ||
"+-----------------+---------------------------------------------------------+\n", | ||
"|palace.JPEG |[a large room filled with furniture and a large window] |\n", | ||
"|egyptian_cat.jpeg|[a cat laying on a couch next to another cat] |\n", | ||
"|hippopotamus.JPEG|[a brown bear in a body of water] |\n", | ||
"|hen.JPEG |[a flock of chickens standing next to each other] |\n", | ||
"|ostrich.JPEG |[a large bird standing on top of a lush green field] |\n", | ||
"|junco.JPEG |[a small bird standing on a wet ground] |\n", | ||
"|bluetick.jpg |[a small dog standing on a wooden floor] |\n", | ||
"|chihuahua.jpg |[a small brown dog wearing a blue sweater] |\n", | ||
"|tractor.JPEG |[a man is standing in a field with a tractor] |\n", | ||
"|ox.JPEG |[a large brown cow standing on top of a lush green field]|\n", | ||
"+-----------------+---------------------------------------------------------+\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = pipeline.fit(data_df)\n", | ||
"image_df = model.transform(data_df)\n", | ||
"image_df \\\n", | ||
" .selectExpr(\"reverse(split(image.origin, '/'))[0] as image_name\", \"caption.result\") \\\n", | ||
" .show(truncate = False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Light Pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use the annotator in a light pipeline, we need to use the new method `fullAnnotateImage`, which can receive 3 kinds of input:\n", | ||
"1. A path to a single image\n", | ||
"2. A path to a list of images" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"dict_keys(['image_assembler', 'caption'])" | ||
] | ||
}, | ||
"execution_count": null, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"light_pipeline = LightPipeline(model)\n", | ||
"annotations_result = light_pipeline.fullAnnotateImage(\"images/images/hippopotamus.JPEG\")\n", | ||
"annotations_result[0].keys()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To process a list of images, we just pass a list of images." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"dict_keys(['image_assembler', 'caption'])" | ||
] | ||
}, | ||
"execution_count": null, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"images = [\"images/images/bluetick.jpg\", \"images/images/palace.JPEG\", \"images/images/hen.JPEG\"]\n", | ||
"annotations_result = light_pipeline.fullAnnotateImage(images)\n", | ||
"annotations_result[0].keys()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[Annotation(document, 0, 37, a small dog standing on a wooden floor, Map(nChannels -> 3, image -> 0, height -> 500, origin -> images/images/bluetick.jpg, mode -> 16, width -> 333), [])]\n", | ||
"[Annotation(document, 0, 52, a large room filled with furniture and a large window, Map(nChannels -> 3, image -> 0, height -> 334, origin -> images/images/palace.JPEG, mode -> 16, width -> 500), [])]\n", | ||
"[Annotation(document, 0, 46, a flock of chickens standing next to each other, Map(nChannels -> 3, image -> 0, height -> 375, origin -> images/images/hen.JPEG, mode -> 16, width -> 500), [])]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for result in annotations_result:\n", | ||
" print(result['caption'])" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |