Skip to content

Commit

Permalink
Merge branch 'master' into andrew/ruff7
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Nov 5, 2024
2 parents ce9d6de + cb14f6b commit d7d3818
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
23 changes: 20 additions & 3 deletions weave-js/src/components/Panel2/PanelTable/tableState.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
allObjPaths,
canSortType,
constFunction,
ConstNode,
constNodeUnsafe,
Expand Down Expand Up @@ -687,6 +688,16 @@ export async function disableGroupByCol(
) {
const colIds = _.isArray(colId) ? colId : [colId];
const groupBy = ts.groupBy;

// (WB-16067)
// We may try to sort on aggregated columns after ungrouping
// To prevent this, disable sorting on all the columns and re-enable
// after the ungroup
const initiallySortedCols = _.clone(ts.sort);
ts.sort.forEach(sortObj => {
ts = disableSortByCol(ts, sortObj.columnId);
});

ts = produce(ts, draft => {
draft.autoColumns = false;
for (const cid of colIds) {
Expand All @@ -701,9 +712,15 @@ export async function disableGroupByCol(
}
});
ts = await refreshSelectFunctions(ts, inputArrayNode, weave, stack);
if (ts.sort.find(s => s.columnId === colId) !== undefined) {
ts = disableSortByCol(ts, colId);
}

initiallySortedCols.forEach(sortObj => {
if (
sortObj.columnId !== colId &&
canSortType(ts.columnSelectFunctions[sortObj.columnId].type)
) {
ts = enableSortByCol(ts, sortObj.columnId, sortObj.dir === 'asc');
}
});
return ts;
}

Expand Down
10 changes: 5 additions & 5 deletions weave/integrations/cohere/cohere_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _accumulate_content(

def cohere_wrapper(name: str) -> Callable:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return op

Expand Down Expand Up @@ -122,7 +122,7 @@ def _wrapper(*args: Any, **kwargs: Any) -> Any:

return _wrapper

op = weave.op()(_post_process_response(fn))
op = weave.op(_post_process_response(fn))
op.name = name # type: ignore
return op

Expand Down Expand Up @@ -156,7 +156,7 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:

return _wrapper

op = weave.op()(_post_process_response(fn))
op = weave.op(_post_process_response(fn))
op.name = name # type: ignore
return op

Expand All @@ -165,7 +165,7 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:

def cohere_stream_wrapper(name: str) -> Callable:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return add_accumulator(op, lambda inputs: cohere_accumulator) # type: ignore

Expand All @@ -174,7 +174,7 @@ def wrapper(fn: Callable) -> Callable:

def cohere_stream_wrapper_v2(name: str) -> Callable:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return add_accumulator(
op, make_accumulator=lambda inputs: cohere_accumulator_v2
Expand Down
4 changes: 2 additions & 2 deletions weave/integrations/instructor/instructor_iterable_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def should_accumulate_iterable(inputs: dict) -> bool:

def instructor_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return add_accumulator(
op, # type: ignore
Expand All @@ -50,7 +50,7 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op = weave.op(_fn_wrapper(fn))
op.name = name # type: ignore
return add_accumulator(
op, # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions weave/integrations/notdiamond/custom_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _get_model_results(provider_name: str) -> pd.DataFrame:
class _DummyEvalModel(weave.Model):
model_results: pd.DataFrame

@weave.op()
@weave.op
def predict(self, prompt: str) -> dict[str, Any]:
response, score = self.model_results[
self.model_results[prompt_column] == prompt
Expand All @@ -93,12 +93,12 @@ def predict(self, prompt: str) -> dict[str, Any]:
class BestRoutedModel(_DummyEvalModel):
model_name: str

@weave.op()
@weave.op
def predict(self, prompt: str) -> dict[str, Any]:
return super().predict(prompt)

class NotDiamondRoutedModel(_DummyEvalModel):
@weave.op()
@weave.op
def predict(self, prompt: str) -> dict[str, Any]:
return super().predict(prompt)

Expand Down

0 comments on commit d7d3818

Please sign in to comment.