Skip to content

Commit

Permalink
Merge pull request #180 from weaviate/hybrid/add-support-for-query-pr…
Browse files Browse the repository at this point in the history
…operty-weights

Add `Bm25QueryProperty` type for use in weighted `.bm25` and `.hybrid`
  • Loading branch information
tsmith023 authored Jul 30, 2024
2 parents 15a1206 + 8fb5445 commit b2dc636
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
32 changes: 32 additions & 0 deletions src/collections/query/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,38 @@ describe('Testing of the collection.query methods with a simple collection', ()
expect(ret.objects[0].uuid).toEqual(id);
});

it('should query with bm25 and weighted query properties', async () => {
const ret = await collection.query.bm25('test', {
queryProperties: [
{
name: 'testProp',
weight: 2,
},
'testProp2',
],
});
expect(ret.objects.length).toEqual(1);
expect(ret.objects[0].properties.testProp).toEqual('test');
expect(ret.objects[0].properties.testProp2).toEqual('test2');
expect(ret.objects[0].uuid).toEqual(id);
});

it('should query with bm25 and weighted query properties with a non-generic collection', async () => {
const ret = await client.collections.get(collectionName).query.bm25('test', {
queryProperties: [
{
name: 'testProp',
weight: 2,
},
'testProp2',
],
});
expect(ret.objects.length).toEqual(1);
expect(ret.objects[0].properties.testProp).toEqual('test');
expect(ret.objects[0].properties.testProp2).toEqual('test2');
expect(ret.objects[0].uuid).toEqual(id);
});

it('should query with hybrid', async () => {
const ret = await collection.query.hybrid('test', { limit: 1 });
expect(ret.objects.length).toEqual(1);
Expand Down
12 changes: 10 additions & 2 deletions src/collections/query/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,18 @@ export type SearchOptions<T> = {
returnReferences?: QueryReference<T>[];
};

/** Which property of the collection to perform the keyword search on. */
export type Bm25QueryProperty<T> = {
/** The property name to search on. */
name: PrimitiveKeys<T>;
/** The weight to provide to the keyword search for this property. */
weight: number;
};

/** Base options available in the `query.bm25` method */
export type BaseBm25Options<T> = SearchOptions<T> & {
/** Which properties of the collection to perform the keyword search on. */
queryProperties?: PrimitiveKeys<T>[];
queryProperties?: (PrimitiveKeys<T> | Bm25QueryProperty<T>)[];
};

/** Options available in the `query.bm25` method when specifying the `groupBy` parameter. */
Expand All @@ -98,7 +106,7 @@ export type BaseHybridOptions<T> = SearchOptions<T> & {
/** The specific vector to search for or a specific vector subsearch. If not specified, the query is vectorized and used in the similarity search. */
vector?: NearVectorInputType | HybridNearTextSubSearch | HybridNearVectorSubSearch;
/** The properties to search in. If not specified, all properties are searched. */
queryProperties?: PrimitiveKeys<T>[];
queryProperties?: (PrimitiveKeys<T> | Bm25QueryProperty<T>)[];
/** The type of fusion to apply. If not specified, the default fusion type specified by the server is used. */
fusionType?: 'Ranked' | 'RelativeScore';
/** Specify which vector(s) to search on if using named vectors. */
Expand Down
19 changes: 16 additions & 3 deletions src/collections/serialize/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ import {
PrimitiveFilterValueType,
PrimitiveListFilterValueType,
} from '../filters/types.js';
import { MultiTargetVectorJoin } from '../index.js';
import { MultiTargetVectorJoin, PrimitiveKeys } from '../index.js';
import {
BaseHybridOptions,
BaseNearOptions,
Bm25Options,
Bm25QueryProperty,
FetchObjectByIdOptions,
FetchObjectsOptions,
HybridNearTextSubSearch,
Expand Down Expand Up @@ -358,12 +359,24 @@ export class Serialize {
};
};

private static bm25QueryProperties = <T>(
properties?: (PrimitiveKeys<T> | Bm25QueryProperty<T>)[]
): string[] | undefined => {
return properties?.map((property) => {
if (typeof property === 'string') {
return property;
} else {
return `${property.name}^${property.weight}`;
}
});
};

public static bm25 = <T>(args: { query: string } & Bm25Options<T>): SearchBm25Args => {
return {
...Serialize.common(args),
bm25Search: BM25.fromPartial({
query: args.query,
properties: args.queryProperties,
properties: this.bm25QueryProperties(args.queryProperties),
}),
autocut: args.autoLimit,
};
Expand Down Expand Up @@ -448,7 +461,7 @@ export class Serialize {
hybridSearch: Hybrid.fromPartial({
query: args.query,
alpha: args.alpha ? args.alpha : 0.5,
properties: args.queryProperties,
properties: this.bm25QueryProperties(args.queryProperties),
vectorBytes: vectorBytes,
fusionType: fusionType(args.fusionType),
targetVectors,
Expand Down

0 comments on commit b2dc636

Please sign in to comment.