From 402787311f1bbf7059ea3911614730b7504cb12c Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sat, 9 Mar 2024 20:57:05 +0100 Subject: [PATCH] Parametrize and update tests --- training/test_train.py | 43 ++++++++++++++++++++++++------------------ training/train.py | 2 +- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/training/test_train.py b/training/test_train.py index 01de758..dbb20da 100644 --- a/training/test_train.py +++ b/training/test_train.py @@ -6,9 +6,12 @@ import numpy as np -def test_sqli_model(): - MAX_WORDS = 10000 - MAX_LEN = 100 +MAX_WORDS = 10000 +MAX_LEN = 100 + + +@pytest.fixture +def model(): # Load dataset data = pd.read_csv("dataset/sqli_dataset.csv") @@ -16,28 +19,32 @@ def test_sqli_model(): # Load TF model from SavedModel sqli_model = load_model("sqli_model/1") - # Create a sample SQL injection data - sample = [ - "select * from users where id='1' or 1=1--", - "select * from users", - "select * from users where id=10000", - ( - "select * from test where id=1 UNION ALL " - "SELECT NULL FROM INFORMATION_SCHEMA.COLUMNS WHERE 1=0; --;" - ), - ] - # Tokenize the sample - tokenizer = Tokenizer(num_words=MAX_WORDS) + tokenizer = Tokenizer(num_words=MAX_WORDS, filters="") tokenizer.fit_on_texts(data["Query"]) + return {"tokenizer": tokenizer, "sqli_model": sqli_model} + + +@pytest.mark.parametrize("sample", + [ + ("select * from users where id='1' or 1=1--", [92.02]), + ("select * from users", [0.077]), + ("select * from users where id=10000", [14.83]), + ("select '1' union select 'a'; -- -'", [99.99]), + ("select '' union select 'malicious php code' \g /var/www/test.php; -- -';", [99.99]), + ("select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';", [99.99]) + ] +) +def test_sqli_model(model, sample): # Vectorize the sample - sample_seq = tokenizer.texts_to_sequences(sample) + sample_seq = model["tokenizer"].texts_to_sequences([sample[0]]) sample_vec = pad_sequences(sample_seq, maxlen=MAX_LEN) # Predict sample - predictions = sqli_model.predict(sample_vec) + predictions = model["sqli_model"].predict(sample_vec) + # Scale up to 100 assert predictions * 100 == pytest.approx( - np.array([[99.99], [0.005], [0.055], [99.99]]), 0.1 + np.array([sample[1]]), 0.1 ) diff --git a/training/train.py b/training/train.py index 084d2f0..36603b0 100644 --- a/training/train.py +++ b/training/train.py @@ -25,7 +25,7 @@ MAX_LEN = 100 # Use Tokenizer to encode text -tokenizer = Tokenizer(num_words=MAX_WORDS, filters='') +tokenizer = Tokenizer(num_words=MAX_WORDS, filters="") tokenizer.fit_on_texts(data["Query"]) sequences = tokenizer.texts_to_sequences(data["Query"])