Skip to content

Commit

Permalink
Parametrize and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Mar 9, 2024
1 parent 0bdcdea commit 4027873
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
43 changes: 25 additions & 18 deletions training/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,45 @@
import numpy as np


def test_sqli_model():
MAX_WORDS = 10000
MAX_LEN = 100
MAX_WORDS = 10000
MAX_LEN = 100


@pytest.fixture
def model():

# Load dataset
data = pd.read_csv("dataset/sqli_dataset.csv")

# Load TF model from SavedModel
sqli_model = load_model("sqli_model/1")

# Create a sample SQL injection data
sample = [
"select * from users where id='1' or 1=1--",
"select * from users",
"select * from users where id=10000",
(
"select * from test where id=1 UNION ALL "
"SELECT NULL FROM INFORMATION_SCHEMA.COLUMNS WHERE 1=0; --;"
),
]

# Tokenize the sample
tokenizer = Tokenizer(num_words=MAX_WORDS)
tokenizer = Tokenizer(num_words=MAX_WORDS, filters="")
tokenizer.fit_on_texts(data["Query"])

return {"tokenizer": tokenizer, "sqli_model": sqli_model}


@pytest.mark.parametrize("sample",
[
("select * from users where id='1' or 1=1--", [92.02]),
("select * from users", [0.077]),
("select * from users where id=10000", [14.83]),
("select '1' union select 'a'; -- -'", [99.99]),
("select '' union select 'malicious php code' \g /var/www/test.php; -- -';", [99.99]),
("select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';", [99.99])
]
)
def test_sqli_model(model, sample):
# Vectorize the sample
sample_seq = tokenizer.texts_to_sequences(sample)
sample_seq = model["tokenizer"].texts_to_sequences([sample[0]])
sample_vec = pad_sequences(sample_seq, maxlen=MAX_LEN)

# Predict sample
predictions = sqli_model.predict(sample_vec)
predictions = model["sqli_model"].predict(sample_vec)

# Scale up to 100
assert predictions * 100 == pytest.approx(
np.array([[99.99], [0.005], [0.055], [99.99]]), 0.1
np.array([sample[1]]), 0.1
)
2 changes: 1 addition & 1 deletion training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
MAX_LEN = 100

# Use Tokenizer to encode text
tokenizer = Tokenizer(num_words=MAX_WORDS, filters='')
tokenizer = Tokenizer(num_words=MAX_WORDS, filters="")
tokenizer.fit_on_texts(data["Query"])
sequences = tokenizer.texts_to_sequences(data["Query"])

Expand Down

0 comments on commit 4027873

Please sign in to comment.