diff --git a/CHANGELOG.md b/CHANGELOG.md
index ac9a48cf..f38d4099 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [Unreleased]
+
+### Added
+- Tutorial for making custom model inherited from ModelBase ([#236](https://github.com/MobileTeleSystems/RecTools/pull/236))
## [0.9.0] - 11.12.2024
diff --git a/examples/10_custom_model_creation.ipynb b/examples/10_custom_model_creation.ipynb
new file mode 100644
index 00000000..7ad943f4
--- /dev/null
+++ b/examples/10_custom_model_creation.ipynb
@@ -0,0 +1,896 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "19d09e64-aa80-47e1-9c8b-a5a24564bee7",
+ "metadata": {},
+ "source": [
+ "# Example of building custom model with ModelBase class\n",
+ "\n",
+ "- Building custom model\n",
+ "- Visual recommendations checking"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "dce07a5b-2716-41c5-8358-63f590dd69f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from rectools.models.base import ModelBase, ModelConfig, Scores, ScoresArray\n",
+ "from rectools import Columns\n",
+ "from rectools.dataset import Dataset\n",
+ "from scipy.sparse import csr_matrix\n",
+ "from sklearn.neighbors import NearestNeighbors\n",
+ "import typing as tp\n",
+ "import typing_extensions as tpe\n",
+ "from rectools.models.base import InternalIdsArray\n",
+ "from rectools.types import *\n",
+ "from tqdm import tqdm\n",
+ "from rectools.utils import fast_isin_for_sorted_test_elements\n",
+ "from rectools.models.utils import get_viewed_item_ids\n",
+ "from rectools.visuals.visual_app import ItemToItemVisualApp, VisualApp\n",
+ "import random"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c05f581d-c60b-40c1-8e52-368e1d97ab36",
+ "metadata": {},
+ "source": [
+ "## Load data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "7814e510-3179-44f0-b116-a30b670fa72e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Archive: ml-1m.zip\n",
+ " inflating: ml-1m/movies.dat \n",
+ " inflating: ml-1m/ratings.dat \n",
+ " inflating: ml-1m/README \n",
+ " inflating: ml-1m/users.dat \n",
+ "CPU times: user 27.2 ms, sys: 23 ms, total: 50.2 ms\n",
+ "Wall time: 3.21 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "!wget -q https://files.grouplens.org/datasets/movielens/ml-1m.zip -O ml-1m.zip\n",
+ "!unzip -o ml-1m.zip\n",
+ "!rm ml-1m.zip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "884ebb14-fcae-4bbc-9221-299b613489f0",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1000209, 4)\n",
+ "CPU times: user 2.04 s, sys: 93.3 ms, total: 2.13 s\n",
+ "Wall time: 2.13 s\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " weight | \n",
+ " datetime | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1193 | \n",
+ " 5 | \n",
+ " 978300760 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 661 | \n",
+ " 3 | \n",
+ " 978302109 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 914 | \n",
+ " 3 | \n",
+ " 978301968 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 3408 | \n",
+ " 4 | \n",
+ " 978300275 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 2355 | \n",
+ " 5 | \n",
+ " 978824291 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id weight datetime\n",
+ "0 1 1193 5 978300760\n",
+ "1 1 661 3 978302109\n",
+ "2 1 914 3 978301968\n",
+ "3 1 3408 4 978300275\n",
+ "4 1 2355 5 978824291"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "ratings = pd.read_csv(\n",
+ " \"ml-1m/ratings.dat\", \n",
+ " sep=\"::\",\n",
+ " engine=\"python\", # Because of 2-chars separators\n",
+ " header=None,\n",
+ " names=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],\n",
+ ")\n",
+ "print(ratings.shape)\n",
+ "ratings.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "332cd6dd-d993-47b9-baec-c601b89fba12",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(3883, 3)\n",
+ "CPU times: user 4.71 ms, sys: 792 μs, total: 5.5 ms\n",
+ "Wall time: 4.99 ms\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " title | \n",
+ " genres | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " Toy Story (1995) | \n",
+ " Animation|Children's|Comedy | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " Jumanji (1995) | \n",
+ " Adventure|Children's|Fantasy | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " Grumpier Old Men (1995) | \n",
+ " Comedy|Romance | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " Waiting to Exhale (1995) | \n",
+ " Comedy|Drama | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " Father of the Bride Part II (1995) | \n",
+ " Comedy | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id title genres\n",
+ "0 1 Toy Story (1995) Animation|Children's|Comedy\n",
+ "1 2 Jumanji (1995) Adventure|Children's|Fantasy\n",
+ "2 3 Grumpier Old Men (1995) Comedy|Romance\n",
+ "3 4 Waiting to Exhale (1995) Comedy|Drama\n",
+ "4 5 Father of the Bride Part II (1995) Comedy"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "movies = pd.read_csv(\n",
+ " \"ml-1m/movies.dat\", \n",
+ " sep=\"::\",\n",
+ " engine=\"python\", # Because of 2-chars separators\n",
+ " header=None,\n",
+ " names=[Columns.Item, \"title\", \"genres\"],\n",
+ " encoding_errors=\"ignore\",\n",
+ ")\n",
+ "print(movies.shape)\n",
+ "movies.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3795f0e2-ac3e-4c89-a901-2988522d0629",
+ "metadata": {},
+ "source": [
+ "## Build model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d3f41d9e-b6aa-4a66-b681-2b504c80fbd0",
+ "metadata": {},
+ "source": [
+ "### Write a model config inherited from `ModelConfig`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "95f01cdf-b126-4035-a5b6-867d45aeaa99",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class MixedKnnRandomModelConfig(ModelConfig):\n",
+ " \"\"\"Config for `KNN` model.\"\"\"\n",
+ "\n",
+ " # KNN algorithm hyperparams\n",
+ " metric: tp.Optional[str] = None\n",
+ " algorithm: tp.Optional[str] = None\n",
+ " n_neighbors: tp.Optional[int] = None\n",
+ " n_jobs: tp.Optional[int] = None\n",
+ " random_state: tp.Optional[int] = None"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6eecbbd5",
+ "metadata": {},
+ "source": [
+ "### Define a `_RandomSampler` and `_RandomGen` class for random recommendations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "60255323",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class _RandomGen:\n",
+ " def __init__(self, random_state: tp.Optional[int] = None) -> None:\n",
+ " self.python_gen = random.Random(random_state) # nosec\n",
+ " self.np_gen = np.random.default_rng(random_state)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "e3283dfd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class _RandomSampler:\n",
+ " def __init__(self, values: np.ndarray, random_gen: _RandomGen) -> None:\n",
+ " self.python_gen = random_gen.python_gen\n",
+ " self.np_gen = random_gen.np_gen\n",
+ " self.values = values\n",
+ " self.values_list = list(values) # for random.sample\n",
+ "\n",
+ " def sample(self, n: int) -> np.ndarray:\n",
+ " if n < 25: # Empiric value, for optimization\n",
+ " sampled = np.asarray(self.python_gen.sample(self.values_list, n))\n",
+ " else:\n",
+ " sampled = self.np_gen.choice(self.values, n, replace=False)\n",
+ " return sampled\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9cf2011c-6b5a-4707-a640-333664cc85b5",
+ "metadata": {},
+ "source": [
+ "### Write a model logic in class inherited from `ModelBase`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "4f2cb8cf-17c3-4f0e-a979-4411a35af283",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class MixedKnnRandomModel(ModelBase[MixedKnnRandomModelConfig]):\n",
+ " # There is a sample recsys model inherited from ModelBase, model is mixed KNN wrapper (i2i) and random model (u2i)\n",
+ " # You can mix other models as well.\n",
+ " # Define is able to make cold and warm recommendations\n",
+ " # Set config class to defined above\n",
+ " recommends_for_warm = False\n",
+ " recommends_for_cold = False\n",
+ " config_class = MixedKnnRandomModelConfig\n",
+ "\n",
+ " # Set all hyperparams in __init__\n",
+ " def __init__(self,\n",
+ " metric: tp.Optional[str] = None,\n",
+ " algorithm: tp.Optional[str] = None,\n",
+ " n_neighbors: tp.Optional[int] = None,\n",
+ " n_jobs: tp.Optional[int] = None,\n",
+ " random_state: tp.Optional[int] = None,\n",
+ " verbose: int = 0):\n",
+ " super().__init__(verbose=verbose)\n",
+ " self.metric = metric\n",
+ " self.algorithm = algorithm\n",
+ " self.n_neighbors = n_neighbors\n",
+ " self.n_jobs = n_jobs\n",
+ " self.knn_model = NearestNeighbors(metric = self.metric,\n",
+ " algorithm = self.algorithm,\n",
+ " n_neighbors = self.n_neighbors,\n",
+ " n_jobs = self.n_jobs)\n",
+ " self.random_state = random_state\n",
+ " self.random_gen = _RandomGen(random_state)\n",
+ " self.all_item_ids: np.ndarray\n",
+ " self.ui_csr: csr_matrix\n",
+ "\n",
+ " # Method used to save hyperparams in config\n",
+ " def _get_config(self) -> MixedKnnRandomModelConfig:\n",
+ " return MixedKnnRandomModelConfig(metric=self.metric, algorithm=self.algorithm, n_neighbors=self.n_neighbors, random_state=self.random_state, verbose=self.verbose)\n",
+ "\n",
+ " # Method used to load model params from config\n",
+ " @classmethod\n",
+ " def _from_config(cls, config: MixedKnnRandomModelConfig) -> tpe.Self:\n",
+ " return cls(metric=config.metric, algorithm=config.algorithm, n_neighbors=config.n_neighbors, random_state=config.random_state, verbose=config.verbose)\n",
+ "\n",
+ " # Method used to fit model, there is a sklearn KNN wrapper, so we need to fit KNN model with dataset csr matrix\n",
+ " def _fit(self, dataset: Dataset) -> None: # type: ignore\n",
+ " self.all_item_ids = dataset.item_id_map.internal_ids\n",
+ " self.ui_csr = dataset.get_user_item_matrix(include_weights=False, dtype=np.float64)\n",
+ " self.knn_model.fit(self.ui_csr)\n",
+ "\n",
+ " # Method used to make item-item recommendations, not for direct invokation, used in recommend_to_items method of base class\n",
+ " # Params:\n",
+ " # target_ids - InternalIdsArray of item ids for which predictions need to be made\n",
+ " # dataset - instance of Dataset class\n",
+ " # k - maximum count of top rated elements presented in recommendations\n",
+ " # sorted_item_ids_to_recommend - optional InternalIdsArray of item ids from which predictions are made\n",
+ " # Returns:\n",
+ " # Equaly sized arrays of target ids, predictions ids, scores\n",
+ " # in this method you need to ensure, that your realization handles all parameters correctly i.e. \n",
+ " # it can limit k predictions and limit the set of allowed items.\n",
+ " def _recommend_i2i(self,\n",
+ " target_ids: InternalIdsArray,\n",
+ " dataset: Dataset,\n",
+ " k: int,\n",
+ " sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray]) -> tp.Tuple[InternalIds, InternalIds, Scores]:\n",
+ " sorted_item_ids_to_recommend = dataset.get_user_item_matrix(include_weights=False,\n",
+ " dtype=np.float64)[sorted_item_ids_to_recommend] if sorted_item_ids_to_recommend is not None else self.all_item_ids\n",
+ "\n",
+ " all_target_ids = []\n",
+ " all_reco_ids: tp.List[np.ndarray] = []\n",
+ " all_scores: tp.List[np.ndarray] = []\n",
+ " for target_id in tqdm(target_ids, disable=self.verbose == 0):\n",
+ " reco_scores, reco_ids = self.knn_model.kneighbors(self.ui_csr[target_id], n_neighbors = k + 1)\n",
+ " all_target_ids.extend([target_id] * len(reco_ids.tolist()[0]))\n",
+ " all_reco_ids.extend(reco_ids.tolist())\n",
+ " all_scores.extend(reco_scores.tolist())\n",
+ "\n",
+ " all_target_ids = np.array(all_target_ids) \n",
+ " all_reco_ids_arr = np.concatenate(all_reco_ids)\n",
+ " all_reco_scores_array = np.concatenate(all_scores)\n",
+ " valid_indices = all_reco_ids_arr < len(sorted_item_ids_to_recommend)\n",
+ "\n",
+ " all_reco_ids_arr = all_reco_ids_arr[valid_indices]\n",
+ " all_target_ids = all_target_ids[valid_indices]\n",
+ " all_reco_scores_array = all_reco_scores_array[valid_indices]\n",
+ " \n",
+ " if sorted_item_ids_to_recommend is not None:\n",
+ " items_indeces = np.isin(all_reco_ids_arr, sorted_item_ids_to_recommend)\n",
+ " all_reco_ids_arr = all_reco_ids_arr[items_indeces]\n",
+ " all_target_ids = all_target_ids[items_indeces]\n",
+ " all_reco_scores_array = all_reco_scores_array[items_indeces]\n",
+ "\n",
+ " return all_target_ids, all_reco_ids_arr, all_reco_scores_array\n",
+ " \n",
+ " # Method used to make user-item recommendations, not for direct invokation, used in recommend method of base class\n",
+ " # Params:\n",
+ " # target_ids - InternalIdsArray of user ids for which predictions need to be made\n",
+ " # dataset - instance of Dataset class\n",
+ " # k - maximum count of top rated elements presented in recommendations\n",
+ " # sorted_item_ids_to_recommend - optional InternalIdsArray of item ids from which predictions are made\n",
+ " # Returns:\n",
+ " # Equaly sized arrays of target ids, predictions ids, scores\n",
+ " # in this method you need to ensure, that your realization handles all parameters correctly i.e. \n",
+ " # it can limit k predictions and limit the set of allowed items.\n",
+ " def _recommend_u2i(\n",
+ " self,\n",
+ " user_ids: InternalIdsArray,\n",
+ " dataset: Dataset,\n",
+ " k: int,\n",
+ " filter_viewed: bool,\n",
+ " sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],\n",
+ " ) -> tp.Tuple[InternalIds, InternalIds, Scores]:\n",
+ " if filter_viewed:\n",
+ " user_items = dataset.get_user_item_matrix(include_weights=False)\n",
+ "\n",
+ " item_ids = sorted_item_ids_to_recommend if sorted_item_ids_to_recommend is not None else self.all_item_ids\n",
+ " sampler = _RandomSampler(item_ids, self.random_gen)\n",
+ "\n",
+ " all_user_ids = []\n",
+ " all_reco_ids: tp.List[InternalId] = []\n",
+ " all_scores: tp.List[float] = []\n",
+ " for user_id in tqdm(user_ids, disable=self.verbose == 0):\n",
+ " if filter_viewed:\n",
+ " viewed_ids = get_viewed_item_ids(user_items, user_id) # sorted\n",
+ " n_reco = k + viewed_ids.size\n",
+ " else:\n",
+ " n_reco = k\n",
+ "\n",
+ " n_reco = min(n_reco, item_ids.size)\n",
+ " reco_ids = sampler.sample(n_reco)\n",
+ "\n",
+ " if filter_viewed:\n",
+ " reco_ids = reco_ids[fast_isin_for_sorted_test_elements(reco_ids, viewed_ids, invert=True)][:k]\n",
+ "\n",
+ " reco_scores = np.arange(reco_ids.size, 0, -1)\n",
+ "\n",
+ " all_user_ids.extend([user_id] * len(reco_ids))\n",
+ " all_reco_ids.extend(reco_ids.tolist())\n",
+ " all_scores.extend(reco_scores.tolist())\n",
+ "\n",
+ " return all_user_ids, all_reco_ids, all_scores"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "e8757369-bec2-48f2-a876-a359d444d706",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = MixedKnnRandomModel(metric=\"cosine\", algorithm=\"brute\", n_neighbors=20, n_jobs=-1, random_state=20)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "53750dfe-9b46-4b8d-9bfe-e0c9caa41770",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = Dataset.construct(ratings)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "0c279b6f-2287-4198-b958-f00729fabadd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<__main__.MixedKnnRandomModel at 0x7b349d4f63e0>"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.fit(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ac06ff1f-d1c2-4fcd-b328-70197e2b02a5",
+ "metadata": {},
+ "source": [
+ "## Use model to recommend similar items"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "5c373d0f-bf04-4256-81c6-35f1f124f118",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reco = model.recommend_to_items([1,7,6,2,3,5], dataset, 10)\n",
+ "reco[Columns.Model] = \"KnnCustomModel\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "591c59da",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " target_item_id | \n",
+ " item_id | \n",
+ " score | \n",
+ " rank | \n",
+ " model | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1895 | \n",
+ " 0.749783 | \n",
+ " 1 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1989 | \n",
+ " 0.750385 | \n",
+ " 2 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 458 | \n",
+ " 0.759902 | \n",
+ " 3 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 1906 | \n",
+ " 0.759902 | \n",
+ " 4 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 877 | \n",
+ " 0.768270 | \n",
+ " 5 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " target_item_id item_id score rank model\n",
+ "0 1 1895 0.749783 1 KnnCustomModel\n",
+ "1 1 1989 0.750385 2 KnnCustomModel\n",
+ "2 1 458 0.759902 3 KnnCustomModel\n",
+ "3 1 1906 0.759902 4 KnnCustomModel\n",
+ "4 1 877 0.768270 5 KnnCustomModel"
+ ]
+ },
+ "execution_count": 37,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "reco.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "f7451d5d-5689-40ff-8f1f-fd9c163c6833",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "selected_items = {\"item_one\": 3}\n",
+ "formatters = {\"item_id\": lambda x: f\"{x}\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "d34cd895-5957-4176-9e59-dc2b74f1079d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "57cba3f32d84411b8fa97416e828e1b8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(ToggleButtons(button_style='warning', description='Target:', options=('item_one',), value='item…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "app = ItemToItemVisualApp.construct(\n",
+ " reco=reco,\n",
+ " item_data=movies,\n",
+ " selected_items=selected_items,\n",
+ " formatters=formatters,\n",
+ " auto_display=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f225a5c2",
+ "metadata": {},
+ "source": [
+ "# Use model to recommend movies for a specific users"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "id": "850b4dcb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reco = model.recommend([1,7,6,2,3,5], dataset, 10, filter_viewed=True)\n",
+ "reco[Columns.Model] = \"KnnCustomModel\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "id": "d03c5383",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " score | \n",
+ " rank | \n",
+ " model | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2138 | \n",
+ " 10.0 | \n",
+ " 1 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 60 | \n",
+ " 9.0 | \n",
+ " 2 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 572 | \n",
+ " 8.0 | \n",
+ " 3 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 3155 | \n",
+ " 7.0 | \n",
+ " 4 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 1760 | \n",
+ " 6.0 | \n",
+ " 5 | \n",
+ " KnnCustomModel | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id score rank model\n",
+ "0 1 2138 10.0 1 KnnCustomModel\n",
+ "1 1 60 9.0 2 KnnCustomModel\n",
+ "2 1 572 8.0 3 KnnCustomModel\n",
+ "3 1 3155 7.0 4 KnnCustomModel\n",
+ "4 1 1760 6.0 5 KnnCustomModel"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "reco.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "id": "b4825b91-eaa0-405e-9b9e-2523fc6ba66b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "467aa87b78d347288a7c093029d9fb82",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(ToggleButtons(button_style='warning', description='Target:', options=('user_one',), value='user…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "selected_users = {\"user_one\": 3}\n",
+ "app = VisualApp.construct(\n",
+ " reco=reco,\n",
+ " interactions=ratings,\n",
+ " item_data=movies,\n",
+ " selected_users=selected_users,\n",
+ " formatters=formatters \n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bfc37d9a",
+ "metadata": {},
+ "source": [
+ "# Conclusion\n",
+ "You can create custom models with any requirements by inheriting from BaseModel and implementing necessary methods."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "new",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}