From 7a6972bcd43c4f2ff9809f7760b3f774ab5bd612 Mon Sep 17 00:00:00 2001 From: dimazhylko Date: Sat, 12 Dec 2020 05:30:26 +0100 Subject: [PATCH] first nlp notebook --- NLP/nlp-v0.ipynb | 1546 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1546 insertions(+) create mode 100644 NLP/nlp-v0.ipynb diff --git a/NLP/nlp-v0.ipynb b/NLP/nlp-v0.ipynb new file mode 100644 index 0000000..f7423e6 --- /dev/null +++ b/NLP/nlp-v0.ipynb @@ -0,0 +1,1546 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install tensorflow==1.15" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import files\n", + "files.upload()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python -m pip install kaggle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir ~/.kaggle\n", + "!mv kaggle.json ~/.kaggle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!kaggle datasets download -d bittlingmayer/amazonreviews" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!unzip amazonreviews.zip " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!rm amazonreviews.zip\n", + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import nltk\n", + "from nltk.tokenize import word_tokenize\n", + "from nltk.corpus import stopwords\n", + "from nltk.stem import SnowballStemmer\n", + "from collections import defaultdict\n", + "from nltk.corpus import wordnet as wn\n", + "from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer\n", + "from sklearn import model_selection, naive_bayes\n", + "from sklearn.metrics import accuracy_score\n", + "import bz2\n", + "import re\n", + "from tqdm.notebook import tqdm\n", + "import matplotlib.pyplot as plt\n", + "plt.style.use('ggplot')" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_text(text):\n", + " text = text.lower()\n", + " idx = text.find(':')\n", + " text = text[idx+1:]\n", + " \n", + " text = re.sub(r\"[^\\w\\s]+\", '', text)\n", + " text = re.sub(r\"\\s+\", ' ', text)\n", + " return ' '.join([word for word in text.strip().split() if len(word) > 1])\n", + "\n", + "def read_and_preprocess(file, total=1, sub_size=-1):\n", + " labels = []\n", + " texts = []\n", + " if sub_size != -1:\n", + " total = min(total, sub_size)\n", + " \n", + " for l in tqdm(bz2.BZ2File(file), total=total):\n", + " x = l.decode('utf-8')\n", + " label = int(x[9]) - 1\n", + " text = x[10:].strip()\n", + " text = prepare_text(text)\n", + " if text != ' ':\n", + " labels.append(label)\n", + " texts.append(text)\n", + " if len(texts) == sub_size:\n", + " break\n", + " \n", + " return np.array(labels), texts" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2f074774a07c4ae49929422093ede39c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "train_labels, train_texts = read_and_preprocess('archive-2/train.ft.txt.bz2', total=3600000, sub_size=500000)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0b53430bf1c74b09a712347241d39866", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=200000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "test_labels, test_texts = read_and_preprocess('archive-2/test.ft.txt.bz2', total=400000, sub_size=200000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Naive Bayes" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "stemmer = SnowballStemmer(\"english\")\n", + "\n", + "def stem_text(text):\n", + " text_tokenized = word_tokenize(text)\n", + " return ' '.join([stemmer.stem(word) for word in text_tokenized])" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "96466f47db334716bf64ac523fc6d634", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "train_texts_stammed = [stem_text(text) for text in tqdm(train_texts)]" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b0585fc220cd4ced8514e0f06c02cb77", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=200000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "test_texts_stemmed = [stem_text(text) for text in tqdm(test_texts)]" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[nltk_data] Downloading package stopwords to\n", + "[nltk_data] /Users/dimazhylko/nltk_data...\n", + "[nltk_data] Package stopwords is already up-to-date!\n" + ] + }, + { + "data": { + "text/plain": [ + "['i',\n", + " 'me',\n", + " 'my',\n", + " 'myself',\n", + " 'we',\n", + " 'our',\n", + " 'ours',\n", + " 'ourselves',\n", + " 'you',\n", + " \"you're\",\n", + " \"you've\",\n", + " \"you'll\",\n", + " \"you'd\",\n", + " 'your',\n", + " 'yours',\n", + " 'yourself',\n", + " 'yourselves',\n", + " 'he',\n", + " 'him',\n", + " 'his',\n", + " 'himself',\n", + " 'she',\n", + " \"she's\",\n", + " 'her',\n", + " 'hers',\n", + " 'herself',\n", + " 'it',\n", + " \"it's\",\n", + " 'its',\n", + " 'itself',\n", + " 'they',\n", + " 'them',\n", + " 'their',\n", + " 'theirs',\n", + " 'themselves',\n", + " 'what',\n", + " 'which',\n", + " 'who',\n", + " 'whom',\n", + " 'this',\n", + " 'that',\n", + " \"that'll\",\n", + " 'these',\n", + " 'those',\n", + " 'am',\n", + " 'is',\n", + " 'are',\n", + " 'was',\n", + " 'were',\n", + " 'be',\n", + " 'been',\n", + " 'being',\n", + " 'have',\n", + " 'has',\n", + " 'had',\n", + " 'having',\n", + " 'do',\n", + " 'does',\n", + " 'did',\n", + " 'doing',\n", + " 'a',\n", + " 'an',\n", + " 'the',\n", + " 'and',\n", + " 'but',\n", + " 'if',\n", + " 'or',\n", + " 'because',\n", + " 'as',\n", + " 'until',\n", + " 'while',\n", + " 'of',\n", + " 'at',\n", + " 'by',\n", + " 'for',\n", + " 'with',\n", + " 'about',\n", + " 'against',\n", + " 'between',\n", + " 'into',\n", + " 'through',\n", + " 'during',\n", + " 'before',\n", + " 'after',\n", + " 'above',\n", + " 'below',\n", + " 'to',\n", + " 'from',\n", + " 'up',\n", + " 'down',\n", + " 'in',\n", + " 'out',\n", + " 'on',\n", + " 'off',\n", + " 'over',\n", + " 'under',\n", + " 'again',\n", + " 'further',\n", + " 'then',\n", + " 'once',\n", + " 'here',\n", + " 'there',\n", + " 'when',\n", + " 'where',\n", + " 'why',\n", + " 'how',\n", + " 'all',\n", + " 'any',\n", + " 'both',\n", + " 'each',\n", + " 'few',\n", + " 'more',\n", + " 'most',\n", + " 'other',\n", + " 'some',\n", + " 'such',\n", + " 'no',\n", + " 'nor',\n", + " 'not',\n", + " 'only',\n", + " 'own',\n", + " 'same',\n", + " 'so',\n", + " 'than',\n", + " 'too',\n", + " 'very',\n", + " 's',\n", + " 't',\n", + " 'can',\n", + " 'will',\n", + " 'just',\n", + " 'don',\n", + " \"don't\",\n", + " 'should',\n", + " \"should've\",\n", + " 'now',\n", + " 'd',\n", + " 'll',\n", + " 'm',\n", + " 'o',\n", + " 're',\n", + " 've',\n", + " 'y',\n", + " 'ain',\n", + " 'aren',\n", + " \"aren't\",\n", + " 'couldn',\n", + " \"couldn't\",\n", + " 'didn',\n", + " \"didn't\",\n", + " 'doesn',\n", + " \"doesn't\",\n", + " 'hadn',\n", + " \"hadn't\",\n", + " 'hasn',\n", + " \"hasn't\",\n", + " 'haven',\n", + " \"haven't\",\n", + " 'isn',\n", + " \"isn't\",\n", + " 'ma',\n", + " 'mightn',\n", + " \"mightn't\",\n", + " 'mustn',\n", + " \"mustn't\",\n", + " 'needn',\n", + " \"needn't\",\n", + " 'shan',\n", + " \"shan't\",\n", + " 'shouldn',\n", + " \"shouldn't\",\n", + " 'wasn',\n", + " \"wasn't\",\n", + " 'weren',\n", + " \"weren't\",\n", + " 'won',\n", + " \"won't\",\n", + " 'wouldn',\n", + " \"wouldn't\"]" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nltk.download('stopwords')\n", + "stopwords.words('english')" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "vectorizer = TfidfVectorizer(stop_words=stopwords.words('english'))" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2db651fa6f654be8b5262ea5d1e5abd4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "TfidfVectorizer(stop_words=['i', 'me', 'my', 'myself', 'we', 'our', 'ours',\n", + " 'ourselves', 'you', \"you're\", \"you've\", \"you'll\",\n", + " \"you'd\", 'your', 'yours', 'yourself', 'yourselves',\n", + " 'he', 'him', 'his', 'himself', 'she', \"she's\",\n", + " 'her', 'hers', 'herself', 'it', \"it's\", 'its',\n", + " 'itself', ...])" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vectorizer.fit(tqdm(train_texts_stammed))" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "464542" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(vectorizer.vocabulary_)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " (0, 459946)\t0.07996947534081235\n", + " (0, 459736)\t0.31538065409307176\n", + " (0, 457003)\t0.09742771221067453\n", + " (0, 456371)\t0.06520515383876084\n", + " (0, 456168)\t0.09713377583791682\n", + " (0, 427397)\t0.1522013109363848\n", + " (0, 414328)\t0.17111347538516142\n", + " (0, 400409)\t0.18128965197057867\n", + " (0, 385844)\t0.22411317316711604\n", + " (0, 381526)\t0.1483145901715332\n", + " (0, 371333)\t0.118999542170106\n", + " (0, 355789)\t0.1682200085112425\n", + " (0, 344880)\t0.1754696697183511\n", + " (0, 339375)\t0.16337151657839263\n", + " (0, 334608)\t0.06474427379750541\n", + " (0, 321766)\t0.10144302733177181\n", + " (0, 321490)\t0.10598393853075028\n", + " (0, 305987)\t0.17359537914160617\n", + " (0, 294015)\t0.3251993177406255\n", + " (0, 292838)\t0.0706608532769978\n", + " (0, 291694)\t0.05459682227165906\n", + " (0, 274154)\t0.09936518099451207\n", + " (0, 273387)\t0.08933815725521534\n", + " (0, 271816)\t0.07582024442678638\n", + " (0, 266313)\t0.09095363420801905\n", + " (0, 264547)\t0.2985955377215965\n", + " (0, 253314)\t0.14794290219685868\n", + " (0, 244473)\t0.09280938288968092\n", + " (0, 240953)\t0.09820270026942181\n", + " (0, 206669)\t0.17025931308755213\n", + " (0, 205099)\t0.11738729685773906\n", + " (0, 177644)\t0.07882700410235738\n", + " (0, 171721)\t0.1143947303147365\n", + " (0, 158388)\t0.1267594336413101\n", + " (0, 156602)\t0.09383720788048346\n", + " (0, 152480)\t0.3251993177406255\n", + " (0, 146695)\t0.0934623823077469\n", + " (0, 122173)\t0.17446662989786169\n", + " (0, 80727)\t0.08981696660906542\n", + " (0, 73620)\t0.0773700997164049\n", + " (0, 59883)\t0.10539847299062545\n", + " (0, 56872)\t0.08339506749009365\n", + " (0, 53599)\t0.11035509264729593\n", + " (0, 35041)\t0.08101940355208788\n" + ] + } + ], + "source": [ + "print(vectorizer.transform([train_texts_stammed[1]]))" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ed38ab126e354cf3ab6f67ac07bff0a9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "473544d2953e45d9a5d640c65d97c8c8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=200000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Naive Bayes Accuracy Score -> 75.4015\n" + ] + } + ], + "source": [ + "train_texts_tfidf = vectorizer.transform(tqdm(train_texts_stammed))\n", + "test_texts_tfidf = vectorizer.transform(tqdm(test_texts))\n", + "\n", + "bayes_model = naive_bayes.MultinomialNB()\n", + "bayes_model.fit(train_texts_tfidf, train_labels)\n", + "predictions_bayes = bayes_model.predict(test_texts_tfidf)\n", + "\n", + "print(\"Naive Bayes Accuracy Score -> \", accuracy_score(predictions_bayes, test_labels)*100)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbtUlEQVR4nO3dbUxUZx738e+Zwa7gUToPsAaD2bVoE41k0DE+bCuuTvuiNk2jpllf1JTV1sbWDZJtUmv2IWu1dC3SRSBuDGm2jS9sjNomd7I2lAUaiVusC401rUtpN0vUIpwRmRUXhHO/8HbussIFMoBgf59XzHXOdeY/p3/7m/PAwXJd10VERGQQnntdgIiITGwKChERMVJQiIiIkYJCRESMFBQiImKkoBAREaOke13AWLh48eK9LsEoGAzS1tZ2r8sY0mSpEyZPrapzdE2WOmHi15qRkTHoMh1RiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETEa8hfu2traKCsr4+rVq1iWRSQS4YknnuD999/n448/ZsaMGQBs3LiRRYsWAXD8+HGqqqrweDzk5eURCoUAaG5upqysjO7ubnJycsjLy8OyLHp6eigtLaW5uZnp06eTn59Peno6ANXV1Rw7dgyAdevWsWrVqjHYDSIiMpghg8Lr9fLss88yZ84curq6ePXVV8nOzgZg7dq1PPXUU/3Wb2lpoa6ujv379xONRtm9ezd/+tOf8Hg8HDp0iK1btzJ37lzeeOMNGhoayMnJoaqqimnTpnHgwAFOnTrF4cOH2bFjB7FYjKNHj1JYWAjAq6++SjgcxrbtMdgVIpND7/NPDb1Sgr4b83cYHZOlThifWr2HPhyT7Q556snn8zFnzhwAkpOTmTVrFo7jDLp+fX09K1asYMqUKaSnpzNz5kyampqIRqN0dXUxb948LMti5cqV1NfXA3DmzJn4kcKyZcs4d+4cruvS0NBAdnY2tm1j2zbZ2dk0NDQk/qlFRGTY7upZT62trXzzzTdkZWXx5ZdfcvLkSWpra5kzZw6bNm3Ctm0cx2Hu3LnxOX6/H8dx8Hq9BAKB+HggEIgHjuM48WVer5eUlBQ6Ozv7jX9/W/+rsrKSyspKAAoLCwkGg3fzscZdUlLShK8RJk+dMHlqHY06J9O3aBlfY/VvYNhBcePGDYqKinjuuedISUnh8ccfZ8OGDQAcOXKEd999l23btjHYn+A2/WnugZZZljXgugONRyIRIpFI/PVEfvAWTPyHg902WeqEyVPrZKlTJqdEesv0UMBhBcXNmzcpKiri0UcfZenSpQA8+OCD8eVr1qzhzTffBG4dKbS3t8eXOY6D3++/Y7y9vR2/399vTiAQoLe3l+vXr2PbNn6/n/Pnz/fb1vz584dT8ojp/O//N1nqhMl9/ldkohvyGoXruhw8eJBZs2bx5JNPxsej0Wj8508//ZTMzEwAwuEwdXV19PT00NrayqVLl8jKysLn85GcnMyFCxdwXZfa2lrC4TAAixcvprq6GoDTp0+zYMECLMsiFArR2NhILBYjFovR2NgYv4NKRETGx5BHFF999RW1tbXMnj2bV155Bbh1K+ypU6f49ttvsSyLtLQ0XnjhBQAyMzNZvnw5BQUFeDweNm/ejMdzK4+2bNlCeXk53d3dhEIhcnJyAFi9ejWlpaVs374d27bJz88HwLZt1q9fz86dOwHYsGGD7ngSERlnlmu6eDBJJfKHi8bj1JNMTqNx6mk0rlGoR2UwifSo/nCRiIiMmIJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFKGmqFtrY2ysrKuHr1KpZlEYlEeOKJJ4jFYhQXF3PlyhXS0tLYsWMHtm0DcPz4caqqqvB4POTl5REKhQBobm6mrKyM7u5ucnJyyMvLw7Isenp6KC0tpbm5menTp5Ofn096ejoA1dXVHDt2DIB169axatWqsdkTIiIyoCGPKLxeL88++yzFxcXs2bOHkydP0tLSwokTJ1i4cCElJSUsXLiQEydOANDS0kJdXR379+9n165dVFRU0NfXB8ChQ4fYunUrJSUlXL58mYaGBgCqqqqYNm0aBw4cYO3atRw+fBiAWCzG0aNH2bt3L3v37uXo0aPEYrGx2RMiIjKgIYPC5/MxZ84cAJKTk5k1axaO41BfX09ubi4Aubm51NfXA1BfX8+KFSuYMmUK6enpzJw5k6amJqLRKF1dXcybNw/Lsli5cmV8zpkzZ+JHCsuWLePcuXO4rktDQwPZ2dnYto1t22RnZ8fDRURExseQp56+r7W1lW+++YasrCw6Ojrw+XzArTC5du0aAI7jMHfu3Pgcv9+P4zh4vV4CgUB8PBAI4DhOfM7tZV6vl5SUFDo7O/uNf39b/6uyspLKykoACgsLCQaDd/Ox+vluxDPlfpdIX92WlJSU8HbUozKY0ejRgQw7KG7cuEFRURHPPfccKSkpg67nuu5djQ+2zLKsAdcdaDwSiRCJROKv29raBn0vkZEajb4KBoPqTxkzifRWRkbGoMuGddfTzZs3KSoq4tFHH2Xp0qUApKamEo1GAYhGo8yYMQO4daTQ3t4en+s4Dn6//47x9vZ2/H7/HXN6e3u5fv06tm3j9/vv2NbtoxgRERkfQwaF67ocPHiQWbNm8eSTT8bHw+EwNTU1ANTU1LBkyZL4eF1dHT09PbS2tnLp0iWysrLw+XwkJydz4cIFXNeltraWcDgMwOLFi6murgbg9OnTLFiwAMuyCIVCNDY2EovFiMViNDY2xu+gEhGR8THkqaevvvqK2tpaZs+ezSuvvALAxo0befrppykuLqaqqopgMEhBQQEAmZmZLF++nIKCAjweD5s3b8bjuZVHW7Zsoby8nO7ubkKhEDk5OQCsXr2a0tJStm/fjm3b5OfnA2DbNuvXr2fnzp0AbNiwIX4LroiIjA/LNV08mKQuXrw44rm9zz81ipXI/cR76MOEtzEa1yjUozKYRHo04WsUIiLyw6WgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMkoZaoby8nLNnz5KamkpRUREA77//Ph9//DEzZswAYOPGjSxatAiA48ePU1VVhcfjIS8vj1AoBEBzczNlZWV0d3eTk5NDXl4elmXR09NDaWkpzc3NTJ8+nfz8fNLT0wGorq7m2LFjAKxbt45Vq1aN9ucXEZEhDHlEsWrVKl577bU7xteuXcu+ffvYt29fPCRaWlqoq6tj//797Nq1i4qKCvr6+gA4dOgQW7dupaSkhMuXL9PQ0ABAVVUV06ZN48CBA6xdu5bDhw8DEIvFOHr0KHv37mXv3r0cPXqUWCw2Wp9bRESGacigmD9/PrZtD2tj9fX1rFixgilTppCens7MmTNpamoiGo3S1dXFvHnzsCyLlStXUl9fD8CZM2fiRwrLli3j3LlzuK5LQ0MD2dnZ2LaNbdtkZ2fHw0VERMbPkKeeBnPy5Elqa2uZM2cOmzZtwrZtHMdh7ty58XX8fj+O4+D1egkEAvHxQCCA4zgAOI4TX+b1eklJSaGzs7Pf+Pe3NZDKykoqKysBKCwsJBgMjvRj8d2IZ8r9LpG+ui0pKSnh7ahHZTCj0aMDGVFQPP7442zYsAGAI0eO8O6777Jt2zZc1x1w/cHGB1tmWdaA6w42HolEiEQi8ddtbW2Dvp/ISI1GXwWDQfWnjJlEeisjI2PQZSO66+nBBx/E4/Hg8XhYs2YNX3/9NXDrSKG9vT2+nuM4+P3+O8bb29vx+/13zOnt7eX69evYto3f779jWz6fbyTliohIAkYUFNFoNP7zp59+SmZmJgDhcJi6ujp6enpobW3l0qVLZGVl4fP5SE5O5sKFC7iuS21tLeFwGIDFixdTXV0NwOnTp1mwYAGWZREKhWhsbCQWixGLxWhsbIzfQSUiIuNnyFNPb7/9NufPn6ezs5MXX3yRZ555hi+++IJvv/0Wy7JIS0vjhRdeACAzM5Ply5dTUFCAx+Nh8+bNeDy3smjLli2Ul5fT3d1NKBQiJycHgNWrV1NaWsr27duxbZv8/HwAbNtm/fr17Ny5E4ANGzYM+6K6iIiMHss1XUCYpC5evDjiub3PPzWKlcj9xHvow4S3MRrXKNSjMphEenTUr1GIiMgPh4JCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFKGmqF8vJyzp49S2pqKkVFRQDEYjGKi4u5cuUKaWlp7NixA9u2ATh+/DhVVVV4PB7y8vIIhUIANDc3U1ZWRnd3Nzk5OeTl5WFZFj09PZSWltLc3Mz06dPJz88nPT0dgOrqao4dOwbAunXrWLVq1RjsAhERMRnyiGLVqlW89tpr/cZOnDjBwoULKSkpYeHChZw4cQKAlpYW6urq2L9/P7t27aKiooK+vj4ADh06xNatWykpKeHy5cs0NDQAUFVVxbRp0zhw4ABr167l8OHDwK0wOnr0KHv37mXv3r0cPXqUWCw2ih9dRESGY8igmD9/fvxo4bb6+npyc3MByM3Npb6+Pj6+YsUKpkyZQnp6OjNnzqSpqYloNEpXVxfz5s3DsixWrlwZn3PmzJn4kcKyZcs4d+4cruvS0NBAdnY2tm1j2zbZ2dnxcBERkfEz5KmngXR0dODz+QDw+Xxcu3YNAMdxmDt3bnw9v9+P4zh4vV4CgUB8PBAI4DhOfM7tZV6vl5SUFDo7O/uNf39bA6msrKSyshKAwsJCgsHgSD4WAN+NeKbc7xLpq9uSkpIS3o56VAYzGj06kBEFxWBc172r8cGWWZY14LqDjUciESKRSPx1W1ubqUyRERmNvgoGg+pPGTOJ9FZGRsagy0Z011NqairRaBSAaDTKjBkzgFtHCu3t7fH1HMfB7/ffMd7e3o7f779jTm9vL9evX8e2bfx+/x3bun0UIyIi42dEQREOh6mpqQGgpqaGJUuWxMfr6uro6emhtbWVS5cukZWVhc/nIzk5mQsXLuC6LrW1tYTDYQAWL15MdXU1AKdPn2bBggVYlkUoFKKxsZFYLEYsFqOxsTF+B5WIiIyfIU89vf3225w/f57Ozk5efPFFnnnmGZ5++mmKi4upqqoiGAxSUFAAQGZmJsuXL6egoACPx8PmzZvxeG5l0ZYtWygvL6e7u5tQKEROTg4Aq1evprS0lO3bt2PbNvn5+QDYts369evZuXMnABs2bLjjorqIiIw9yzVdQJikLl68OOK5vc8/NYqVyP3Ee+jDhLcxGtco1KMymER6dNSvUYiIyA+HgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMUpKZPJLL73E1KlT8Xg8eL1eCgsLicViFBcXc+XKFdLS0tixYwe2bQNw/Phxqqqq8Hg85OXlEQqFAGhubqasrIzu7m5ycnLIy8vDsix6enooLS2lubmZ6dOnk5+fT3p6esIfWkREhi/hI4rf/e537Nu3j8LCQgBOnDjBwoULKSkpYeHChZw4cQKAlpYW6urq2L9/P7t27aKiooK+vj4ADh06xNatWykpKeHy5cs0NDQAUFVVxbRp0zhw4ABr167l8OHDiZYrIiJ3adRPPdXX15ObmwtAbm4u9fX18fEVK1YwZcoU0tPTmTlzJk1NTUSjUbq6upg3bx6WZbFy5cr4nDNnzrBq1SoAli1bxrlz53Bdd7RLFhERg4ROPQHs2bMHgMcee4xIJEJHRwc+nw8An8/HtWvXAHAch7lz58bn+f1+HMfB6/USCATi44FAAMdx4nNuL/N6vaSkpNDZ2cmMGTP61VBZWUllZSUAhYWFBIPBEX+e70Y8U+53ifTVbUlJSQlvRz0qgxmNHh1IQkGxe/du/H4/HR0dvP7662RkZAy67mBHAqYjhIGWWZZ1x1gkEiESicRft7W1mcoWGZHR6KtgMKj+lDGTSG+Z/v+d0Kknv98PQGpqKkuWLKGpqYnU1FSi0SgA0Wg0/u0/EAjQ3t4en+s4Dn6//47x9vb2+Ha/v6y3t5fr16/HL4yLiMj4GHFQ3Lhxg66urvjPn3/+ObNnzyYcDlNTUwNATU0NS5YsASAcDlNXV0dPTw+tra1cunSJrKwsfD4fycnJXLhwAdd1qa2tJRwOA7B48WKqq6sBOH36NAsWLBjwiEJERMbOiE89dXR08NZbbwG3vu0/8sgjhEIhHnroIYqLi6mqqiIYDFJQUABAZmYmy5cvp6CgAI/Hw+bNm/F4buXUli1bKC8vp7u7m1AoRE5ODgCrV6+mtLSU7du3Y9s2+fn5CX5cERG5W5Z7H95GdPHixRHP7X3+qVGsRO4n3kMfJryN0bhGoR6VwSTSo2N2jUJERO5/CgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExCjpXhcwHA0NDbzzzjv09fWxZs0ann766XtdkojID8aEP6Lo6+ujoqKC1157jeLiYk6dOkVLS8u9LktE5AdjwgdFU1MTM2fO5Mc//jFJSUmsWLGC+vr6e12WiMgPxoQ/9eQ4DoFAIP46EAjwz3/+s986lZWVVFZWAlBYWEhGRsbI3/D/nBn5XJFhSKg/QT0q427CH1G4rnvHmGVZ/V5HIhEKCwspLCwcr7IS8uqrr97rEoZlstQJk6dW1Tm6JkudMLlq/V8TPigCgQDt7e3x1+3t7fh8vntYkYjID8uED4qHHnqIS5cu0drays2bN6mrqyMcDt/rskREfjAm/DUKr9fLL3/5S/bs2UNfXx8///nPyczMvNdlJSQSidzrEoZlstQJk6dW1Tm6JkudMLlq/V+WO9BFABERkf9nwp96EhGRe0tBISIiRhP+GsVkFYvFKC4u5sqVK6SlpbFjxw5s2+63TltbG2VlZVy9ehXLsohEIjzxxBMAvP/++3z88cfMmDEDgI0bN7Jo0aJRq2+ox6K4rss777zDP/7xD370ox+xbds25syZM6y5o2mo9/rkk0/44IMPAJg6dSpbtmzhJz/5CQAvvfQSU6dOxePx4PV6x/T26aHq/OKLL/jjH/9Ieno6AEuXLmXDhg3DmjvetX744Yd88sknwK0nI7S0tFBRUYFt2+O2T8vLyzl79iypqakUFRXdsXyi9Odwap0oPZoQV8bEe++95x4/ftx1Xdc9fvy4+957792xjuM47tdff+26rutev37d/dWvfuX++9//dl3XdY8cOeJ+8MEHY1Jbb2+v+/LLL7uXL192e3p63F//+tfx973ts88+c/fs2eP29fW5X331lbtz585hzx3POr/88ku3s7PTdV3XPXv2bLxO13Xdbdu2uR0dHWNS293Wee7cOfeNN94Y0dzxrvX76uvr3d///vfx1+O1T7/44gv366+/dgsKCgZcPhH6c7i1ToQeTZROPY2R+vp6cnNzAcjNzR3wsSM+ny/+LSg5OZlZs2bhOM6Y1zacx6KcOXOGlStXYlkW8+bN4z//+Q/RaHRcH6kynPd6+OGH40dqc+fO7fc7N+MlkX0y3o+oudv3O3XqFD/72c/GrJ7BzJ8//44j8O+bCP053FonQo8mSqeexkhHR0f8FwN9Ph/Xrl0zrt/a2so333xDVlZWfOzkyZPU1tYyZ84cNm3aZGzGuzGcx6I4jkMwGOy3juM4w5o7Wu72vaqqqsjJyek3tmfPHgAee+yxMbs9cbh1XrhwgVdeeQWfz8ezzz5LZmbmuO7Pu6kV4L///S8NDQ1s3ry53/h47NOhTIT+HIl71aOJUlAkYPfu3Vy9evWO8V/84hd3tZ0bN25QVFTEc889R0pKCgCPP/54/Bz2kSNHePfdd9m2bVvCNcPwHosy2DrDmTta7ua9zp07x9/+9jf+8Ic/xMd2796N3++no6OD119/nYyMDObPn39P6vzpT39KeXk5U6dO5ezZs+zbt4+SkpJx3Z/DrfW2zz77rN+3YRi/fTqUidCfd+te9miiFBQJ+M1vfjPostTUVKLRKD6fj2g0Gr8o/b9u3rxJUVERjz76KEuXLo2PP/jgg/Gf16xZw5tvvjlqdQ/nsSiBQIC2trY71rl58+a4PVJluI9v+de//sWf//xndu7cyfTp0+Pjfr8fuPXfYsmSJTQ1NY3JP8Lh1Hn7CwDAokWLqKio4Nq1a+P+iJq7eb9Tp07xyCOP9Bsbr306lInQn3fjXvdoonSNYoyEw2FqamoAqKmpYcmSJXes47ouBw8eZNasWTz55JP9lkWj0fjPn3766aj+NvpwHosSDoepra3FdV0uXLhASkoKPp9vXB+pMpz3amtr46233uLll1/u91TWGzdu0NXVFf/5888/Z/bs2feszqtXr8a/7TY1NdHX18f06dPH/RE1w32/69evc/78+X7LxnOfDmUi9OdwTYQeTZR+M3uMdHZ2UlxcTFtbG8FgkIKCAmzbxnGc+DeLL7/8kt/+9rfMnj07fnh8+zbYAwcO8O2332JZFmlpabzwwguj+s3o7Nmz/OUvf4k/FmXdunV89NFHwK3TXq7rUlFRQWNjIw888ADbtm3joYceGnTuWBmqzoMHD/L3v/89fr769i2G3333HW+99RYAvb29PPLII/e0zr/+9a989NFHeL1eHnjgATZt2sTDDz886NyxNFStANXV1TQ0NJCfnx+fN5779O233+b8+fN0dnaSmprKM888w82bN+M1TpT+HE6tE6VHE6GgEBERI516EhERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESM/i82U10ZTsT6HgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.bar([0, 1], [len(train_labels[train_labels==0]), len(train_labels[train_labels==1])])" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAa4ElEQVR4nO3db0yd9f3/8ec5hzqhB/D8gTUwjKtQEzsU7CGlnZaunu2GdaZxzsxFjdhunZ01lMxMbbaYuTo2pcdQIF0I6zbTGy5moNl3m4adAUbS7WDFWZ0i0hlJWymcU/4I9VC4fjf668kY8CmcAxTq63GL87muzznv6/iur/O5rsOFzbIsCxERkRnYL3UBIiKytCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjJIudQEL4cSJE5e6BCOv10tfX9+lLuOilkudsHxqVZ3za7nUCUu/1qysrBm3aUUhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBhd9PcoamtrOXr0KOnp6VRWVgIwPDxMIBDg9OnTZGRksGfPHpxOJwANDQ0Eg0HsdjulpaUUFBQA0N3dTU1NDdFolMLCQkpLS7HZbIyNjVFdXU13dzepqamUlZWRmZkJQHNzM3/84x8BuPPOO9m8efMCvAUiImJy0RXF5s2beeKJJyaNNTY2kp+fT1VVFfn5+TQ2NgLQ09NDW1sb+/fvZ+/evdTX1zMxMQFAXV0dO3fupKqqilOnTtHR0QFAMBhk5cqVHDhwgK1bt3L48GHgfBi9+OKLPP300zz99NO8+OKLDA8Pz+Ohi4jIbFx0RXH99dfT29s7aSwUCvHkk08CUFJSwpNPPsm9995LKBRi48aNrFixgszMTFatWkVXVxcZGRmMjo6yZs0aADZt2kQoFKKwsJD29na+/e1vA1BcXMxvfvMbLMuio6ODG264IbZSueGGG+jo6ODmm2+ez+MXWXbGv3fHgr/GJwv+CvNjudQJi1Oro+7lBXneuG7hMTAwgMvlAsDlcjE4OAhAOBwmLy8vtp/b7SYcDuNwOPB4PLFxj8dDOByOzbmwzeFwkJKSwtDQ0KTx/36u6TQ1NdHU1ARARUUFXq83nsNaNElJSUu+Rlg+dcLyqXU+6lxO/3OUxbVQ/wbm9V5PM/1VVdNfW51um81mm3bfmcb9fj9+vz/2OJH7qSzGpzVZnubj09pSv9+PLG+J9Na83+spPT2dSCQCQCQSIS0tDTi/Uujv74/tFw6HcbvdU8b7+/txu91T5oyPjzMyMoLT6cTtdk95rgurGBERWTxxBYXP56OlpQWAlpYWioqKYuNtbW2MjY3R29vLyZMnyc3NxeVykZycTGdnJ5Zl0drais/nA2DdunU0NzcDcOTIEdauXYvNZqOgoIC33nqL4eFhhoeHeeutt2LfoBIRkcVz0VNPzz33HO+++y5DQ0P84Ac/4O6772bbtm0EAgGCwSBer5fy8nIAcnJy2LBhA+Xl5djtdrZv347dfj6LduzYQW1tLdFolIKCAgoLCwHYsmUL1dXV7N69G6fTSVlZGQBOp5NvfetbPP744wDcddddsQvbIiKyeGyW6QLCMpXI36PQNQqZyVK5RqEelZkk0qP6exQiIhI3BYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYpSUyOQ//elPBINBbDYbOTk57Nq1i2g0SiAQ4PTp02RkZLBnzx6cTicADQ0NBINB7HY7paWlFBQUANDd3U1NTQ3RaJTCwkJKS0ux2WyMjY1RXV1Nd3c3qamplJWVkZmZmfBBi4jI7MW9ogiHw/zlL3+hoqKCyspKJiYmaGtro7Gxkfz8fKqqqsjPz6exsRGAnp4e2tra2L9/P3v37qW+vp6JiQkA6urq2LlzJ1VVVZw6dYqOjg4AgsEgK1eu5MCBA2zdupXDhw8nfMAiIjI3CZ16mpiYIBqNMj4+TjQaxeVyEQqFKCkpAaCkpIRQKARAKBRi48aNrFixgszMTFatWkVXVxeRSITR0VHWrFmDzWZj06ZNsTnt7e1s3rwZgOLiYo4dO4ZlWYmULCIicxT3qSe32803v/lNHnroIa644gpuvPFGbrzxRgYGBnC5XAC4XC4GBweB8yuQvLy8SfPD4TAOhwOPxxMb93g8hMPh2JwL2xwOBykpKQwNDZGWljaplqamJpqamgCoqKjA6/XGe1h8EvdMudwl0lcXJCUlJfw86lGZyXz06HTiDorh4WFCoRA1NTWkpKSwf/9+WltbZ9x/ppWAaYUw3TabzTZlzO/34/f7Y4/7+vpMpYvEZT76yuv1qj9lwSTSW1lZWTNui/vU09tvv01mZiZpaWkkJSWxfv16Ojs7SU9PJxKJABCJRGKf/j0eD/39/bH54XAYt9s9Zby/vx+32z1lzvj4OCMjI7EL4yIisjjiDgqv18sHH3zAZ599hmVZvP3222RnZ+Pz+WhpaQGgpaWFoqIiAHw+H21tbYyNjdHb28vJkyfJzc3F5XKRnJxMZ2cnlmXR2tqKz+cDYN26dTQ3NwNw5MgR1q5dO+2KQkREFk7cp57y8vIoLi7mxz/+MQ6Hg2uuuQa/38/Zs2cJBAIEg0G8Xi/l5eUA5OTksGHDBsrLy7Hb7Wzfvh27/XxO7dixg9raWqLRKAUFBRQWFgKwZcsWqqur2b17N06nk7KyssSPWERE5sRmXYZfIzpx4kTcc8e/d8c8ViKXE0fdywk/x3xco1CPykwS6dEFuUYhIiKfDwoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMQoKZHJn376KQcPHuTjjz/GZrPx0EMPkZWVRSAQ4PTp02RkZLBnzx6cTicADQ0NBINB7HY7paWlFBQUANDd3U1NTQ3RaJTCwkJKS0ux2WyMjY1RXV1Nd3c3qamplJWVkZmZmfBBi4jI7CW0ojh06BAFBQU899xzPPPMM2RnZ9PY2Eh+fj5VVVXk5+fT2NgIQE9PD21tbezfv5+9e/dSX1/PxMQEAHV1dezcuZOqqipOnTpFR0cHAMFgkJUrV3LgwAG2bt3K4cOHEzpYERGZu7iDYmRkhH//+99s2bIFgKSkJFauXEkoFKKkpASAkpISQqEQAKFQiI0bN7JixQoyMzNZtWoVXV1dRCIRRkdHWbNmDTabjU2bNsXmtLe3s3nzZgCKi4s5duwYlmUlcrwiIjJHcZ966u3tJS0tjdraWj766CNWr17NAw88wMDAAC6XCwCXy8Xg4CAA4XCYvLy82Hy32004HMbhcODxeGLjHo+HcDgcm3Nhm8PhICUlhaGhIdLS0ibV0tTURFNTEwAVFRV4vd54D4tP4p4pl7tE+uqCpKSkhJ9HPSozmY8enU7cQTE+Ps7x48d58MEHycvL49ChQ7HTTNOZaSVgWiFMt81ms00Z8/v9+P3+2OO+vj5D5SLxmY++8nq96k9ZMIn0VlZW1ozb4j715PF48Hg8sVVCcXExx48fJz09nUgkAkAkEol9+vd4PPT398fmh8Nh3G73lPH+/n7cbveUOePj44yMjMQujIuIyOKIOyiuuuoqPB4PJ06cAODtt9/mS1/6Ej6fj5aWFgBaWlooKioCwOfz0dbWxtjYGL29vZw8eZLc3FxcLhfJycl0dnZiWRatra34fD4A1q1bR3NzMwBHjhxh7dq1064oRERk4ST09dgHH3yQqqoqzp07R2ZmJrt27cKyLAKBAMFgEK/XS3l5OQA5OTls2LCB8vJy7HY727dvx24/n1M7duygtraWaDRKQUEBhYWFAGzZsoXq6mp2796N0+mkrKwssaMVEZE5s1mX4deILqxy4jH+vTvmsRK5nDjqXk74OebjGoV6VGaSSI8uyDUKERH5fFBQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGSYk+wcTEBI899hhut5vHHnuM4eFhAoEAp0+fJiMjgz179uB0OgFoaGggGAxit9spLS2loKAAgO7ubmpqaohGoxQWFlJaWorNZmNsbIzq6mq6u7tJTU2lrKyMzMzMREsWEZE5SHhF8ec//5ns7OzY48bGRvLz86mqqiI/P5/GxkYAenp6aGtrY//+/ezdu5f6+nomJiYAqKurY+fOnVRVVXHq1Ck6OjoACAaDrFy5kgMHDrB161YOHz6caLkiIjJHCQVFf38/R48e5dZbb42NhUIhSkpKACgpKSEUCsXGN27cyIoVK8jMzGTVqlV0dXURiUQYHR1lzZo12Gw2Nm3aFJvT3t7O5s2bASguLubYsWNYlpVIySIiMkcJnXr67W9/y7333svo6GhsbGBgAJfLBYDL5WJwcBCAcDhMXl5ebD+32004HMbhcODxeGLjHo+HcDgcm3Nhm8PhICUlhaGhIdLS0ibV0dTURFNTEwAVFRV4vd64j+mTuGfK5S6RvrogKSkp4edRj8pM5qNHpxN3ULzxxhukp6ezevVq3nnnnYvuP9NKwLRCmG6bzWabMub3+/H7/bHHfX19F61HZK7mo6+8Xq/6UxZMIr2VlZU147a4g+L999+nvb2dN998k2g0yujoKFVVVaSnpxOJRHC5XEQikdinf4/HQ39/f2x+OBzG7XZPGe/v78ftdk+a4/F4GB8fZ2RkJHZhXEREFkfc1yi++93vcvDgQWpqaigrK+MrX/kKjzzyCD6fj5aWFgBaWlooKioCwOfz0dbWxtjYGL29vZw8eZLc3FxcLhfJycl0dnZiWRatra34fD4A1q1bR3NzMwBHjhxh7dq1064oRERk4ST89dj/tW3bNgKBAMFgEK/XS3l5OQA5OTls2LCB8vJy7HY727dvx24/n1M7duygtraWaDRKQUEBhYWFAGzZsoXq6mp2796N0+mkrKxsvssVEZGLsFmX4deITpw4Effc8e/dMY+VyOXEUfdyws8xH9co1KMyk0R61HSNQr+ZLSIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMQoKd6JfX191NTUcObMGWw2G36/n9tuu43h4WECgQCnT58mIyODPXv24HQ6AWhoaCAYDGK32yktLaWgoACA7u5uampqiEajFBYWUlpais1mY2xsjOrqarq7u0lNTaWsrIzMzMx5OXAREZmduFcUDoeD++67j0AgwL59+3jllVfo6emhsbGR/Px8qqqqyM/Pp7GxEYCenh7a2trYv38/e/fupb6+nomJCQDq6urYuXMnVVVVnDp1io6ODgCCwSArV67kwIEDbN26lcOHDyd8wCIiMjdxB4XL5WL16tUAJCcnk52dTTgcJhQKUVJSAkBJSQmhUAiAUCjExo0bWbFiBZmZmaxatYquri4ikQijo6OsWbMGm83Gpk2bYnPa29vZvHkzAMXFxRw7dgzLshI5XhERmaN5uUbR29vL8ePHyc3NZWBgAJfLBZwPk8HBQQDC4TAejyc2x+12Ew6Hp4x7PB7C4fCUOQ6Hg5SUFIaGhuajZBERmaW4r1FccPbsWSorK3nggQdISUmZcb+ZVgKmFcJ022w225SxpqYmmpqaAKioqMDr9V6s7Bl9EvdMudwl0lcXJCUlJfw86lGZyXz06HQSCopz585RWVnJLbfcwvr16wFIT08nEongcrmIRCKkpaUB51cK/f39sbnhcBi32z1lvL+/H7fbPWmOx+NhfHyckZGR2IXx/+b3+/H7/bHHfX19iRyWyLTmo6+8Xq/6UxZMIr2VlZU147a4Tz1ZlsXBgwfJzs7m9ttvj437fD5aWloAaGlpoaioKDbe1tbG2NgYvb29nDx5ktzcXFwuF8nJyXR2dmJZFq2trfh8PgDWrVtHc3MzAEeOHGHt2rXTrihERGThxL2ieP/992ltbeXqq6/m0UcfBeCee+5h27ZtBAIBgsEgXq+X8vJyAHJyctiwYQPl5eXY7Xa2b9+O3X4+p3bs2EFtbS3RaJSCggIKCwsB2LJlC9XV1ezevRun00lZWVmChysiInNlsy7DrxGdOHEi7rnj37tjHiuRy4mj7uWEn2M+Tj2pR2UmifTogpx6EhGRzwcFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIiRgkJERIwUFCIiYqSgEBERIwWFiIgYKShERMRIQSEiIkYKChERMVJQiIiIkYJCRESMFBQiImKkoBARESMFhYiIGCkoRETESEEhIiJGCgoRETFSUIiIiJGCQkREjBQUIiJipKAQERGjpEtdwGx0dHRw6NAhJiYmuPXWW9m2bdulLklE5HNjya8oJiYmqK+v54knniAQCPD666/T09NzqcsSEfncWPJB0dXVxapVq/jiF79IUlISGzduJBQKXeqyREQ+N5b8qadwOIzH44k99ng8fPDBB5P2aWpqoqmpCYCKigqysrLif8H/a49/rsgsJNSfoB6VRbfkVxSWZU0Zs9lskx77/X4qKiqoqKhYrLIS8thjj13qEmZludQJy6dW1Tm/lkudsLxq/V9LPig8Hg/9/f2xx/39/bhcrktYkYjI58uSD4prr72WkydP0tvby7lz52hra8Pn813qskREPjeW/DUKh8PBgw8+yL59+5iYmOBrX/saOTk5l7qshPj9/ktdwqwslzph+dSqOufXcqkTllet/8tmTXcRQERE5P9b8qeeRETk0lJQiIiI0ZK/RrFcDQ8PEwgEOH36NBkZGezZswen0zlpn76+Pmpqajhz5gw2mw2/389tt90GwB/+8Af+9re/kZaWBsA999zDTTfdNG/1Xey2KJZlcejQId58802+8IUvsGvXLlavXj2rufPpYq/12muv8dJLLwFw5ZVXsmPHDq655hoAfvjDH3LllVdit9txOBwL+vXpi9X5zjvv8Ktf/YrMzEwA1q9fz1133TWruYtd68svv8xrr70GnL8zQk9PD/X19TidzkV7T2trazl69Cjp6elUVlZO2b5U+nM2tS6VHk2IJQvi+eeftxoaGizLsqyGhgbr+eefn7JPOBy2PvzwQ8uyLGtkZMR65JFHrI8//tiyLMt64YUXrJdeemlBahsfH7cefvhh69SpU9bY2Jj1ox/9KPa6F7zxxhvWvn37rImJCev999+3Hn/88VnPXcw633vvPWtoaMiyLMs6evRorE7Lsqxdu3ZZAwMDC1LbXOs8duyY9Ytf/CKuuYtd638LhULWk08+GXu8WO/pO++8Y3344YdWeXn5tNuXQn/Ottal0KOJ0qmnBRIKhSgpKQGgpKRk2tuOuFyu2Keg5ORksrOzCYfDC17bbG6L0t7ezqZNm7DZbKxZs4ZPP/2USCSyqLdUmc1rXXfddbGVWl5e3qTfuVksibwni32Lmrm+3uuvv85Xv/rVBatnJtdff/2UFfh/Wwr9Odtal0KPJkqnnhbIwMBA7BcDXS4Xg4ODxv17e3s5fvw4ubm5sbFXXnmF1tZWVq9ezf33329sxrmYzW1RwuEwXq930j7hcHhWc+fLXF8rGAxSWFg4aWzfvn0AfP3rX1+wryfOts7Ozk4effRRXC4X9913Hzk5OYv6fs6lVoDPPvuMjo4Otm/fPml8Md7Ti1kK/RmPS9WjiVJQJOCpp57izJkzU8a/853vzOl5zp49S2VlJQ888AApKSkAfOMb34idw37hhRf4/e9/z65duxKuGWZ3W5SZ9pnN3Pkyl9c6duwYf//73/nZz34WG3vqqadwu90MDAzw85//nKysLK6//vpLUueXv/xlamtrufLKKzl69CjPPPMMVVVVi/p+zrbWC954441Jn4Zh8d7Ti1kK/TlXl7JHE6WgSMBPfvKTGbelp6cTiURwuVxEIpHYRen/de7cOSorK7nllltYv359bPyqq66K/Xzrrbfyy1/+ct7qns1tUTweD319fVP2OXfu3KLdUmW2t2/56KOP+PWvf83jjz9OampqbNztdgPn/1sUFRXR1dW1IP8IZ1PnhQ8AADfddBP19fUMDg4u+i1q5vJ6r7/+OjfffPOkscV6Ty9mKfTnXFzqHk2UrlEsEJ/PR0tLCwAtLS0UFRVN2ceyLA4ePEh2dja33377pG2RSCT28z//+c95/W302dwWxefz0draimVZdHZ2kpKSgsvlWtRbqszmtfr6+nj22Wd5+OGHJ92V9ezZs4yOjsZ+/te//sXVV199yeo8c+ZM7NNuV1cXExMTpKamLvotamb7eiMjI7z77ruTti3me3oxS6E/Z2sp9Gii9JvZC2RoaIhAIEBfXx9er5fy8nKcTifhcDj2yeK9997jpz/9KVdffXVseXzha7AHDhzgP//5DzabjYyMDL7//e/P6yejo0eP8rvf/S52W5Q777yTV199FTh/2suyLOrr63nrrbe44oor2LVrF9dee+2McxfKxeo8ePAg//jHP2Lnqy98xfCTTz7h2WefBWB8fJybb775ktb517/+lVdffRWHw8EVV1zB/fffz3XXXTfj3IV0sVoBmpub6ejooKysLDZvMd/T5557jnfffZehoSHS09O5++67OXfuXKzGpdKfs6l1qfRoIhQUIiJipFNPIiJipKAQEREjBYWIiBgpKERExEhBISIiRgoKERExUlCIiIjR/wNdJQ5lDA6CbwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.bar([0, 1], [len(test_labels[test_labels==0]), len(test_labels[test_labels==1])])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Naive DL implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_history(history):\n", + " acc = history.history['acc']\n", + " val_acc = history.history['val_acc']\n", + " loss = history.history['loss']\n", + " val_loss = history.history['val_loss']\n", + " x = range(1, len(acc) + 1)\n", + "\n", + " plt.figure(figsize=(12, 5))\n", + " plt.subplot(1, 2, 1)\n", + " plt.plot(x, acc, 'b', label='Training acc')\n", + " plt.plot(x, val_acc, 'r', label='Validation acc')\n", + " plt.title('Training and validation accuracy')\n", + " plt.legend()\n", + " plt.subplot(1, 2, 2)\n", + " plt.plot(x, loss, 'b', label='Training loss')\n", + " plt.plot(x, val_loss, 'r', label='Validation loss')\n", + " plt.title('Training and validation loss')\n", + " plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bb27037fede14648968b62c4e87937a5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7857701f5a424c5a8fdb89f115a51cca", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "53ac9841498f4e7fbaaf012990bd8911", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=200000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "464670\n", + "WARNING:tensorflow:From /usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "If using Keras pass *_constraint arguments to layers.\n", + "WARNING:tensorflow:From /usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "dense (Dense) (None, 64) 29738944 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 1) 65 \n", + "=================================================================\n", + "Total params: 29,739,009\n", + "Trainable params: 29,739,009\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "\n", + "vectorizer = CountVectorizer()\n", + "vectorizer.fit(tqdm(train_texts_stammed))\n", + "train_X = vectorizer.transform(tqdm(train_texts_stammed))\n", + "test_X = vectorizer.transform(tqdm(test_texts_stemmed))\n", + "\n", + "input_dim = train_X.shape[1]\n", + "print(input_dim)\n", + "\n", + "model = keras.models.Sequential([\n", + " keras.layers.Dense(64, input_dim=input_dim, activation='relu'),\n", + " keras.layers.Dense(1, activation='sigmoid')\n", + "])\n", + "\n", + "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mvalidation_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_X\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m batch_size=32)\n\u001b[0m", + "\u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mmax_queue_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_queue_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0mworkers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mworkers\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m use_multiprocessing=use_multiprocessing)\n\u001b[0m\u001b[1;32m 728\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 729\u001b[0m def evaluate(self,\n", + "\u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)\u001b[0m\n\u001b[1;32m 641\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0mvalidation_split\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalidation_split\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 643\u001b[0;31m shuffle=shuffle)\n\u001b[0m\u001b[1;32m 644\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 645\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvalidation_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py\u001b[0m in \u001b[0;36m_standardize_user_data\u001b[0;34m(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)\u001b[0m\n\u001b[1;32m 2487\u001b[0m \u001b[0mconverted_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2488\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflat_expected_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2489\u001b[0;31m \u001b[0mconverted_x\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_convert_scipy_sparse_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2490\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpack_sequence_as\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconverted_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpand_composites\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2491\u001b[0m \u001b[0mx_shapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap_structure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype_spec_from_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py\u001b[0m in \u001b[0;36m_convert_scipy_sparse_tensor\u001b[0;34m(value, expected_input)\u001b[0m\n\u001b[1;32m 3233\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0missparse\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0missparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3234\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_dense_tensor_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3235\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3236\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3237\u001b[0m \u001b[0msparse_coo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtocoo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/scipy/sparse/compressed.py\u001b[0m in \u001b[0;36mtoarray\u001b[0;34m(self, order, out)\u001b[0m\n\u001b[1;32m 1034\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1035\u001b[0m \u001b[0mM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mN\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_swap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1036\u001b[0;31m \u001b[0mcsr_todense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindptr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1037\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1038\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "history = model.fit(train_X, train_labels, \n", + " validation_data=(test_X, test_labels),\n", + " epochs=3,\n", + " batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_history(history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model with Embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "555f313ae4bc45b88d702ad2c6061ece", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=1000000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from tensorflow.keras.preprocessing.text import Tokenizer\n", + "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", + "\n", + "tokenizer = Tokenizer(num_words=50000)\n", + "tokenizer.fit_on_texts(tqdm(train_texts))\n", + "\n", + "vocab_size = len(tokenizer.word_index) + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "X_train = tokenizer.texts_to_sequences(tqdm(train_texts))\n", + "X_test = tokenizer.texts_to_sequences(tqdm(test_texts))" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [], + "source": [ + "max_len = 512\n", + "\n", + "X_train = pad_sequences(X_train, padding='post', maxlen=max_len)\n", + "X_test = pad_sequences(X_test, padding='post', maxlen=max_len)" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:From /usr/local/Caskroom/miniconda/base/envs/bertsum/lib/python3.7/site-packages/tensorflow_core/python/keras/initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Call initializer instance with the dtype argument instead of passing it to the constructor\n", + "Model: \"sequential_2\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "embedding (Embedding) (None, 512, 50) 20122500 \n", + "_________________________________________________________________\n", + "flatten (Flatten) (None, 25600) 0 \n", + "_________________________________________________________________\n", + "dense_6 (Dense) (None, 128) 3276928 \n", + "_________________________________________________________________\n", + "dense_7 (Dense) (None, 128) 16512 \n", + "_________________________________________________________________\n", + "dense_8 (Dense) (None, 1) 129 \n", + "=================================================================\n", + "Total params: 23,416,069\n", + "Trainable params: 23,416,069\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "embedding_dim = 50\n", + "\n", + "model = keras.models.Sequential([\n", + " keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=max_len),\n", + " keras.layers.Flatten(),\n", + " keras.layers.Dense(64, input_dim=input_dim, activation='relu'),\n", + " keras.layers.Dense(1, activation='sigmoid')\n", + "])\n", + "\n", + "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "history = model.fit(X_train, train_labels, \n", + " validation_data=(X_test, test_labels),\n", + " epochs=30,\n", + " batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_history(history)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "embedding_dim = 50\n", + "\n", + "model = keras.models.Sequential([\n", + " keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=max_len),\n", + " keras.layers.GlobalAveragePooling1D(),\n", + " keras.layers.Dense(64, activation='relu'),\n", + " keras.layers.Dense(1, activation='sigmoid')\n", + "])\n", + "\n", + "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "history = model.fit(X_train, train_labels, \n", + " validation_data=(X_test, test_labels),\n", + " epochs=3,\n", + " batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_history(history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model with pre-trained Embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dyld: Library not loaded: /usr/local/opt/openssl/lib/libssl.1.0.0.dylib\r\n", + " Referenced from: /usr/local/bin/wget\r\n", + " Reason: image not found\r\n" + ] + } + ], + "source": [ + "!wget http://nlp.stanford.edu/data/glove.twitter.27B.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Archive: glove.twitter.27B.zip\n", + " inflating: glove.twitter.27B.25d.txt \n", + " inflating: glove.twitter.27B.50d.txt \n", + " inflating: glove.twitter.27B.100d.txt \n", + " inflating: glove.twitter.27B.200d.txt \n", + "Untitled.ipynb glove.twitter.27B.200d.txt\n", + "\u001b[34marchive-2\u001b[m\u001b[m glove.twitter.27B.25d.txt\n", + "glove.twitter.27B.100d.txt glove.twitter.27B.50d.txt\n" + ] + } + ], + "source": [ + "!unzip glove.twitter.27B.zip\n", + "!rm glove.twitter.27B.zip\n", + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2a992cc357ce4e71978f98466a85dc70", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=1000000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "713947" + ] + }, + "execution_count": 192, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer = Tokenizer(num_words=5000)\n", + "tokenizer.fit_on_texts(tqdm(train_texts))\n", + "\n", + "vocab_size = len(tokenizer.word_index) + 1\n", + "vocab_size" + ] + }, + { + "cell_type": "code", + "execution_count": 229, + "metadata": {}, + "outputs": [], + "source": [ + "vocabulary = set()\n", + "\n", + "def load_embeddings(file_name, total=1):\n", + " weights = np.zeros((vocab_size, embedding_dim))\n", + " \n", + " with open(file_name) as f:\n", + " for l in tqdm(f, total=total):\n", + " word, *vector = l.split()\n", + " word = word.lower()\n", + " vocabulary.add(word)\n", + " \n", + " if word in tokenizer.word_index:\n", + " weights[tokenizer.word_index[word]] = np.array(vector, dtype=np.float32)\n", + " return weights" + ] + }, + { + "cell_type": "code", + "execution_count": 231, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a5dd687b4c7d4a8d8c9c5322cfaf85ad", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=400000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "0.1925829228220022" + ] + }, + "execution_count": 231, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "w = load_embeddings('glove.6B.50d.txt', total=400000)\n", + "\n", + "nonzero_elements = np.count_nonzero(np.count_nonzero(w, axis=1))\n", + "nonzero_elements / vocab_size" + ] + }, + { + "cell_type": "code", + "execution_count": 228, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"sequential_11\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "embedding_8 (Embedding) (None, 512, 50) 35697350 \n", + "_________________________________________________________________\n", + "flatten_8 (Flatten) (None, 25600) 0 \n", + "_________________________________________________________________\n", + "dense_33 (Dense) (None, 64) 1638464 \n", + "_________________________________________________________________\n", + "dense_34 (Dense) (None, 1) 65 \n", + "=================================================================\n", + "Total params: 37,335,879\n", + "Trainable params: 1,638,529\n", + "Non-trainable params: 35,697,350\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "embedding_dim = 50\n", + "\n", + "model = keras.models.Sequential([\n", + " keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, \n", + " input_length=max_len, weights=[w], trainable=False),\n", + " keras.layers.Flatten(),\n", + " keras.layers.Dense(64, activation='relu'),\n", + " keras.layers.Dense(1, activation='sigmoid')\n", + "])\n", + "\n", + "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Transformer (spoiler alert!)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install tqdm==4.47.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install transformers==3.5.0 simpletransformers==0.49.3" + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "metadata": {}, + "outputs": [], + "source": [ + "from simpletransformers.classification import ClassificationModel\n", + "import logging\n", + "\n", + "logging.basicConfig(level=logging.INFO)\n", + "transformers_logger = logging.getLogger(\"transformers\")\n", + "transformers_logger.setLevel(logging.WARNING)" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "metadata": {}, + "outputs": [], + "source": [ + "train_data_df = pd.DataFrame({'text': train_texts[:200000], 'labels': train_labels[:200000]})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "eval_data_df = pd.DataFrame({'text': train_texts[200000:250000], 'labels': train_labels[200000:250000]})" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "metadata": {}, + "outputs": [], + "source": [ + "test_data_df = pd.DataFrame({'text': test_texts, 'labels': test_labels})" + ] + }, + { + "cell_type": "code", + "execution_count": 177, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:filelock:Lock 6910369040 acquired on /Users/dimazhylko/.cache/torch/transformers/51ba668f7ff34e7cdfa9561e8361747738113878850a7d717dbc69de8683aaad.c7efaa30a0d80b2958b876969faa180e485944a849deee4ad482332de65365a7.lock\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ec33c6e53af7411aadc0fa2dc7964ce5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501200538.0, style=ProgressStyle(descri…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:filelock:Lock 6910369040 released on /Users/dimazhylko/.cache/torch/transformers/51ba668f7ff34e7cdfa9561e8361747738113878850a7d717dbc69de8683aaad.c7efaa30a0d80b2958b876969faa180e485944a849deee4ad482332de65365a7.lock\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n", + "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "INFO:filelock:Lock 5249639248 acquired on /Users/dimazhylko/.cache/torch/transformers/d3ccdbfeb9aaa747ef20432d4976c32ee3fa69663b379deb253ccfce2bb1fdc5.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a18b1229ff9c4637ae0ea7f94d24e8f5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:filelock:Lock 5249639248 released on /Users/dimazhylko/.cache/torch/transformers/d3ccdbfeb9aaa747ef20432d4976c32ee3fa69663b379deb253ccfce2bb1fdc5.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:filelock:Lock 5249686544 acquired on /Users/dimazhylko/.cache/torch/transformers/cafdecc90fcab17011e12ac813dd574b4b3fea39da6dd817813efa010262ff3f.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "36435f2eb8d84fad9779daab0af6d0e0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:filelock:Lock 5249686544 released on /Users/dimazhylko/.cache/torch/transformers/cafdecc90fcab17011e12ac813dd574b4b3fea39da6dd817813efa010262ff3f.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model = ClassificationModel('roberta', 'roberta-base', use_cuda=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "del train_texts\n", + "del train_labels\n", + "del test_texts\n", + "del test_labels" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.train_model(train_data_df, eval_df=eval_data_df, \n", + " args={\"num_train_epochs\": 1, 'evaluate_during_training': True,\n", + " 'learning_rate': 5e-5, 'train_batch_size': 32, 'eval_batch_size': 32, 'gradient_accumulation_steps': 1, \n", + " 'use_multipprocessing': False, 'fp16': True, 'lazy_loading': False, 'reprocess_input_data': False\n", + " 'save_steps': 7000}, acc=accuracy_score)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.args.reprocess_input_data = True\n", + "model.eval_model(test_data_df, acc=accuracy_score)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}