diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..9a1b586 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +ignore = F401 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..5dc9c2d --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,17 @@ +name: Release +on: + release: + types: [published] +jobs: + release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: '3.7' + architecture: x64 + - run: pip install tox poetry + - run: tox + - run: poetry build + - run: poetry publish --username=__token__ --password=${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..9d68cfd --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: Test +on: [push] +jobs: + build: + runs-on: ubuntu-latest + strategy: + max-parallel: 8 + matrix: + python-version: [3.6, 3.7, 3.8] + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox tox-gh-actions poetry + - name: Test with tox + run: tox + env: + PLATFORM: ${{ matrix.platform }} diff --git a/.gitignore b/.gitignore index b6e4761..89c3923 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,9 @@ dmypy.json # Pyre type checker .pyre/ + +# VSCode +.vscode/ + +# InteliJ +.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..78d2c2c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: + - repo: git://github.com/pre-commit/pre-commit-hooks + rev: v2.1.0 + hooks: + - id: end-of-file-fixer + exclude: ^docs/.*$ + - id: trailing-whitespace + exclude: README.md + - id: flake8 + + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.29.0 + hooks: + - id: yapf + args: ['--style=.style.yapf', '--parallel', '--in-place'] + + - repo: git@github.com:humitos/mirrors-autoflake.git + rev: v1.3 + hooks: + - id: autoflake + args: ['--in-place', '--remove-all-unused-imports', '--ignore-init-module-imports'] diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..8357165 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,5 @@ +[style] +indent_width = 4 +column_limit = 110 +allow_split_before_dict_value = false +split_before_expression_after_opening_paren = true diff --git a/README.md b/README.md index 7fd8c10..0ef46d5 100644 --- a/README.md +++ b/README.md @@ -1 +1,32 @@ -# dmatrix2np \ No newline at end of file +# dmatrix2np + +[![Tests](https://github.com/aporia-ai/dmatrix2np/workflows/Test/badge.svg)](https://github.com/aporia-ai/dmatrix2np/actions?workflow=Test) [![PyPI](https://img.shields.io/pypi/v/dmatrix2np.svg)](https://pypi.org/project/dmatrix2np/) + +Convert XGBoost's DMatrix format to np.array. + +## Usage + +To install the library, run: + + pip install dmatrix2np + +Then, you can call in your code: + + from dmatrix2np import dmatrix2np + + converted_np_array = dmatrix_to_numpy(dmatrix) + +## Development + +We use [poetry](https://python-poetry.org/) for development: + + pip install poetry + +To install all dependencies and run tests: + + poetry run pytest + +To run tests on the entire matrix (Python 3.6, 3.7, 3.8 + XGBoost 0.80, 0.90, 1.0): + + pip install tox + tox diff --git a/dmatrix2np/__init__.py b/dmatrix2np/__init__.py new file mode 100644 index 0000000..1d3c3ee --- /dev/null +++ b/dmatrix2np/__init__.py @@ -0,0 +1,12 @@ +from .dmatrix_to_numpy import dmatrix_to_numpy +from .exceptions import DMatrix2NpError, InvalidStructure, UnsupportedVersion, InvalidInput + +# Single source the package version +try: + from importlib.metadata import version, PackageNotFoundError # type: ignore +except ImportError: # pragma: no cover + from importlib_metadata import version, PackageNotFoundError # type: ignore +try: + __version__ = version(__name__) +except PackageNotFoundError: # pragma: no cover + __version__ = "unknown" diff --git a/dmatrix2np/common.py b/dmatrix2np/common.py new file mode 100644 index 0000000..d338cc5 --- /dev/null +++ b/dmatrix2np/common.py @@ -0,0 +1,41 @@ +from enum import Enum +import struct +import sys + + +class FieldDataType(Enum): + """ + This Enum provides an integer translation for the data type corresponding to the 'DataType' enum + on '/include/xgboost/data.h' file in the XGBoost project + """ + kFloat32 = 1 + kDouble = 2 + kUInt32 = 3 + kUInt64 = 4 + + +# Dictionary of data types size in bytes +data_type_sizes = { + 'kFloat32': 4, + 'kDouble': struct.calcsize("d"), + 'kUInt32': 4, + 'kUInt64': 8, + 'bool': 1, + 'uint8_t': 1, + 'int32_t': 4, + 'uint32_t': 4, + 'uint64_t': 8, + 'int': struct.calcsize("i"), + 'float': struct.calcsize("f"), + 'double': struct.calcsize("d"), + 'size_t': struct.calcsize("N"), +} + + +SIZE_T_DTYPE = f'{"<" if sys.byteorder == "little" else ">"}i{data_type_sizes["size_t"]}' +VERSION_STRUCT = struct.Struct('iii') +SIMPLE_VERSION_STRUCT = struct.Struct('I') +VECTOR_SIZE_STRUCT = struct.Struct('=Q') +FIELD_TYPE_STRUCT = struct.Struct('b') +FLAG_STRUCT = struct.Struct('?') +KMAGIC_STRUCT = struct.Struct('I') diff --git a/dmatrix2np/dmatrix_stream_parser.py b/dmatrix2np/dmatrix_stream_parser.py new file mode 100644 index 0000000..f3bab0b --- /dev/null +++ b/dmatrix2np/dmatrix_stream_parser.py @@ -0,0 +1,69 @@ +import abc +import numpy as np +from struct import Struct +from .common import SIZE_T_DTYPE, data_type_sizes, VECTOR_SIZE_STRUCT + + +class DMatrixStreamParser(metaclass=abc.ABCMeta): + """Abstract base class for DMatrix stream parser.""" + + def __init__(self, buffer_reader, num_row, num_col): + self._handle = buffer_reader + self.num_row = num_row + self.num_col = num_col + self._offset_vector = [] + self._data_vector = {} + + @abc.abstractmethod + def parse(self) -> np.ndarray: + """Parse DMatrix to numpy 2d array + + Returns + ------- + np.ndarray + DMatrix values in numpy 2d array + """ + pass + + def _read_struct(self, s: Struct): + return s.unpack(self._handle.read(s.size)) + + def _parse_offset_vector(self): + offset_vector_size, = self._read_struct(VECTOR_SIZE_STRUCT) + self._offset_vector = np.frombuffer(buffer=self._handle.read(offset_vector_size * data_type_sizes['size_t']), + dtype=SIZE_T_DTYPE) + + def _parse_data_vector(self): + data_vector_size, = self._read_struct(VECTOR_SIZE_STRUCT) + data_vector_entry_size = data_type_sizes['uint32_t'] + data_type_sizes['float'] + self._data_vector = np.frombuffer(buffer=self._handle.read(data_vector_size * data_vector_entry_size), + dtype=np.dtype([('column_index', 'i4'), ('data', 'float32')])) + + def _get_nparray(self) -> np.ndarray: + """Generate 2d numpy array + + Returns + ------- + np.ndarray + dmatrix converted to 2d numpy array + """ + # When the matrix is flat, there are no values and matrix could be generated immediately + if self.num_row == 0 or self.num_col == 0: + return np.empty((self.num_row, self.num_col)) + + # Create flat matrix filled with nan values + matrix = np.nan * np.empty((self.num_row * self.num_col)) + + # The offset vector contains the offsets of the values in the data vector for each row according to its index + # We create size vector that contains the size of each row according to its index + size_vector = self._offset_vector[1:] - self._offset_vector[:-1] + + # Since we work with flat matrix we want to convert 2d index (x, y) to 1d index (x*num_col + y) + # The data vector only keep column index and data, + # increase vector is the addition needed for the column index to be 1d index + increase_vector = np.repeat(np.arange(0, size_vector.size * self.num_col, self.num_col), size_vector) + flat_indexes = self._data_vector['column_index'] + increase_vector + + # Values assignment + matrix[flat_indexes] = self._data_vector['data'] + return matrix.reshape((self.num_row, self.num_col)) diff --git a/dmatrix2np/dmatrix_to_numpy.py b/dmatrix2np/dmatrix_to_numpy.py new file mode 100644 index 0000000..2972684 --- /dev/null +++ b/dmatrix2np/dmatrix_to_numpy.py @@ -0,0 +1,47 @@ +import tempfile +import os +import xgboost as xgb +import numpy as np +from .dmatrix_v_1_0_0_stream_parser import DMatrixStreamParserV1_0_0 +from .dmatrix_v_0_80_stream_parser import DMatrixStreamParserV0_80 +from .exceptions import InvalidInput +from packaging import version +from contextlib import suppress + + +def dmatrix_to_numpy(dmatrix: xgb.DMatrix) -> np.ndarray: + """Convert DMatrix to 2d numpy array + + Parameters + ---------- + dmatrix : xgb.DMatrix + DMatrix to convert + + Returns + ------- + np.ndarray + 2d numpy array with the corresponding DMatrix feature values + + Raises + ------ + InvalidInput + Input is not a valid DMatrix + """ + if not isinstance(dmatrix, xgb.DMatrix): + raise InvalidInput("Type error: input parameter is not DMatrix") + + stream_parser = DMatrixStreamParserV0_80 if version.parse(xgb.__version__) < version.parse('1.0.0') \ + else DMatrixStreamParserV1_0_0 + + # We set delete=False to avoid permissions error. This way, file can be accessed + # by XGBoost without being deleted while handle is closed + try: + with tempfile.NamedTemporaryFile(delete=False) as fp: + dmatrix.save_binary(fp.name) + result = stream_parser(fp, dmatrix.num_row(), dmatrix.num_col()).parse() + finally: + # We can safely remove the temp file now, parsing process finished + with suppress(OSError): + os.remove(fp.name) + + return result diff --git a/dmatrix2np/dmatrix_v_0_80_stream_parser.py b/dmatrix2np/dmatrix_v_0_80_stream_parser.py new file mode 100644 index 0000000..2c57e87 --- /dev/null +++ b/dmatrix2np/dmatrix_v_0_80_stream_parser.py @@ -0,0 +1,55 @@ +from .dmatrix_stream_parser import DMatrixStreamParser +from .exceptions import InvalidStructure +from .common import data_type_sizes, SIMPLE_VERSION_STRUCT, VECTOR_SIZE_STRUCT, KMAGIC_STRUCT +from os import SEEK_CUR + + +class DMatrixStreamParserV0_80(DMatrixStreamParser): + + NUM_OF_SCALAR_FIELDS = 3 + kMagic = 0xffffab01 + kVersion = 2 + + def __init__(self, buffer_reader, num_row, num_col): + self._handle = buffer_reader + self.num_row = num_row + self.num_col = num_col + + def parse(self): + self._handle.seek(0) + self._parse_magic() + self._parse_version() + self._skip_fields() + self._parse_offset_vector() + self._parse_data_vector() + return self._get_nparray() + + def _parse_magic(self): + kMagic, = self._read_struct(KMAGIC_STRUCT) + if kMagic != self.kMagic: + raise InvalidStructure('Invalid magic') + + def _parse_version(self): + version, = self._read_struct(SIMPLE_VERSION_STRUCT) + if version != self.kVersion: + raise InvalidStructure('Invalid version') + + def _skip_fields(self): + # Skip num_row_, num_col_, num_nonzero_ (all uint64_t) + self._handle.seek(self.NUM_OF_SCALAR_FIELDS * data_type_sizes['uint64_t'], SEEK_CUR) + + # skip info's vector fields (labels_, group_ptr_, qids_, weights_, root_index_, base_margin_) + vectors_entry_sizes = [ + data_type_sizes['float'], # labels_ + data_type_sizes['uint32_t'], # group_ptr_ + data_type_sizes['uint64_t'], # qids_ + data_type_sizes['float'], # weights_ + data_type_sizes['uint32_t'], # root_index_ + data_type_sizes['float'], # base_margin_ + ] + + # Each vector field starts with uint64_t size indicator + # followed by number of vector entries equals to indicated size + for vector_entry_size in vectors_entry_sizes: + vector_size, = self._read_struct(VECTOR_SIZE_STRUCT) + self._handle.read(vector_size * vector_entry_size) diff --git a/dmatrix2np/dmatrix_v_1_0_0_stream_parser.py b/dmatrix2np/dmatrix_v_1_0_0_stream_parser.py new file mode 100644 index 0000000..cec63de --- /dev/null +++ b/dmatrix2np/dmatrix_v_1_0_0_stream_parser.py @@ -0,0 +1,61 @@ +from .dmatrix_stream_parser import DMatrixStreamParser +from .exceptions import InvalidStructure +from .common import (FieldDataType, data_type_sizes, FLAG_STRUCT, VERSION_STRUCT, VECTOR_SIZE_STRUCT, + FIELD_TYPE_STRUCT, KMAGIC_STRUCT) +from os import SEEK_CUR + + +class DMatrixStreamParserV1_0_0(DMatrixStreamParser): + + kMagic = 0xffffab01 + verstr = 'version:' + + def __init__(self, buffer_reader, num_row, num_col): + self._handle = buffer_reader + self.num_row = num_row + self.num_col = num_col + + def parse(self): + self._handle.seek(0) + self._parse_magic() + self._parse_version() + self._skip_fields() + self._parse_offset_vector() + self._parse_data_vector() + return self._get_nparray() + + def _parse_magic(self): + kMagic, = self._read_struct(KMAGIC_STRUCT) + if kMagic != self.kMagic: + raise InvalidStructure('Invalid magic') + + def _parse_version(self): + verstr = self._handle.read(len(self.verstr.encode())) + if verstr != self.verstr.encode(): + raise InvalidStructure('Invalid verstr') + self._version = self._read_struct(VERSION_STRUCT) + + def _skip_fields(self): + fields_count, = self._read_struct(VECTOR_SIZE_STRUCT) + for _ in range(fields_count): + self._skip_field() + + def _skip_field(self): + # Skip field name (pascal string) + name_size, = self._read_struct(VECTOR_SIZE_STRUCT) + self._handle.seek(name_size, SEEK_CUR) + + # Find field type + field_type = FieldDataType(self._read_struct(FIELD_TYPE_STRUCT)[0]) + is_scalar, = self._read_struct(FLAG_STRUCT) + + if is_scalar: + self._handle.seek(data_type_sizes[field_type.name], SEEK_CUR) + else: + # Skip shape.first, shape.second + self._handle.seek(2 * data_type_sizes['uint64_t'], SEEK_CUR) + + vector_size, = self._read_struct(VECTOR_SIZE_STRUCT) + + # Skip vector + self._handle.seek(vector_size * data_type_sizes[field_type.name], SEEK_CUR) diff --git a/dmatrix2np/exceptions.py b/dmatrix2np/exceptions.py new file mode 100644 index 0000000..bb17408 --- /dev/null +++ b/dmatrix2np/exceptions.py @@ -0,0 +1,17 @@ +class DMatrix2NpError(Exception): + '''Basic exception for errors raised by dmatrix2np''' + pass + + +class InvalidStructure(DMatrix2NpError): + def __init__(self, message='Invalid structure'): + super().__init__(message) + + +class UnsupportedVersion(DMatrix2NpError): + pass + + +class InvalidInput(DMatrix2NpError): + def __init__(self, message='Invalid input'): + super().__init__(message) diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..aa0f8aa --- /dev/null +++ b/poetry.lock @@ -0,0 +1,433 @@ +[[package]] +category = "dev" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +name = "appdirs" +optional = false +python-versions = "*" +version = "1.4.3" + +[[package]] +category = "dev" +description = "Atomic file writes." +marker = "sys_platform == \"win32\"" +name = "atomicwrites" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "1.4.0" + +[[package]] +category = "dev" +description = "Classes Without Boilerplate" +name = "attrs" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "19.3.0" + +[package.extras] +azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "pytest-azurepipelines"] +dev = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "pre-commit"] +docs = ["sphinx", "zope.interface"] +tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] + +[[package]] +category = "dev" +description = "Validate configuration and produce human readable error messages." +name = "cfgv" +optional = false +python-versions = ">=3.6.1" +version = "3.1.0" + +[[package]] +category = "dev" +description = "Cross-platform colored terminal text." +marker = "sys_platform == \"win32\"" +name = "colorama" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "0.4.3" + +[[package]] +category = "dev" +description = "Distribution utilities" +name = "distlib" +optional = false +python-versions = "*" +version = "0.3.0" + +[[package]] +category = "dev" +description = "A platform independent file lock." +name = "filelock" +optional = false +python-versions = "*" +version = "3.0.12" + +[[package]] +category = "dev" +description = "File identification library for Python" +name = "identify" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +version = "1.4.15" + +[package.extras] +license = ["editdistance"] + +[[package]] +category = "dev" +description = "Read metadata from Python packages" +marker = "python_version < \"3.8\"" +name = "importlib-metadata" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +version = "1.6.0" + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["sphinx", "rst.linker"] +testing = ["packaging", "importlib-resources"] + +[[package]] +category = "dev" +description = "Read resources from Python packages" +marker = "python_version < \"3.7\"" +name = "importlib-resources" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +version = "1.5.0" + +[package.dependencies] +[package.dependencies.importlib-metadata] +python = "<3.8" +version = "*" + +[package.dependencies.zipp] +python = "<3.8" +version = ">=0.4" + +[package.extras] +docs = ["sphinx", "rst.linker", "jaraco.packaging"] + +[[package]] +category = "dev" +description = "More routines for operating on iterables, beyond itertools" +name = "more-itertools" +optional = false +python-versions = ">=3.5" +version = "8.2.0" + +[[package]] +category = "dev" +description = "Node.js virtual environment builder" +name = "nodeenv" +optional = false +python-versions = "*" +version = "1.3.5" + +[[package]] +category = "main" +description = "NumPy is the fundamental package for array computing with Python." +name = "numpy" +optional = false +python-versions = ">=3.5" +version = "1.18.4" + +[[package]] +category = "dev" +description = "Core utilities for Python packages" +name = "packaging" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "20.3" + +[package.dependencies] +pyparsing = ">=2.0.2" +six = "*" + +[[package]] +category = "dev" +description = "plugin and hook calling mechanisms for python" +name = "pluggy" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "0.13.1" + +[package.dependencies] +[package.dependencies.importlib-metadata] +python = "<3.8" +version = ">=0.12" + +[package.extras] +dev = ["pre-commit", "tox"] + +[[package]] +category = "dev" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +name = "pre-commit" +optional = false +python-versions = ">=3.6.1" +version = "2.3.0" + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +toml = "*" +virtualenv = ">=15.2" + +[package.dependencies.importlib-metadata] +python = "<3.8" +version = "*" + +[package.dependencies.importlib-resources] +python = "<3.7" +version = "*" + +[[package]] +category = "dev" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +name = "py" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "1.8.1" + +[[package]] +category = "dev" +description = "Python parsing module" +name = "pyparsing" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +version = "2.4.7" + +[[package]] +category = "dev" +description = "pytest: simple powerful testing with Python" +name = "pytest" +optional = false +python-versions = ">=3.5" +version = "5.4.1" + +[package.dependencies] +atomicwrites = ">=1.0" +attrs = ">=17.4.0" +colorama = "*" +more-itertools = ">=4.0.0" +packaging = "*" +pluggy = ">=0.12,<1.0" +py = ">=1.5.0" +wcwidth = "*" + +[package.dependencies.importlib-metadata] +python = "<3.8" +version = ">=0.12" + +[package.extras] +checkqa-mypy = ["mypy (v0.761)"] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] + +[[package]] +category = "dev" +description = "YAML parser and emitter for Python" +name = "pyyaml" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "5.3.1" + +[[package]] +category = "dev" +description = "Python 2 and 3 compatibility utilities" +name = "six" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +version = "1.14.0" + +[[package]] +category = "dev" +description = "Python Library for Tom's Obvious, Minimal Language" +name = "toml" +optional = false +python-versions = "*" +version = "0.10.0" + +[[package]] +category = "dev" +description = "Virtual Python Environment builder" +name = "virtualenv" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +version = "20.0.20" + +[package.dependencies] +appdirs = ">=1.4.3,<2" +distlib = ">=0.3.0,<1" +filelock = ">=3.0.0,<4" +six = ">=1.9.0,<2" + +[package.dependencies.importlib-metadata] +python = "<3.8" +version = ">=0.12,<2" + +[package.dependencies.importlib-resources] +python = "<3.7" +version = ">=1.0,<2" + +[package.extras] +docs = ["sphinx (>=3)", "sphinx-argparse (>=0.2.5)", "sphinx-rtd-theme (>=0.4.3)", "towncrier (>=19.9.0rc1)", "proselint (>=0.10.2)"] +testing = ["pytest (>=4)", "coverage (>=5)", "coverage-enable-subprocess (>=1)", "pytest-xdist (>=1.31.0)", "pytest-mock (>=2)", "pytest-env (>=0.6.2)", "pytest-randomly (>=1)", "pytest-timeout", "packaging (>=20.0)", "xonsh (>=0.9.16)"] + +[[package]] +category = "dev" +description = "Measures number of Terminal column cells of wide-character codes" +name = "wcwidth" +optional = false +python-versions = "*" +version = "0.1.9" + +[[package]] +category = "dev" +description = "Backport of pathlib-compatible object wrapper for zip files" +marker = "python_version < \"3.8\"" +name = "zipp" +optional = false +python-versions = ">=2.7" +version = "1.2.0" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] +testing = ["pathlib2", "unittest2", "jaraco.itertools", "func-timeout"] + +[metadata] +content-hash = "8f5b2ea1333a8f74014432148e08771d748e6f8fad277680cfffa24b213fb332" +python-versions = "^3.6.1" + +[metadata.files] +appdirs = [ + {file = "appdirs-1.4.3-py2.py3-none-any.whl", hash = "sha256:d8b24664561d0d34ddfaec54636d502d7cea6e29c3eaf68f3df6180863e2166e"}, + {file = "appdirs-1.4.3.tar.gz", hash = "sha256:9e5896d1372858f8dd3344faf4e5014d21849c756c8d5701f78f8a103b372d92"}, +] +atomicwrites = [ + {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, + {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, +] +attrs = [ + {file = "attrs-19.3.0-py2.py3-none-any.whl", hash = "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c"}, + {file = "attrs-19.3.0.tar.gz", hash = "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"}, +] +cfgv = [ + {file = "cfgv-3.1.0-py2.py3-none-any.whl", hash = "sha256:1ccf53320421aeeb915275a196e23b3b8ae87dea8ac6698b1638001d4a486d53"}, + {file = "cfgv-3.1.0.tar.gz", hash = "sha256:c8e8f552ffcc6194f4e18dd4f68d9aef0c0d58ae7e7be8c82bee3c5e9edfa513"}, +] +colorama = [ + {file = "colorama-0.4.3-py2.py3-none-any.whl", hash = "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff"}, + {file = "colorama-0.4.3.tar.gz", hash = "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1"}, +] +distlib = [ + {file = "distlib-0.3.0.zip", hash = "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21"}, +] +filelock = [ + {file = "filelock-3.0.12-py3-none-any.whl", hash = "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"}, + {file = "filelock-3.0.12.tar.gz", hash = "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59"}, +] +identify = [ + {file = "identify-1.4.15-py2.py3-none-any.whl", hash = "sha256:88ed90632023e52a6495749c6732e61e08ec9f4f04e95484a5c37b9caf40283c"}, + {file = "identify-1.4.15.tar.gz", hash = "sha256:23c18d97bb50e05be1a54917ee45cc61d57cb96aedc06aabb2b02331edf0dbf0"}, +] +importlib-metadata = [ + {file = "importlib_metadata-1.6.0-py2.py3-none-any.whl", hash = "sha256:2a688cbaa90e0cc587f1df48bdc97a6eadccdcd9c35fb3f976a09e3b5016d90f"}, + {file = "importlib_metadata-1.6.0.tar.gz", hash = "sha256:34513a8a0c4962bc66d35b359558fd8a5e10cd472d37aec5f66858addef32c1e"}, +] +importlib-resources = [ + {file = "importlib_resources-1.5.0-py2.py3-none-any.whl", hash = "sha256:85dc0b9b325ff78c8bef2e4ff42616094e16b98ebd5e3b50fe7e2f0bbcdcde49"}, + {file = "importlib_resources-1.5.0.tar.gz", hash = "sha256:6f87df66833e1942667108628ec48900e02a4ab4ad850e25fbf07cb17cf734ca"}, +] +more-itertools = [ + {file = "more-itertools-8.2.0.tar.gz", hash = "sha256:b1ddb932186d8a6ac451e1d95844b382f55e12686d51ca0c68b6f61f2ab7a507"}, + {file = "more_itertools-8.2.0-py3-none-any.whl", hash = "sha256:5dd8bcf33e5f9513ffa06d5ad33d78f31e1931ac9a18f33d37e77a180d393a7c"}, +] +nodeenv = [ + {file = "nodeenv-1.3.5-py2.py3-none-any.whl", hash = "sha256:5b2438f2e42af54ca968dd1b374d14a1194848955187b0e5e4be1f73813a5212"}, +] +numpy = [ + {file = "numpy-1.18.4-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:efdba339fffb0e80fcc19524e4fdbda2e2b5772ea46720c44eaac28096d60720"}, + {file = "numpy-1.18.4-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:2b573fcf6f9863ce746e4ad00ac18a948978bb3781cffa4305134d31801f3e26"}, + {file = "numpy-1.18.4-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:3f0dae97e1126f529ebb66f3c63514a0f72a177b90d56e4bce8a0b5def34627a"}, + {file = "numpy-1.18.4-cp35-cp35m-win32.whl", hash = "sha256:dccd380d8e025c867ddcb2f84b439722cf1f23f3a319381eac45fd077dee7170"}, + {file = "numpy-1.18.4-cp35-cp35m-win_amd64.whl", hash = "sha256:02ec9582808c4e48be4e93cd629c855e644882faf704bc2bd6bbf58c08a2a897"}, + {file = "numpy-1.18.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:904b513ab8fbcbdb062bed1ce2f794ab20208a1b01ce9bd90776c6c7e7257032"}, + {file = "numpy-1.18.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:e22cd0f72fc931d6abc69dc7764484ee20c6a60b0d0fee9ce0426029b1c1bdae"}, + {file = "numpy-1.18.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:2466fbcf23711ebc5daa61d28ced319a6159b260a18839993d871096d66b93f7"}, + {file = "numpy-1.18.4-cp36-cp36m-win32.whl", hash = "sha256:00d7b54c025601e28f468953d065b9b121ddca7fff30bed7be082d3656dd798d"}, + {file = "numpy-1.18.4-cp36-cp36m-win_amd64.whl", hash = "sha256:7d59f21e43bbfd9a10953a7e26b35b6849d888fc5a331fa84a2d9c37bd9fe2a2"}, + {file = "numpy-1.18.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:efb7ac5572c9a57159cf92c508aad9f856f1cb8e8302d7fdb99061dbe52d712c"}, + {file = "numpy-1.18.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:0e6f72f7bb08f2f350ed4408bb7acdc0daba637e73bce9f5ea2b207039f3af88"}, + {file = "numpy-1.18.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:9933b81fecbe935e6a7dc89cbd2b99fea1bf362f2790daf9422a7bb1dc3c3085"}, + {file = "numpy-1.18.4-cp37-cp37m-win32.whl", hash = "sha256:96dd36f5cdde152fd6977d1bbc0f0561bccffecfde63cd397c8e6033eb66baba"}, + {file = "numpy-1.18.4-cp37-cp37m-win_amd64.whl", hash = "sha256:57aea170fb23b1fd54fa537359d90d383d9bf5937ee54ae8045a723caa5e0961"}, + {file = "numpy-1.18.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ed722aefb0ebffd10b32e67f48e8ac4c5c4cf5d3a785024fdf0e9eb17529cd9d"}, + {file = "numpy-1.18.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:50fb72bcbc2cf11e066579cb53c4ca8ac0227abb512b6cbc1faa02d1595a2a5d"}, + {file = "numpy-1.18.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:709c2999b6bd36cdaf85cf888d8512da7433529f14a3689d6e37ab5242e7add5"}, + {file = "numpy-1.18.4-cp38-cp38-win32.whl", hash = "sha256:f22273dd6a403ed870207b853a856ff6327d5cbce7a835dfa0645b3fc00273ec"}, + {file = "numpy-1.18.4-cp38-cp38-win_amd64.whl", hash = "sha256:1be2e96314a66f5f1ce7764274327fd4fb9da58584eaff00b5a5221edefee7d6"}, + {file = "numpy-1.18.4.zip", hash = "sha256:bbcc85aaf4cd84ba057decaead058f43191cc0e30d6bc5d44fe336dc3d3f4509"}, +] +packaging = [ + {file = "packaging-20.3-py2.py3-none-any.whl", hash = "sha256:82f77b9bee21c1bafbf35a84905d604d5d1223801d639cf3ed140bd651c08752"}, + {file = "packaging-20.3.tar.gz", hash = "sha256:3c292b474fda1671ec57d46d739d072bfd495a4f51ad01a055121d81e952b7a3"}, +] +pluggy = [ + {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, + {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, +] +pre-commit = [ + {file = "pre_commit-2.3.0-py2.py3-none-any.whl", hash = "sha256:979b53dab1af35063a483bfe13b0fcbbf1a2cf8c46b60e0a9a8d08e8269647a1"}, + {file = "pre_commit-2.3.0.tar.gz", hash = "sha256:f3e85e68c6d1cbe7828d3471896f1b192cfcf1c4d83bf26e26beeb5941855257"}, +] +py = [ + {file = "py-1.8.1-py2.py3-none-any.whl", hash = "sha256:c20fdd83a5dbc0af9efd622bee9a5564e278f6380fffcacc43ba6f43db2813b0"}, + {file = "py-1.8.1.tar.gz", hash = "sha256:5e27081401262157467ad6e7f851b7aa402c5852dbcb3dae06768434de5752aa"}, +] +pyparsing = [ + {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, + {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, +] +pytest = [ + {file = "pytest-5.4.1-py3-none-any.whl", hash = "sha256:0e5b30f5cb04e887b91b1ee519fa3d89049595f428c1db76e73bd7f17b09b172"}, + {file = "pytest-5.4.1.tar.gz", hash = "sha256:84dde37075b8805f3d1f392cc47e38a0e59518fb46a431cfdaf7cf1ce805f970"}, +] +pyyaml = [ + {file = "PyYAML-5.3.1-cp27-cp27m-win32.whl", hash = "sha256:74809a57b329d6cc0fdccee6318f44b9b8649961fa73144a98735b0aaf029f1f"}, + {file = "PyYAML-5.3.1-cp27-cp27m-win_amd64.whl", hash = "sha256:240097ff019d7c70a4922b6869d8a86407758333f02203e0fc6ff79c5dcede76"}, + {file = "PyYAML-5.3.1-cp35-cp35m-win32.whl", hash = "sha256:4f4b913ca1a7319b33cfb1369e91e50354d6f07a135f3b901aca02aa95940bd2"}, + {file = "PyYAML-5.3.1-cp35-cp35m-win_amd64.whl", hash = "sha256:cc8955cfbfc7a115fa81d85284ee61147059a753344bc51098f3ccd69b0d7e0c"}, + {file = "PyYAML-5.3.1-cp36-cp36m-win32.whl", hash = "sha256:7739fc0fa8205b3ee8808aea45e968bc90082c10aef6ea95e855e10abf4a37b2"}, + {file = "PyYAML-5.3.1-cp36-cp36m-win_amd64.whl", hash = "sha256:69f00dca373f240f842b2931fb2c7e14ddbacd1397d57157a9b005a6a9942648"}, + {file = "PyYAML-5.3.1-cp37-cp37m-win32.whl", hash = "sha256:d13155f591e6fcc1ec3b30685d50bf0711574e2c0dfffd7644babf8b5102ca1a"}, + {file = "PyYAML-5.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:73f099454b799e05e5ab51423c7bcf361c58d3206fa7b0d555426b1f4d9a3eaf"}, + {file = "PyYAML-5.3.1-cp38-cp38-win32.whl", hash = "sha256:06a0d7ba600ce0b2d2fe2e78453a470b5a6e000a985dd4a4e54e436cc36b0e97"}, + {file = "PyYAML-5.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:95f71d2af0ff4227885f7a6605c37fd53d3a106fcab511b8860ecca9fcf400ee"}, + {file = "PyYAML-5.3.1.tar.gz", hash = "sha256:b8eac752c5e14d3eca0e6dd9199cd627518cb5ec06add0de9d32baeee6fe645d"}, +] +six = [ + {file = "six-1.14.0-py2.py3-none-any.whl", hash = "sha256:8f3cd2e254d8f793e7f3d6d9df77b92252b52637291d0f0da013c76ea2724b6c"}, + {file = "six-1.14.0.tar.gz", hash = "sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a"}, +] +toml = [ + {file = "toml-0.10.0-py2.7.egg", hash = "sha256:f1db651f9657708513243e61e6cc67d101a39bad662eaa9b5546f789338e07a3"}, + {file = "toml-0.10.0-py2.py3-none-any.whl", hash = "sha256:235682dd292d5899d361a811df37e04a8828a5b1da3115886b73cf81ebc9100e"}, + {file = "toml-0.10.0.tar.gz", hash = "sha256:229f81c57791a41d65e399fc06bf0848bab550a9dfd5ed66df18ce5f05e73d5c"}, +] +virtualenv = [ + {file = "virtualenv-20.0.20-py2.py3-none-any.whl", hash = "sha256:b4c14d4d73a0c23db267095383c4276ef60e161f94fde0427f2f21a0132dde74"}, + {file = "virtualenv-20.0.20.tar.gz", hash = "sha256:fd0e54dec8ac96c1c7c87daba85f0a59a7c37fe38748e154306ca21c73244637"}, +] +wcwidth = [ + {file = "wcwidth-0.1.9-py2.py3-none-any.whl", hash = "sha256:cafe2186b3c009a04067022ce1dcd79cb38d8d65ee4f4791b8888d6599d1bbe1"}, + {file = "wcwidth-0.1.9.tar.gz", hash = "sha256:ee73862862a156bf77ff92b09034fc4825dd3af9cf81bc5b360668d425f3c5f1"}, +] +zipp = [ + {file = "zipp-1.2.0-py2.py3-none-any.whl", hash = "sha256:e0d9e63797e483a30d27e09fffd308c59a700d365ec34e93cc100844168bf921"}, + {file = "zipp-1.2.0.tar.gz", hash = "sha256:c70410551488251b0fee67b460fb9a536af8d6f9f008ad10ac51f615b6a521b1"}, +] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0cdd62e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[tool.poetry] +name = "dmatrix2np" +version = "0.1.0" +description = "Convert XGBoost's DMatrix to numpy array" +authors = ["Aporia technologies ltd "] +license = "GPL-3.0" + +[tool.poetry.dependencies] +python = "^3.6.1" +numpy = "*" + +[tool.poetry.dev-dependencies] +pytest = "^5.4.1" +pre-commit = "^2.3.0" + +[build-system] +requires = ["poetry>=0.12"] +build-backend = "poetry.masonry.api" diff --git a/test/test_dmatrix2np.py b/test/test_dmatrix2np.py new file mode 100644 index 0000000..73c8ae9 --- /dev/null +++ b/test/test_dmatrix2np.py @@ -0,0 +1,100 @@ +import unittest +import numpy as np +import xgboost as xgb +from dmatrix2np import dmatrix_to_numpy, InvalidInput +from packaging import version + + +class TestDmatrix2Numpy(unittest.TestCase): + + def test_none(self): + with self.assertRaises(InvalidInput): + dmatrix_to_numpy(None) + + def test_simple_matrix(self): + ndarr = np.array([[1, 2], [3,4]]) + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_empty_matrices(self): + ndarr = np.empty((0, 0)) + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + ndarr = np.empty((1, 0)) + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + ndarr = np.empty((0, 1)) + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_nan_matrix(self): + ndarr = np.nan * np.empty((2, 2)) + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_large_random_matrix(self): + ndarr = np.random.rand(1000, 1000).astype(np.float32) + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_nan_row_matrix(self): + ndarr = np.random.rand(100, 100).astype(np.float32) + ndarr[10] = np.nan + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_nan_col_matrix(self): + ndarr = np.random.rand(100, 100).astype(np.float32) + ndarr = ndarr.T + ndarr[10] = np.nan + ndarr = ndarr.T + dmat = xgb.DMatrix(ndarr) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_dmatrix_with_labels(self): + ndarr = np.random.rand(100, 100).astype(np.float32) + label = np.random.rand(100, 1).astype(np.float32) + dmat = xgb.DMatrix(ndarr, label=label) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_dmatrix_with_weight(self): + ndarr = np.random.rand(100, 100).astype(np.float32) + weight = np.random.rand(100, 1).astype(np.float32) + dmat = xgb.DMatrix(ndarr, weight=weight) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_dmatrix_with_base_margin(self): + if version.parse(xgb.__version__) < version.parse('1.0.0'): + return True + ndarr = np.random.rand(100, 100).astype(np.float32) + base_margin = np.random.rand(100, 1).astype(np.float32) + dmat = xgb.DMatrix(ndarr, base_margin=base_margin) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_complex_dmatrix(self): + ndarr = np.random.rand(100, 100).astype(np.float32) + ndarr[10] = np.nan + ndarr = ndarr.T + ndarr[10] = np.nan + ndarr = ndarr.T + label = np.random.rand(100, 1).astype(np.float32) + weight = np.random.rand(100, 1).astype(np.float32) + base_margin = np.random.rand(100, 1).astype(np.float32) + if version.parse(xgb.__version__) < version.parse('1.0.0'): + dmat = xgb.DMatrix(ndarr, label=label, weight=weight) + else: + dmat = xgb.DMatrix(ndarr, label=label, weight=weight, base_margin=base_margin) + np.testing.assert_equal(dmatrix_to_numpy(dmat), ndarr) + + def test_simple_vector(self): + vector = np.array([1, 2]) + with self.assertRaises(ValueError): + xgb.DMatrix(vector) + + def test_unsupported_shapes(self): + shapes = [0, (2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)] + for shape in shapes: + with self.assertRaises(ValueError): + xgb.DMatrix(np.empty(shape)) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..321fff2 --- /dev/null +++ b/tox.ini @@ -0,0 +1,20 @@ +[tox] +isolated_build = true +envlist = py{36,37,38}-xgboost{08,09,10} + +[gh-actions] +python = + 3.6: py36 + 3.7: py37 + 3.8: py38 + +[testenv] +whitelist_externals = poetry +deps = + xgboost08: xgboost >=0.80, <0.90 + xgboost09: xgboost >=0.90, <1.0 + xgboost10: xgboost >=1.0, <1.1 + +commands = + poetry install --no-interaction --no-ansi + poetry run pytest