diff --git a/src/main/java/com/bc/zarr/CompressorFactory.java b/src/main/java/com/bc/zarr/CompressorFactory.java index e9c3c76..d32a29a 100644 --- a/src/main/java/com/bc/zarr/CompressorFactory.java +++ b/src/main/java/com/bc/zarr/CompressorFactory.java @@ -208,7 +208,7 @@ public void uncompress(InputStream is, OutputStream os) throws IOException { } } - static class BloscCompressor extends Compressor { + public static class BloscCompressor extends Compressor { final static int AUTOSHUFFLE = -1; final static int NOSHUFFLE = 0; @@ -223,21 +223,24 @@ static class BloscCompressor extends Compressor { public final static int defaultShuffle = BYTESHUFFLE; public final static String keyBlocksize = "blocksize"; public final static int defaultBlocksize = 0; + public final static String keyNumThreads = "nthreads"; + public final static int defaultNumThreads = 1; public final static int[] supportedShuffle = new int[]{/*AUTOSHUFFLE, */NOSHUFFLE, BYTESHUFFLE, BITSHUFFLE}; public final static String[] supportedCnames = new String[]{"zstd", "blosclz", defaultCname, "lz4hc", "zlib"/*, "snappy"*/}; - public final static Map defaultProperties = Collections - .unmodifiableMap(new HashMap() {{ + public final static Map defaultProperties = new HashMap() {{ put(keyCname, defaultCname); put(keyClevel, defaultCLevel); put(keyShuffle, defaultShuffle); put(keyBlocksize, defaultBlocksize); - }}); + put(keyNumThreads, defaultNumThreads); + }}; private final int clevel; private final int blocksize; private final int shuffle; private final String cname; + private final int nthreads; private BloscCompressor(Map map) { final Object cnameObj = map.get(keyCname); @@ -285,6 +288,16 @@ private BloscCompressor(Map map) { } else { this.blocksize = ((Number) blocksizeObj).intValue(); } + + Object nthreadsObj = map.get(keyNumThreads); + if (nthreadsObj == null) { + nthreadsObj = defaultProperties.get(keyNumThreads); + } + if (nthreadsObj instanceof String) { + this.nthreads = Integer.parseInt((String) nthreadsObj); + } else { + this.nthreads = ((Number) nthreadsObj).intValue(); + } } @Override @@ -308,6 +321,10 @@ public String getCname() { return cname; } + public int getNumThreads() { + return nthreads; + } + @Override public String toString() { return "compressor=" + getId() @@ -324,7 +341,7 @@ public void compress(InputStream is, OutputStream os) throws IOException { final int outputSize = inputSize + JBlosc.OVERHEAD; final ByteBuffer inputBuffer = ByteBuffer.wrap(inputBytes); final ByteBuffer outBuffer = ByteBuffer.allocate(outputSize); - final int i = JBlosc.compressCtx(clevel, shuffle, 1, inputBuffer, inputSize, outBuffer, outputSize, cname, blocksize, 1); + final int i = JBlosc.compressCtx(clevel, shuffle, 1, inputBuffer, inputSize, outBuffer, outputSize, cname, blocksize, nthreads); final BufferSizes bs = cbufferSizes(outBuffer); byte[] compressedChunk = Arrays.copyOfRange(outBuffer.array(), 0, (int) bs.getCbytes()); os.write(compressedChunk); @@ -341,7 +358,7 @@ public void uncompress(InputStream is, OutputStream os) throws IOException { byte[] inBytes = Arrays.copyOf(header, compressedSize); di.readFully(inBytes, header.length, compressedSize - header.length); ByteBuffer outBuffer = ByteBuffer.allocate(uncompressedSize); - JBlosc.decompressCtx(ByteBuffer.wrap(inBytes), outBuffer, outBuffer.limit(), 1); + JBlosc.decompressCtx(ByteBuffer.wrap(inBytes), outBuffer, outBuffer.limit(), nthreads); os.write(outBuffer.array()); } diff --git a/src/test/java/com/bc/zarr/CompressorFactoryTest.java b/src/test/java/com/bc/zarr/CompressorFactoryTest.java index e03d275..fcca49e 100644 --- a/src/test/java/com/bc/zarr/CompressorFactoryTest.java +++ b/src/test/java/com/bc/zarr/CompressorFactoryTest.java @@ -42,18 +42,20 @@ public class CompressorFactoryTest { public void getDefaultCompressorProperties() { final Map map = CompressorFactory.getDefaultCompressorProperties(); assertNotNull(map); - assertEquals(5, map.size()); + assertEquals(6, map.size()); assertThat(map.containsKey("id"), is(true)); assertThat(map.containsKey("cname"), is(true)); assertThat(map.containsKey("clevel"), is(true)); assertThat(map.containsKey("blocksize"), is(true)); assertThat(map.containsKey("shuffle"), is(true)); + assertThat(map.containsKey("nthreads"), is(true)); assertThat(map.get("id"), is("blosc")); assertThat(map.get("cname"), is("lz4")); assertThat(map.get("clevel"), is(5)); assertThat(map.get("blocksize"), is(0)); assertThat(map.get("shuffle"), is(1)); + assertThat(map.get("nthreads"), is(1)); } @Test