From e311349d44c00f97860a64f750fa7bdbd8b950b7 Mon Sep 17 00:00:00 2001 From: almogbaku Date: Sat, 20 Apr 2024 22:27:20 +0300 Subject: [PATCH] fix: fix labsdk bugs --- api/common.go | 2 +- labsdk/_test/main.py | 12 ++++++------ labsdk/raptor/decorators.py | 14 +++++++------- labsdk/raptor/program.py | 2 +- labsdk/raptor/replay.py | 5 +---- labsdk/raptor/types/feature.py | 6 +++--- labsdk/raptor/types/model_impl.py | 8 ++++---- 7 files changed, 23 insertions(+), 26 deletions(-) diff --git a/api/common.go b/api/common.go index 5a72cd32..b49532e2 100644 --- a/api/common.go +++ b/api/common.go @@ -22,7 +22,7 @@ import ( "strconv" ) -var FQNRegExp = regexp.MustCompile(`(?si)^((?P([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})\.)?(?P([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})(\+(?P([a-z]+_*[a-z]+)))?(@-(?P([0-9]+)))?(\[(?P([a-z]+_*[a-z]+))])?$`) +var FQNRegExp = regexp.MustCompile(`(?si)^((?P[a-z0-9]+(?:_[a-z0-9]+)*)\.)?(?P[a-z0-9]+(?:_[a-z0-9]+)*)(\+(?P([a-z]+_*[a-z]+)))?(@-(?P([0-9]+)))?(\[(?P([a-z]+_*[a-z]+))])?$`) func ParseSelector(fqn string) (namespace, name string, aggrFn AggrFn, version uint, encoding string, err error) { if !FQNRegExp.MatchString(fqn) { diff --git a/labsdk/_test/main.py b/labsdk/_test/main.py index eadaafc7..7856fc3e 100644 --- a/labsdk/_test/main.py +++ b/labsdk/_test/main.py @@ -67,7 +67,7 @@ def question_marks_10h(this_row: Email, ctx: Context) -> int: ) class Deal(TypedDict): id: int - event_at: pd.Timestamp + event_at: datetime account_id: str amount: float @@ -164,8 +164,8 @@ def deal_prediction(ctx: TrainingContext) -> float: accuracy = xgb_model.score(X_test, y_test) # Make sure the model has a minimum accuracy of 0.7 - if accuracy < 0.7: - raise Exception('Accuracy is below 0.7') + # if accuracy < 0.7: + # raise Exception('Accuracy is below 0.7') return xgb_model @@ -178,7 +178,7 @@ def deal_prediction(ctx: TrainingContext) -> float: # counters @feature(keys='account_id', data_source=Deal) -@aggregation(function=AggregationFunction.Count, over='9999984h', granularity='9999984h') +@aggregation(function=AggregationFunction.Count, over='999999h', granularity='999999h') def views(this_row: Deal, ctx: Context) -> int: return 1 @@ -235,7 +235,7 @@ def unique_deals_involvement_annually(this_row: CrmRecord, ctx: Context) -> int: @feature(keys='salesman_id', data_source=CrmRecord) -@aggregation(function=AggregationFunction.DistinctCount, over='8760h', granularity='24h') +@aggregation(function=AggregationFunction.Count, over='8760h', granularity='24h') def closed_deals_annually(this_row: CrmRecord, ctx: Context) -> int: if this_row['action'] == 'deal_closed': return 1 @@ -344,7 +344,7 @@ def commits_30m_greater_2(this_row: Commit, ctx: Context) -> bool: model_framework='sklearn', ) @freshness(max_age='1h', max_stale='100h') -def newest(): +def newest(ctx: TrainingContext) -> float: # TODO: implement pass diff --git a/labsdk/raptor/decorators.py b/labsdk/raptor/decorators.py index 5a189274..9ac8a951 100644 --- a/labsdk/raptor/decorators.py +++ b/labsdk/raptor/decorators.py @@ -26,8 +26,8 @@ from warnings import warn from pandas import DataFrame -from pydantic import create_model_from_typeddict -from typing_extensions import TypedDict +from pydantic import TypeAdapter +from typing_extensions import TypedDict, _TypedDictMeta from . import local_state, config, replay from ._internal import durpy @@ -39,9 +39,9 @@ from .types.dsrc_config_stubs.rest import RestConfig if sys.version_info >= (3, 8): - from typing import TypedDict as typing_TypedDict + from typing import _TypedDictMeta as typing_TypedDictMeta else: - typing_TypedDict = type(None) + typing_TypedDictMeta = type(None) def _wrap_decorator_err(f): @@ -245,9 +245,9 @@ class Deal(typing_extensions.TypedDict): @_wrap_decorator_err def decorator(cls: TypedDict): - if type(cls) == type(typing_TypedDict): + if isinstance(cls, typing_TypedDictMeta): raise Exception('You should use typing_extensions.TypedDict instead of typing.TypedDict') - elif type(cls) != type(TypedDict): + elif not isinstance(cls, _TypedDictMeta): raise Exception('data_source decorator must be used on a class that extends typing_extensions.TypedDict') nonlocal name @@ -269,7 +269,7 @@ def decorator(cls: TypedDict): spec.namespace = options['namespace'] # convert cls to json schema - spec.schema = create_model_from_typeddict(cls).schema() + spec.schema = TypeAdapter(cls).json_schema() # register cls.raptor_spec = spec diff --git a/labsdk/raptor/program.py b/labsdk/raptor/program.py index 754a30fd..b101c6e4 100644 --- a/labsdk/raptor/program.py +++ b/labsdk/raptor/program.py @@ -43,7 +43,7 @@ from redbaron import RedBaron, DefNode selector_regex = re.compile( - r'^((?P([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})\.)?(?P([a0-z9]+[a0-z9_]*[a0-z9]+){1,256})(\+(?P([a-z]+_*[a-z]+)))?(@-(?P([0-9]+)))?(\[(?P([a-z]+_*[a-z]+))])?$', + r'^((?P[a-z0-9]+(?:_[a-z0-9]+)*)\.)?(?P[a-z0-9]+(?:_[a-z0-9]+)*)(\+(?P([a-z]+_*[a-z]+)))?(@-(?P([0-9]+)))?(\[(?P([a-z]+_*[a-z]+))])?$', re.IGNORECASE) primitive = Union[str, int, float, bool, datetime, List[str], List[int], List[float], List[bool], List[datetime], None] diff --git a/labsdk/raptor/replay.py b/labsdk/raptor/replay.py index f71178d4..d9d8ad5f 100644 --- a/labsdk/raptor/replay.py +++ b/labsdk/raptor/replay.py @@ -141,13 +141,10 @@ def _replay(store_locally=True): for aggr in spec.aggr.funcs: f = f'{spec.fqn()}+{aggr.value}' - result = aggr.apply(fvg).reset_index(0).rename(columns={'value': f}) + result = aggr.apply(fvg).reset_index(0).rename(columns={val_field: f}) feature_values = feature_values.merge(result, on=['timestamp', 'keys'], how='left') fields.append(f) - if 'f_value' in feature_values.columns: - feature_values = feature_values.drop('f_value', axis=1) - feature_values = feature_values.reset_index().drop(columns=['value']). \ melt(id_vars=['timestamp', 'keys'], value_vars=fields, var_name='fqn', value_name='value') if store_locally: diff --git a/labsdk/raptor/types/feature.py b/labsdk/raptor/types/feature.py index 03c9724e..c6e46570 100644 --- a/labsdk/raptor/types/feature.py +++ b/labsdk/raptor/types/feature.py @@ -20,7 +20,7 @@ import pandas as pd import yaml from pandas.core.window import RollingGroupby -from typing_extensions import TypedDict +from typing_extensions import _TypedDictMeta from .common import RaptorSpec, ResourceReference, _k8s_name, EnumSpec, RuntimeSpec from .dsrc import DataSourceSpec @@ -183,8 +183,8 @@ def __setattr__(self, key, value): elif isinstance(value, str): value = ResourceReference(value) self.data_source_spec = local_state.spec_by_selector(value.fqn()) - elif type(value) == type(TypedDict) or isinstance(value, DataSourceSpec): - if type(value) == type(TypedDict): + elif isinstance(value, _TypedDictMeta) or isinstance(value, DataSourceSpec): + if isinstance(value, _TypedDictMeta): if hasattr(value, 'raptor_spec'): value = value.raptor_spec else: diff --git a/labsdk/raptor/types/model_impl.py b/labsdk/raptor/types/model_impl.py index 29b21e2f..c67d3e72 100644 --- a/labsdk/raptor/types/model_impl.py +++ b/labsdk/raptor/types/model_impl.py @@ -14,8 +14,8 @@ # limitations under the License. import inspect -from datetime import timedelta -from typing import Optional, Callable +from datetime import timedelta, datetime +from typing import Optional, Callable, Union from . import SecretKeyRef from .common import _k8s_name @@ -40,8 +40,8 @@ def __init__(self, keys=None, model_framework=None, model_server=None, *args, ** self.exporter = ModelExporter(self) self._features_and_labels = replay.new_historical_get(self) - def features_and_labels(self): - return self._features_and_labels() + def features_and_labels(self, since: Optional[Union[datetime, str]] = None, until: Optional[Union[datetime, str]] = None): + return self._features_and_labels(since, until) def train(self): for f in (self.features + self.label_features + ([self.key_feature] if self.key_feature is not None else [])):