Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MQTT predictions for amateur payloads + amateur predictor tests #124

Merged
merged 2 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ham_predictor.tf
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ resource "aws_lambda_function" "ham_predict_updater" {
tags = {
Name = "ham_predict_updater"
}
lifecycle {
ignore_changes = [environment]
}
}


Expand Down
130 changes: 99 additions & 31 deletions lambda/ham_predict_updater/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys
sys.path.append("sns_to_mqtt/vendor")

import paho.mqtt.client as mqtt
import json
from datetime import datetime
from datetime import timedelta
Expand All @@ -8,7 +12,11 @@
import es
import asyncio
import functools
import os
import random
import time

TAWHIRI_SERVER = "tawhiri.v2.sondehub.org"

# FLIGHT PROFILE DEFAULTS
#
Expand Down Expand Up @@ -39,6 +47,49 @@
# Do not run predictions if the payload is below this altitude AMSL, and has an ascent rate below the above threshold.
ALTITUDE_AMSL_THRESHOLD = 1500.0


# Setup MQTT
client = mqtt.Client(transport="websockets")

connected_flag = False

import socket
socket.setdefaulttimeout(1)


## MQTT functions
def connect():
client.on_connect = on_connect
client.on_disconnect = on_disconnect
client.on_publish = on_publish
#client.tls_set()
client.username_pw_set(username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"))
HOSTS = os.getenv("MQTT_HOST").split(",")
PORT = int(os.getenv("MQTT_PORT", default="8080"))
if PORT == 443:
client.tls_set()
HOST = random.choice(HOSTS)
print(f"Connecting to {HOST}")
client.connect(HOST, PORT, 5)
client.loop_start()
print("loop started")

def on_disconnect(client, userdata, rc):
global connected_flag
print("disconnected")
connected_flag=False #set flag

def on_connect(client, userdata, flags, rc):
global connected_flag
if rc==0:
print("connected")
connected_flag=True #set flag
else:
print("Bad connection Returned code")

def on_publish(client, userdata, mid):
pass

def get_flight_docs():
path = "flight-doc/_search"
payload = {
Expand Down Expand Up @@ -252,7 +303,7 @@ def get_float_prediction(timestamp, latitude, longitude, altitude, current_rate=
# Generate the prediction URL
url = f"/api/v1/?launch_altitude={altitude}&launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&float_altitude={burst_altitude:.2f}&stop_datetime={(datetime.now() + timedelta(days=1)).isoformat()}Z&ascent_rate={ascent_rate:.2f}&profile=float_profile"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
conn.request("GET", url)
res = conn.getresponse()
data = res.read()
Expand Down Expand Up @@ -309,7 +360,7 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
# Generate the prediction URL
url = f"/api/v1/?launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&launch_altitude={altitude:.2f}&ascent_rate={ascent_rate:.2f}&burst_altitude={burst_altitude:.2f}&descent_rate={descent_rate:.2f}"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
conn.request("GET", url)
res = conn.getresponse()
data = res.read()
Expand Down Expand Up @@ -343,37 +394,38 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
return None



def get_ruaumoko(latitude, longitude):
"""
Request the ground level from ruaumoko.

Returns 0.0 if the ground level could not be determined, effectively
defaulting to any checks based on this data being based on mean sea level.
"""

# Shift longitude into the appropriate range for Tawhiri
if longitude < 0:
longitude += 360.0

# Generate the prediction URL
url = f"/api/ruaumoko/?latitude={latitude}&longitude={longitude}"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn.request("GET", url)
res = conn.getresponse()
data = res.read()

if res.code != 200:
logging.debug(data)
return None
# Need to mock this out if we ever use it again
#
# def get_ruaumoko(latitude, longitude):
# """
# Request the ground level from ruaumoko.

# Returns 0.0 if the ground level could not be determined, effectively
# defaulting to any checks based on this data being based on mean sea level.
# """

# # Shift longitude into the appropriate range for Tawhiri
# if longitude < 0:
# longitude += 360.0

# # Generate the prediction URL
# url = f"/api/ruaumoko/?latitude={latitude}&longitude={longitude}"
# logging.debug(url)
# conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
# conn.request("GET", url)
# res = conn.getresponse()
# data = res.read()

# if res.code != 200:
# logging.debug(data)
# return None

resp_data = json.loads(data.decode("utf-8"))
# resp_data = json.loads(data.decode("utf-8"))

if 'altitude' in resp_data:
return resp_data['altitude']
else:
return 0.0
# if 'altitude' in resp_data:
# return resp_data['altitude']
# else:
# return 0.0


def bulk_upload_es(index_prefix,payloads):
Expand All @@ -392,8 +444,11 @@ def bulk_upload_es(index_prefix,payloads):
raise RuntimeError

def predict(event, context):
# Connect to MQTT
connect()
# Use asyncio.run to synchronously "await" an async function
result = asyncio.run(predict_async(event, context))
time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with
return result

async def predict_async(event, context):
Expand Down Expand Up @@ -593,6 +648,18 @@ async def predict_async(event, context):
if len(output) > 0:
bulk_upload_es("ham-predictions", output)

# upload to mqtt
while not connected_flag:
time.sleep(0.01) # wait until connected
for prediction in output:
logging.debug(f'Publishing prediction for {prediction["payload_callsign"]} to MQTT')
client.publish(
topic=f'amateur-prediction/{prediction["payload_callsign"]}',
payload=json.dumps(prediction),
qos=0,
retain=False
)
logging.debug(f'Published prediction for {prediction["payload_callsign"]} to MQTT')

logging.debug("Finished")
return
Expand Down Expand Up @@ -639,6 +706,7 @@ async def run_predictions_for_serial(sem, flight_docs, serial, value):
if (abs(value['rate']) <= ASCENT_RATE_THRESHOLD) and (value['alt'] < ALTITUDE_AMSL_THRESHOLD):
# Payload is 'floating' (e.g. is probably on the ground), and is below 1500m AMSL.
# Don't run a prediction in this case. It probably just hasn't been launched yet.
logging.debug(f"{serial} is floating and alt is low so not running prediction")
return None


Expand Down
91 changes: 57 additions & 34 deletions lambda/ham_predict_updater/__main__.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,66 @@
from . import *
from . import mock_values, test_values
import unittest
from unittest.mock import MagicMock, call, patch

# Predictor test
# conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
# _now = datetime.utcnow().isoformat() + "Z"

# _ascent = get_standard_prediction(conn, _now, -34.0, 138.0, 10.0, burst_altitude=26000)
# print(f"Got {len(_ascent)} data points for ascent prediction.")
# _descent = get_standard_prediction(conn, _now, -34.0, 138.0, 24000.0, burst_altitude=24000.5)
# print(f"Got {len(_descent)} data points for descent prediction.")

# test = predict(
# {},{}
# )
#print(get_launch_sites())
#print(get_reverse_predictions())
# for _serial in test:
# print(f"{_serial['serial']}: {len(_serial['data'])}")
# Mock OpenSearch requests
def mock_es_request(body, path, method):
if path.endswith("_bulk"): # handle when the upload happens
return {}
elif(path == "flight-doc/_search"): # handle flightdoc queries
return mock_values.flight_docs
elif(path == "ham-telm-*/_search"): # handle telm searches
return mock_values.ham_telm
else:
raise NotImplemented

# Mock out tawhiri
class MockResponse(object):
code = 200
def read(self):
return mock_values.tawhiri_respose # currently we only mock a float profile

class MockHTTPS(object):
logging.debug(object)
def __init__(self, url):
logging.debug(url)
def request(self,method, url):
pass
def getresponse(self):
return MockResponse()

http.client.HTTPSConnection = MockHTTPS

logging.basicConfig(
format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG
)

print(predict(
{},{}
))
# bulk_upload_es("reverse-prediction",[{
# "datetime" : "2021-10-04",
# "data" : { },
# "serial" : "R12341234",
# "station" : "-2",
# "subtype" : "RS41-SGM",
# "ascent_rate" : "5",
# "alt" : 1000,
# "position" : [
# 1,
# 2
# ],
# "type" : "RS41"
# }]
# )
class TestAmateurPrediction(unittest.TestCase):
def setUp(self):
es.request = MagicMock(side_effect=mock_es_request)
client.connect = MagicMock()
client.loop_start = MagicMock()
client.username_pw_set = MagicMock()
client.tls_set = MagicMock()
client.publish = MagicMock()
on_connect(client, "userdata", "flags", 0)

@patch("time.sleep")
def test_float_prediction(self, MockSleep):
predict({},{})
date_prefix = datetime.now().strftime("%Y-%m")
es.request.assert_has_calls(
[
call(json.dumps(test_values.flight_doc_search),"flight-doc/_search", "POST"),
call(json.dumps(test_values.ham_telm_search), "ham-telm-*/_search", "GET"),
call(test_values.es_bulk_upload,f"ham-predictions-{date_prefix}/_bulk","POST")
]
)
client.username_pw_set.assert_called()
client.loop_start.assert_called()
client.connect.assert_called()
client.publish.assert_has_calls([test_values.mqtt_publish_call])
time.sleep.assert_called_with(0.5) # make sure we sleep to let paho mqtt queue clear

if __name__ == '__main__':
unittest.main()
Loading
Loading