Skip to content

Commit 85962c1

Browse files
authored
Merge pull request #89 from weaviate/add-near-media-filters
Add nearMedia filters for multi2vec-bind model
2 parents 78cc66b + 8a4997b commit 85962c1

File tree

8 files changed

+423
-23
lines changed

8 files changed

+423
-23
lines changed

src/graphql/aggregator.ts

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
import Where from './where';
2+
import NearMedia, {
3+
NearMediaArgs,
4+
NearVideoArgs,
5+
NearAudioArgs,
6+
NearDepthArgs,
7+
NearIMUArgs,
8+
NearMediaBase,
9+
NearMediaType,
10+
} from './nearMedia';
211
import NearText, { NearTextArgs } from './nearText';
312
import NearVector, { NearVectorArgs } from './nearVector';
413
import NearObject, { NearObjectArgs } from './nearObject';
@@ -7,12 +16,18 @@ import Connection from '../connection';
716
import { CommandBase } from '../validation/commandBase';
817
import { WhereFilter } from '../openapi/types';
918

19+
interface NearImageArgs extends NearMediaBase {
20+
image: string;
21+
}
22+
1023
export default class Aggregator extends CommandBase {
1124
private className?: string;
1225
private fields?: string;
1326
private groupBy?: string[];
1427
private includesNearMediaFilter: boolean;
1528
private limit?: number;
29+
private nearMediaString?: string;
30+
private nearMediaType?: string;
1631
private nearObjectString?: string;
1732
private nearTextString?: string;
1833
private nearVectorString?: string;
@@ -44,56 +59,84 @@ export default class Aggregator extends CommandBase {
4459
return this;
4560
};
4661

47-
withNearText = (args: NearTextArgs) => {
62+
private withNearMedia = (args: NearMediaArgs) => {
4863
if (this.includesNearMediaFilter) {
4964
throw new Error('cannot use multiple near<Media> filters in a single query');
5065
}
66+
try {
67+
this.nearMediaString = new NearMedia(args).toString();
68+
this.nearMediaType = args.type;
69+
this.includesNearMediaFilter = true;
70+
} catch (e: any) {
71+
this.addError(e.toString());
72+
}
73+
74+
return this;
75+
};
76+
77+
withNearImage = (args: NearImageArgs) => {
78+
return this.withNearMedia({ ...args, media: args.image, type: NearMediaType.Image });
79+
};
80+
81+
withNearAudio = (args: NearAudioArgs) => {
82+
return this.withNearMedia({ ...args, media: args.audio, type: NearMediaType.Audio });
83+
};
84+
85+
withNearVideo = (args: NearVideoArgs) => {
86+
return this.withNearMedia({ ...args, media: args.video, type: NearMediaType.Video });
87+
};
5188

89+
withNearDepth = (args: NearDepthArgs) => {
90+
return this.withNearMedia({ ...args, media: args.depth, type: NearMediaType.Depth });
91+
};
92+
93+
withNearIMU = (args: NearIMUArgs) => {
94+
return this.withNearMedia({ ...args, media: args.imu, type: NearMediaType.IMU });
95+
};
96+
97+
withNearText = (args: NearTextArgs) => {
98+
if (this.includesNearMediaFilter) {
99+
throw new Error('cannot use multiple near<Media> filters in a single query');
100+
}
52101
try {
53102
this.nearTextString = new NearText(args).toString();
54103
this.includesNearMediaFilter = true;
55104
} catch (e: any) {
56105
this.addError(e.toString());
57106
}
58-
59107
return this;
60108
};
61109

62110
withNearObject = (args: NearObjectArgs) => {
63111
if (this.includesNearMediaFilter) {
64112
throw new Error('cannot use multiple near<Media> filters in a single query');
65113
}
66-
67114
try {
68115
this.nearObjectString = new NearObject(args).toString();
69116
this.includesNearMediaFilter = true;
70117
} catch (e: any) {
71118
this.addError(e.toString());
72119
}
73-
74120
return this;
75121
};
76122

77123
withNearVector = (args: NearVectorArgs) => {
78124
if (this.includesNearMediaFilter) {
79125
throw new Error('cannot use multiple near<Media> filters in a single query');
80126
}
81-
82127
try {
83128
this.nearVectorString = new NearVector(args).toString();
84129
this.includesNearMediaFilter = true;
85130
} catch (e: any) {
86131
this.addError(e.toString());
87132
}
88-
89133
return this;
90134
};
91135

92136
withObjectLimit = (objectLimit: number) => {
93137
if (!isValidPositiveIntProperty(objectLimit)) {
94138
throw new Error('objectLimit must be a non-negative integer');
95139
}
96-
97140
this.objectLimit = objectLimit;
98141
return this;
99142
};
@@ -171,6 +214,10 @@ export default class Aggregator extends CommandBase {
171214
args = [...args, `nearVector:${this.nearVectorString}`];
172215
}
173216

217+
if (this.nearMediaString) {
218+
args = [...args, `${this.nearMediaType}:${this.nearMediaString}`];
219+
}
220+
174221
if (this.groupBy) {
175222
args = [...args, `groupBy:${JSON.stringify(this.groupBy)}`];
176223
}

src/graphql/explorer.ts

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
import NearText, { NearTextArgs } from './nearText';
22
import NearVector, { NearVectorArgs } from './nearVector';
3-
import NearObject, { NearObjectArgs } from './nearObject';
43
import NearImage, { NearImageArgs } from './nearImage';
4+
import NearObject, { NearObjectArgs } from './nearObject';
5+
import NearMedia, {
6+
NearAudioArgs,
7+
NearDepthArgs,
8+
NearIMUArgs,
9+
NearMediaArgs,
10+
NearMediaType,
11+
NearThermalArgs,
12+
NearVideoArgs,
13+
} from './nearMedia';
514
import Ask, { AskArgs } from './ask';
615
import Connection from '../connection';
716
import { CommandBase } from '../validation/commandBase';
@@ -11,7 +20,9 @@ export default class Explorer extends CommandBase {
1120
private fields?: string;
1221
private group?: string[];
1322
private limit?: number;
14-
private nearImageString?: string;
23+
private includesNearMediaFilter: boolean;
24+
private nearMediaString?: string;
25+
private nearMediaType?: NearMediaType;
1526
private nearObjectString?: string;
1627
private nearTextString?: string;
1728
private nearVectorString?: string;
@@ -20,6 +31,7 @@ export default class Explorer extends CommandBase {
2031
constructor(client: Connection) {
2132
super(client);
2233
this.params = {};
34+
this.includesNearMediaFilter = false;
2335
}
2436

2537
withFields = (fields: string) => {
@@ -33,6 +45,9 @@ export default class Explorer extends CommandBase {
3345
};
3446

3547
withNearText = (args: NearTextArgs) => {
48+
if (this.includesNearMediaFilter) {
49+
throw new Error('cannot use multiple near<Media> filters in a single query');
50+
}
3651
try {
3752
this.nearTextString = new NearText(args).toString();
3853
} catch (e: any) {
@@ -42,6 +57,9 @@ export default class Explorer extends CommandBase {
4257
};
4358

4459
withNearObject = (args: NearObjectArgs) => {
60+
if (this.includesNearMediaFilter) {
61+
throw new Error('cannot use multiple near<Media> filters in a single query');
62+
}
4563
try {
4664
this.nearObjectString = new NearObject(args).toString();
4765
} catch (e: any) {
@@ -51,6 +69,9 @@ export default class Explorer extends CommandBase {
5169
};
5270

5371
withAsk = (args: AskArgs) => {
72+
if (this.includesNearMediaFilter) {
73+
throw new Error('cannot use multiple near<Media> filters in a single query');
74+
}
5475
try {
5576
this.askString = new Ask(args).toString();
5677
} catch (e: any) {
@@ -59,16 +80,58 @@ export default class Explorer extends CommandBase {
5980
return this;
6081
};
6182

83+
private withNearMedia = (args: NearMediaArgs) => {
84+
if (this.includesNearMediaFilter) {
85+
throw new Error('cannot use multiple near<Media> filters in a single query');
86+
}
87+
try {
88+
this.nearMediaString = new NearMedia(args).toString();
89+
this.nearMediaType = args.type;
90+
this.includesNearMediaFilter = true;
91+
} catch (e: any) {
92+
this.addError(e.toString());
93+
}
94+
return this;
95+
};
96+
6297
withNearImage = (args: NearImageArgs) => {
98+
if (this.includesNearMediaFilter) {
99+
throw new Error('cannot use multiple near<Media> filters in a single query');
100+
}
63101
try {
64-
this.nearImageString = new NearImage(args).toString();
102+
this.nearMediaString = new NearImage(args).toString();
103+
this.nearMediaType = NearMediaType.Image;
104+
this.includesNearMediaFilter = true;
65105
} catch (e: any) {
66106
this.addError(e.toString());
67107
}
68108
return this;
69109
};
70110

111+
withNearAudio = (args: NearAudioArgs) => {
112+
return this.withNearMedia({ ...args, media: args.audio, type: NearMediaType.Audio });
113+
};
114+
115+
withNearVideo = (args: NearVideoArgs) => {
116+
return this.withNearMedia({ ...args, media: args.video, type: NearMediaType.Video });
117+
};
118+
119+
withNearDepth = (args: NearDepthArgs) => {
120+
return this.withNearMedia({ ...args, media: args.depth, type: NearMediaType.Depth });
121+
};
122+
123+
withNearThermal = (args: NearThermalArgs) => {
124+
return this.withNearMedia({ ...args, media: args.thermal, type: NearMediaType.Thermal });
125+
};
126+
127+
withNearIMU = (args: NearIMUArgs) => {
128+
return this.withNearMedia({ ...args, media: args.imu, type: NearMediaType.IMU });
129+
};
130+
71131
withNearVector = (args: NearVectorArgs) => {
132+
if (this.includesNearMediaFilter) {
133+
throw new Error('cannot use multiple near<Media> filters in a single query');
134+
}
72135
try {
73136
this.nearVectorString = new NearVector(args).toString();
74137
} catch (e: any) {
@@ -120,8 +183,8 @@ export default class Explorer extends CommandBase {
120183
args = [...args, `ask:${this.askString}`];
121184
}
122185

123-
if (this.nearImageString) {
124-
args = [...args, `nearImage:${this.nearImageString}`];
186+
if (this.nearMediaString) {
187+
args = [...args, `${this.nearMediaType}:${this.nearMediaString}`];
125188
}
126189

127190
if (this.nearVectorString) {

src/graphql/getter.test.ts

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ import Getter, { FusionType } from './getter';
22
import { WhereFilter } from '../openapi/types';
33
import { NearObjectArgs } from './nearObject';
44
import { AskArgs } from './ask';
5-
import { NearImageArgs } from './nearImage';
65
import { SortArgs } from './sort';
76
import { NearTextArgs } from './nearText';
7+
import { NearImageArgs } from './nearImage';
88

99
test('a simple query without params', () => {
1010
const mockClient: any = {
@@ -1091,6 +1091,113 @@ describe('nearImage searchers', () => {
10911091
});
10921092
});
10931093

1094+
describe('nearMedia searchers', () => {
1095+
test('a query with a valid nearVideo', () => {
1096+
const mockClient: any = {
1097+
query: jest.fn(),
1098+
};
1099+
1100+
const subQuery = `(nearVideo:{video:"iVBORw0KGgoAAAANS"})`;
1101+
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;
1102+
1103+
new Getter(mockClient)
1104+
.withClassName('Person')
1105+
.withFields('name')
1106+
.withNearVideo({ video: 'iVBORw0KGgoAAAANS' })
1107+
.do();
1108+
1109+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1110+
});
1111+
1112+
test('a query with a valid nearVideo with all params', () => {
1113+
const mockClient: any = {
1114+
query: jest.fn(),
1115+
};
1116+
1117+
const expectedQuery = `{Get{Person(nearVideo:{video:"iVBORw0KGgoAAAANS",certainty:0.8,distance:0.6}){name}}}`;
1118+
1119+
new Getter(mockClient)
1120+
.withClassName('Person')
1121+
.withFields('name')
1122+
.withNearVideo({
1123+
video: 'iVBORw0KGgoAAAANS',
1124+
certainty: 0.8,
1125+
distance: 0.6,
1126+
})
1127+
.do();
1128+
1129+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1130+
});
1131+
1132+
test('a query with a valid nearAudio', () => {
1133+
const mockClient: any = {
1134+
query: jest.fn(),
1135+
};
1136+
1137+
const subQuery = `(nearAudio:{audio:"iVBORw0KGgoAAAANS"})`;
1138+
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;
1139+
1140+
new Getter(mockClient)
1141+
.withClassName('Person')
1142+
.withFields('name')
1143+
.withNearAudio({ audio: 'iVBORw0KGgoAAAANS' })
1144+
.do();
1145+
1146+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1147+
});
1148+
1149+
test('a query with a valid nearThermal', () => {
1150+
const mockClient: any = {
1151+
query: jest.fn(),
1152+
};
1153+
1154+
const subQuery = `(nearThermal:{thermal:"iVBORw0KGgoAAAANS"})`;
1155+
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;
1156+
1157+
new Getter(mockClient)
1158+
.withClassName('Person')
1159+
.withFields('name')
1160+
.withNearThermal({ thermal: 'iVBORw0KGgoAAAANS' })
1161+
.do();
1162+
1163+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1164+
});
1165+
1166+
test('a query with a valid nearDepth', () => {
1167+
const mockClient: any = {
1168+
query: jest.fn(),
1169+
};
1170+
1171+
const subQuery = `(nearDepth:{depth:"iVBORw0KGgoAAAANS"})`;
1172+
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;
1173+
1174+
new Getter(mockClient)
1175+
.withClassName('Person')
1176+
.withFields('name')
1177+
.withNearDepth({ depth: 'iVBORw0KGgoAAAANS' })
1178+
.do();
1179+
1180+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1181+
});
1182+
1183+
test('a query with a valid nearIMU', () => {
1184+
const mockClient: any = {
1185+
query: jest.fn(),
1186+
};
1187+
1188+
const subQuery = `(nearIMU:{imu:"iVBORw0KGgoAAAANS"})`;
1189+
const expectedQuery = `{Get{Person` + subQuery + `{name}}}`;
1190+
1191+
new Getter(mockClient)
1192+
.withClassName('Person')
1193+
.withFields('name')
1194+
.withNearIMU({ imu: 'iVBORw0KGgoAAAANS' })
1195+
.do();
1196+
1197+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1198+
});
1199+
});
1200+
10941201
describe('sort filters', () => {
10951202
test('a query with a valid sort filter', () => {
10961203
const mockClient: any = {

0 commit comments

Comments
 (0)