-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Matic Lubej
committed
Nov 3, 2023
1 parent
e811862
commit 8a8cb06
Showing
1 changed file
with
21 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,37 @@ | ||
import io | ||
|
||
|
||
import numpy as np | ||
import numpy.testing | ||
import pytest | ||
import requests | ||
import xarray | ||
|
||
from fusets.mogpr import mogpr, MOGPRTransformer | ||
from fusets.whittaker import whittaker | ||
from fusets.mogpr import MOGPRTransformer | ||
|
||
|
||
@pytest.fixture() | ||
def data() -> xarray.Dataset: | ||
bytes_data = requests.get("https://artifactory.vgt.vito.be/testdata-public/fusets/b4_b8_vv_vh/rape.nc", stream=True) | ||
ds = xarray.load_dataset(io.BytesIO(bytes_data.content)) | ||
ds = ds.isel(t=ds.t.dt.year.isin([2019, 2020]), x=slice(9, 11), y=slice(9, 11)) | ||
|
||
ds["RVI"] = (ds.VH + ds.VH) / (ds.VV + ds.VH) | ||
ds["NDVI"] = (ds.B08 - ds.B04) / (ds.B04 + ds.B08) | ||
return ds[["NDVI", "RVI"]] | ||
|
||
ds = xarray.load_dataset(io.BytesIO(requests.get("https://artifactory.vgt.vito.be/testdata-public/fusets/b4_b8_vv_vh/rape.nc",stream=True).content)) | ||
|
||
ds['RVI'] = (ds.VH + ds.VH) / (ds.VV + ds.VH) | ||
ds['NDVI'] = (ds.B08 - ds.B04) / (ds.B04 + ds.B08) | ||
vars = ds[['NDVI', 'RVI']] | ||
def test_mogpr_udf(data): | ||
from openeo.udf import XarrayDataCube | ||
|
||
def test_mogpr_udf(): | ||
""" | ||
Simple test to help debug udf | ||
Returns: | ||
from fusets.openeo.mogpr_udf import apply_datacube | ||
|
||
""" | ||
result = apply_datacube(XarrayDataCube(data.to_array(dim="bands")), context={}) | ||
assert result.array.dims == ("bands", "t", "y", "x") | ||
assert result.array.shape == (2, 146, 2, 2) | ||
|
||
from fusets.openeo.mogpr_udf import apply_datacube | ||
from openeo.udf import XarrayDataCube | ||
result = apply_datacube(XarrayDataCube(vars.to_array(dim="bands")),context={}) | ||
print(result) | ||
assert result.array.dims == ("bands","t","y","x") | ||
assert result.array.shape == (2, 374, 19, 21) | ||
|
||
def test_mogpr_train_model(): | ||
def test_mogpr_train_model(data): | ||
t = MOGPRTransformer() | ||
t.fit(vars) | ||
out = t.transform(vars) | ||
t.fit(data) | ||
out = t.transform(data) | ||
|
||
print(out) | ||
print(out.NDVI.mean(dim=("x","y"))) | ||
assert tuple(out.coords) == ("t", "y", "x") | ||
assert out.NDVI.shape == (146, 2, 2) |