Skip to content

A Policy Retrieval and Recommend System Using BERT PLM and Graph Inference

License

Notifications You must be signed in to change notification settings

Polarisjame/Policy_Retrieval_BERT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


Policy-Retrieval-BERT

A Poicy Retrieval and Recommend System Developed with PLM-BERT and Graph Inference
Explore the docs »

2023/4/15 - 项目刚部署于服务器上:Link. 2023/5/19 - 刚出了结果,要去南京决赛答辩了,爆!

Table of Contents
  1. About The Project
  2. Getting Started
  3. License

About The Project

This system consists of two parts

Retrieval System

 As shown in Figure,The Retirval System is divided into two Channels: KeyWords Retieval and Semantic Retrieval.We use ElasticSearch ES for KeyWords Search. Our Semantic Retrieval Channel mainly consists of 4 modules: Encoding Module, Community Embedding Module, High-Dimensional Vector Retrieval Module, Triplet Loss Training Module. Then we use Xgboost to fusion the results from different Channels.

image-20230415152458437

Encoding Module

 The input data consists of Title, Body, and other attributes of the policy. We use PLM Model(e.g., BERT/RoBERTa) to extract Semantic informations in policy's title and body.And we use Onehot Embedding to embed other attributes of a ploicy.

Community Embedding module and High-dimensional vector retrieval module is Implemented by my teammates

Triplet Loss Training module

 From Community Embedding Module, We can split all policies into diffrent communities. To help Language Model better distinguish policies from different communities, we build triplets like $ (anchor, positive, negative) $ from Community Graph and fine-tune PLM by TripletLoss $$L = Max(d(a,p)-d(a,n)+margin,0)$$ where $d$ refers to Euclidean Distance, $a,p$ and $ n $ represents anchor, positive and negative points respectively. margin is a hyper-parameter.Here we randomly select 10% of all policies from different communities as anchor points, and define points in the same community as anchor points as positive points.Further, We select the semi-hard triplets(i.e.,$ 0<L<margin $,which means the positive point is closer to the anchor but not close enough) as Trainset.

 Besides, considering the difference between Policy Title and Policy BERT, we treat title as a special kinds of keywords, Therefore, We split our Retrieval System into another two Channels : Title2Body and Body2Body.In Title2Body channel, we select the concatenate of policy title and other attributes as anchor points, Body as positive and negative points.In Body2Body channel, we select policy body as anchor, positive and negative points.

 Then,two Retrieval Channels provide two models, We send the origin policy data into two models parallelly, and the output of each model is sent into High-dimensional vector retrieval module and get the output of semantic retrival.

Fusion Layer

 In sum up, we obtain a 3 channel result : ES , Title2Body, Body2Body. The intersection of these results is most appropriate, but it's also time-consuming.Thus, We use Xgboost model(or other algorithms like TA) to help the model learn the rank of intersection.In detail, we calculate the average similarity in intersection, and let other policy's similarity that not included in ther intersection as 0.Then, we apply xgboost to fit the Mean Similarity.

Recommend System

 The traditional method in recommend fields is using collaborative filtering model, which ignore the origin feature of User and Item.Inspired by GCMC and LSR ,we treat the recommend task as link prediction,using graph convolutional model to gather the featuer between user and items and further treat the structure of the graph as a learnable latent parameter.

image-20230415162128000

(back to top)

Getting Started

Requirements

My code works with the following environment.

  • python=3.7
  • pytorch=1.13.1+cu116
  • transformers=4.18.0
  • ubuntu 20.04
pip install requirements.txt  //Install Requirements

Dataset

Preprocess

  1. put the downloaded data under ./init_data_process/data

    cd ./init_data_process
    python preprocess2DBLP.py
    
  2. the result used to build community is stored under
    ./init_data_process/results

  3. Produce Random Anchor

    python Random_sample.py
    
  4. Random sample result is stored under ./init_data_process/results/random_sample

  5. Produce triplets

    python read_sample.py
    

Triplets Training

  1. put triplets_body.csv/data_sample.csv/category_index.txt produced in Preprocess under ./train_model/data
  2. put origin policy data policyinfo_new.tsv under ./train_model/Conver2vec/data
cd ../
sh run_BERT_MLP.sh gpu_id # train model
sh test_BERT_MLP.sh gpu_id # evaluate

Index_type: 'Title' means Title2Body Channel, 'Body' means Body2Body Channel.

  • Convert Origin policy data into vector:

    cd ./Conver2vec
    sh convert_data.sh gpu_id
    

Fusion Layer

  1. put results of 3 channels under ./Fusion/data
cd ../../Fusion
python TA.py

Recommend

cd ../UserCF
python UserCF.py

(back to top)

License

This project is licensed under the MIT License - see the LICENSE file for details.

About

A Policy Retrieval and Recommend System Using BERT PLM and Graph Inference

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published