Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HPL v1.3 #11

Merged
merged 27 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## v1.3.0 - 2023-10-04
### Added
- `split_and(p: HplPredicate | HplExpression)` function to `hpl.rewrite` module.
- `simplify(p: HplPredicate | HplExpression)` function to `hpl.rewrite` module.
- `is_inclusion` and `is_comparison` properties to `BinaryOperatorDefinition`.
- Factory functions `boolean`, `number` and `string` to `HplLiteral`.

### Changed
- Small optimization to `HplExpression` type checking system.

### Fixed
- `HplPredicateExpression.expression` is now cast to `DataType.BOOL` on construction.

## v1.2.0 - 2023-09-08
### Added
- `canonical_form(property: HplProperty)` function to `hpl.rewrite` module.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def read(filename):
install_requires=[
'attrs~=23.0',
'lark~=1.0',
'typeguard~=4.1',
],
extras_require={
'dev': [
Expand Down
112 changes: 103 additions & 9 deletions src/hpl/ast/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from attrs import field, frozen
from attrs.validators import deep_iterable, instance_of
from typeguard import check_type

from hpl.ast.base import HplAstObject
from hpl.errors import HplSanityError, index_out_of_range, missing_field, type_error_in_expr
Expand Down Expand Up @@ -112,7 +113,7 @@ def is_fully_typed(self) -> bool:
def cast(self, t: DataType) -> 'HplExpression':
try:
r: DataType = self.data_type.cast(t)
return self.but(data_type=r)
return self if r == self.data_type else self.but(data_type=r)
except TypeError as e:
raise type_error_in_expr(e, self)

Expand Down Expand Up @@ -399,6 +400,21 @@ def true(cls) -> 'HplLiteral':
def false(cls) -> 'HplLiteral':
return cls(token='False', value=False)

@classmethod
def boolean(cls, value: bool) -> 'HplLiteral':
value = check_type(value, bool)
return cls(token=str(value), value=value)

@classmethod
def number(cls, value: Union[int, float]) -> 'HplLiteral':
value = check_type(value, Union[int, float])
return cls(token=str(value), value=value)

@classmethod
def string(cls, value: str) -> 'HplLiteral':
value = check_type(value, str)
return cls(token=value, value=value)

@property
def default_data_type(self) -> DataType:
return DataType.PRIMITIVE
Expand Down Expand Up @@ -795,11 +811,12 @@ class BinaryOperatorDefinition:
result: DataType
infix: bool = True
commutative: bool = False
associative: bool = False

@classmethod
def addition(cls) -> 'BinaryOperatorDefinition':
t = DataType.NUMBER
return cls('+', t, t, t, infix=True, commutative=True)
return cls('+', t, t, t, infix=True, commutative=True, associative=True)

@classmethod
def subtraction(cls) -> 'BinaryOperatorDefinition':
Expand All @@ -809,7 +826,7 @@ def subtraction(cls) -> 'BinaryOperatorDefinition':
@classmethod
def multiplication(cls) -> 'BinaryOperatorDefinition':
t = DataType.NUMBER
return cls('*', t, t, t, infix=True, commutative=True)
return cls('*', t, t, t, infix=True, commutative=True, associative=True)

@classmethod
def division(cls) -> 'BinaryOperatorDefinition':
Expand All @@ -819,7 +836,7 @@ def division(cls) -> 'BinaryOperatorDefinition':
@classmethod
def power(cls) -> 'BinaryOperatorDefinition':
t = DataType.NUMBER
return cls('**', t, t, t, infix=True, commutative=False)
return cls('**', t, t, t, infix=True, commutative=False, associative=True)

@classmethod
def implication(cls) -> 'BinaryOperatorDefinition':
Expand All @@ -829,27 +846,27 @@ def implication(cls) -> 'BinaryOperatorDefinition':
@classmethod
def equivalence(cls) -> 'BinaryOperatorDefinition':
t = DataType.BOOL
return cls(IFF_OPERATOR, t, t, t, infix=True, commutative=True)
return cls(IFF_OPERATOR, t, t, t, infix=True, commutative=True, associative=True)

@classmethod
def disjunction(cls) -> 'BinaryOperatorDefinition':
t = DataType.BOOL
return cls(OR_OPERATOR, t, t, t, infix=True, commutative=True)
return cls(OR_OPERATOR, t, t, t, infix=True, commutative=True, associative=True)

@classmethod
def conjunction(cls) -> 'BinaryOperatorDefinition':
t = DataType.BOOL
return cls(AND_OPERATOR, t, t, t, infix=True, commutative=True)
return cls(AND_OPERATOR, t, t, t, infix=True, commutative=True, associative=True)

@classmethod
def equality(cls) -> 'BinaryOperatorDefinition':
t = DataType.PRIMITIVE
return cls('=', t, t, DataType.BOOL, infix=True, commutative=True)
return cls('=', t, t, DataType.BOOL, infix=True, commutative=True, associative=True)

@classmethod
def inequality(cls) -> 'BinaryOperatorDefinition':
t = DataType.PRIMITIVE
return cls('!=', t, t, DataType.BOOL, infix=True, commutative=True)
return cls('!=', t, t, DataType.BOOL, infix=True, commutative=True, associative=True)

@classmethod
def less_than(cls) -> 'BinaryOperatorDefinition':
Expand Down Expand Up @@ -886,6 +903,10 @@ def name(self) -> str:
def parameters(self) -> Tuple[DataType, DataType]:
return (self.parameter1, self.parameter2)

@property
def is_arithmetic(self) -> bool:
return self.token in ('+', '-', '*', '/', '**')

@property
def is_plus(self) -> bool:
return self.token == '+'
Expand All @@ -902,6 +923,42 @@ def is_times(self) -> bool:
def is_division(self) -> bool:
return self.token == '/'

@property
def is_power(self) -> bool:
return self.token == '**'

@property
def is_inclusion(self) -> bool:
return self.token == IN_OPERATOR

@property
def is_comparison(self) -> bool:
return self.token in ('=', '!=', '<', '<=', '>', '>=')

@property
def is_equality(self) -> bool:
return self.token == '='

@property
def is_inequality(self) -> bool:
return self.token == '!='

@property
def is_less_than(self) -> bool:
return self.token == '<'

@property
def is_less_than_eq(self) -> bool:
return self.token == '<='

@property
def is_greater_than(self) -> bool:
return self.token == '>'

@property
def is_greater_than_eq(self) -> bool:
return self.token == '>='

@property
def is_and(self) -> bool:
return self.token == AND_OPERATOR
Expand All @@ -918,6 +975,10 @@ def is_implies(self) -> bool:
def is_iff(self) -> bool:
return self.token == IFF_OPERATOR

@property
def similar_parameter_types(self) -> bool:
return bool(self.parameter1 & self.parameter2)

def __str__(self) -> str:
return self.token

Expand Down Expand Up @@ -986,6 +1047,31 @@ def _check_operand2(self, _attribute, arg: HplExpression):

def __attrs_post_init__(self):
object.__setattr__(self, 'data_type', self.operator.result)
if self.operator.similar_parameter_types:
a: HplExpression = self.operand1.cast(self.operand2.data_type)
b: HplExpression = self.operand2.cast(a.data_type)
object.__setattr__(self, 'operand1', a)
object.__setattr__(self, 'operand2', b)

@classmethod
def addition(cls, a: HplExpression, b: HplExpression) -> 'BinaryOperatorDefinition':
return cls(operator=BuiltinBinaryOperator.ADD, operand1=a, operand2=b)

@classmethod
def subtraction(cls, a: HplExpression, b: HplExpression) -> 'BinaryOperatorDefinition':
return cls(operator=BuiltinBinaryOperator.SUB, operand1=a, operand2=b)

@classmethod
def multiplication(cls, a: HplExpression, b: HplExpression) -> 'BinaryOperatorDefinition':
return cls(operator=BuiltinBinaryOperator.MULT, operand1=a, operand2=b)

@classmethod
def division(cls, a: HplExpression, b: HplExpression) -> 'BinaryOperatorDefinition':
return cls(operator=BuiltinBinaryOperator.DIV, operand1=a, operand2=b)

@classmethod
def power(cls, a: HplExpression, b: HplExpression) -> 'BinaryOperatorDefinition':
return cls(operator=BuiltinBinaryOperator.POW, operand1=a, operand2=b)

@classmethod
def conjunction(cls, a: HplExpression, b: HplExpression) -> 'HplBinaryOperator':
Expand Down Expand Up @@ -1462,6 +1548,10 @@ class HplFieldAccess(HplDataAccess):
def is_field(self) -> bool:
return True

@property
def object(self) -> HplExpression:
return self.message

def children(self) -> Tuple[HplExpression]:
return (self.message,)

Expand Down Expand Up @@ -1513,6 +1603,10 @@ class HplArrayAccess(HplDataAccess):
def is_indexed(self) -> bool:
return True

@property
def object(self) -> HplExpression:
return self.array

def children(self) -> Tuple[HplExpression, HplExpression]:
return (self.array, self.index)

Expand Down
9 changes: 7 additions & 2 deletions src/hpl/ast/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, List, Mapping, Optional, Set, Tuple

from attrs import field, frozen
from attrs.validators import instance_of
from typeguard import typechecked

from hpl.ast.base import HplAstObject
from hpl.ast.expressions import (
Expand Down Expand Up @@ -85,9 +85,14 @@ def type_check_references(
###############################################################################


@typechecked
def _cast_expr_to_bool(expr: HplExpression) -> HplExpression:
return expr.cast(DataType.BOOL)


@frozen
class HplPredicateExpression(HplPredicate):
expression: HplExpression = field(validator=instance_of(HplExpression))
expression: HplExpression = field(converter=_cast_expr_to_bool)

@expression.validator
def _check_expression(self, _attribute, expr: HplExpression):
Expand Down
Loading