Skip to content

Commit

Permalink
Merge pull request #89 from weaviate/add-near-media-filters
Browse files Browse the repository at this point in the history
Add nearMedia filters for multi2vec-bind model
  • Loading branch information
tsmith023 authored Aug 22, 2023
2 parents 78cc66b + 8a4997b commit 85962c1
Show file tree
Hide file tree
Showing 8 changed files with 423 additions and 23 deletions.
61 changes: 54 additions & 7 deletions src/graphql/aggregator.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
import Where from './where';
import NearMedia, {
NearMediaArgs,
NearVideoArgs,
NearAudioArgs,
NearDepthArgs,
NearIMUArgs,
NearMediaBase,
NearMediaType,
} from './nearMedia';
import NearText, { NearTextArgs } from './nearText';
import NearVector, { NearVectorArgs } from './nearVector';
import NearObject, { NearObjectArgs } from './nearObject';
Expand All @@ -7,12 +16,18 @@ import Connection from '../connection';
import { CommandBase } from '../validation/commandBase';
import { WhereFilter } from '../openapi/types';

interface NearImageArgs extends NearMediaBase {
image: string;
}

export default class Aggregator extends CommandBase {
private className?: string;
private fields?: string;
private groupBy?: string[];
private includesNearMediaFilter: boolean;
private limit?: number;
private nearMediaString?: string;
private nearMediaType?: string;
private nearObjectString?: string;
private nearTextString?: string;
private nearVectorString?: string;
Expand Down Expand Up @@ -44,56 +59,84 @@ export default class Aggregator extends CommandBase {
return this;
};

withNearText = (args: NearTextArgs) => {
private withNearMedia = (args: NearMediaArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.nearMediaString = new NearMedia(args).toString();
this.nearMediaType = args.type;
this.includesNearMediaFilter = true;
} catch (e: any) {
this.addError(e.toString());
}

return this;
};

withNearImage = (args: NearImageArgs) => {
return this.withNearMedia({ ...args, media: args.image, type: NearMediaType.Image });
};

withNearAudio = (args: NearAudioArgs) => {
return this.withNearMedia({ ...args, media: args.audio, type: NearMediaType.Audio });
};

withNearVideo = (args: NearVideoArgs) => {
return this.withNearMedia({ ...args, media: args.video, type: NearMediaType.Video });
};

withNearDepth = (args: NearDepthArgs) => {
return this.withNearMedia({ ...args, media: args.depth, type: NearMediaType.Depth });
};

withNearIMU = (args: NearIMUArgs) => {
return this.withNearMedia({ ...args, media: args.imu, type: NearMediaType.IMU });
};

withNearText = (args: NearTextArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.nearTextString = new NearText(args).toString();
this.includesNearMediaFilter = true;
} catch (e: any) {
this.addError(e.toString());
}

return this;
};

withNearObject = (args: NearObjectArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}

try {
this.nearObjectString = new NearObject(args).toString();
this.includesNearMediaFilter = true;
} catch (e: any) {
this.addError(e.toString());
}

return this;
};

withNearVector = (args: NearVectorArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}

try {
this.nearVectorString = new NearVector(args).toString();
this.includesNearMediaFilter = true;
} catch (e: any) {
this.addError(e.toString());
}

return this;
};

withObjectLimit = (objectLimit: number) => {
if (!isValidPositiveIntProperty(objectLimit)) {
throw new Error('objectLimit must be a non-negative integer');
}

this.objectLimit = objectLimit;
return this;
};
Expand Down Expand Up @@ -171,6 +214,10 @@ export default class Aggregator extends CommandBase {
args = [...args, `nearVector:${this.nearVectorString}`];
}

if (this.nearMediaString) {
args = [...args, `${this.nearMediaType}:${this.nearMediaString}`];
}

if (this.groupBy) {
args = [...args, `groupBy:${JSON.stringify(this.groupBy)}`];
}
Expand Down
73 changes: 68 additions & 5 deletions src/graphql/explorer.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import NearText, { NearTextArgs } from './nearText';
import NearVector, { NearVectorArgs } from './nearVector';
import NearObject, { NearObjectArgs } from './nearObject';
import NearImage, { NearImageArgs } from './nearImage';
import NearObject, { NearObjectArgs } from './nearObject';
import NearMedia, {
NearAudioArgs,
NearDepthArgs,
NearIMUArgs,
NearMediaArgs,
NearMediaType,
NearThermalArgs,
NearVideoArgs,
} from './nearMedia';
import Ask, { AskArgs } from './ask';
import Connection from '../connection';
import { CommandBase } from '../validation/commandBase';
Expand All @@ -11,7 +20,9 @@ export default class Explorer extends CommandBase {
private fields?: string;
private group?: string[];
private limit?: number;
private nearImageString?: string;
private includesNearMediaFilter: boolean;
private nearMediaString?: string;
private nearMediaType?: NearMediaType;
private nearObjectString?: string;
private nearTextString?: string;
private nearVectorString?: string;
Expand All @@ -20,6 +31,7 @@ export default class Explorer extends CommandBase {
constructor(client: Connection) {
super(client);
this.params = {};
this.includesNearMediaFilter = false;
}

withFields = (fields: string) => {
Expand All @@ -33,6 +45,9 @@ export default class Explorer extends CommandBase {
};

withNearText = (args: NearTextArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.nearTextString = new NearText(args).toString();
} catch (e: any) {
Expand All @@ -42,6 +57,9 @@ export default class Explorer extends CommandBase {
};

withNearObject = (args: NearObjectArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.nearObjectString = new NearObject(args).toString();
} catch (e: any) {
Expand All @@ -51,6 +69,9 @@ export default class Explorer extends CommandBase {
};

withAsk = (args: AskArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.askString = new Ask(args).toString();
} catch (e: any) {
Expand All @@ -59,16 +80,58 @@ export default class Explorer extends CommandBase {
return this;
};

private withNearMedia = (args: NearMediaArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.nearMediaString = new NearMedia(args).toString();
this.nearMediaType = args.type;
this.includesNearMediaFilter = true;
} catch (e: any) {
this.addError(e.toString());
}
return this;
};

withNearImage = (args: NearImageArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.nearImageString = new NearImage(args).toString();
this.nearMediaString = new NearImage(args).toString();
this.nearMediaType = NearMediaType.Image;
this.includesNearMediaFilter = true;
} catch (e: any) {
this.addError(e.toString());
}
return this;
};

withNearAudio = (args: NearAudioArgs) => {
return this.withNearMedia({ ...args, media: args.audio, type: NearMediaType.Audio });
};

withNearVideo = (args: NearVideoArgs) => {
return this.withNearMedia({ ...args, media: args.video, type: NearMediaType.Video });
};

withNearDepth = (args: NearDepthArgs) => {
return this.withNearMedia({ ...args, media: args.depth, type: NearMediaType.Depth });
};

withNearThermal = (args: NearThermalArgs) => {
return this.withNearMedia({ ...args, media: args.thermal, type: NearMediaType.Thermal });
};

withNearIMU = (args: NearIMUArgs) => {
return this.withNearMedia({ ...args, media: args.imu, type: NearMediaType.IMU });
};

withNearVector = (args: NearVectorArgs) => {
if (this.includesNearMediaFilter) {
throw new Error('cannot use multiple near<Media> filters in a single query');
}
try {
this.nearVectorString = new NearVector(args).toString();
} catch (e: any) {
Expand Down Expand Up @@ -120,8 +183,8 @@ export default class Explorer extends CommandBase {
args = [...args, `ask:${this.askString}`];
}

if (this.nearImageString) {
args = [...args, `nearImage:${this.nearImageString}`];
if (this.nearMediaString) {
args = [...args, `${this.nearMediaType}:${this.nearMediaString}`];
}

if (this.nearVectorString) {
Expand Down
109 changes: 108 additions & 1 deletion src/graphql/getter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import Getter, { FusionType } from './getter';
import { WhereFilter } from '../openapi/types';
import { NearObjectArgs } from './nearObject';
import { AskArgs } from './ask';
import { NearImageArgs } from './nearImage';
import { SortArgs } from './sort';
import { NearTextArgs } from './nearText';
import { NearImageArgs } from './nearImage';

test('a simple query without params', () => {
const mockClient: any = {
Expand Down Expand Up @@ -1091,6 +1091,113 @@ describe('nearImage searchers', () => {
});
});

describe('nearMedia searchers', () => {
test('a query with a valid nearVideo', () => {
const mockClient: any = {
query: jest.fn(),
};

const subQuery = `(nearVideo:{video:"iVBORw0KGgoAAAANS"})`;
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;

new Getter(mockClient)
.withClassName('Person')
.withFields('name')
.withNearVideo({ video: 'iVBORw0KGgoAAAANS' })
.do();

expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
});

test('a query with a valid nearVideo with all params', () => {
const mockClient: any = {
query: jest.fn(),
};

const expectedQuery = `{Get{Person(nearVideo:{video:"iVBORw0KGgoAAAANS",certainty:0.8,distance:0.6}){name}}}`;

new Getter(mockClient)
.withClassName('Person')
.withFields('name')
.withNearVideo({
video: 'iVBORw0KGgoAAAANS',
certainty: 0.8,
distance: 0.6,
})
.do();

expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
});

test('a query with a valid nearAudio', () => {
const mockClient: any = {
query: jest.fn(),
};

const subQuery = `(nearAudio:{audio:"iVBORw0KGgoAAAANS"})`;
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;

new Getter(mockClient)
.withClassName('Person')
.withFields('name')
.withNearAudio({ audio: 'iVBORw0KGgoAAAANS' })
.do();

expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
});

test('a query with a valid nearThermal', () => {
const mockClient: any = {
query: jest.fn(),
};

const subQuery = `(nearThermal:{thermal:"iVBORw0KGgoAAAANS"})`;
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;

new Getter(mockClient)
.withClassName('Person')
.withFields('name')
.withNearThermal({ thermal: 'iVBORw0KGgoAAAANS' })
.do();

expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
});

test('a query with a valid nearDepth', () => {
const mockClient: any = {
query: jest.fn(),
};

const subQuery = `(nearDepth:{depth:"iVBORw0KGgoAAAANS"})`;
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;

new Getter(mockClient)
.withClassName('Person')
.withFields('name')
.withNearDepth({ depth: 'iVBORw0KGgoAAAANS' })
.do();

expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
});

test('a query with a valid nearIMU', () => {
const mockClient: any = {
query: jest.fn(),
};

const subQuery = `(nearIMU:{imu:"iVBORw0KGgoAAAANS"})`;
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;

new Getter(mockClient)
.withClassName('Person')
.withFields('name')
.withNearIMU({ imu: 'iVBORw0KGgoAAAANS' })
.do();

expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
});
});

describe('sort filters', () => {
test('a query with a valid sort filter', () => {
const mockClient: any = {
Expand Down
Loading

0 comments on commit 85962c1

Please sign in to comment.