diff --git a/load-keyword-extraction.ipynb b/load-keyword-extraction.ipynb
new file mode 100644
index 00000000..ff473315
--- /dev/null
+++ b/load-keyword-extraction.ipynb
@@ -0,0 +1,730 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Keyword Extraction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "This tutorial is available as an IPython notebook at [Malaya/example/keyword-extraction](https://github.com/huseinzol05/Malaya/tree/master/example/keyword-extraction).\n",
+ " \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import malaya"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# https://www.bharian.com.my/berita/nasional/2020/06/698386/isu-bersatu-tun-m-6-yang-lain-saman-muhyiddin\n",
+ "\n",
+ "string = \"\"\"\n",
+ "Dalam saman itu, plaintif memohon perisytiharan, antaranya mereka adalah ahli BERSATU yang sah, masih lagi memegang jawatan dalam parti (bagi pemegang jawatan) dan layak untuk bertanding pada pemilihan parti.\n",
+ "\n",
+ "Mereka memohon perisytiharan bahawa semua surat pemberhentian yang ditandatangani Muhammad Suhaimi bertarikh 28 Mei lalu dan pengesahan melalui mesyuarat Majlis Pimpinan Tertinggi (MPT) parti bertarikh 4 Jun lalu adalah tidak sah dan terbatal.\n",
+ "\n",
+ "Plaintif juga memohon perisytiharan bahawa keahlian Muhyiddin, Hamzah dan Muhammad Suhaimi di dalam BERSATU adalah terlucut, berkuat kuasa pada 28 Februari 2020 dan/atau 29 Februari 2020, menurut Fasal 10.2.3 perlembagaan parti.\n",
+ "\n",
+ "Yang turut dipohon, perisytiharan bahawa Seksyen 18C Akta Pertubuhan 1966 adalah tidak terpakai untuk menghalang pelupusan pertikaian berkenaan oleh mahkamah.\n",
+ "\n",
+ "Perisytiharan lain ialah Fasal 10.2.6 Perlembagaan BERSATU tidak terpakai di atas hal melucutkan/ memberhentikan keahlian semua plaintif.\n",
+ "\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "\n",
+ "# minimum cleaning, just simply to remove newlines.\n",
+ "def cleaning(string):\n",
+ " string = string.replace('\\n', ' ')\n",
+ " string = re.sub('[^A-Za-z\\-() ]+', ' ', string).strip()\n",
+ " string = re.sub(r'[ ]+', ' ', string).strip()\n",
+ " return string\n",
+ "\n",
+ "string = cleaning(string)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Use RAKE algorithm\n",
+ "\n",
+ "Original implementation from [https://github.com/aneesha/RAKE](https://github.com/aneesha/RAKE). Malaya added attention mechanism into RAKE algorithm.\n",
+ "\n",
+ "```python\n",
+ "def rake(\n",
+ " string: str,\n",
+ " model = None,\n",
+ " top_k: int = 5,\n",
+ " auto_ngram: bool = True,\n",
+ " ngram_method: str = 'bow',\n",
+ " ngram: Tuple[int, int] = (1, 1),\n",
+ " atleast: int = 1,\n",
+ " stop_words: List[str] = STOPWORDS,\n",
+ " **kwargs\n",
+ "):\n",
+ " \"\"\"\n",
+ " Extract keywords using Rake algorithm.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " string: str\n",
+ " model: Object, optional (default='None')\n",
+ " Transformer model or any model has `attention` method.\n",
+ " top_k: int, optional (default=5)\n",
+ " return top-k results.\n",
+ " auto_ngram: bool, optional (default=True)\n",
+ " If True, will generate keyword candidates using N suitable ngram. Else use `ngram_method`.\n",
+ " ngram_method: str, optional (default='bow')\n",
+ " Only usable if `auto_ngram` is False. supported ngram generator:\n",
+ "\n",
+ " * ``'bow'`` - bag-of-word.\n",
+ " * ``'skipgram'`` - bag-of-word with skip technique.\n",
+ " ngram: tuple, optional (default=(1,1))\n",
+ " n-grams size.\n",
+ " atleast: int, optional (default=1)\n",
+ " at least count appeared in the string to accept as candidate.\n",
+ " stop_words: list, (default=malaya.text.function.STOPWORDS)\n",
+ " list of stop words to remove. \n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " result: Tuple[float, str]\n",
+ " \"\"\"\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### auto-ngram\n",
+ "\n",
+ "This will auto generated N-size ngram for keyword candidates."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.11666666666666665, 'ditandatangani Muhammad Suhaimi bertarikh Mei'),\n",
+ " (0.08888888888888888, 'mesyuarat Majlis Pimpinan Tertinggi'),\n",
+ " (0.08888888888888888, 'Seksyen C Akta Pertubuhan'),\n",
+ " (0.05138888888888888, 'parti bertarikh Jun'),\n",
+ " (0.04999999999999999, 'keahlian Muhyiddin Hamzah')]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.rake(string)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### auto-gram with Attention\n",
+ "\n",
+ "This will use attention mechanism as the scores. I will use `small-electra` in this example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:56: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/modeling.py:240: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use keras.layers.Dense instead.\n",
+ "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Please use `layer.__call__` method instead.\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:79: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:93: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/sampling.py:26: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:115: multinomial (from tensorflow.python.ops.random_ops) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use `tf.random.categorical` instead.\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:118: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:119: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:121: The name tf.get_collection is deprecated. Please use tf.compat.v1.get_collection instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:122: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:128: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/transformers/electra/__init__.py:130: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
+ "\n",
+ "INFO:tensorflow:Restoring parameters from /Users/huseinzolkepli/Malaya/electra-model/small/electra-small/model.ckpt\n"
+ ]
+ }
+ ],
+ "source": [
+ "electra = malaya.transformer.load(model = 'small-electra')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.2113546236771915, 'ditandatangani Muhammad Suhaimi bertarikh Mei'),\n",
+ " (0.1707678455680971, 'terlucut berkuat kuasa'),\n",
+ " (0.16650756665229807, 'Muhammad Suhaimi'),\n",
+ " (0.1620429894692799, 'mesyuarat Majlis Pimpinan Tertinggi'),\n",
+ " (0.08333952583953884, 'Seksyen C Akta Pertubuhan')]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.rake(string, model = electra)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### fixed-ngram"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.0010991603139160087, 'parti memohon perisytiharan'),\n",
+ " (0.0010989640254270869, 'memohon perisytiharan Muhammad'),\n",
+ " (0.0010985209375133323, 'perisytiharan Muhammad Suhaimi'),\n",
+ " (0.0010972572356757605, 'memohon perisytiharan BERSATU'),\n",
+ " (0.0010970435210070695, 'memohon perisytiharan sah')]"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.rake(string, auto_ngram = False, ngram = (1, 3), \n",
+ " ngram_method = 'skipgram', skip = 3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### fixed-ngram with Attention"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.007511555412415397, 'Suhaimi terlucut kuasa'),\n",
+ " (0.00726812348703141, 'Suhaimi terlucut Februari'),\n",
+ " (0.00725420955956774, 'Suhaimi terlucut berkuat'),\n",
+ " (0.007235384019369932, 'Muhyiddin Suhaimi terlucut'),\n",
+ " (0.00721164037502389, 'Hamzah Suhaimi terlucut')]"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.rake(string, model = electra, auto_ngram = False, ngram = (1, 3), \n",
+ " ngram_method = 'skipgram', skip = 3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Use Textrank algorithm\n",
+ "\n",
+ "Malaya simply use textrank algorithm from networkx library.\n",
+ "\n",
+ "```python\n",
+ "def textrank(\n",
+ " string: str,\n",
+ " vectorizer,\n",
+ " top_k: int = 5,\n",
+ " auto_ngram: bool = True,\n",
+ " ngram_method: str = 'bow',\n",
+ " ngram: Tuple[int, int] = (1, 1),\n",
+ " atleast: int = 1,\n",
+ " stop_words: List[str] = STOPWORDS,\n",
+ " **kwargs\n",
+ "):\n",
+ " \"\"\"\n",
+ " Extract keywords using Textrank algorithm.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " string: str\n",
+ " vectorizer: Object, optional (default='None')\n",
+ " model has `fit_transform` or `vectorize` method.\n",
+ " top_k: int, optional (default=5)\n",
+ " return top-k results.\n",
+ " auto_ngram: bool, optional (default=True)\n",
+ " If True, will generate keyword candidates using N suitable ngram. Else use `ngram_method`.\n",
+ " ngram_method: str, optional (default='bow')\n",
+ " Only usable if `auto_ngram` is False. supported ngram generator:\n",
+ "\n",
+ " * ``'bow'`` - bag-of-word.\n",
+ " * ``'skipgram'`` - bag-of-word with skip technique.\n",
+ " ngram: tuple, optional (default=(1,1))\n",
+ " n-grams size.\n",
+ " atleast: int, optional (default=1)\n",
+ " at least count appeared in the string to accept as candidate.\n",
+ " stop_words: list, (default=malaya.text.function.STOPWORDS)\n",
+ " list of stop words to remove. \n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " result: Tuple[float, str]\n",
+ " \"\"\"\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.feature_extraction.text import TfidfVectorizer\n",
+ "tfidf = TfidfVectorizer()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### auto-ngram with TFIDF\n",
+ "\n",
+ "This will auto generated N-size ngram for keyword candidates."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.00015733542115111895, 'plaintif memohon perisytiharan'),\n",
+ " (0.00012558589872969095, 'Fasal perlembagaan parti'),\n",
+ " (0.00011512878779574369, 'Fasal Perlembagaan BERSATU'),\n",
+ " (0.00011505807280697136, 'parti'),\n",
+ " (0.00010763518916902933, 'memohon perisytiharan')]"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.textrank(string, vectorizer = tfidf)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### auto-ngram with Attention\n",
+ "\n",
+ "This will auto generated N-size ngram for keyword candidates."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:Restoring parameters from /Users/huseinzolkepli/Malaya/electra-model/small/electra-small/model.ckpt\n",
+ "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
+ "\n",
+ "INFO:tensorflow:loading sentence piece model\n",
+ "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/albert/modeling.py:116: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/albert/modeling.py:588: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/albert/modeling.py:1025: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
+ "\n",
+ "INFO:tensorflow:Restoring parameters from /Users/huseinzolkepli/Malaya/albert-model/base/albert-base/model.ckpt\n"
+ ]
+ }
+ ],
+ "source": [
+ "electra = malaya.transformer.load(model = 'small-electra')\n",
+ "albert = malaya.transformer.load(model = 'albert')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(6.3182663025223e-05, 'dipohon perisytiharan'),\n",
+ " (6.31674674645778e-05, 'pemegang jawatan'),\n",
+ " (6.316119389302752e-05, 'parti bertarikh Jun'),\n",
+ " (6.316104723812124e-05, 'Februari'),\n",
+ " (6.315819355276039e-05, 'plaintif')]"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.textrank(string, vectorizer = electra)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(7.94645241452814e-05, 'Fasal Perlembagaan BERSATU'),\n",
+ " (7.728400390215039e-05, 'mesyuarat Majlis Pimpinan Tertinggi'),\n",
+ " (7.506390584039057e-05, 'Muhammad Suhaimi'),\n",
+ " (7.503252483650059e-05, 'pengesahan'),\n",
+ " (7.502407753712274e-05, 'terbatal Plaintif')]"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.textrank(string, vectorizer = albert)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Or you can use any classification model to find keywords sensitive towards to specific domain**."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:root:Load quantized model will cause accuracy drop.\n"
+ ]
+ }
+ ],
+ "source": [
+ "sentiment = malaya.sentiment.transformer(model = 'xlnet', quantized = True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(6.621038516352373e-05, 'plaintif memohon perisytiharan'),\n",
+ " (6.61143060050603e-05, 'ditandatangani Muhammad Suhaimi bertarikh Mei'),\n",
+ " (6.517221024654814e-05, 'terbatal Plaintif'),\n",
+ " (6.469109066728589e-05, 'terlucut berkuat kuasa'),\n",
+ " (6.450719772460985e-05, 'pengesahan')]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.textrank(string, vectorizer = sentiment)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### fixed-ngram with Attention"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(1.7071539462023998e-09, 'perisytiharan ahli sah'),\n",
+ " (1.7071528386679705e-09, 'Fasal parti perisytiharan'),\n",
+ " (1.7071498274826471e-09, 'Plaintif perisytiharan keahlian'),\n",
+ " (1.7071355361007092e-09, 'Fasal dipohon perisytiharan'),\n",
+ " (1.707130673312775e-09, 'plaintif perisytiharan')]"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.textrank(string, vectorizer = electra, auto_ngram = False,\n",
+ " ngram = (1, 3), ngram_method = 'skipgram', skip = 3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(2.1995491577326747e-09, 'Perisytiharan Fasal melucutkan'),\n",
+ " (2.1990164283127147e-09, 'Pimpinan Tertinggi (MPT)'),\n",
+ " (2.1981574699825158e-09, 'Majlis Pimpinan (MPT)'),\n",
+ " (2.1980610020130363e-09, 'Perisytiharan Fasal BERSATU'),\n",
+ " (2.1973393621296214e-09, 'Perisytiharan Perlembagaan')]"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.textrank(string, vectorizer = albert, auto_ngram = False,\n",
+ " ngram = (1, 3), ngram_method = 'skipgram', skip = 3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load Attention mechanism\n",
+ "\n",
+ "Use attention mechanism to get important keywords."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### auto-ngram\n",
+ "\n",
+ "This will auto generated N-size ngram for keyword candidates."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.9452064568002397, 'menghalang pelupusan pertikaian'),\n",
+ " (0.007486688404188947, 'Fasal Perlembagaan BERSATU'),\n",
+ " (0.005130747276971111, 'ahli BERSATU'),\n",
+ " (0.005036595631722718, 'melucutkan memberhentikan keahlian'),\n",
+ " (0.004883706288857347, 'BERSATU')]"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.attention(string, model = electra)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.16196368022187793, 'plaintif memohon perisytiharan'),\n",
+ " (0.09294065744319371, 'memohon perisytiharan'),\n",
+ " (0.06902302277868422, 'plaintif'),\n",
+ " (0.05584840295920779, 'ditandatangani Muhammad Suhaimi bertarikh Mei'),\n",
+ " (0.05206225590337424, 'dipohon perisytiharan')]"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.attention(string, model = albert)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### fixed-ngram"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.15667043125587973, 'pelupusan pertikaian mahkamah'),\n",
+ " (0.15665311872357476, 'pertikaian mahkamah Perisytiharan'),\n",
+ " (0.15657934237804905, 'pertikaian mahkamah'),\n",
+ " (0.1563242367855659, 'menghalang pelupusan pertikaian'),\n",
+ " (0.1562270516451705, 'pelupusan pertikaian')]"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.attention(string, model = electra, auto_ngram = False,\n",
+ " ngram = (1, 3), ngram_method = 'bow')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(0.031264380566934015, 'saman plaintif memohon'),\n",
+ " (0.02621530292963218, 'plaintif memohon perisytiharan'),\n",
+ " (0.02573609954868083, 'Dalam saman plaintif'),\n",
+ " (0.022935623722179672, 'plaintif memohon'),\n",
+ " (0.019724791761830188, 'Mereka memohon perisytiharan')]"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.keyword_extraction.attention(string, model = albert, auto_ngram = False,\n",
+ " ngram = (1, 3), ngram_method = 'bow')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/load-similarity.ipynb b/load-similarity.ipynb
new file mode 100644
index 00000000..0e6a9587
--- /dev/null
+++ b/load-similarity.ipynb
@@ -0,0 +1,872 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Text Similarity"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "This tutorial is available as an IPython notebook at [Malaya/example/similarity](https://github.com/huseinzol05/Malaya/tree/master/example/similarity).\n",
+ " \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "This module trained on both standard and local (included social media) language structures, so it is save to use for both.\n",
+ " \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 5.55 s, sys: 1.09 s, total: 6.64 s\n",
+ "Wall time: 7.7 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "import malaya"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "string1 = 'Pemuda mogok lapar desak kerajaan prihatin isu iklim'\n",
+ "string2 = 'Perbincangan isu pembalakan perlu babit kerajaan negeri'\n",
+ "string3 = 'kerajaan perlu kisah isu iklim, pemuda mogok lapar'\n",
+ "string4 = 'Kerajaan dicadang tubuh jawatankuasa khas tangani isu alam sekitar'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "news1 = 'Tun Dr Mahathir Mohamad mengakui pembubaran Parlimen bagi membolehkan pilihan raya diadakan tidak sesuai dilaksanakan pada masa ini berikutan isu COVID-19'\n",
+ "tweet1 = 'DrM sembang pilihan raya tak boleh buat sebab COVID 19'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Calculate similarity using doc2vec\n",
+ "\n",
+ "We can use any word vector interface provided by Malaya to use doc2vec similarity interface.\n",
+ "\n",
+ "Important parameters,\n",
+ "1. `aggregation`, aggregation function to accumulate word vectors. Default is `mean`.\n",
+ "\n",
+ " * ``'mean'`` - mean.\n",
+ " * ``'min'`` - min.\n",
+ " * ``'max'`` - max.\n",
+ " * ``'sum'`` - sum.\n",
+ " * ``'sqrt'`` - square root.\n",
+ " \n",
+ "2. `similarity` distance function to calculate similarity. Default is `cosine`.\n",
+ "\n",
+ " * ``'cosine'`` - cosine similarity.\n",
+ " * ``'euclidean'`` - euclidean similarity.\n",
+ " * ``'manhattan'`` - manhattan similarity."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Using word2vec\n",
+ "\n",
+ "I will use `load_news`, word2vec from wikipedia took a very long time. wikipedia much more accurate."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "downloading frozen /Users/huseinzolkepli/Malaya/wordvector/news vocab\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4.00MB [00:01, 2.03MB/s] \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "downloading frozen /Users/huseinzolkepli/Malaya/wordvector/news model\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "191MB [01:01, 3.13MB/s] \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/wordvector.py:114: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/wordvector.py:125: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab_news, embedded_news = malaya.wordvector.load_news()\n",
+ "w2v = malaya.wordvector.load(embedded_news, vocab_news)\n",
+ "doc2vec = malaya.similarity.doc2vec(w2v)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### predict for 2 strings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0.899711], dtype=float32)"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "doc2vec.predict_proba([string1], [string2], aggregation = 'mean', soft = False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### predict batch of strings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0.9215344, 0.853461 ], dtype=float32)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "doc2vec.predict_proba([string1, string2], [string3, string4])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### visualize heatmap"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "doc2vec.heatmap([string1, string2, string3, string4])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Different similarity function different percentage."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Calculate similarity using deep encoder\n",
+ "\n",
+ "We can use any encoder models provided by Malaya to use encoder similarity interface, example, BERT, XLNET, and skip-thought. Again, these encoder models not trained to do similarity classification, it just encode the strings into vector representation.\n",
+ "\n",
+ "Important parameters,\n",
+ " \n",
+ "1. `similarity` distance function to calculate similarity. Default is `cosine`.\n",
+ "\n",
+ " * ``'cosine'`` - cosine similarity.\n",
+ " * ``'euclidean'`` - euclidean similarity.\n",
+ " * ``'manhattan'`` - manhattan similarity."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### using xlnet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:memory input None\n",
+ "INFO:tensorflow:Use float type \n",
+ "INFO:tensorflow:Restoring parameters from /Users/huseinzolkepli/Malaya/xlnet-model/base/xlnet-base/model.ckpt\n"
+ ]
+ }
+ ],
+ "source": [
+ "xlnet = malaya.transformer.load(model = 'xlnet')\n",
+ "encoder = malaya.similarity.encoder(xlnet)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### predict for 2 strings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0.8212017], dtype=float32)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "encoder.predict_proba([string1], [string2])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### predict batch of strings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0.8097714 , 0.78071797, 0.8244793 , 0.5807183 ], dtype=float32)"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "encoder.predict_proba([string1, string2, news1, news1], [string3, string4, husein, string1])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### visualize heatmap"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "encoder.heatmap([string1, string2, string3, string4])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### List available Transformer models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:root:tested on 20% test set.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Size (MB) \n",
+ " Quantized Size (MB) \n",
+ " Accuracy \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " bert \n",
+ " 423.4 \n",
+ " 111.0 \n",
+ " 0.885 \n",
+ " \n",
+ " \n",
+ " tiny-bert \n",
+ " 56.6 \n",
+ " 15.0 \n",
+ " 0.873 \n",
+ " \n",
+ " \n",
+ " albert \n",
+ " 48.3 \n",
+ " 12.8 \n",
+ " 0.873 \n",
+ " \n",
+ " \n",
+ " tiny-albert \n",
+ " 21.9 \n",
+ " 6.0 \n",
+ " 0.824 \n",
+ " \n",
+ " \n",
+ " xlnet \n",
+ " 448.7 \n",
+ " 119.0 \n",
+ " 0.784 \n",
+ " \n",
+ " \n",
+ " alxlnet \n",
+ " 49.0 \n",
+ " 13.9 \n",
+ " 0.888 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Size (MB) Quantized Size (MB) Accuracy\n",
+ "bert 423.4 111.0 0.885\n",
+ "tiny-bert 56.6 15.0 0.873\n",
+ "albert 48.3 12.8 0.873\n",
+ "tiny-albert 21.9 6.0 0.824\n",
+ "xlnet 448.7 119.0 0.784\n",
+ "alxlnet 49.0 13.9 0.888"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.similarity.available_transformer()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We trained on [Quora Question Pairs](https://github.com/huseinzol05/Malay-Dataset#quora), [translated SNLI](https://github.com/huseinzol05/Malay-Dataset#snli) and [translated MNLI](https://github.com/huseinzol05/Malay-Dataset#mnli)\n",
+ "\n",
+ "Make sure you can check accuracy chart from here first before select a model, https://malaya.readthedocs.io/en/latest/Accuracy.html#similarity\n",
+ "\n",
+ "**You might want to use ALXLNET, a very small size, 49MB, but the accuracy is still on the top notch.**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load transformer model\n",
+ "\n",
+ "In this example, I am going to load `alxlnet`, feel free to use any available models above."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = malaya.similarity.transformer(model = 'alxlnet')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load Quantized model\n",
+ "\n",
+ "To load 8-bit quantized model, simply pass `quantized = True`, default is `False`.\n",
+ "\n",
+ "We can expect slightly accuracy drop from quantized model, and not necessary faster than normal 32-bit float model, totally depends on machine."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:root:Load quantized model will cause accuracy drop.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:74: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:74: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:76: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:76: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:69: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:69: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "quantized_model = malaya.similarity.transformer(model = 'alxlnet', quantized = True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### predict batch\n",
+ "\n",
+ "```python\n",
+ "def predict_proba(self, strings_left: List[str], strings_right: List[str]):\n",
+ " \"\"\"\n",
+ " calculate similarity for two different batch of texts.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " string_left : List[str]\n",
+ " string_right : List[str]\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " result : List[float]\n",
+ " \"\"\"\n",
+ "```\n",
+ "\n",
+ "you need to give list of left strings, and list of right strings.\n",
+ "\n",
+ "first left string will compare will first right string and so on.\n",
+ "\n",
+ "similarity model only supported `predict_proba`.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0.9986665 , 0.04221377, 0.7916767 , 0.98151684], dtype=float32)"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.predict_proba([string1, string2, news1, news1], [string3, string4, tweet1, string1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0.99855036, 0.06619915, 0.29902616, 0.98125756], dtype=float32)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "quantized_model.predict_proba([string1, string2, news1, news1], [string3, string4, tweet1, string1])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### visualize heatmap"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.heatmap([string1, string2, string3, string4])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Vectorize\n",
+ "\n",
+ "Let say you want to visualize sentences in lower dimension, you can use `model.vectorize`,\n",
+ "\n",
+ "```python\n",
+ "def vectorize(self, strings: List[str]):\n",
+ " \"\"\"\n",
+ " Vectorize list of strings.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " strings : List[str]\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " result: np.array\n",
+ " \"\"\"\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = [string1, string2, string3, string4, news1, tweet1]\n",
+ "r = quantized_model.vectorize(texts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(6, 2)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.manifold import TSNE\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "tsne = TSNE().fit_transform(r)\n",
+ "tsne.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.figure(figsize = (7, 7))\n",
+ "plt.scatter(tsne[:, 0], tsne[:, 1])\n",
+ "labels = texts\n",
+ "for label, x, y in zip(\n",
+ " labels, tsne[:, 0], tsne[:, 1]\n",
+ "):\n",
+ " label = (\n",
+ " '%s, %.3f' % (label[0], label[1])\n",
+ " if isinstance(label, list)\n",
+ " else label\n",
+ " )\n",
+ " plt.annotate(\n",
+ " label,\n",
+ " xy = (x, y),\n",
+ " xytext = (0, 0),\n",
+ " textcoords = 'offset points',\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Stacking models\n",
+ "\n",
+ "More information, you can read at https://malaya.readthedocs.io/en/latest/Stack.html\n",
+ "\n",
+ "If you want to stack zero-shot classification models, you need to pass labels using keyword parameter,\n",
+ "\n",
+ "```python\n",
+ "malaya.stack.predict_stack([model1, model2], List[str], strings_right = List[str])\n",
+ "```\n",
+ "\n",
+ "We will passed `strings_right` as `**kwargs`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:54: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:55: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:49: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
+ "\n",
+ "INFO:tensorflow:loading sentence piece model\n"
+ ]
+ }
+ ],
+ "source": [
+ "alxlnet = malaya.similarity.transformer(model = 'alxlnet')\n",
+ "albert = malaya.similarity.transformer(model = 'albert')\n",
+ "tiny_bert = malaya.similarity.transformer(model = 'tiny-bert')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0.99745977, 0.07261255, 0.16457608, 0.03985301], dtype=float32)"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.stack.predict_stack([alxlnet, albert, tiny_bert], [string1, string2, news1, news1], \n",
+ " strings_right = [string3, string4, tweet1, string1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/load-zeroshot-classification.ipynb b/load-zeroshot-classification.ipynb
new file mode 100644
index 00000000..36e1c444
--- /dev/null
+++ b/load-zeroshot-classification.ipynb
@@ -0,0 +1,895 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Classification"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "This tutorial is available as an IPython notebook at [Malaya/example/zeroshot-classification](https://github.com/huseinzol05/Malaya/tree/master/example/zeroshot-classification).\n",
+ " \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "This module trained on both standard and local (included social media) language structures, so it is save to use for both.\n",
+ " \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 5.61 s, sys: 1.08 s, total: 6.69 s\n",
+ "Wall time: 6.75 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "import malaya"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### what is zero-shot classification\n",
+ "\n",
+ "Commonly we supervised a machine learning on specific labels, negative / positive for sentiment, anger / happy / sadness for emotion and etc. The model cannot give an output if we want to know how much percentage of 'jealous' in emotion analysis model because supported labels are only {anger, happy, sadness}. Imagine, for example, trying to identify a text without ever having seen one 'jealous' label before, impossible. **So, zero-shot trying to solve this problem.**\n",
+ "\n",
+ "zero-shot learning refers to the process by which a machine learns how to recognize objects (image, text, any features) without any labeled training data to help in the classification.\n",
+ "\n",
+ "[Yin et al. (2019)](https://arxiv.org/abs/1909.00161) stated in his paper, any pretrained language model finetuned on text similarity actually can acted as an out-of-the-box zero-shot text classifier."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "So, we are going to use transformer models from `malaya.similarity.transformer` with a little tweaks."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### List available Transformer models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Size (MB) \n",
+ " Quantized Size (MB) \n",
+ " Accuracy \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " bert \n",
+ " 423.4 \n",
+ " 111.0 \n",
+ " 0.885 \n",
+ " \n",
+ " \n",
+ " tiny-bert \n",
+ " 56.6 \n",
+ " 15.0 \n",
+ " 0.873 \n",
+ " \n",
+ " \n",
+ " albert \n",
+ " 48.3 \n",
+ " 12.8 \n",
+ " 0.873 \n",
+ " \n",
+ " \n",
+ " tiny-albert \n",
+ " 21.9 \n",
+ " 6.0 \n",
+ " 0.824 \n",
+ " \n",
+ " \n",
+ " xlnet \n",
+ " 448.7 \n",
+ " 119.0 \n",
+ " 0.784 \n",
+ " \n",
+ " \n",
+ " alxlnet \n",
+ " 49.0 \n",
+ " 13.9 \n",
+ " 0.888 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Size (MB) Quantized Size (MB) Accuracy\n",
+ "bert 423.4 111.0 0.885\n",
+ "tiny-bert 56.6 15.0 0.873\n",
+ "albert 48.3 12.8 0.873\n",
+ "tiny-albert 21.9 6.0 0.824\n",
+ "xlnet 448.7 119.0 0.784\n",
+ "alxlnet 49.0 13.9 0.888"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "malaya.zero_shot.classification.available_transformer()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We trained on [Quora Question Pairs](https://github.com/huseinzol05/Malay-Dataset#quora), [translated SNLI](https://github.com/huseinzol05/Malay-Dataset#snli) and [translated MNLI](https://github.com/huseinzol05/Malay-Dataset#mnli)\n",
+ "\n",
+ "Make sure you can check accuracy chart from here first before select a model, https://malaya.readthedocs.io/en/latest/Accuracy.html#similarity\n",
+ "\n",
+ "**You might want to use ALXLNET, a very small size, 49MB, but the accuracy is still on the top notch.**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load transformer model\n",
+ "\n",
+ "In this example, I am going to load `alxlnet`, feel free to use any available models above."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:54: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:55: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:49: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = malaya.zero_shot.classification.transformer(model = 'alxlnet')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load Quantized model\n",
+ "\n",
+ "To load 8-bit quantized model, simply pass `quantized = True`, default is `False`.\n",
+ "\n",
+ "We can expect slightly accuracy drop from quantized model, and not necessary faster than normal 32-bit float model, totally depends on machine."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:root:Load quantized model will cause accuracy drop.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:74: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:74: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:76: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:76: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:69: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/malaya/function/__init__.py:69: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "quantized_model = malaya.zero_shot.classification.transformer(model = 'alxlnet', quantized = True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### predict batch\n",
+ "\n",
+ "```python\n",
+ "def predict_proba(self, strings: List[str], labels: List[str]):\n",
+ " \"\"\"\n",
+ " classify list of strings and return probability.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " strings : List[str]\n",
+ " labels : List[str]\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " list: list of float\n",
+ " \"\"\"\n",
+ "```\n",
+ "\n",
+ "Because it is a zero-shot, we need to give labels for the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# copy from twitter\n",
+ "\n",
+ "string = 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'najib razak': 0.011697772,\n",
+ " 'mahathir': 0.030579083,\n",
+ " 'kerajaan': 0.038274202,\n",
+ " 'PRU': 0.74709743,\n",
+ " 'anarki': 0.054001417}]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.predict_proba([string], labels = ['najib razak', 'mahathir', 'kerajaan', 'PRU', 'anarki'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'najib razak': 0.020772826,\n",
+ " 'mahathir': 0.03612631,\n",
+ " 'kerajaan': 0.091763854,\n",
+ " 'PRU': 0.34365898,\n",
+ " 'anarki': 0.007840766}]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "quantized_model.predict_proba([string], labels = ['najib razak', 'mahathir', 'kerajaan', 'PRU', 'anarki'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Quite good."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "string = 'tolong order foodpanda jab, lapar'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'makan': 0.4262973,\n",
+ " 'makanan': 0.94525576,\n",
+ " 'novel': 0.0016873145,\n",
+ " 'buku': 0.00282516,\n",
+ " 'kerajaan': 0.0013985565,\n",
+ " 'food delivery': 0.9190869}]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "the model understood `order foodpanda` got close relationship with `makan`, `makanan` and `food delivery`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'makan': 0.0010322841,\n",
+ " 'makanan': 0.0059771817,\n",
+ " 'novel': 0.0068290858,\n",
+ " 'buku': 0.00083946186,\n",
+ " 'kerajaan': 0.9823078,\n",
+ " 'food delivery': 0.017137317,\n",
+ " 'kerajaan jahat': 0.4863779,\n",
+ " 'kerajaan prihatin': 0.96803045,\n",
+ " 'bantuan rakyat': 0.94919217}]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n",
+ " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Vectorize\n",
+ "\n",
+ "Let say you want to visualize sentence / word level in lower dimension, you can use `model.vectorize`,\n",
+ "\n",
+ "```python\n",
+ "def vectorize(\n",
+ " self, strings: List[str], labels: List[str], method: str = 'first'\n",
+ "):\n",
+ " \"\"\"\n",
+ " vectorize a string.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " strings: List[str]\n",
+ " labels : List[str]\n",
+ " method : str, optional (default='first')\n",
+ " Vectorization layer supported. Allowed values:\n",
+ "\n",
+ " * ``'last'`` - vector from last sequence.\n",
+ " * ``'first'`` - vector from first sequence.\n",
+ " * ``'mean'`` - average vectors from all sequences.\n",
+ " * ``'word'`` - average vectors based on tokens.\n",
+ "\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " result: np.array\n",
+ " \"\"\"\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Sentence level"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = ['kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan',\n",
+ " 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai',\n",
+ " 'tolong order foodpanda jab, lapar',\n",
+ " 'Hapuskan vernacular school first, only then we can talk about UiTM']\n",
+ "labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n",
+ " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat']\n",
+ "r = quantized_model.vectorize(texts, labels, method = 'first')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`vectorize` method from zeroshot classification model will returned 2 values, (combined, vector)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[('kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan',\n",
+ " 'makan'),\n",
+ " ('kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan',\n",
+ " 'makanan'),\n",
+ " ('kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan',\n",
+ " 'novel'),\n",
+ " ('kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan',\n",
+ " 'buku'),\n",
+ " ('kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan',\n",
+ " 'kerajaan')]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "r[0][:5]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[-0.00587193, -0.7214614 , -0.7524409 , ..., 0.31107777,\n",
+ " 1.022762 , 0.28308758],\n",
+ " [ 0.63863456, 0.12698255, 0.67567766, ..., 0.7627216 ,\n",
+ " 0.56795114, -0.37056473],\n",
+ " [-0.90291303, 0.93581504, 0.05650915, ..., 0.5578094 ,\n",
+ " 1.1304276 , 0.5470246 ],\n",
+ " ...,\n",
+ " [-2.1161728 , -1.4592253 , 0.5284856 , ..., 0.28636536,\n",
+ " -0.36558965, -0.8226106 ],\n",
+ " [-2.2050292 , -0.14624506, 0.19812807, ..., 0.1307496 ,\n",
+ " -0.20792441, 0.18430969],\n",
+ " [-2.5969799 , 0.4205628 , 0.18376699, ..., 0.124988 ,\n",
+ " -0.9915105 , -0.10085672]], dtype=float32)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "r[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(36, 2)"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.manifold import TSNE\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "tsne = TSNE().fit_transform(r[1])\n",
+ "tsne.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "unique_labels = list(set([i[1] for i in r[0]]))\n",
+ "palette = plt.cm.get_cmap('hsv', len(unique_labels))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.figure(figsize = (7, 7))\n",
+ "\n",
+ "for label in unique_labels:\n",
+ " indices = [i for i in range(len(r[0])) if r[0][i][1] == label]\n",
+ " plt.scatter(tsne[indices, 0], tsne[indices, 1], cmap = palette(unique_labels.index(label)),\n",
+ " label = label)\n",
+ " \n",
+ "labels = [i[0] for i in r[0]]\n",
+ "for label, x, y in zip(\n",
+ " labels, tsne[:, 0], tsne[:, 1]\n",
+ "):\n",
+ " label = (\n",
+ " '%s, %.3f' % (label[0], label[1])\n",
+ " if isinstance(label, list)\n",
+ " else label\n",
+ " )\n",
+ " plt.annotate(\n",
+ " label,\n",
+ " xy = (x, y),\n",
+ " xytext = (0, 0),\n",
+ " textcoords = 'offset points',\n",
+ " )\n",
+ "plt.legend()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Word level"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = ['kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan',\n",
+ " 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai',\n",
+ " 'tolong order foodpanda jab, lapar',\n",
+ " 'Hapuskan vernacular school first, only then we can talk about UiTM']\n",
+ "labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n",
+ " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat']\n",
+ "r = quantized_model.vectorize(texts, labels, method = 'word')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x, y, labels = [], [], []\n",
+ "for no, row in enumerate(r[1]):\n",
+ " x.extend([i[0] for i in row])\n",
+ " y.extend([i[1] for i in row])\n",
+ " labels.extend([r[0][no][1]] * len(row))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(315, 2)"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tsne = TSNE().fit_transform(y)\n",
+ "tsne.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "unique_labels = list(set(labels))\n",
+ "palette = plt.cm.get_cmap('hsv', len(unique_labels))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.figure(figsize = (7, 7))\n",
+ "\n",
+ "for label in unique_labels:\n",
+ " indices = [i for i in range(len(labels)) if labels[i] == label]\n",
+ " plt.scatter(tsne[indices, 0], tsne[indices, 1], cmap = palette(unique_labels.index(label)),\n",
+ " label = label)\n",
+ " \n",
+ "labels = x\n",
+ "for label, x, y in zip(\n",
+ " labels, tsne[:, 0], tsne[:, 1]\n",
+ "):\n",
+ " label = (\n",
+ " '%s, %.3f' % (label[0], label[1])\n",
+ " if isinstance(label, list)\n",
+ " else label\n",
+ " )\n",
+ " plt.annotate(\n",
+ " label,\n",
+ " xy = (x, y),\n",
+ " xytext = (0, 0),\n",
+ " textcoords = 'offset points',\n",
+ " )\n",
+ "plt.legend()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Stacking models\n",
+ "\n",
+ "More information, you can read at https://malaya.readthedocs.io/en/latest/Stack.html\n",
+ "\n",
+ "If you want to stack zero-shot classification models, you need to pass labels using keyword parameter,\n",
+ "\n",
+ "```python\n",
+ "malaya.stack.predict_stack([model1, model2], List[str], labels = List[str])\n",
+ "```\n",
+ "\n",
+ "We will passed `labels` as `**kwargs`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
+ "\n",
+ "INFO:tensorflow:loading sentence piece model\n"
+ ]
+ }
+ ],
+ "source": [
+ "alxlnet = malaya.zero_shot.classification.transformer(model = 'alxlnet')\n",
+ "albert = malaya.zero_shot.classification.transformer(model = 'albert')\n",
+ "tiny_bert = malaya.zero_shot.classification.transformer(model = 'tiny-bert')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'makan': 0.0044827852,\n",
+ " 'makanan': 0.0027062024,\n",
+ " 'novel': 0.0020867025,\n",
+ " 'buku': 0.013082165,\n",
+ " 'kerajaan': 0.8859287,\n",
+ " 'food delivery': 0.0028363755,\n",
+ " 'kerajaan jahat': 0.018133936,\n",
+ " 'kerajaan prihatin': 0.9922408,\n",
+ " 'bantuan rakyat': 0.909674}]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n",
+ "labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery', \n",
+ " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat']\n",
+ "malaya.stack.predict_stack([alxlnet, albert, tiny_bert], [string], \n",
+ " labels = labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/malaya/__init__.py b/malaya/__init__.py
index 87cc37dc..2a8eaee3 100644
--- a/malaya/__init__.py
+++ b/malaya/__init__.py
@@ -13,8 +13,8 @@
home = os.path.join(str(Path.home()), 'Malaya')
-version = '3.9'
-bump_version = '3.9.2'
+version = '4.0'
+bump_version = '4.0'
version_path = os.path.join(home, 'version')
__version__ = bump_version
path = os.path.dirname(__file__)
diff --git a/malaya/model/bert.py b/malaya/model/bert.py
index f325a209..5dd7740f 100644
--- a/malaya/model/bert.py
+++ b/malaya/model/bert.py
@@ -729,16 +729,16 @@ def __init__(
segment_ids = segment_ids,
input_masks = input_masks,
logits = logits,
+ vectorizer = vectorizer,
sess = sess,
tokenizer = tokenizer,
label = label,
)
- self._vectorizer = vectorizer
self._softmax = tf.nn.softmax(self._logits)
self._batch_size = 20
def _base(self, strings_left, strings_right):
- input_ids, input_masks, segment_ids = bert_tokenization_siamese(
+ input_ids, input_masks, segment_ids, _ = bert_tokenization_siamese(
self._tokenizer, strings_left, strings_right
)
@@ -1066,6 +1066,7 @@ def __init__(
segment_ids,
input_masks,
logits,
+ vectorizer,
sess,
tokenizer,
label = ['not similar', 'similar'],
@@ -1076,6 +1077,7 @@ def __init__(
segment_ids = segment_ids,
input_masks = input_masks,
logits = logits,
+ vectorizer = vectorizer,
sess = sess,
tokenizer = tokenizer,
label = label,
@@ -1092,7 +1094,7 @@ def _base(self, strings, labels):
mapping[no].append(index)
index += 1
- input_ids, input_masks, segment_ids = bert_tokenization_siamese(
+ input_ids, input_masks, segment_ids, _ = bert_tokenization_siamese(
self._tokenizer, strings_left, strings_right
)
@@ -1113,6 +1115,69 @@ def _base(self, strings, labels):
results.append(result)
return results
+ @check_type
+ def vectorize(
+ self, strings: List[str], labels: List[str], method: str = 'first'
+ ):
+ """
+ vectorize a string.
+
+ Parameters
+ ----------
+ strings: List[str]
+ labels : List[str]
+ method : str, optional (default='first')
+ Vectorization layer supported. Allowed values:
+
+ * ``'last'`` - vector from last sequence.
+ * ``'first'`` - vector from first sequence.
+ * ``'mean'`` - average vectors from all sequences.
+ * ``'word'`` - average vectors based on tokens.
+
+
+ Returns
+ -------
+ result: np.array
+ """
+ strings_left, strings_right, combined = [], [], []
+ for no, string in enumerate(strings):
+ for label in labels:
+ strings_left.append(string)
+ strings_right.append(f'teks ini adalah mengenai {label}')
+ combined.append((string, label))
+
+ input_ids, input_masks, segment_ids, s_tokens = bert_tokenization_siamese(
+ self._tokenizer, strings_left, strings_right
+ )
+
+ v = self._sess.run(
+ self._vectorizer,
+ feed_dict = {
+ self._X: input_ids,
+ self._segment_ids: segment_ids,
+ self._input_masks: input_masks,
+ },
+ )
+ if len(v.shape) == 2:
+ v = v.reshape((*np.array(input_ids).shape, -1))
+
+ if method == 'first':
+ v = v[:, 0]
+ elif method == 'last':
+ v = v[:, -1]
+ elif method == 'mean':
+ v = np.mean(v, axis = 1)
+ else:
+ v = [
+ merge_sentencepiece_tokens(
+ list(zip(s_tokens[i], v[i][: len(s_tokens[i])])),
+ weighted = False,
+ vectorize = True,
+ )
+ for i in range(len(v))
+ ]
+ return combined, v
+
@check_type
def predict_proba(self, strings: List[str], labels: List[str]):
"""
diff --git a/malaya/model/xlnet.py b/malaya/model/xlnet.py
index d7762d08..1a443787 100644
--- a/malaya/model/xlnet.py
+++ b/malaya/model/xlnet.py
@@ -752,24 +752,22 @@ def __init__(
X = X,
segment_ids = segment_ids,
input_masks = input_masks,
+ vectorizer = vectorizer,
logits = logits,
sess = sess,
tokenizer = tokenizer,
label = label,
)
- self._vectorizer = vectorizer
self._softmax = tf.nn.softmax(self._logits)
self._batch_size = 20
def _base(self, strings_left, strings_right):
- input_ids, input_masks, segment_ids = xlnet_tokenization_siamese(
+ input_ids, input_masks, segment_ids, _ = xlnet_tokenization_siamese(
self._tokenizer, strings_left, strings_right
)
- segment_ids = np.array(segment_ids)
- batch_segment[batch_segment == 0] = 1
return self._sess.run(
- self._vectorizer,
+ self._softmax,
feed_dict = {
self._X: input_ids,
self._segment_ids: segment_ids,
@@ -793,6 +791,16 @@ def vectorize(self, strings: List[str]):
input_ids, input_masks, segment_ids, _ = xlnet_tokenization(
self._tokenizer, strings
)
+ segment_ids = np.array(segment_ids)
+ segment_ids[segment_ids == 0] = 1
+ return self._sess.run(
+ self._vectorizer,
+ feed_dict = {
+ self._X: input_ids,
+ self._segment_ids: segment_ids,
+ self._input_masks: input_masks,
+ },
+ )
@check_type
def predict_proba(self, strings_left: List[str], strings_right: List[str]):
@@ -1111,6 +1119,7 @@ def __init__(
segment_ids,
input_masks,
logits,
+ vectorizer,
sess,
tokenizer,
label = ['not similar', 'similar'],
@@ -1121,6 +1130,7 @@ def __init__(
segment_ids = segment_ids,
input_masks = input_masks,
logits = logits,
+ vectorizer = vectorizer,
sess = sess,
tokenizer = tokenizer,
label = label,
@@ -1138,7 +1148,7 @@ def _base(self, strings, labels):
mapping[no].append(index)
index += 1
- input_ids, input_masks, segment_ids = xlnet_tokenization_siamese(
+ input_ids, input_masks, segment_ids, _ = xlnet_tokenization_siamese(
self._tokenizer, strings_left, strings_right
)
@@ -1159,6 +1169,70 @@ def _base(self, strings, labels):
results.append(result)
return results
+ @check_type
+ def vectorize(
+ self, strings: List[str], labels: List[str], method: str = 'first'
+ ):
+ """
+ vectorize a string.
+
+ Parameters
+ ----------
+ strings: List[str]
+ labels : List[str]
+ method : str, optional (default='first')
+ Vectorization layer supported. Allowed values:
+
+ * ``'last'`` - vector from last sequence.
+ * ``'first'`` - vector from first sequence.
+ * ``'mean'`` - average vectors from all sequences.
+ * ``'word'`` - average vectors based on tokens.
+
+
+ Returns
+ -------
+ result: np.array
+ """
+
+ strings_left, strings_right, combined = [], [], []
+ for no, string in enumerate(strings):
+ for label in labels:
+ strings_left.append(string)
+ strings_right.append(f'teks ini adalah mengenai {label}')
+ combined.append((string, label))
+
+ input_ids, input_masks, segment_ids, s_tokens = xlnet_tokenization_siamese(
+ self._tokenizer, strings_left, strings_right
+ )
+
+ v = self._sess.run(
+ self._vectorizer,
+ feed_dict = {
+ self._X: input_ids,
+ self._segment_ids: segment_ids,
+ self._input_masks: input_masks,
+ },
+ )
+ v = np.transpose(v, [1, 0, 2])
+
+ if method == 'first':
+ v = v[:, 0]
+ elif method == 'last':
+ v = v[:, -1]
+ elif method == 'mean':
+ v = np.mean(v, axis = 1)
+ else:
+ v = [
+ merge_sentencepiece_tokens(
+ list(zip(s_tokens[i], v[i][: len(s_tokens[i])])),
+ weighted = False,
+ vectorize = True,
+ model = 'xlnet',
+ )
+ for i in range(len(v))
+ ]
+ return combined, v
+
@check_type
def predict_proba(self, strings: List[str], labels: List[str]):
"""
diff --git a/malaya/path/__init__.py b/malaya/path/__init__.py
index a3fc7016..7de4c010 100644
--- a/malaya/path/__init__.py
+++ b/malaya/path/__init__.py
@@ -274,7 +274,7 @@
},
'alxlnet': {
'model': 'v34/emotion/alxlnet-base-emotion.pb',
- 'quantized': 'v34/emotion/alxlnet-base-emotion.pb.quantized',
+ 'quantized': 'v40/emotion/alxlnet-base-emotion.pb.quantized',
'vocab': 'tokenizer/sp10m.cased.v9.vocab',
'tokenizer': 'tokenizer/sp10m.cased.v9.model',
},
@@ -857,36 +857,42 @@
PATH_SIMILARITY = {
'bert': {
'model': home + '/similarity/bert/base/model.pb',
+ 'quantized': home + '/similarity/bert/base/quantized/model.pb',
'vocab': home + '/bert/sp10m.cased.bert.vocab',
'tokenizer': home + '/bert/sp10m.cased.bert.model',
'version': 'v36',
},
'tiny-bert': {
'model': home + '/similarity/bert/tiny/model.pb',
+ 'quantized': home + '/similarity/bert/tiny/quantized/model.pb',
'vocab': home + '/bert/sp10m.cased.bert.vocab',
'tokenizer': home + '/bert/sp10m.cased.bert.model',
'version': 'v36',
},
'albert': {
'model': home + '/similarity/albert/base/model.pb',
+ 'quantized': home + '/similarity/albert/base/quantized/model.pb',
'vocab': home + '/albert/sp10m.cased.v10.vocab',
'tokenizer': home + '/albert/sp10m.cased.v10.model',
'version': 'v36',
},
'tiny-albert': {
'model': home + '/similarity/albert/tiny/model.pb',
+ 'quantized': home + '/similarity/albert/tiny/quantized/model.pb',
'vocab': home + '/bert/sp10m.cased.bert.vocab',
'tokenizer': home + '/bert/sp10m.cased.bert.model',
'version': 'v36',
},
'xlnet': {
'model': home + '/similarity/xlnet/base/model.pb',
+ 'quantized': home + '/similarity/xlnet/base/quantized/model.pb',
'vocab': home + '/xlnet/sp10m.cased.v9.vocab',
'tokenizer': home + '/xlnet/sp10m.cased.v9.model',
'version': 'v36',
},
'alxlnet': {
'model': home + '/similarity/alxlnet/base/model.pb',
+ 'quantized': home + '/similarity/alxlnet/base/quantized/model.pb',
'vocab': home + '/xlnet/sp10m.cased.v9.vocab',
'tokenizer': home + '/xlnet/sp10m.cased.v9.model',
'version': 'v36',
@@ -896,31 +902,37 @@
S3_PATH_SIMILARITY = {
'bert': {
'model': 'v36/similarity/bert-base-similarity.pb',
+ 'quantized': 'v40/similarity/bert-base-similarity.pb.quantized',
'vocab': 'tokenizer/sp10m.cased.bert.vocab',
'tokenizer': 'tokenizer/sp10m.cased.bert.model',
},
'tiny-bert': {
'model': 'v36/similarity/tiny-bert-similarity.pb',
+ 'quantized': 'v40/similarity/tiny-bert-similarity.pb.quantized',
'vocab': 'tokenizer/sp10m.cased.bert.vocab',
'tokenizer': 'tokenizer/sp10m.cased.bert.model',
},
'albert': {
'model': 'v36/similarity/albert-base-similarity.pb',
+ 'quantized': 'v40/similarity/albert-base-similarity.pb.quantized',
'vocab': 'tokenizer/sp10m.cased.v10.vocab',
'tokenizer': 'tokenizer/sp10m.cased.v10.model',
},
'tiny-albert': {
'model': 'v36/similarity/albert-tiny-similarity.pb',
+ 'quantized': 'v40/similarity/albert-tiny-similarity.pb.quantized',
'vocab': 'tokenizer/sp10m.cased.v10.vocab',
'tokenizer': 'tokenizer/sp10m.cased.v10.model',
},
'xlnet': {
'model': 'v36/similarity/xlnet-base-similarity.pb',
+ 'quantized': 'v40/similarity/xlnet-base-similarity.pb.quantized',
'vocab': 'tokenizer/sp10m.cased.v9.vocab',
'tokenizer': 'tokenizer/sp10m.cased.v9.model',
},
'alxlnet': {
'model': 'v36/similarity/alxlnet-base-similarity.pb',
+ 'quantized': 'v40/similarity/alxlnet-base-similarity.pb.quantized',
'vocab': 'tokenizer/sp10m.cased.v9.vocab',
'tokenizer': 'tokenizer/sp10m.cased.v9.model',
},
diff --git a/malaya/relevancy.py b/malaya/relevancy.py
index 2bb710a9..245aa01e 100644
--- a/malaya/relevancy.py
+++ b/malaya/relevancy.py
@@ -5,12 +5,32 @@
label = ['not relevant', 'relevant']
_transformer_availability = {
- 'bert': {'Size (MB)': 425.6, 'Accuracy': 0.872},
- 'tiny-bert': {'Size (MB)': 57.4, 'Accuracy': 0.656},
- 'albert': {'Size (MB)': 48.6, 'Accuracy': 0.871},
- 'tiny-albert': {'Size (MB)': 22.4, 'Accuracy': 0.843},
- 'xlnet': {'Size (MB)': 446.6, 'Accuracy': 0.885},
- 'alxlnet': {'Size (MB)': 46.8, 'Accuracy': 0.874},
+ 'bert': {'Size (MB)': 425.6, 'Quantized Size (MB)': 111, 'Accuracy': 0.872},
+ 'tiny-bert': {
+ 'Size (MB)': 57.4,
+ 'Quantized Size (MB)': 15.4,
+ 'Accuracy': 0.656,
+ },
+ 'albert': {
+ 'Size (MB)': 48.6,
+ 'Quantized Size (MB)': 12.8,
+ 'Accuracy': 0.85265,
+ },
+ 'tiny-albert': {
+ 'Size (MB)': 22.4,
+ 'Quantized Size (MB)': 5.98,
+ 'Accuracy': 0.843,
+ },
+ 'xlnet': {
+ 'Size (MB)': 446.6,
+ 'Quantized Size (MB)': 118,
+ 'Accuracy': 0.885,
+ },
+ 'alxlnet': {
+ 'Size (MB)': 46.8,
+ 'Quantized Size (MB)': 13.3,
+ 'Accuracy': 0.9123,
+ },
}
@@ -26,7 +46,7 @@ def available_transformer():
@check_type
-def transformer(model: str = 'xlnet', **kwargs):
+def transformer(model: str = 'xlnet', quantized: bool = False, **kwargs):
"""
Load Transformer relevancy model.
@@ -41,6 +61,9 @@ def transformer(model: str = 'xlnet', **kwargs):
* ``'tiny-albert'`` - Google ALBERT TINY parameters.
* ``'xlnet'`` - Google XLNET BASE parameters.
* ``'alxlnet'`` - Malaya ALXLNET BASE parameters.
+ quantized : bool, optional (default=False)
+ if True, will load 8-bit quantized model.
+ Quantized model not necessary faster, totally depends on the machine.
Returns
-------
@@ -58,5 +81,6 @@ def transformer(model: str = 'xlnet', **kwargs):
'relevancy',
label,
model = model,
+ quantized = quantized,
**kwargs
)
diff --git a/malaya/similarity.py b/malaya/similarity.py
index d02e54ea..0b0d2bc5 100644
--- a/malaya/similarity.py
+++ b/malaya/similarity.py
@@ -454,6 +454,15 @@ def encoder(vectorizer):
},
}
+_vectorizer_mapping = {
+ 'bert': 'import/bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1:0',
+ 'tiny-bert': 'import/bert/encoder/layer_3/output/LayerNorm/batchnorm/add_1:0',
+ 'albert': 'import/bert/encoder/transformer/group_0_11/layer_11/inner_group_0/LayerNorm_1/batchnorm/add_1:0',
+ 'tiny-albert': 'import/bert/encoder/transformer/group_0_3/layer_3/inner_group_0/LayerNorm_1/batchnorm/add_1:0',
+ 'xlnet': 'import/model/transformer/layer_11/ff/LayerNorm/batchnorm/add_1:0',
+ 'alxlnet': 'import/model/transformer/layer_shared_11/ff/LayerNorm/batchnorm/add_1:0',
+}
+
def available_transformer():
"""
@@ -466,7 +475,9 @@ def available_transformer():
)
-def _transformer(model, bert_class, xlnet_class, quantized = False, **kwargs):
+def _transformer(
+ model, bert_class, xlnet_class, quantized = False, siamese = False, **kwargs
+):
model = model.lower()
if model not in _transformer_availability:
raise ValueError(
@@ -501,13 +512,18 @@ def _transformer(model, bert_class, xlnet_class, quantized = False, **kwargs):
)
selected_class = bert_class
- selected_node = 'import/bert/pooler/dense/BiasAdd:0'
+ if siamese:
+ selected_node = 'import/bert/pooler/dense/BiasAdd:0'
if model in ['xlnet', 'alxlnet']:
tokenizer = sentencepiece_tokenizer_xlnet(path[model]['tokenizer'])
selected_class = xlnet_class
- selected_node = 'import/model_1/sequnece_summary/summary/BiasAdd:0'
+ if siamese:
+ selected_node = 'import/model_1/sequnece_summary/summary/BiasAdd:0'
+
+ if not siamese:
+ selected_node = _vectorizer_mapping[model]
return selected_class(
X = g.get_tensor_by_name('import/Placeholder:0'),
@@ -552,5 +568,6 @@ def transformer(model: str = 'bert', quantized: bool = False, **kwargs):
bert_class = SIAMESE_BERT,
xlnet_class = SIAMESE_XLNET,
quantized = quantized,
+ siamese = True,
**kwargs
)
diff --git a/malaya/text/bpe.py b/malaya/text/bpe.py
index 68bb9580..7e732d9b 100644
--- a/malaya/text/bpe.py
+++ b/malaya/text/bpe.py
@@ -143,7 +143,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
def bert_tokenization_siamese(tokenizer, left, right):
- input_ids, input_masks, segment_ids = [], [], []
+ input_ids, input_masks, segment_ids, s_tokens = [], [], [], []
a, b = [], []
for i in range(len(left)):
tokens_a = tokenizer.tokenize(transformer_textcleaning(left[i]))
@@ -164,6 +164,7 @@ def bert_tokenization_siamese(tokenizer, left, right):
segment_id.append(0)
tokens.append('[SEP]')
+ s_tokens.append(tokens[:])
segment_id.append(0)
for token in tokens_b:
tokens.append(token)
@@ -182,7 +183,7 @@ def bert_tokenization_siamese(tokenizer, left, right):
input_masks = padding_sequence(input_masks, maxlen)
segment_ids = padding_sequence(segment_ids, maxlen)
- return input_ids, input_masks, segment_ids
+ return input_ids, input_masks, segment_ids, s_tokens
SEG_ID_A = 0
@@ -324,7 +325,7 @@ def tokenize_fn(text, sp_model):
def xlnet_tokenization_siamese(tokenizer, left, right):
- input_ids, input_mask, all_seg_ids = [], [], []
+ input_ids, input_mask, all_seg_ids, s_tokens = [], [], [], []
for i in range(len(left)):
tokens = tokenize_fn(transformer_textcleaning(left[i]), tokenizer)
tokens_right = tokenize_fn(
@@ -332,6 +333,7 @@ def xlnet_tokenization_siamese(tokenizer, left, right):
)
segment_ids = [SEG_ID_A] * len(tokens)
tokens.append(SEP_ID)
+ s_tokens.append([tokenizer.IdToPiece(i) for i in tokens])
segment_ids.append(SEG_ID_A)
tokens.extend(tokens_right)
@@ -354,7 +356,7 @@ def xlnet_tokenization_siamese(tokenizer, left, right):
input_ids = padding_sequence(input_ids, maxlen)
input_mask = padding_sequence(input_mask, maxlen, pad_int = 1)
all_seg_ids = padding_sequence(all_seg_ids, maxlen, pad_int = 4)
- return input_ids, input_mask, all_seg_ids
+ return input_ids, input_mask, all_seg_ids, s_tokens
def xlnet_tokenization(tokenizer, texts):
diff --git a/malaya/train/model/transformer/layer.py b/malaya/train/model/transformer/layer.py
new file mode 100644
index 00000000..e69de29b
diff --git a/malaya/zero_shot/classification.py b/malaya/zero_shot/classification.py
index a30171ea..b1b3926c 100644
--- a/malaya/zero_shot/classification.py
+++ b/malaya/zero_shot/classification.py
@@ -44,5 +44,6 @@ def transformer(model: str = 'bert', quantized: bool = False, **kwargs):
bert_class = ZEROSHOT_BERT,
xlnet_class = ZEROSHOT_XLNET,
quantized = quantized,
+ siamese = False,
**kwargs
)
diff --git a/session/emotion/quantize-emotion-model.ipynb b/session/emotion/quantize.ipynb
similarity index 67%
rename from session/emotion/quantize-emotion-model.ipynb
rename to session/emotion/quantize.ipynb
index 1484535b..14686d0b 100644
--- a/session/emotion/quantize-emotion-model.ipynb
+++ b/session/emotion/quantize.ipynb
@@ -120,13 +120,16 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
+ "WARNING:tensorflow:From :11: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use tf.gfile.GFile.\n",
"bert-base-emotion.pb ['Placeholder', 'Placeholder_1']\n",
"xlnet-base-emotion.pb ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
"alxlnet-base-emotion.pb ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
@@ -139,11 +142,9 @@
"source": [
"transforms = ['add_default_attributes',\n",
" 'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
- " 'fold_constants(ignore_errors=true)',\n",
" 'fold_batch_norms',\n",
" 'fold_old_batch_norms',\n",
" 'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
- " 'quantize_nodes(fallback_min=-10, fallback_max=10)',\n",
" 'strip_unused_nodes',\n",
" 'sort_by_execution_order']\n",
"\n",
@@ -154,14 +155,17 @@
" \n",
" if 'bert' in pb:\n",
" inputs = ['Placeholder', 'Placeholder_1']\n",
+ " outputs = ['dense/BiasAdd']\n",
+ " \n",
" if 'xlnet'in pb:\n",
" inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
+ " outputs = ['transpose_3']\n",
" \n",
" print(pb, inputs)\n",
" \n",
" transformed_graph_def = TransformGraph(input_graph_def, \n",
" inputs,\n",
- " ['logits', 'logits_seq'], transforms)\n",
+ " ['logits', 'logits_seq'] + outputs, transforms)\n",
" \n",
" with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
" f.write(transformed_graph_def.SerializeToString())"
@@ -173,58 +177,58 @@
"metadata": {},
"outputs": [],
"source": [
- "def load_graph(frozen_graph_filename, **kwargs):\n",
- " with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
- " graph_def = tf.GraphDef()\n",
- " graph_def.ParseFromString(f.read())\n",
+ "# def load_graph(frozen_graph_filename, **kwargs):\n",
+ "# with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
+ "# graph_def = tf.GraphDef()\n",
+ "# graph_def.ParseFromString(f.read())\n",
"\n",
- " # https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091\n",
- " # to fix import T5\n",
- " for node in graph_def.node:\n",
- " if node.op == 'RefSwitch':\n",
- " node.op = 'Switch'\n",
- " for index in xrange(len(node.input)):\n",
- " if 'moving_' in node.input[index]:\n",
- " node.input[index] = node.input[index] + '/read'\n",
- " elif node.op == 'AssignSub':\n",
- " node.op = 'Sub'\n",
- " if 'use_locking' in node.attr:\n",
- " del node.attr['use_locking']\n",
- " elif node.op == 'AssignAdd':\n",
- " node.op = 'Add'\n",
- " if 'use_locking' in node.attr:\n",
- " del node.attr['use_locking']\n",
- " elif node.op == 'Assign':\n",
- " node.op = 'Identity'\n",
- " if 'use_locking' in node.attr:\n",
- " del node.attr['use_locking']\n",
- " if 'validate_shape' in node.attr:\n",
- " del node.attr['validate_shape']\n",
- " if len(node.input) == 2:\n",
- " node.input[0] = node.input[1]\n",
- " del node.input[1]\n",
+ "# # https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091\n",
+ "# # to fix import T5\n",
+ "# for node in graph_def.node:\n",
+ "# if node.op == 'RefSwitch':\n",
+ "# node.op = 'Switch'\n",
+ "# for index in xrange(len(node.input)):\n",
+ "# if 'moving_' in node.input[index]:\n",
+ "# node.input[index] = node.input[index] + '/read'\n",
+ "# elif node.op == 'AssignSub':\n",
+ "# node.op = 'Sub'\n",
+ "# if 'use_locking' in node.attr:\n",
+ "# del node.attr['use_locking']\n",
+ "# elif node.op == 'AssignAdd':\n",
+ "# node.op = 'Add'\n",
+ "# if 'use_locking' in node.attr:\n",
+ "# del node.attr['use_locking']\n",
+ "# elif node.op == 'Assign':\n",
+ "# node.op = 'Identity'\n",
+ "# if 'use_locking' in node.attr:\n",
+ "# del node.attr['use_locking']\n",
+ "# if 'validate_shape' in node.attr:\n",
+ "# del node.attr['validate_shape']\n",
+ "# if len(node.input) == 2:\n",
+ "# node.input[0] = node.input[1]\n",
+ "# del node.input[1]\n",
"\n",
- " with tf.Graph().as_default() as graph:\n",
- " tf.import_graph_def(graph_def)\n",
- " return graph"
+ "# with tf.Graph().as_default() as graph:\n",
+ "# tf.import_graph_def(graph_def)\n",
+ "# return graph"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
- "g = load_graph('xlnet-base-emotion.pb.quantized')\n",
- "x = g.get_tensor_by_name('import/Placeholder:0')\n",
- "x_len = g.get_tensor_by_name('import/Placeholder_1:0')\n",
- "x_len2 = g.get_tensor_by_name('import/Placeholder_2:0')\n",
- "logits = g.get_tensor_by_name('import/logits:0')"
+ "# g = load_graph('xlnet-base-emotion.pb.quantized')\n",
+ "# x = g.get_tensor_by_name('import/Placeholder:0')\n",
+ "# x_len = g.get_tensor_by_name('import/Placeholder_1:0')\n",
+ "# x_len2 = g.get_tensor_by_name('import/Placeholder_2:0')\n",
+ "# logits = g.get_tensor_by_name('import/logits:0')"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -233,47 +237,27 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
- "test_sess = tf.InteractiveSession(graph = g)"
+ "# test_sess = tf.InteractiveSession(graph = g)"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 11,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 2.54 s, sys: 187 ms, total: 2.72 s\n",
- "Wall time: 1.8 s\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "array([[-1.9607849, -3.294118 , -2.1960788, 10.039215 , 5.1764708,\n",
- " 3.7647057]], dtype=float32)"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "%%time\n",
- "test_sess.run(logits, feed_dict = {x: [[1,2,3,3,4]], x_len: [[1,1,1,1,1]],\n",
- " x_len2: [[1,1,1,1,1]]})"
+ "# %%time\n",
+ "# test_sess.run(logits, feed_dict = {x: [[1,2,3,3,4]], x_len: [[1,1,1,1,1]],\n",
+ "# x_len2: [[1,1,1,1,1]]})"
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -283,7 +267,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -297,7 +281,7 @@
" 'albert-tiny-emotion.pb.quantized']"
]
},
- "execution_count": 24,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -307,61 +291,6 @@
"quantized"
]
},
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [],
- "source": [
- "from b2sdk.v1 import *\n",
- "info = InMemoryAccountInfo()\n",
- "b2_api = B2Api(info)\n",
- "application_key_id = 'd3c416cf4cb1'\n",
- "application_key = '0007c73b0ef09cbff76ebdd5b14f2e0044d6d44b74'\n",
- "b2_api.authorize_account(\"production\", application_key_id, application_key)\n",
- "file_info = {'how': 'good-file'}\n",
- "b2_bucket = b2_api.get_bucket_by_name('malaya-model')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "bert-base-emotion.pb.quantized\n",
- "albert-base-emotion.pb.quantized\n",
- "xlnet-base-emotion.pb.quantized\n",
- "tiny-bert-emotion.pb.quantized\n",
- "alxlnet-base-emotion.pb.quantized\n",
- "albert-tiny-emotion.pb.quantized\n"
- ]
- }
- ],
- "source": [
- "for file in quantized:\n",
- " print(file)\n",
- " key = file\n",
- " outPutname = f\"v40/emotion/{file}\"\n",
- " b2_bucket.upload_local_file(\n",
- " local_file=key,\n",
- " file_name=outPutname,\n",
- " file_infos=file_info,\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [],
- "source": [
- "!rm *.pb*"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
diff --git a/session/entities/quantize-entity-model.ipynb b/session/entities/quantize.ipynb
similarity index 87%
rename from session/entities/quantize-entity-model.ipynb
rename to session/entities/quantize.ipynb
index aa67b9dc..7b2c81c6 100644
--- a/session/entities/quantize-entity-model.ipynb
+++ b/session/entities/quantize.ipynb
@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -58,7 +58,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -121,13 +121,16 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
+ "WARNING:tensorflow:From :11: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use tf.gfile.GFile.\n",
"xlnet-base-entity.pb ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
"alxlnet-base-entity.pb ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
"albert-tiny-entity.pb ['Placeholder', 'Placeholder_1']\n",
@@ -140,11 +143,9 @@
"source": [
"transforms = ['add_default_attributes',\n",
" 'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
- " 'fold_constants(ignore_errors=true)',\n",
" 'fold_batch_norms',\n",
" 'fold_old_batch_norms',\n",
" 'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
- " 'quantize_nodes(fallback_min=-10, fallback_max=10)',\n",
" 'strip_unused_nodes',\n",
" 'sort_by_execution_order']\n",
"\n",
@@ -155,14 +156,16 @@
" \n",
" if 'bert' in pb:\n",
" inputs = ['Placeholder', 'Placeholder_1']\n",
+ " outputs = ['dense/BiasAdd']\n",
" if 'xlnet'in pb:\n",
" inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
+ " outputs = ['transpose_3']\n",
" \n",
" print(pb, inputs)\n",
" \n",
" transformed_graph_def = TransformGraph(input_graph_def, \n",
" inputs,\n",
- " ['logits'], transforms)\n",
+ " ['logits'] + outputs, transforms)\n",
" \n",
" with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
" f.write(transformed_graph_def.SerializeToString())"
@@ -170,7 +173,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -212,7 +215,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -225,7 +228,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -234,7 +237,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -243,24 +246,24 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.78 s, sys: 179 ms, total: 2.96 s\n",
- "Wall time: 1.93 s\n"
+ "CPU times: user 2.58 s, sys: 615 ms, total: 3.19 s\n",
+ "Wall time: 2.68 s\n"
]
},
{
"data": {
"text/plain": [
- "array([[8, 7, 0, 0, 8]], dtype=int32)"
+ "array([[2, 2, 2, 0, 0]], dtype=int32)"
]
},
- "execution_count": 15,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -273,7 +276,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -283,7 +286,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -297,7 +300,7 @@
" 'albert-base-entity.pb.quantized']"
]
},
- "execution_count": 17,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -307,52 +310,6 @@
"quantized"
]
},
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "from b2sdk.v1 import *\n",
- "info = InMemoryAccountInfo()\n",
- "b2_api = B2Api(info)\n",
- "application_key_id = 'd3c416cf4cb1'\n",
- "application_key = '0007c73b0ef09cbff76ebdd5b14f2e0044d6d44b74'\n",
- "b2_api.authorize_account(\"production\", application_key_id, application_key)\n",
- "file_info = {'how': 'good-file'}\n",
- "b2_bucket = b2_api.get_bucket_by_name('malaya-model')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "bert-base-entity.pb.quantized\n",
- "tiny-bert-entity.pb.quantized\n",
- "alxlnet-base-entity.pb.quantized\n",
- "xlnet-base-entity.pb.quantized\n",
- "albert-tiny-entity.pb.quantized\n",
- "albert-base-entity.pb.quantized\n"
- ]
- }
- ],
- "source": [
- "for file in quantized:\n",
- " print(file)\n",
- " key = file\n",
- " outPutname = f\"v40/entity/{file}\"\n",
- " b2_bucket.upload_local_file(\n",
- " local_file=key,\n",
- " file_name=outPutname,\n",
- " file_infos=file_info,\n",
- " )"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
diff --git a/session/sentiment/quantize.ipynb b/session/sentiment/quantize.ipynb
index c95c99a9..418ef0d5 100644
--- a/session/sentiment/quantize.ipynb
+++ b/session/sentiment/quantize.ipynb
@@ -58,7 +58,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -95,7 +95,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
@@ -109,7 +109,7 @@
" 'alxlnet-base-sentiment.pb']"
]
},
- "execution_count": 7,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -121,45 +121,77 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# with tf.gfile.GFile('alxlnet-base-sentiment.pb', \"rb\") as f:\n",
+ "# graph_def = tf.GraphDef()\n",
+ "# graph_def.ParseFromString(f.read())\n",
+ "\n",
+ "# with tf.Graph().as_default() as graph:\n",
+ "# tf.import_graph_def(graph_def)\n",
+ "\n",
+ "# op = graph.get_operations()\n",
+ "# x = []\n",
+ "# for i in op:\n",
+ "# try:\n",
+ "# #if 'pooler' in i.values()[0].name:\n",
+ "# x.append(i.values())\n",
+ "# except:\n",
+ "# pass\n",
+ " \n",
+ "# x[-100:]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "albert-tiny-sentiment.pb\n",
- "WARNING:tensorflow:From :14: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
+ "WARNING:tensorflow:From :11: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.gfile.GFile.\n",
- "xlnet-base-sentiment.pb\n",
- "albert-base-sentiment.pb\n",
- "tiny-bert-sentiment.pb\n",
- "bert-base-sentiment.pb\n",
- "alxlnet-base-sentiment.pb\n"
+ "albert-tiny-sentiment.pb ['Placeholder', 'Placeholder_1']\n",
+ "xlnet-base-sentiment.pb ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
+ "albert-base-sentiment.pb ['Placeholder', 'Placeholder_1']\n",
+ "tiny-bert-sentiment.pb ['Placeholder', 'Placeholder_1']\n",
+ "bert-base-sentiment.pb ['Placeholder', 'Placeholder_1']\n",
+ "alxlnet-base-sentiment.pb ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n"
]
}
],
"source": [
"transforms = ['add_default_attributes',\n",
" 'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
- " 'fold_constants(ignore_errors=true)',\n",
" 'fold_batch_norms',\n",
" 'fold_old_batch_norms',\n",
" 'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
- " 'quantize_nodes(fallback_min=-10, fallback_max=10)',\n",
" 'strip_unused_nodes',\n",
" 'sort_by_execution_order']\n",
"\n",
"for pb in pbs:\n",
- " print(pb)\n",
" input_graph_def = tf.GraphDef()\n",
" with tf.gfile.FastGFile(pb, 'rb') as f:\n",
" input_graph_def.ParseFromString(f.read())\n",
" \n",
+ " if 'bert' in pb:\n",
+ " inputs = ['Placeholder', 'Placeholder_1']\n",
+ " outputs = ['dense/BiasAdd']\n",
+ " \n",
+ " if 'xlnet'in pb:\n",
+ " inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
+ " outputs = ['transpose_3']\n",
+ " \n",
+ " print(pb, inputs)\n",
+ " \n",
" transformed_graph_def = TransformGraph(input_graph_def, \n",
- " ['Placeholder', 'Placeholder_1'],\n",
- " ['logits', 'logits_seq'], transforms)\n",
+ " inputs,\n",
+ " ['logits', 'logits_seq'] + outputs, transforms)\n",
" \n",
" with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
" f.write(transformed_graph_def.SerializeToString())"
@@ -167,7 +199,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -209,7 +241,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -221,7 +253,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -230,7 +262,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -239,7 +271,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -249,7 +281,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -259,7 +291,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -273,7 +305,7 @@
" 'tiny-bert-sentiment.pb.quantized']"
]
},
- "execution_count": 17,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -283,52 +315,6 @@
"quantized"
]
},
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "from b2sdk.v1 import *\n",
- "info = InMemoryAccountInfo()\n",
- "b2_api = B2Api(info)\n",
- "application_key_id = 'd3c416cf4cb1'\n",
- "application_key = '0007c73b0ef09cbff76ebdd5b14f2e0044d6d44b74'\n",
- "b2_api.authorize_account(\"production\", application_key_id, application_key)\n",
- "file_info = {'how': 'good-file'}\n",
- "b2_bucket = b2_api.get_bucket_by_name('malaya-model')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "albert-base-sentiment.pb.quantized\n",
- "xlnet-base-sentiment.pb.quantized\n",
- "albert-tiny-sentiment.pb.quantized\n",
- "bert-base-sentiment.pb.quantized\n",
- "alxlnet-base-sentiment.pb.quantized\n",
- "tiny-bert-sentiment.pb.quantized\n"
- ]
- }
- ],
- "source": [
- "for file in quantized:\n",
- " print(file)\n",
- " key = file\n",
- " outPutname = f\"v40/sentiment/{file}\"\n",
- " b2_bucket.upload_local_file(\n",
- " local_file=key,\n",
- " file_name=outPutname,\n",
- " file_infos=file_info,\n",
- " )"
- ]
- },
{
"cell_type": "code",
"execution_count": null,