-
Notifications
You must be signed in to change notification settings - Fork 751
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
379 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,370 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# textgenrnn 1.3 Encoding Text\n", | ||
"by [Max Woolf](http://minimaxir.com)\n", | ||
"\n", | ||
"*Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use.*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Intro\n", | ||
"\n", | ||
"textgenrnn can also be used to generate sentence vectors much more powerful than traditional word vectors.\n", | ||
"\n", | ||
"**IMPORTANT NOTE**: The sentence vectors only account for the first `max_length - 1` tokens. (in the pretrained model, that is the first **39 characters**). If you want more robust sentence vectors, train a new model with a very high `max_length` and/or use word-level training." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using TensorFlow backend.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from textgenrnn import textgenrnn\n", | ||
"\n", | ||
"textgen = textgenrnn()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The function `encode_text_vectors` takes the Attention layer output of the model and can use PCA and TSNE to compress it into a more reasonable size.\n", | ||
"\n", | ||
"The size of the Attention layer is `dim_embeddings + (rnn_size * rnn_layers)`. In the case of the included pretrained model, the size is `100 + (128 * 2) = 356`.\n", | ||
"\n", | ||
"By default, `encode_text_vectors` uses PCA to project and calibrate this high-dimensional output to the number of provided texts, or 50D, whichever is lower." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[[ 4.2585015e+00 -3.2080102e+00 -8.6409906e-03 7.3456152e-07]\n", | ||
" [-3.3668094e+00 -7.7063727e-01 -7.0967728e-01 7.3456215e-07]\n", | ||
" [ 2.1060679e+00 4.5051994e+00 -3.9030526e-02 7.3456152e-07]\n", | ||
" [-2.9977567e+00 -5.2655154e-01 7.5734943e-01 7.3456158e-07]]\n", | ||
"(4, 4)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"texts = ['Never gonna give you up, never gonna let you down',\n", | ||
" 'Never gonna run around and desert you',\n", | ||
" 'Never gonna make you cry, never gonna say goodbye',\n", | ||
" 'Never gonna tell a lie and hurt you']\n", | ||
"\n", | ||
"word_vector = textgen.encode_text_vectors(texts)\n", | ||
"\n", | ||
"print(word_vector)\n", | ||
"print(word_vector.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Additionally, you can pass `tsne_dims` to further project the texts into 2D or 3D; great for data visualization. (NB: t-SNE is a random-seeded algorithm; for consistent output, set `tsne_seed` to make the output deterministic)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[[115.05732 149.46983 ]\n", | ||
" [-35.552177 7.491257]\n", | ||
" [110.7458 3.17162 ]\n", | ||
" [-31.240635 153.78947 ]]\n", | ||
"(4, 2)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"word_vector = textgen.encode_text_vectors(texts, tsne_dims=2, tsne_seed=123)\n", | ||
"\n", | ||
"print(str(word_vector))\n", | ||
"print(word_vector.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"If you want to encode a single text, you'll have to set `pca_dims=None`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[[ 4.78423387e-01 -6.65371776e-01 -1.54521123e-01 \n", | ||
"(1, 356)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"word_vector = textgen.encode_text_vectors(\"What is love?\", pca_dims=None)\n", | ||
"\n", | ||
"print(str(word_vector)[0:50])\n", | ||
"print(word_vector.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"You can also have the model return the `pca` object, which can then be used to learn more about the projection, and/or used in an encoding pipeline to transform any arbitrary text." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"PCA(copy=True, iterated_power='auto', n_components=50, random_state=None,\n", | ||
" svd_solver='auto', tol=0.0, whiten=False)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"word_vector, pca = textgen.encode_text_vectors(texts, return_pca=True)\n", | ||
"\n", | ||
"print(pca)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([5.6863421e-01, 4.1706368e-01, 1.4302170e-02, 2.8613451e-14],\n", | ||
" dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"pca.explained_variance_ratio_" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In this case, 56.9% of the variance is explained by the 1st component, and 98.5% of the variance is explained by the first 2 components." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[[ 0.7431461 -0.30268374 0.30652896 -0.27054822]]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"def transform_text(text, textgen, pca):\n", | ||
" text = textgen.encode_text_vectors(text, pca_dims=None)\n", | ||
" text = pca.transform(text)\n", | ||
" return text\n", | ||
"\n", | ||
"single_encoded_text = transform_text(\"Never gonna give\", textgen, pca)\n", | ||
"\n", | ||
"print(single_encoded_text)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Sentence Vector Similarity" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"For example you could calculate pairwise similarity..." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[[ 0.8607758 -0.78297377 0.04231021 -0.65008384]]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from sklearn.metrics.pairwise import cosine_similarity\n", | ||
"\n", | ||
"word_vectors = textgen.encode_text_vectors(texts)\n", | ||
"similarity = cosine_similarity(single_encoded_text, word_vectors)\n", | ||
"\n", | ||
"print(similarity)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"...or use textgenrnn's native similarity metrics!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[('Never gonna give you up, never gonna let you down', 0.8607758),\n", | ||
" ('Never gonna make you cry, never gonna say goodbye', 0.042310208),\n", | ||
" ('Never gonna tell a lie and hurt you', -0.65008384),\n", | ||
" ('Never gonna run around and desert you', -0.78297377)]" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"textgen.similarity(\"Never gonna give\", texts)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"By default similarity is calculated using the PCA-transformed values, but you can calculate similarity on the raw values as well if needed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[('Never gonna tell a lie and hurt you', 0.18147705),\n", | ||
" ('Never gonna run around and desert you', 0.17993625),\n", | ||
" ('Never gonna give you up, never gonna let you down', 0.17391011),\n", | ||
" ('Never gonna make you cry, never gonna say goodbye', 0.053340655)]" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"textgen.similarity(\"Never gonna give\", texts, use_pca=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# LICENSE\n", | ||
"\n", | ||
"MIT License\n", | ||
"\n", | ||
"Copyright (c) 2018 Max Woolf\n", | ||
"\n", | ||
"Permission is hereby granted, free of charge, to any person obtaining a copy\n", | ||
"of this software and associated documentation files (the \"Software\"), to deal\n", | ||
"in the Software without restriction, including without limitation the rights\n", | ||
"to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", | ||
"copies of the Software, and to permit persons to whom the Software is\n", | ||
"furnished to do so, subject to the following conditions:\n", | ||
"\n", | ||
"The above copyright notice and this permission notice shall be included in all\n", | ||
"copies or substantial portions of the Software.\n", | ||
"\n", | ||
"THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", | ||
"IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", | ||
"FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", | ||
"AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", | ||
"LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", | ||
"OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", | ||
"SOFTWARE." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.