From 3c7fcf904ab3b91df97d435c719653be5f3a0244 Mon Sep 17 00:00:00 2001 From: "bodong.ybd" Date: Fri, 21 Oct 2022 14:27:05 +0800 Subject: [PATCH] TairVector: refactor jedispool --- .../aliyun/tair/tairvector/TairVector.java | 272 +++++++++++++----- .../tests/tairvector/TairVectorTestBase.java | 21 +- 2 files changed, 200 insertions(+), 93 deletions(-) diff --git a/src/main/java/com/aliyun/tair/tairvector/TairVector.java b/src/main/java/com/aliyun/tair/tairvector/TairVector.java index 5b5ecf8..55e66fa 100644 --- a/src/main/java/com/aliyun/tair/tairvector/TairVector.java +++ b/src/main/java/com/aliyun/tair/tairvector/TairVector.java @@ -15,6 +15,7 @@ import com.aliyun.tair.util.JoinParameters; import redis.clients.jedis.BuilderFactory; import redis.clients.jedis.Jedis; +import redis.clients.jedis.JedisPool; import redis.clients.jedis.ScanResult; import redis.clients.jedis.util.SafeEncoder; @@ -22,17 +23,27 @@ public class TairVector { private Jedis jedis; + private JedisPool jedisPool; public TairVector(Jedis jedis) { this.jedis = jedis; } + public TairVector(JedisPool jedisPool) { + this.jedisPool = jedisPool; + } + private Jedis getJedis() { + if (jedisPool != null) { + return jedisPool.getResource(); + } return jedis; } - public void quit() { - jedis.quit(); + private void releaseJedis(Jedis jedis) { + if (jedisPool != null) { + jedis.close(); + } } /** @@ -51,13 +62,23 @@ public void quit() { * @return Success: +OK; Fail: error */ public String tvscreateindex(final String index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... params) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), SafeEncoder.encodeMany(params))); - return BuilderFactory.STRING.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), SafeEncoder.encodeMany(params))); + return BuilderFactory.STRING.build(obj); + } finally { + releaseJedis(jedis); + } } public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final byte[]... params) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(index, toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), params)); - return BuilderFactory.BYTE_ARRAY.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(index, toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), params)); + return BuilderFactory.BYTE_ARRAY.build(obj); + } finally { + releaseJedis(jedis); + } } /** @@ -69,13 +90,23 @@ public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, D * @return Success: string_map, Fail: empty */ public Map tvsgetindex(final String index) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSGETINDEX, SafeEncoder.encode(index)); - return BuilderFactory.STRING_MAP.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSGETINDEX, SafeEncoder.encode(index)); + return BuilderFactory.STRING_MAP.build(obj); + } finally { + releaseJedis(jedis); + } } public Map tvsgetindex(byte[] index) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSGETINDEX, index); - return BuilderFactory.BYTE_ARRAY_MAP.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSGETINDEX, index); + return BuilderFactory.BYTE_ARRAY_MAP.build(obj); + } finally { + releaseJedis(jedis); + } } /** @@ -87,13 +118,23 @@ public Map tvsgetindex(byte[] index) { * @return Success: 1; Fail: 0 */ public Long tvsdelindex(final String index) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSDELINDEX, SafeEncoder.encode(index)); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSDELINDEX, SafeEncoder.encode(index)); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } public Long tvsdelindex(byte[] index) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSDELINDEX, index); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSDELINDEX, index); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } @@ -110,11 +151,16 @@ public Long tvsdelindex(byte[] index) { * @return A ScanResult. {@link VectorBuilderFactory#SCAN_CURSOR_STRING} */ public ScanResult tvsscanindex(Long cursor, HscanParams params) { - final List args = new ArrayList(); - args.add(toByteArray(cursor)); - args.addAll(params.getParams()); - Object obj = getJedis().sendCommand(ModuleCommand.TVSSCANINDEX, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); + Jedis jedis = getJedis(); + try { + final List args = new ArrayList(); + args.add(toByteArray(cursor)); + args.addAll(params.getParams()); + Object obj = jedis.sendCommand(ModuleCommand.TVSSCANINDEX, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); + } finally { + releaseJedis(jedis); + } } @@ -132,13 +178,23 @@ public ScanResult tvsscanindex(Long cursor, HscanParams params) { * throw error like "(error) Illegal vector dimensions" if error */ public Long tvshset(final String index, final String entityid, final String vector, final String... params) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode(vector), SafeEncoder.encodeMany(params))); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode(vector), SafeEncoder.encodeMany(params))); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[]... params) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(index, entityid, SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), vector, params)); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(index, entityid, SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), vector, params)); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } /** @@ -151,13 +207,23 @@ public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[].. * @return Map, an empty list when {@code entityid} does not exist. */ public Map tvshgetall(final String index, final String entityid) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHGETALL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); - return BuilderFactory.STRING_MAP.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHGETALL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); + return BuilderFactory.STRING_MAP.build(obj); + } finally { + releaseJedis(jedis); + } } public Map tvshgetall(byte[] index, byte[] entityid) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHGETALL, index, entityid); - return BuilderFactory.BYTE_ARRAY_MAP.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHGETALL, index, entityid); + return BuilderFactory.BYTE_ARRAY_MAP.build(obj); + } finally { + releaseJedis(jedis); + } } /** @@ -171,13 +237,23 @@ public Map tvshgetall(byte[] index, byte[] entityid) { * @return List, an empty list when {@code entityid} or {@code attrs} does not exist . */ public List tvshmget(final String index, final String entityid, final String... attrs) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs))); - return BuilderFactory.STRING_LIST.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs))); + return BuilderFactory.STRING_LIST.build(obj); + } finally { + releaseJedis(jedis); + } } public List tvshmget(byte[] index, byte[] entityid, byte[]... attrs) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(index, entityid, attrs)); - return BuilderFactory.BYTE_ARRAY_LIST.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(index, entityid, attrs)); + return BuilderFactory.BYTE_ARRAY_LIST.build(obj); + } finally { + releaseJedis(jedis); + } } @@ -192,13 +268,23 @@ public List tvshmget(byte[] index, byte[] entityid, byte[]... attrs) { * not including specified but no existing fields. */ public Long tvsdel(final String index, final String entityid) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSDEL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSDEL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } public Long tvsdel(byte[] index, byte[] entityid) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSDEL, index, entityid); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSDEL, index, entityid); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } /** @@ -214,13 +300,23 @@ public Long tvsdel(byte[] index, byte[] entityid) { * not including specified but no existing fields. */ public Long tvshdel(final String index, final String entityid, final String attr, final String... attrs) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(attr), SafeEncoder.encodeMany(attrs))); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(attr), SafeEncoder.encodeMany(attrs))); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } public Long tvshdel(byte[] index, byte[] entityid, byte[] attr, byte[]... attrs) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid, attr, attrs)); - return BuilderFactory.LONG.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid, attr, attrs)); + return BuilderFactory.LONG.build(obj); + } finally { + releaseJedis(jedis); + } } @@ -238,21 +334,31 @@ public Long tvshdel(byte[] index, byte[] entityid, byte[] attr, byte[]... attrs) * @return A ScanResult. */ public ScanResult tvsscan(final String index, Long cursor, HscanParams params) { - final List args = new ArrayList(); - args.add(SafeEncoder.encode(index)); - args.add(toByteArray(cursor)); - args.addAll(params.getParams()); - Object obj = getJedis().sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); + Jedis jedis = getJedis(); + try { + final List args = new ArrayList(); + args.add(SafeEncoder.encode(index)); + args.add(toByteArray(cursor)); + args.addAll(params.getParams()); + Object obj = jedis.sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); + } finally { + releaseJedis(jedis); + } } public ScanResult tvsscan(byte[] index, Long cursor, HscanParams params) { - final List args = new ArrayList(); - args.add(index); - args.add(toByteArray(cursor)); - args.addAll(params.getParams()); - Object obj = getJedis().sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.SCAN_CURSOR_BYTE.build(obj); + Jedis jedis = getJedis(); + try { + final List args = new ArrayList(); + args.add(index); + args.add(toByteArray(cursor)); + args.addAll(params.getParams()); + Object obj = jedis.sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.SCAN_CURSOR_BYTE.build(obj); + } finally { + releaseJedis(jedis); + } } /** @@ -289,14 +395,24 @@ public VectorBuilderFactory.Knn tvsknnsearch(byte[] index, Long topn, by * @return VectorBuilderFactory.Knn<> */ public VectorBuilderFactory.Knn tvsknnsearchfilter(final String index, Long topn, final String vector, final String pattern, final String... params) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(SafeEncoder.encode(index), + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(topn), SafeEncoder.encode(vector), SafeEncoder.encode(pattern), SafeEncoder.encodeMany(params))); - return VectorBuilderFactory.STRING_KNN_RESULT.build(obj); + return VectorBuilderFactory.STRING_KNN_RESULT.build(obj); + } finally { + releaseJedis(jedis); + } } public VectorBuilderFactory.Knn tvsknnsearchfilter(byte[] index, Long topn, byte[] vector, byte[] pattern, final byte[]... params) { - Object obj = getJedis().sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(index, toByteArray(topn), vector, pattern, params)); - return VectorBuilderFactory.BYTE_KNN_RESULT.build(obj); + Jedis jedis = getJedis(); + try { + Object obj = jedis.sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(index, toByteArray(topn), vector, pattern, params)); + return VectorBuilderFactory.BYTE_KNN_RESULT.build(obj); + } finally { + releaseJedis(jedis); + } } @@ -330,26 +446,36 @@ public Collection> tvsmknnsearch(byte[] index, * @return Collection<> */ public Collection> tvsmknnsearchfilter(final String index, Long topn, Collection vectors, final String pattern, final String... params) { - final List args = new ArrayList(); - args.add(SafeEncoder.encode(index)); - args.add(toByteArray(topn)); - args.add(toByteArray(vectors.size())); - args.addAll(vectors.stream().map(vector -> SafeEncoder.encode(vector)).collect(Collectors.toList())); - args.add(SafeEncoder.encode(pattern)); - args.addAll(Arrays.stream(params).map(str -> SafeEncoder.encode(str)).collect(Collectors.toList())); - Object obj = getJedis().sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.STRING_KNN_BATCH_RESULT.build(obj); + Jedis jedis = getJedis(); + try { + final List args = new ArrayList(); + args.add(SafeEncoder.encode(index)); + args.add(toByteArray(topn)); + args.add(toByteArray(vectors.size())); + args.addAll(vectors.stream().map(vector -> SafeEncoder.encode(vector)).collect(Collectors.toList())); + args.add(SafeEncoder.encode(pattern)); + args.addAll(Arrays.stream(params).map(str -> SafeEncoder.encode(str)).collect(Collectors.toList())); + Object obj = jedis.sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.STRING_KNN_BATCH_RESULT.build(obj); + } finally { + releaseJedis(jedis); + } } public Collection> tvsmknnsearchfilter(byte[] index, Long topn, Collection vectors, byte[] pattern, final byte[]... params) { - final List args = new ArrayList(); - args.add(index); - args.add(toByteArray(topn)); - args.add(toByteArray(vectors.size())); - args.addAll(vectors); - args.add(pattern); - args.addAll(Arrays.stream(params).collect(Collectors.toList())); - Object obj = getJedis().sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.BYTE_KNN_BATCH_RESULT.build(obj); + Jedis jedis = getJedis(); + try { + final List args = new ArrayList(); + args.add(index); + args.add(toByteArray(topn)); + args.add(toByteArray(vectors.size())); + args.addAll(vectors); + args.add(pattern); + args.addAll(Arrays.stream(params).collect(Collectors.toList())); + Object obj = jedis.sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.BYTE_KNN_BATCH_RESULT.build(obj); + } finally { + releaseJedis(jedis); + } } } diff --git a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java index 5a9027e..2dc2ad4 100644 --- a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java +++ b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java @@ -19,28 +19,9 @@ public class TairVectorTestBase extends TestBase { @BeforeClass public static void setUp() { - tairVector = new TairVector(jedis); + tairVector = new TairVector(jedisPool); tairVectorPipeline = new TairVectorPipeline(); tairVectorPipeline.setClient(jedis.getClient()); tairVectorCluster = new TairVectorCluster(jedisCluster); } - - @AfterClass - public static void closeDown() { - tairVector.quit(); - tairVectorCluster.quit(); - } - - public static void assertLongListEquals(List expected, List actual) { - assertEquals(expected.size(), actual.size()); - for (int n = 0; n < expected.size(); n++) { - assertEquals(expected.get(n), actual.get(n)); - } - } - - public static void assertScanResultEquals(List expected, ScanResult actual) { - for (int n = 0; n < expected.size(); n++) { - assertEquals(expected.get(n), actual.getResult().get(n)); - } - } }