Skip to content

Commit

Permalink
update test files of NCDM
Browse files Browse the repository at this point in the history
  • Loading branch information
LegionKing committed Nov 17, 2023
1 parent e7646b3 commit 8c6ee4e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 35 deletions.
50 changes: 20 additions & 30 deletions tests/ncdm/conftest.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,31 @@
# coding: utf-8
# 2021/4/6 @ WangFei
# 2023/11/17 @ WangFei

import random
import pytest
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import random


@pytest.fixture(scope="package")
def conf():
user_num = 5
item_num = 2
knowledge_num = 4
return user_num, item_num, knowledge_num
def meta():
meta_data = {'userId': ['001', '002', '003'], 'itemId': ['adf', 'w5'], 'skill': ['skill1', 'skill2', 'skill3', 'skill4']}
return meta_data


@pytest.fixture(scope="package")
def data(conf):
user_num, item_num, knowledge_num = conf
knowledge_embs = np.zeros((item_num, knowledge_num))
for i in range(item_num):
for j in range(knowledge_num):
knowledge_embs[i][j] = random.randint(0, 1)
log = []
for i in range(user_num):
for j in range(item_num):
score = random.randint(0, 1)
log.append((i, j, knowledge_embs[j], score))

user_id, item_id, knowledge_emb, score = zip(*log)
batch_size = 4
meta_data = meta
item_skills = []
skll_n = len(meta_data['skill'])
for itemid in meta_data['itemId']:
item_skills.append(meta['skill'][random.randint(0, skll_n - 1)])
userIds, itemIds, skills, responses = []
for user in meta_data['userId']:
for i, item in enumerate(meta_data['itemId']):
userIds.append(user)
itemIds.append(item)
skills.append(item_skills[i])
responses.append(random.randint(0, 1))

dataset = TensorDataset(
torch.tensor(user_id, dtype=torch.int64),
torch.tensor(item_id, dtype=torch.int64),
torch.tensor(knowledge_emb, dtype=torch.int64),
torch.tensor(score, dtype=torch.float)
)
return DataLoader(dataset, batch_size=batch_size)
df_data = pd.DataFrame({'userId': userIds, 'itemId': itemIds, 'skill': skills, 'response': responses})
return df_data
12 changes: 7 additions & 5 deletions tests/ncdm/test_ncdm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# coding: utf-8
# 2021/4/6 @ WangFei
# 2023/11/17 @ WangFei
from EduCDM import NCDM


def test_train(data, conf, tmp_path):
user_num, item_num, knowledge_num = conf
cdm = NCDM(knowledge_num, item_num, user_num)
cdm.train(data, test_data=data, epoch=2)
def test_train(data, meta, tmp_path):
df_data = data
meta_data = meta
cdm = NCDM(meta_data)
cdm.fit(train_data=df_data, epoch=2, val_data=df_data)
filepath = tmp_path / "mcd.params"
cdm.save(filepath)
cdm.load(filepath)
cdm.eval(df_data)

0 comments on commit 8c6ee4e

Please sign in to comment.