Skip to content

Commit

Permalink
test(test-endpoints): adds tests for the endpoints and home view - re…
Browse files Browse the repository at this point in the history
…direct
  • Loading branch information
nifedara committed Aug 6, 2024
1 parent 4a87e73 commit b0b9441
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 43 deletions.
339 changes: 296 additions & 43 deletions backend/tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import json
import os
import shutil

import validators
from django.conf import settings
import validators
from rest_framework import status
from rest_framework.test import APILiveServerTestCase, RequestsClient
from .factories import (
OsmUserFactory,
TrainingFactory,
DatasetFactory,
AoiFactory,
LabelFactory,
ModelFactory,
FeedbackAoiFactory,
)

API_BASE = "http://testserver/api/v1"

Expand All @@ -19,10 +29,15 @@ class TaskApiTest(APILiveServerTestCase):
def setUp(self):
# Create a request factory instance
self.client = RequestsClient()
self.user = OsmUserFactory(osm_id=123)
self.dataset = DatasetFactory(created_by=self.user)
self.aoi = AoiFactory(dataset=self.dataset)
self.model = ModelFactory(dataset=self.dataset, created_by=self.user)
self.json_type_header = headersList.copy()
self.json_type_header["content-type"] = "application/json"

def test_auth_me(self):
res = self.client.get(f"{API_BASE}/auth/me/", headers=headersList)
print(res.json())
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

def test_auth_login(self):
Expand All @@ -32,9 +47,11 @@ def test_auth_login(self):
self.assertEqual(validators.url(res_body["login_url"]), True)

def test_create_dataset(self):
# create dataset

payload = {
"name": "My test dataset",
"source_imagery": "https://tiles.openaerialmap.org/5ac4fc6f26964b0010033112/0/5ac4fc6f26964b0010033113/{z}/{x}/{y}",
"name": self.dataset.name,
"source_imagery": self.dataset.source_imagery,
}
# test without authentication should be forbidden
res = self.client.post(f"{API_BASE}/dataset/", payload)
Expand All @@ -43,55 +60,291 @@ def test_create_dataset(self):
res = self.client.post(f"{API_BASE}/dataset/", payload, headers=headersList)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

# now dataset is created , create first aoi inside it
payload_second = {
"geom": {
"type": "Polygon",
"coordinates": [
[
[32.588507094820351, 0.348666499011499],
[32.588517512656978, 0.348184682976698],
[32.588869114643053, 0.348171660921362],
[32.588840465592334, 0.348679521066151],
[32.588507094820351, 0.348666499011499],
]
],
},
"dataset": 1,
}
json_type_header = headersList
json_type_header["content-type"] = "application/json"
def test_create_training(self):
# now dataset is created, create first aoi inside it

payload_second = {"geom": self.aoi.geom.json, "dataset": self.dataset.id}

res = self.client.post(
f"{API_BASE}/aoi/", json.dumps(payload_second), headers=json_type_header
f"{API_BASE}/aoi/",
json.dumps(payload_second),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

# create second aoi too , to test multiple aois
# create second aoi too, to test multiple aois

payload_third = {
"geom": {
"type": "Polygon",
"coordinates": [
[
[32.588046105549715, 0.349843692679227],
[32.588225813231475, 0.349484284008701],
[32.588624295482369, 0.349734307433132],
[32.588371662944233, 0.350088507273009],
[32.588046105549715, 0.349843692679227],
]
],
},
"dataset": 1,
"geom": self.aoi.geom.json,
"dataset": self.dataset.id,
}
res = self.client.post(
f"{API_BASE}/aoi/", json.dumps(payload_third), headers=json_type_header
f"{API_BASE}/aoi/", json.dumps(payload_third), headers=self.json_type_header
)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

# create model

model_payload = {"name": self.model.name, "dataset": self.dataset.id}
res = self.client.post(
f"{API_BASE}/model/",
json.dumps(model_payload),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

# create training without label

training_payload = {
"description": "My very first training",
"epochs": 1,
"zoom_level": [20, 21],
"batch_size": 1,
"model": self.model.id,
}
res = self.client.post(
f"{API_BASE}/training/",
json.dumps(training_payload),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST)

# download labels from osm for 1

## Fetch AOI
res = self.client.post(
f"{API_BASE}/label/osm/fetch/{self.aoi.id}/", "", headers=headersList
)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

# download labels from osm for 2

res = self.client.post(
f"{API_BASE}/label/osm/fetch/{self.aoi.id}/", "", headers=headersList
)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

# create training with epochs greater than the limit

training_payload = {
"description": "My very first training",
"epochs": 31,
"zoom_level": [20, 21],
"batch_size": 1,
"model": self.model.id,
}
res = self.client.post(
f"{API_BASE}/training/",
json.dumps(training_payload),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST)

# create training with batch size greater than the limit

training_payload = {
"description": "My very first training",
"epochs": 1,
"zoom_level": [20, 21],
"batch_size": 9,
"model": self.model.id,
}
res = self.client.post(
f"{API_BASE}/training/",
json.dumps(training_payload),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST)

# create training inside model

training_payload = {
"description": "My very first training",
"epochs": 1,
"zoom_level": [20, 21],
"batch_size": 1,
"model": self.model.id,
}
res = self.client.post(
f"{API_BASE}/training/",
json.dumps(training_payload),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

# create another training for the same model

training_payload = {
"description": "My very first training",
"epochs": 1,
"zoom_level": [20, 21],
"batch_size": 1,
"model": self.model.id,
}
res = self.client.post(
f"{API_BASE}/training/",
json.dumps(training_payload),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST)

self.training = TrainingFactory(model=self.model, created_by=self.user)

def test_create_label(self):
self.label = LabelFactory(aoi=self.aoi)
self.training = TrainingFactory(model=self.model, created_by=self.user)

# create label

label_payload = {
"geom": self.label.geom.json,
"aoi": self.aoi.id,
}

res = self.client.post(
f"{API_BASE}/label/",
json.dumps(label_payload),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_200_OK) # 201- for create

# create another label with the same geom and aoi

label_payload2 = {
"geom": self.label.geom.json,
"aoi": self.aoi.id,
}

res = self.client.post(
f"{API_BASE}/label/",
json.dumps(label_payload2),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_200_OK) # 200- for update

# create another label with error

label_payload3 = {
"geom": self.label.geom.json,
"aoi": 40, # non-existent aoi
}
res = self.client.post(
f"{API_BASE}/label/",
json.dumps(label_payload3),
headers=self.json_type_header,
)
self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST)

def test_fetch_feedbackAoi_osm_label(self):
# create feedback aoi
training = TrainingFactory(model=self.model, created_by=self.user)
feedbackAoi = FeedbackAoiFactory(training=training, user=self.user)

# download available osm data as labels for the feedback aoi

res = self.client.post(
f"{API_BASE}/label/feedback/osm/fetch/{feedbackAoi.id}/",
"",
headers=headersList,
)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

def test_get_runStatus(self):
training = TrainingFactory(model=self.model, created_by=self.user)

# get running training status

res = self.client.get(
f"{API_BASE}/training/status/{training.id}/", headers=headersList
)
self.assertEqual(res.status_code, status.HTTP_200_OK)

def test_submit_training_feedback(self):
training = TrainingFactory(model=self.model, created_by=self.user)

# apply feedback to training published checkpoints

training_feedback_payload = {
"training_id": training.id,
"epochs": 20,
"batch_size": 8,
"zoom_level": [19, 20],
}
res = self.client.post(
f"{API_BASE}/feedback/training/submit/",
json.dumps(training_feedback_payload),
headers=self.json_type_header,
)
# submit unfinished/unpublished training feedback should not pass
self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST)

def test_publish_training(self):
training = TrainingFactory(model=self.model, created_by=self.user)

# publish an unfinished training should not pass

res = self.client.post(
f"{API_BASE}/training/publish/{training.id}/", headers=headersList
)
self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND)

def test_get_GpxView(self):
training = TrainingFactory(model=self.model, created_by=self.user)
feedbackAoi = FeedbackAoiFactory(training=training, user=self.user)

# generate aoi GPX view - aoi_id

res = self.client.get(f"{API_BASE}/aoi/gpx/{self.aoi.id}/", headers=headersList)
self.assertEqual(res.status_code, status.HTTP_200_OK)

# generate feedback aoi GPX view - feedback aoi_id

res = self.client.get(
f"{API_BASE}/feedback-aoi/gpx/{feedbackAoi.id}/", headers=headersList
)
self.assertEqual(res.status_code, status.HTTP_200_OK)

def test_get_workspace(self):
# get training workspace

res = self.client.get(f"{API_BASE}/workspace/", headers=headersList)
self.assertEqual(res.status_code, status.HTTP_201_CREATED)

def test_download_workspace(self):
try:
lookup_dir = "test_dir"

# download non-existent dir should fail
res = self.client.get(
f"{API_BASE}/workspace/download/{lookup_dir}", headers=headersList
)
self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND)

# test download workspace
base_dir = os.path.join(settings.TRAINING_WORKSPACE, lookup_dir)
os.makedirs(base_dir)

with open(os.path.join(base_dir, "file.txt"), "wb") as f:
f.write(b"Test file")

res = self.client.get(
f"{API_BASE}/workspace/download/{lookup_dir}",
headers=headersList,
)
self.assertEqual(res.status_code, status.HTTP_200_OK)

# test download file greater than the 200 mb limit

with open(os.path.join(base_dir, "large_file.txt"), "wb") as f:
f.seek(201 * 1024**2)
f.write(b"\0")

# download file size greater than limit should fail
res = self.client.get(
f"{API_BASE}/workspace/download/{lookup_dir}",
headers=headersList,
)
self.assertEqual(res.status_code, status.HTTP_403_FORBIDDEN)

aoi_res = self.client.get(f"{API_BASE}/aoi/?dataset=1")
self.assertEqual(aoi_res.status_code, 200)
aoi_res_json = aoi_res.json()
self.assertEqual(len(aoi_res_json["features"]), 2)
finally:
# clean up
shutil.rmtree(base_dir)
Loading

0 comments on commit b0b9441

Please sign in to comment.