Skip to content

Commit

Permalink
build: add first/draft dsl tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nmvrs committed Apr 30, 2024
1 parent 268a0e1 commit c221fe4
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
dsl: marks tests that verify moai DSL parsing (deselect with '-m "not dsl"')
Empty file added tests/dsl/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions tests/dsl/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from hydra_plugins.moai_dsl_plugin.moai_dsl_plugin import __MOAI_GRAMMAR__
from lark import Lark

import pytest
import torch
import numpy as np

@pytest.fixture
def parser():
return Lark(__MOAI_GRAMMAR__, parser='earley')

@pytest.fixture
def various_tensors():
return {
'test': 1,
'test2': torch.scalar_tensor(2),
'test3': torch.scalar_tensor(2)[np.newaxis],
'another': {
'number': 10,
'number2': torch.scalar_tensor(1),
'number3': torch.scalar_tensor(1)[np.newaxis],
},
'add': {
'this': -1,
'this2': torch.scalar_tensor(15),
'this3': torch.scalar_tensor(15)[np.newaxis],
}
}
31 changes: 31 additions & 0 deletions tests/dsl/test_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from moai.core.execution.expression import TreeModule

import pytest
import torch
import lark

@pytest.mark.dsl
class TestDSL:
def _parse_and_run(self, parser: lark.Lark, expression: str, tensors) -> torch.Tensor:
tree = parser.parse(expression)
m = TreeModule('check', tree)
m(tensors)
return tensors['check']

def test_combined(self, parser, various_tensors):
expression = "test + ( another.number + add.this - 5) + stack(another.number2, test2, add.this2, 0) + cat(another.number3, test3, add.this3, 0)"
x = self._parse_and_run(parser, expression, various_tensors)
y = torch.tensor([7.0, 9.0, 35.])
assert torch.equal(x, y)

def test_add(self, parser, various_tensors):
expression = "test2 + add.this2"
x = self._parse_and_run(parser, expression, various_tensors)
y = torch.scalar_tensor(17.0)
assert torch.equal(x, y)

def test_adds(self, parser, various_tensors):
expression = "test2 + add.this2 + (another.number2 + another.number3 + add.this3)"
x = self._parse_and_run(parser, expression, various_tensors)
y = torch.scalar_tensor(34.0)
assert torch.equal(x, y)

0 comments on commit c221fe4

Please sign in to comment.