Skip to content

Commit

Permalink
speed up MOGPR test
Browse files Browse the repository at this point in the history
  • Loading branch information
Matic Lubej committed Nov 3, 2023
1 parent e811862 commit 8a8cb06
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions tests/test_mogpr.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
import io


import numpy as np
import numpy.testing
import pytest
import requests
import xarray

from fusets.mogpr import mogpr, MOGPRTransformer
from fusets.whittaker import whittaker
from fusets.mogpr import MOGPRTransformer


@pytest.fixture()
def data() -> xarray.Dataset:
bytes_data = requests.get("https://artifactory.vgt.vito.be/testdata-public/fusets/b4_b8_vv_vh/rape.nc", stream=True)
ds = xarray.load_dataset(io.BytesIO(bytes_data.content))
ds = ds.isel(t=ds.t.dt.year.isin([2019, 2020]), x=slice(9, 11), y=slice(9, 11))

ds["RVI"] = (ds.VH + ds.VH) / (ds.VV + ds.VH)
ds["NDVI"] = (ds.B08 - ds.B04) / (ds.B04 + ds.B08)
return ds[["NDVI", "RVI"]]

ds = xarray.load_dataset(io.BytesIO(requests.get("https://artifactory.vgt.vito.be/testdata-public/fusets/b4_b8_vv_vh/rape.nc",stream=True).content))

ds['RVI'] = (ds.VH + ds.VH) / (ds.VV + ds.VH)
ds['NDVI'] = (ds.B08 - ds.B04) / (ds.B04 + ds.B08)
vars = ds[['NDVI', 'RVI']]
def test_mogpr_udf(data):
from openeo.udf import XarrayDataCube

def test_mogpr_udf():
"""
Simple test to help debug udf
Returns:
from fusets.openeo.mogpr_udf import apply_datacube

"""
result = apply_datacube(XarrayDataCube(data.to_array(dim="bands")), context={})
assert result.array.dims == ("bands", "t", "y", "x")
assert result.array.shape == (2, 146, 2, 2)

from fusets.openeo.mogpr_udf import apply_datacube
from openeo.udf import XarrayDataCube
result = apply_datacube(XarrayDataCube(vars.to_array(dim="bands")),context={})
print(result)
assert result.array.dims == ("bands","t","y","x")
assert result.array.shape == (2, 374, 19, 21)

def test_mogpr_train_model():
def test_mogpr_train_model(data):
t = MOGPRTransformer()
t.fit(vars)
out = t.transform(vars)
t.fit(data)
out = t.transform(data)

print(out)
print(out.NDVI.mean(dim=("x","y")))
assert tuple(out.coords) == ("t", "y", "x")
assert out.NDVI.shape == (146, 2, 2)

0 comments on commit 8a8cb06

Please sign in to comment.