Skip to content

Commit

Permalink
Merge pull request #66 from microsoft/dev
Browse files Browse the repository at this point in the history
merge dev into main
  • Loading branch information
Micheallei authored Aug 24, 2024
2 parents 6968dcc + aa64d45 commit 757e4ca
Show file tree
Hide file tree
Showing 37 changed files with 2,808 additions and 1,957 deletions.
4 changes: 3 additions & 1 deletion RecExplainer/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ data_*/
output/
output_*/

*.log
*.log

gradio_cached_examples/
69 changes: 41 additions & 28 deletions RecExplainer/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# RecExplainer: Aligning Large Language Models for Explaining Recommendation Models
This is the Repo for [RecExplainer: Aligning Large Language Models for Recommendation Model Interpretability](https://arxiv.org/abs/2311.10947), which leverages LLMs as surrogate models for explaining black-box recommender models.
This is the Repo for our KDD2024 paper: [RecExplainer: Aligning Large Language Models for Explaining Recommendation Models](https://arxiv.org/abs/2311.10947), which leverages LLMs as surrogate models for explaining black-box recommender models.
![Figure Caption](framework.png)

## Introduction
Expand All @@ -18,36 +18,36 @@ Our evaluation of the RecExplainer framework is two-fold:
* Overall ratings: We use both GPT4 and human experts to annotate the quality of generated explanations.
* Distinction and Coherence: We also train a classifier and a score predictor to further verify whether RecExplainer are indeed explaining its own predictions.


## Environment Setting

### Environment
```bash
conda create -n recexplainer python=3.9
conda create -n recexplainer python==3.10.14
conda activate recexplainer
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install -r requirements.txt
```

### Set OpenAI API Environment
If you want to use OpenAI API, you need to firstly run the following scripts in your console. If it is not Azure OpenAI API (OPENAI_API_TYPE is not "azure"), you only need to specify OPENAI_API_KEY and ENGINE.
If you want to use OpenAI API, you need to firstly run the following scripts in your console. If it is not Azure OpenAI API (OPENAI_API_TYPE is not "azure"), you only need to specify OPENAI_API_KEY and MODEL.

```bash
export OPENAI_API_KEY_01=xxx;
export OPENAI_API_BASE_01=https://xxx.openai.azure.com/;
export OPENAI_API_VERSION_01=2023-03-15-preview;
export OPENAI_API_TYPE_01=azure;
export ENGINE_01=xxx;

export OPENAI_API_KEY=xxx;
export OPENAI_API_BASE=https://xxx.openai.azure.com/;
export OPENAI_API_VERSION=2023-03-15-preview;
export OPENAI_API_TYPE=azure;
export MODEL=xxx;
```

###If you want to use multiple keys at the same time to speed up data generation.
export OPENAI_API_KEY_02=xxx;
export OPENAI_API_BASE_02=https://xxx.openai.azure.com/;
export OPENAI_API_VERSION_02=2023-03-15-preview;
export OPENAI_API_TYPE_02=azure;
export ENGINE_02=xxx;
We also support AzureCliCredential login:
```bash
az login
export OPENAI_API_BASE=https://xxx.openai.azure.com/;
export OPENAI_API_VERSION=2023-03-15-preview;
export MODEL=xxx;
```


## Dataset Preparation for Target Recommender Model
For data preparation, you need to download three raw files: Amazon review, Amazon metadata, ShareGPT
* Amazon Video Games 5-core reviews: https://jmcauley.ucsd.edu/data/amazon_v2/categoryFilesSmall/Video_Games_5.json.gz
Expand All @@ -66,8 +66,10 @@ bash shell/unirec_prepare_data.sh

## Training and Inference with the Target Recommender Model
### Training
Currently we support both SASRec model and MF model, you can train them respectively.
```bash
bash shell/unirec_train.sh
bash shell/unirec_sasrec_train.sh
bash shell/unirec_mf_train.sh
```

### Inference
Expand All @@ -77,16 +79,20 @@ cp preprocess/unirec_utils/data4Exp.py $HOME/UniRec/unirec/main
cp $HOME/RecExplainer/data/unirec_raw_data/amazon_video_games_v3/train_ids.csv $HOME/UniRec/data/amazon_video_games_v3
cp $HOME/RecExplainer/data/unirec_raw_data/amazon_video_games_v3/test_ids.csv $HOME/UniRec/data/amazon_video_games_v3
```

For SASRec model:
```bash
bash shell/unirec_infer.sh
bash shell/unirec_sasrec_infer.sh
```
For MF model:
```bash
bash shell/unirec_mf_infer.sh
```

## Dataset Preparation for RecExplainer Model
```bash
bash shell/recexplainer_data_pipeline.sh
```
After running the above script, you will get the following training and testing files:
After running the above script, you will get the following training and testing files for both SASRec and MF model:

For alignmen tasks
* `behaviour_train.json` & `behaviour_valid.json`: data for training and testing RecExplainer-B (behavior alignmet), also used to test the alignment performance of open source LLMs.
Expand All @@ -97,18 +103,24 @@ For explanation tasks
* `explan_behaviour_valid.json`: prompts for RecExplainer-B to generate explanations.
* `explan_intention_valid.json`: prompts for RecExplainer-I to generate explanations.
* `explan_both_valid.json`: prompts for RecExplainer-H to generate explanations.
* `explan_chatgpt.csv`: prompts for ChatGPT to generate explanations.
* `explan_chatgpt_valid.csv`: prompts for ChatGPT to generate explanations.

Note: Files such as `explain_behaviour_train.json` are also used to generate explanations. After generation, we'll use them to train our score predictors as well as classifiers. See Section [Distinction and Coherence](#custom-anchor) for more information.

## Train RecExplainer
```bash
bash shell/train.sh
```

Important Parameters:
- `--data_names`: the json file name with suffix ('_train.json' or '_valid.json') removed.
- `--task_type`: alignment method of the training stage. "behaviour" means behavior aligment, "intention" means intention aligment, "both" means hybrid aligment.
- `--rec_model_type`: The type of the target recommender model. We currently support "SASRec" and "MF".
- `--task_type`: Alignment method of the training stage. "behaviour" means behavior aligment, "intention" means intention aligment, "both" means hybrid aligment.
- `--template_name`: The chat template of the LLM. We currently support "mistral"/"vicuna"/"llama-2"/"llama-3"/"phi3".

After training, you need to merge lora adapters into the base model.
```bash
bash shell/merge.sh
```

## Evaluation
### Alignment Effect
Expand All @@ -117,9 +129,8 @@ bash shell/infer_alignment.sh
```

Parameters:
- `--data_names`: the json file name with suffix ('_train.json' or '_valid.json') removed.
- `--task_type`: alignment method used by the model. "behaviour" means behavior aligment, "intention" means intention aligment, "both" means hybrid aligment, "none" means using LLM without alignment training.
- `--inference_mode`: the name of the inference task. "uid2hist" means history reconstruction task, "uid2next" means next item retrieval task, "uidiid2rank" means item ranking task, "uidiid2binary" means interest classification task.
- `--inference_mode`: the name of the inference task. "iid2title" means item recovery task, "uid2hist" means history reconstruction task, "uid2next" means next item retrieval task, "uidiid2rank" means item ranking task, "uidiid2binary" means interest classification task.

### Explanation Generation Ability

Expand All @@ -129,7 +140,6 @@ bash shell/infer_explan.sh
```

Parameters:
- `--data_names`: the json file name with suffix ('_train.json' or '_valid.json') removed.
- `--task_type`: alignment method used by the model. "behaviour" means behavior aligment, "intention" means intention aligment, "both" means hybrid aligment, "none" means using LLM without alignment training.
- `--inference_mode`: the name of the inference task. "case study" means generating explanation texts.

Expand All @@ -150,7 +160,7 @@ Parameters:
- `--judge_response_file`: Output file of gpt4


#### Distinction and Coherence
#### Distinction and Coherence <a id="custom-anchor"></a>
We train a classifier and a score predictor to further verify whether RecExplainer are indeed explaining its own predictions.

**Data generation**
Expand Down Expand Up @@ -184,3 +194,6 @@ If you find this project useful in your research, please consider citing:
year={2023}
}
```

## Acknowledge
Thanks to the open source codes of [UniRec](https://github.com/microsoft/UniRec).
12 changes: 6 additions & 6 deletions RecExplainer/discriminator/data_gen.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

PROCESS_DATA_DIR="$HOME/RecExplainer/data/amazon_video_games_v3/process_data"
UNIREC_DATA_DIR="$HOME/UniRec/output/amazon_video_games_v3/SASRec/RecExplainer/xxx/"
PROCESS_DATA_DIR="$HOME/blob/RecExplainer/amazon_video_games_v3"
UNIREC_DATA_DIR="$HOME/blob/RecExplainer/amazon_video_games_v3"
DISCRIMINATOR_DATA_DIR="$PROCESS_DATA_DIR/discriminator"
EXPLAN_DIR="$HOME/RecExplainer/output/amazon_video_games_v3/explan"

cd $HOME/RecExplainer/discriminator

python data_process.py --top_file $UNIREC_DATA_DIR/train_top.txt --seqdata_file $PROCESS_DATA_DIR/sequential_data.txt --in_gpt_file $EXPLAN_DIR/discriminator_train/chatgpt_response.csv \
--in_vicuna_file $EXPLAN_DIR/discriminator_train/vicuna_response.csv --in_recexplainer_file $EXPLAN_DIR/discriminator_train/recexplainer-H_response.csv \
--in_vicuna_file $EXPLAN_DIR/discriminator_train/llama3_response.csv --in_recexplainer_file $EXPLAN_DIR/discriminator_train/recexplainer-H_response.csv \
--out_cls_file $DISCRIMINATOR_DATA_DIR/classification_train.csv --out_reg_gpt_file $DISCRIMINATOR_DATA_DIR/regression_chatgpt_train.csv \
--out_reg_vicuna_file $DISCRIMINATOR_DATA_DIR/regression_vicuna_train.csv --out_reg_recexplainer_file $DISCRIMINATOR_DATA_DIR/regression_recexplainer_train.csv \
--out_reg_vicuna_file $DISCRIMINATOR_DATA_DIR/regression_llama3_train.csv --out_reg_recexplainer_file $DISCRIMINATOR_DATA_DIR/regression_recexplainer_train.csv \
--split "train" --max_samples 2000

python data_process.py --top_file $UNIREC_DATA_DIR/test_top.txt --seqdata_file $PROCESS_DATA_DIR/sequential_data.txt --in_gpt_file $EXPLAN_DIR/chatgpt_response.csv \
--in_vicuna_file $EXPLAN_DIR/vicuna_response.csv --in_recexplainer_file $EXPLAN_DIR/recexplainer-H_response.csv \
--in_vicuna_file $EXPLAN_DIR/llama3_response.csv --in_recexplainer_file $EXPLAN_DIR/recexplainer-H_response.csv \
--out_cls_file $DISCRIMINATOR_DATA_DIR/classification_test.csv --out_reg_gpt_file $DISCRIMINATOR_DATA_DIR/regression_chatgpt_test.csv \
--out_reg_vicuna_file $DISCRIMINATOR_DATA_DIR/regression_vicuna_test.csv --out_reg_recexplainer_file $DISCRIMINATOR_DATA_DIR/regression_recexplainer_test.csv \
--out_reg_vicuna_file $DISCRIMINATOR_DATA_DIR/regression_llama3_test.csv --out_reg_recexplainer_file $DISCRIMINATOR_DATA_DIR/regression_recexplainer_test.csv \
--split "valid" --max_samples 500
4 changes: 2 additions & 2 deletions RecExplainer/discriminator/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def parse_args():
user_items[user] = items

gpt_df = pd.read_csv(args.in_gpt_file)
gpt_df = gpt_df.drop(['question','response','target'], axis=1)
gpt_df = gpt_df.rename(columns={'response-chatGPT':'explan'})
gpt_df = gpt_df.drop(['question'], axis=1)
gpt_df = gpt_df.rename(columns={'answer':'explan'})
gpt_df['label'] = 0.0
gpt_df['user_id'] = 0
gpt_df['item_id'] = 0
Expand Down
2 changes: 1 addition & 1 deletion RecExplainer/discriminator/run_cls.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

DISCRIMINATOR_DATA_DIR="$HOME/RecExplainer/data/amazon_video_games_v3/process_data/discriminator"
DISCRIMINATOR_DATA_DIR="$HOME/RecExplainer/data/amazon_video_games_v3/discriminator"
EX_DIR=$HOME/RecExplainer/discriminator
cd $EX_DIR

Expand Down
14 changes: 7 additions & 7 deletions RecExplainer/discriminator/run_reg.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

DISCRIMINATOR_DATA_DIR="$HOME/RecExplainer/data/amazon_video_games_v3/process_data/discriminator"
DISCRIMINATOR_DATA_DIR="$HOME/RecExplainer/data/amazon_video_games_v3/discriminator"
EX_DIR=$HOME/RecExplainer/discriminator
cd $EX_DIR

Expand Down Expand Up @@ -29,7 +29,7 @@ CUDA_VISIBLE_DEVICES=0 python run_cls.py \
--evaluation_strategy steps \
--eval_steps 125 \
--warmup_ratio 0.1 \
--report_to none > $EX_DIR/output/discriminator/regression_chatgpt/regression_chatgpt.log 2>&1
--report_to none > $EX_DIR/output/regression_chatgpt.log 2>&1

CUDA_VISIBLE_DEVICES=0 python run_cls.py \
--num_labels 1 \
Expand All @@ -55,7 +55,7 @@ CUDA_VISIBLE_DEVICES=0 python run_cls.py \
--evaluation_strategy steps \
--eval_steps 125 \
--warmup_ratio 0.1 \
--report_to none > $EX_DIR/output/discriminator/regression_recexplainer/regression_recexplainer.log 2>&1
--report_to none > $EX_DIR/output/regression_recexplainer.log 2>&1


CUDA_VISIBLE_DEVICES=0 python run_cls.py \
Expand All @@ -64,11 +64,11 @@ CUDA_VISIBLE_DEVICES=0 python run_cls.py \
--do_train \
--do_eval \
--max_seq_length 512 \
--train_file $DISCRIMINATOR_DATA_DIR/regression_vicuna_train.csv \
--validation_file $DISCRIMINATOR_DATA_DIR/regression_vicuna_test.csv \
--train_file $DISCRIMINATOR_DATA_DIR/regression_llama3_train.csv \
--validation_file $DISCRIMINATOR_DATA_DIR/regression_llama3_test.csv \
--model_name_or_path bert-base-uncased \
--cache_dir $HOME/.cache/ \
--output_dir $EX_DIR/output/discriminator/regression_vicuna \
--output_dir $EX_DIR/output/discriminator/regression_llama3 \
--learning_rate 3e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 8 \
Expand All @@ -82,4 +82,4 @@ CUDA_VISIBLE_DEVICES=0 python run_cls.py \
--evaluation_strategy steps \
--eval_steps 125 \
--warmup_ratio 0.1 \
--report_to none > $EX_DIR/output/discriminator/regression_vicuna/regression_vicuna.log 2>&1
--report_to none > $EX_DIR/output/regression_llama3.log 2>&1
Loading

0 comments on commit 757e4ca

Please sign in to comment.