Keep in mind: KANs may not be proper for text classification (or even NLP?), at least in our research. Extra experiments must be conducted to prove this.
This repo uses Kolmogorov-Arnold Networks (KANs) for text classification over GLUE tasks (RTE, CoLA, MRPC, etc). Our paper will be published in arXiv soon.
- Python >= 3.9.7
- Install pykan (https://github.com/KindXiaoming/pykan)
- requirements.txt
We use bert-base-cased as the pre-trained model for producing embeddings (pooled_outputs) in the training process. All models have 768 input size, 64 hidden neurons, and 2 output classes (0 & 1). The training was performed on Tesla V100 16GB, 10 epochs, lr = 2e-5 for all transformer models, and lr = 2e-3 for other models.
python run_train.py --mode "train" --network "trans_effi_kan" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2
python run_train.py --mode "train" --network "trans_fast_kan" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2
python run_train.py --mode "train" --network "trans_faster_kan" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2
python run_train.py --mode "train" --network "mlp" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2
python run_train.py --mode "train" --network "classifier" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2
The training takes a very long time when the model infers outputs with an input size of 768 (outputs = KAN(texts)). Therefore, we must reduce the embedding size from 768 to 8 (n_size*m_size) by using reduce_size() in utils.py. The smaller the input size, the faster the training time.
def reduce_size(embeddings, n_size = 1, m_size = 8):
second_dim = list(embeddings.shape)[-1]
first_dim = list(embeddings.shape)[0]
embeddings = torch.reshape(embeddings, (first_dim, int(second_dim/(n_size*m_size)), n_size*m_size))
embeddings = torch.sum(embeddings, (1), keepdim = True).squeeze()
return embeddings
Then, we can reluctantly run the training:
python run_train.py --mode "train" --network "kan" --em_model_name "bert-base-cased" --ds_name "wnli" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 8 --n_hidden 64 --n_class 2 --device "cpu"
- mode: working mode ("train" or "test")
- network: type of model (efficientkan, TransformerClassifier, mlp)
- em_model_name: the model offers embeddings (BERT)
- ds_name: dataset name
- epochs: the number of epochs
- batch_size: the training batch size
- max_len: the maximum length of input text
- n_size, m_size: We consider the input size a matrix with n_size x m_size. For example, BERT offers 768 input size (1 x 768).
- n_hidden: The number of hidden neurons. We use only 1 hidden layer. You can modify the code for more layers.
- n_class: The number of classes. For GLUE tasks, there are only 2 classes (0 & 1)
- embed_type: the type of embeddings (pool, last hidden, or weight)
- device: use "cuda" or "cpu"
Network | Training Accuracy | Val Accuracy | Training time (seconds) |
---|---|---|---|
trans_mlp | 0.9897 | 0.8282 | 2798 |
trans_classifier | 0.9619 | 0.8282 | 2802 |
trans_effi_kan | 0.9635 | 0.8292 | 2827 |
trans_fast_kan | 0.9949 | 0.8206 | 2831 |
trans_faster_kan | 0.9756 | 0.8215 | 2818 |
effi_kan | 0.749 | 0.7458 | 951 |
fast_kan | 0.7501 | 0.742 | 937 |
faster_kan | 0.7235 | 0.7315 | 924 |
Network | Training Accuracy | Val Accuracy | Training time (seconds) |
---|---|---|---|
trans_mlp | 0.7377 | 0.8603 | 1195 |
trans_classifier | 0.9866 | 0.8848 | 1204 |
trans_effi_kan | 0.9986 | 0.8676 | 1219 |
trans_fast_kan | 0.9422 | 0.8554 | 1214 |
trans_faster_kan | 0.9591 | 0.8701 | 1207 |
effi_kan | 0.6955 | 0.7255 | 407 |
fast_kan | 0.7009 | 0.7157 | 401 |
faster_kan | 0.6848 | 0.7059 | 395 |
Network | Training Accuracy | Val Accuracy | Training time (seconds) |
---|---|---|---|
trans_mlp | 0.9302 | 0.675 | 821 |
trans_classifier | 0.8475 | 0.625 | 818 |
trans_effi_kan | 0.9069 | 0.675 | 826 |
trans_fast_kan | 0.9394 | 0.6071 | 831 |
trans_faster_kan | 0.9639 | 0.6964 | 829 |
effi_kan | 0.5004 | 0.5214 | 277 |
fast_kan | 0.5269 | 0.5429 | 273 |
faster_kan | 0.496 | 0.5214 | 269 |
- https://github.com/Blealtan/efficient-kan
- https://github.com/KindXiaoming/pykan
- https://github.com/AthanasiosDelis/faster-kan
- https://github.com/ZiyaoLi/fast-kan/
If you have any questions, please contact: [email protected]. If you want to know more about me, please visit website: https://tahoangthang.com.