Skip to content

Commit

Permalink
fix: fix labsdk bugs (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlmogBaku authored Apr 20, 2024
2 parents 314ecdf + 447f9b5 commit f4fe978
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 26 deletions.
2 changes: 1 addition & 1 deletion api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"strconv"
)

var FQNRegExp = regexp.MustCompile(`(?si)^((?P<namespace>([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})\.)?(?P<name>([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})(\+(?P<aggrFn>([a-z]+_*[a-z]+)))?(@-(?P<version>([0-9]+)))?(\[(?P<encoding>([a-z]+_*[a-z]+))])?$`)
var FQNRegExp = regexp.MustCompile(`(?si)^((?P<namespace>[a-z0-9]+(?:_[a-z0-9]+)*)\.)?(?P<name>[a-z0-9]+(?:_[a-z0-9]+)*)(\+(?P<aggrFn>([a-z]+_*[a-z]+)))?(@-(?P<version>([0-9]+)))?(\[(?P<encoding>([a-z]+_*[a-z]+))])?$`)

func ParseSelector(fqn string) (namespace, name string, aggrFn AggrFn, version uint, encoding string, err error) {
if !FQNRegExp.MatchString(fqn) {
Expand Down
12 changes: 6 additions & 6 deletions labsdk/_test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def question_marks_10h(this_row: Email, ctx: Context) -> int:
)
class Deal(TypedDict):
id: int
event_at: pd.Timestamp
event_at: datetime
account_id: str
amount: float

Expand Down Expand Up @@ -164,8 +164,8 @@ def deal_prediction(ctx: TrainingContext) -> float:
accuracy = xgb_model.score(X_test, y_test)

# Make sure the model has a minimum accuracy of 0.7
if accuracy < 0.7:
raise Exception('Accuracy is below 0.7')
# if accuracy < 0.7:
# raise Exception('Accuracy is below 0.7')

return xgb_model

Expand All @@ -178,7 +178,7 @@ def deal_prediction(ctx: TrainingContext) -> float:

# counters
@feature(keys='account_id', data_source=Deal)
@aggregation(function=AggregationFunction.Count, over='9999984h', granularity='9999984h')
@aggregation(function=AggregationFunction.Count, over='999999h', granularity='999999h')
def views(this_row: Deal, ctx: Context) -> int:
return 1

Expand Down Expand Up @@ -235,7 +235,7 @@ def unique_deals_involvement_annually(this_row: CrmRecord, ctx: Context) -> int:


@feature(keys='salesman_id', data_source=CrmRecord)
@aggregation(function=AggregationFunction.DistinctCount, over='8760h', granularity='24h')
@aggregation(function=AggregationFunction.Count, over='8760h', granularity='24h')
def closed_deals_annually(this_row: CrmRecord, ctx: Context) -> int:
if this_row['action'] == 'deal_closed':
return 1
Expand Down Expand Up @@ -344,7 +344,7 @@ def commits_30m_greater_2(this_row: Commit, ctx: Context) -> bool:
model_framework='sklearn',
)
@freshness(max_age='1h', max_stale='100h')
def newest():
def newest(ctx: TrainingContext) -> float:
# TODO: implement
pass

Expand Down
14 changes: 7 additions & 7 deletions labsdk/raptor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from warnings import warn

from pandas import DataFrame
from pydantic import create_model_from_typeddict
from typing_extensions import TypedDict
from pydantic import TypeAdapter
from typing_extensions import TypedDict, _TypedDictMeta

from . import local_state, config, replay
from ._internal import durpy
Expand All @@ -39,9 +39,9 @@
from .types.dsrc_config_stubs.rest import RestConfig

if sys.version_info >= (3, 8):
from typing import TypedDict as typing_TypedDict
from typing import _TypedDictMeta as typing_TypedDictMeta
else:
typing_TypedDict = type(None)
typing_TypedDictMeta = type(None)


def _wrap_decorator_err(f):
Expand Down Expand Up @@ -245,9 +245,9 @@ class Deal(typing_extensions.TypedDict):

@_wrap_decorator_err
def decorator(cls: TypedDict):
if type(cls) == type(typing_TypedDict):
if isinstance(cls, typing_TypedDictMeta):
raise Exception('You should use typing_extensions.TypedDict instead of typing.TypedDict')
elif type(cls) != type(TypedDict):
elif not isinstance(cls, _TypedDictMeta):
raise Exception('data_source decorator must be used on a class that extends typing_extensions.TypedDict')

nonlocal name
Expand All @@ -269,7 +269,7 @@ def decorator(cls: TypedDict):
spec.namespace = options['namespace']

# convert cls to json schema
spec.schema = create_model_from_typeddict(cls).schema()
spec.schema = TypeAdapter(cls).json_schema()

# register
cls.raptor_spec = spec
Expand Down
2 changes: 1 addition & 1 deletion labsdk/raptor/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from redbaron import RedBaron, DefNode

selector_regex = re.compile(
r'^((?P<namespace>([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})\.)?(?P<name>([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})(\+(?P<aggrFn>([a-z]+_*[a-z]+)))?(@-(?P<version>([0-9]+)))?(\[(?P<encoding>([a-z]+_*[a-z]+))])?$',
r'^((?P<namespace>[a-z0-9]+(?:_[a-z0-9]+)*)\.)?(?P<name>[a-z0-9]+(?:_[a-z0-9]+)*)(\+(?P<aggrFn>([a-z]+_*[a-z]+)))?(@-(?P<version>([0-9]+)))?(\[(?P<encoding>([a-z]+_*[a-z]+))])?$',
re.IGNORECASE)

primitive = Union[str, int, float, bool, datetime, List[str], List[int], List[float], List[bool], List[datetime], None]
Expand Down
5 changes: 1 addition & 4 deletions labsdk/raptor/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,10 @@ def _replay(store_locally=True):

for aggr in spec.aggr.funcs:
f = f'{spec.fqn()}+{aggr.value}'
result = aggr.apply(fvg).reset_index(0).rename(columns={'value': f})
result = aggr.apply(fvg).reset_index(0).rename(columns={val_field: f})
feature_values = feature_values.merge(result, on=['timestamp', 'keys'], how='left')
fields.append(f)

if 'f_value' in feature_values.columns:
feature_values = feature_values.drop('f_value', axis=1)

feature_values = feature_values.reset_index().drop(columns=['value']). \
melt(id_vars=['timestamp', 'keys'], value_vars=fields, var_name='fqn', value_name='value')
if store_locally:
Expand Down
6 changes: 3 additions & 3 deletions labsdk/raptor/types/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pandas as pd
import yaml
from pandas.core.window import RollingGroupby
from typing_extensions import TypedDict
from typing_extensions import _TypedDictMeta

from .common import RaptorSpec, ResourceReference, _k8s_name, EnumSpec, RuntimeSpec
from .dsrc import DataSourceSpec
Expand Down Expand Up @@ -183,8 +183,8 @@ def __setattr__(self, key, value):
elif isinstance(value, str):
value = ResourceReference(value)
self.data_source_spec = local_state.spec_by_selector(value.fqn())
elif type(value) == type(TypedDict) or isinstance(value, DataSourceSpec):
if type(value) == type(TypedDict):
elif isinstance(value, _TypedDictMeta) or isinstance(value, DataSourceSpec):
if isinstance(value, _TypedDictMeta):
if hasattr(value, 'raptor_spec'):
value = value.raptor_spec
else:
Expand Down
8 changes: 4 additions & 4 deletions labsdk/raptor/types/model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.

import inspect
from datetime import timedelta
from typing import Optional, Callable
from datetime import timedelta, datetime
from typing import Optional, Callable, Union

from . import SecretKeyRef
from .common import _k8s_name
Expand All @@ -40,8 +40,8 @@ def __init__(self, keys=None, model_framework=None, model_server=None, *args, **
self.exporter = ModelExporter(self)
self._features_and_labels = replay.new_historical_get(self)

def features_and_labels(self):
return self._features_and_labels()
def features_and_labels(self, since: Optional[Union[datetime, str]] = None, until: Optional[Union[datetime, str]] = None):
return self._features_and_labels(since, until)

def train(self):
for f in (self.features + self.label_features + ([self.key_feature] if self.key_feature is not None else [])):
Expand Down

0 comments on commit f4fe978

Please sign in to comment.