diff --git a/inc/Kcounter.h b/inc/Kcounter.h index f87cef3..73ecc95 100644 --- a/inc/Kcounter.h +++ b/inc/Kcounter.h @@ -13,13 +13,13 @@ class Kcounter public: Kcounter( const int k ); ~Kcounter(); - void insert( char* kmer, int count ); + void insert( char* kmer, count_dtype count ); bool contains( char* kmer ); void clear(); uint64_t size(); void remove( char* kmer ); void add_seq( char* seq, uint32_t length ); - int get( char* kmer ); + count_dtype get( char* kmer ); int get_k() { return m_k; } Kcontainer* get_kc() { return kc; } void parallel_add_init(int threads) { diff --git a/inc/UContainer.h b/inc/UContainer.h index 9233384..6857b8e 100644 --- a/inc/UContainer.h +++ b/inc/UContainer.h @@ -15,7 +15,7 @@ typedef struct { #if KDICT py::handle* objs; #elif KCOUNTER - int* counts; + count_dtype* counts; #endif uint16_t size; } UC; diff --git a/inc/globals.h b/inc/globals.h index 5902462..dc56e25 100644 --- a/inc/globals.h +++ b/inc/globals.h @@ -1,7 +1,11 @@ #pragma once +#include + #define CAPACITY 4096 #define NHASHES 12 #define HASHSIZE 512 // HASHSIZE % 32 must be 0 +typedef uint16_t count_dtype; +#define MAXCOUNT UINT16_MAX diff --git a/kcollections/src/Kcounter.cc b/kcollections/src/Kcounter.cc index 3603f6e..686c232 100644 --- a/kcollections/src/Kcounter.cc +++ b/kcollections/src/Kcounter.cc @@ -18,13 +18,13 @@ void Kcounter::clear() kc = create_kcontainer( m_k ); } -void Kcounter::insert( char* kmer, int count ) +void Kcounter::insert( char* kmer, count_dtype count ) { CHECK_KMER_LENGTH( kmer, m_k, "Kcounter" ); kcontainer_add( kc, kmer, count ); } -int Kcounter::get( char* kmer ) +count_dtype Kcounter::get( char* kmer ) { CHECK_KMER_LENGTH( kmer, m_k, "Kcounter" ); return kcontainer_get( kc, kmer ); diff --git a/kcollections/src/UContainer.cc b/kcollections/src/UContainer.cc index 968c18c..065d086 100644 --- a/kcollections/src/UContainer.cc +++ b/kcollections/src/UContainer.cc @@ -38,7 +38,7 @@ void uc_insert( UC* uc, uint8_t* bseq, int k, int depth, int idx, int count ) #if KDICT uc->objs = ( py::handle* ) calloc( len , sizeof( py::handle ) ); #elif KCOUNTER - uc->counts = ( int* ) calloc( len, sizeof( int ) ); + uc->counts = ( count_dtype* ) calloc( len, sizeof(count_dtype) ); #endif } else @@ -53,9 +53,9 @@ void uc_insert( UC* uc, uint8_t* bseq, int k, int depth, int idx, int count ) ( uc->size + 1 ) * sizeof( py::handle ) ); #elif KCOUNTER - uc->counts = ( int* ) realloc( + uc->counts = ( count_dtype* ) realloc( uc->counts, - ( uc->size + 1 ) * sizeof( int ) + ( uc->size + 1 ) * sizeof(count_dtype) ); #endif } @@ -80,7 +80,7 @@ void uc_insert( UC* uc, uint8_t* bseq, int k, int depth, int idx, int count ) bytes_to_move ); #elif KCOUNTER - bytes_to_move = ( uc->size - idx ) * sizeof( int ); + bytes_to_move = ( uc->size - idx ) * sizeof(count_dtype); std::memmove( &uc->counts[ idx + 1 ], &uc->counts[ idx ], @@ -90,10 +90,12 @@ void uc_insert( UC* uc, uint8_t* bseq, int k, int depth, int idx, int count ) } #if KDICT - std::memcpy( &uc->objs[ idx ], obj, sizeof( py::handle ) ); - uc->objs[ idx ].inc_ref(); + std::memcpy(&uc->objs[idx], obj, sizeof(py::handle)); + uc->objs[idx].inc_ref(); #elif KCOUNTER - std::memcpy( &uc->counts[ idx ], &count, sizeof( int ) ); + if(uc->counts[idx] < MAXCOUNT) { + std::memcpy(&uc->counts[idx], &count, sizeof(count_dtype)); + } #endif std::memcpy( &uc->suffixes[ suffix_idx ], bseq, len ); uc->size++; @@ -138,7 +140,7 @@ void uc_remove( UC* uc, int bk, int idx ) bytes_to_move ); #elif KCOUNTER - bytes_to_move = ( uc->size - ( idx + 1 ) ) * sizeof( int ); + bytes_to_move = ( uc->size - ( idx + 1 ) ) * sizeof( count_dtype ); std::memmove( &uc->counts[ idx ], &uc->counts[ idx + 1 ], diff --git a/kcollections/src/Vertex.cc b/kcollections/src/Vertex.cc index bba0a69..5be0a1b 100644 --- a/kcollections/src/Vertex.cc +++ b/kcollections/src/Vertex.cc @@ -139,7 +139,7 @@ void burst_uc( Vertex* v, int k, int depth ) #if KDICT py::handle* objs = v->uc.objs; #elif KCOUNTER - int* counts = v->uc.counts; + count_dtype* counts = v->uc.counts; #endif int idx; for( int i = 0; i < v->uc.size; i++ ) @@ -229,11 +229,13 @@ void vertex_insert( Vertex* v, uint8_t* bseq, int k, int depth, int count ) ); v->uc.objs[ uc_idx ].inc_ref(); #elif KCOUNTER - std::memcpy( - &v->uc.counts[ uc_idx ], - &count, - sizeof( int ) - ); + if(v->uc.counts[uc_idx] < MAXCOUNT) { + std::memcpy( + &v->uc.counts[ uc_idx ], + &count, + sizeof(count_dtype) + ); + } #endif return; } diff --git a/scripts/counter.bft.parallel.fast.py b/scripts/counter.bft.parallel.fast.py new file mode 100644 index 0000000..5711a9e --- /dev/null +++ b/scripts/counter.bft.parallel.fast.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python + +import kcollections, sys, time +from tqdm import tqdm + +k = int(sys.argv[1]) +threads = int(sys.argv[2]) +ks = kcollections.Kcounter(k) +ks.parallel_add_init(threads) + +seqs = [] +seq = '' +c = 0 + +start_time = time.time() + +with(open(sys.argv[3], 'r')) as fh: + for line in fh: + if line[0] == '>': + if len(seq) > 0: + #seqs.append(seq) + tstart_time = time.time() + ks.parallel_add_seq(seq, len(seq)) + telapsed_time = time.time() - tstart_time + print c, '\tadded seq of len', len(seq), telapsed_time + sys.stdout.flush() + c += 1 + seq = '' + else: + seq += line.strip() + if len(seq) > 0: + #seqs.append(seq) + tstart_time = time.time() + ks.parallel_add_seq(seq, len(seq)) + telapsed_time = time.time() - tstart_time + print c, '\tadded seq of len', len(seq), telapsed_time + +ks.parallel_add_join() + +elapsed_time = time.time() - start_time +print 'elapsed time:', elapsed_time +#print 'read', len(seqs), 'seqs, adding to ks...' + +#for seq in seqs: +# print '\tadding seq...' +# sys.stdout.flush() +# ks.parallel_add_seq(seq, len(seq)) + +print len(ks), 'kmers' +print 'done!' +print 'checking correctness' + +for kmer, count in ks.iteritems(): + print kmer, count + +c = 0 +if len(sys.argv) > 4: + with open(sys.argv[3], 'r') as fh: + seq = '' + for line in fh: + if line[0] == '>': + if len(seq) > 0: + for i in range(len(seq) - k + 1): + kmer = seq[i : i + k] + assert kmer in ks, "not find: " + kmer + c += 1 + seq = '' + else: + seq += line.strip() + for i in range(len(seq) - k + 1): + kmer = seq[i : i + k] + assert kmer in ks, "not find: " + kmer + c += 1 +print 'checked', c, 'kmers' + +del ks diff --git a/setup.py b/setup.py index 12be87a..40ae7bf 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, Extension, find_packages from setuptools.command.build_ext import build_ext -__version__ = '0.0.9' +__version__ = '0.1' class CMakeExtension(Extension): def __init__(self, name, sourcedir=''):