-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init poetry package * create an empty __init__ * add tox and linter * lots of devops * remove x86 arch * update release yml * install poetry * separate dmatrix2np function to a different file * readme * readme * readme * Update README.md * rename function to dmatrix_to_numpy * Add support in v1.0.2 parsing * Support version 0.9.0 * Add versions support + minor fixes * Add tests * CR + Test fixes * CR Fixes * Add docstrings * Complete docstring * CR Fixes * Update git ignore file Co-authored-by: Alon Gubkin <[email protected]>
- Loading branch information
1 parent
237c8f1
commit c226f0c
Showing
18 changed files
with
980 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[flake8] | ||
max-line-length = 120 | ||
ignore = F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name: Release | ||
on: | ||
release: | ||
types: [published] | ||
jobs: | ||
release: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: actions/setup-python@v1 | ||
with: | ||
python-version: '3.7' | ||
architecture: x64 | ||
- run: pip install tox poetry | ||
- run: tox | ||
- run: poetry build | ||
- run: poetry publish --username=__token__ --password=${{ secrets.PYPI_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
name: Test | ||
on: [push] | ||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
max-parallel: 8 | ||
matrix: | ||
python-version: [3.6, 3.7, 3.8] | ||
steps: | ||
- uses: actions/checkout@v1 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v1 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install tox tox-gh-actions poetry | ||
- name: Test with tox | ||
run: tox | ||
env: | ||
PLATFORM: ${{ matrix.platform }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,9 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# VSCode | ||
.vscode/ | ||
|
||
# InteliJ | ||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
repos: | ||
- repo: git://github.com/pre-commit/pre-commit-hooks | ||
rev: v2.1.0 | ||
hooks: | ||
- id: end-of-file-fixer | ||
exclude: ^docs/.*$ | ||
- id: trailing-whitespace | ||
exclude: README.md | ||
- id: flake8 | ||
|
||
- repo: https://github.com/pre-commit/mirrors-yapf | ||
rev: v0.29.0 | ||
hooks: | ||
- id: yapf | ||
args: ['--style=.style.yapf', '--parallel', '--in-place'] | ||
|
||
- repo: [email protected]:humitos/mirrors-autoflake.git | ||
rev: v1.3 | ||
hooks: | ||
- id: autoflake | ||
args: ['--in-place', '--remove-all-unused-imports', '--ignore-init-module-imports'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[style] | ||
indent_width = 4 | ||
column_limit = 110 | ||
allow_split_before_dict_value = false | ||
split_before_expression_after_opening_paren = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,32 @@ | ||
# dmatrix2np | ||
# dmatrix2np | ||
|
||
[![Tests](https://github.com/aporia-ai/dmatrix2np/workflows/Test/badge.svg)](https://github.com/aporia-ai/dmatrix2np/actions?workflow=Test) [![PyPI](https://img.shields.io/pypi/v/dmatrix2np.svg)](https://pypi.org/project/dmatrix2np/) | ||
|
||
Convert XGBoost's DMatrix format to np.array. | ||
|
||
## Usage | ||
|
||
To install the library, run: | ||
|
||
pip install dmatrix2np | ||
|
||
Then, you can call in your code: | ||
|
||
from dmatrix2np import dmatrix2np | ||
|
||
converted_np_array = dmatrix_to_numpy(dmatrix) | ||
|
||
## Development | ||
|
||
We use [poetry](https://python-poetry.org/) for development: | ||
|
||
pip install poetry | ||
|
||
To install all dependencies and run tests: | ||
|
||
poetry run pytest | ||
|
||
To run tests on the entire matrix (Python 3.6, 3.7, 3.8 + XGBoost 0.80, 0.90, 1.0): | ||
|
||
pip install tox | ||
tox |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from .dmatrix_to_numpy import dmatrix_to_numpy | ||
from .exceptions import DMatrix2NpError, InvalidStructure, UnsupportedVersion, InvalidInput | ||
|
||
# Single source the package version | ||
try: | ||
from importlib.metadata import version, PackageNotFoundError # type: ignore | ||
except ImportError: # pragma: no cover | ||
from importlib_metadata import version, PackageNotFoundError # type: ignore | ||
try: | ||
__version__ = version(__name__) | ||
except PackageNotFoundError: # pragma: no cover | ||
__version__ = "unknown" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from enum import Enum | ||
import struct | ||
import sys | ||
|
||
|
||
class FieldDataType(Enum): | ||
""" | ||
This Enum provides an integer translation for the data type corresponding to the 'DataType' enum | ||
on '/include/xgboost/data.h' file in the XGBoost project | ||
""" | ||
kFloat32 = 1 | ||
kDouble = 2 | ||
kUInt32 = 3 | ||
kUInt64 = 4 | ||
|
||
|
||
# Dictionary of data types size in bytes | ||
data_type_sizes = { | ||
'kFloat32': 4, | ||
'kDouble': struct.calcsize("d"), | ||
'kUInt32': 4, | ||
'kUInt64': 8, | ||
'bool': 1, | ||
'uint8_t': 1, | ||
'int32_t': 4, | ||
'uint32_t': 4, | ||
'uint64_t': 8, | ||
'int': struct.calcsize("i"), | ||
'float': struct.calcsize("f"), | ||
'double': struct.calcsize("d"), | ||
'size_t': struct.calcsize("N"), | ||
} | ||
|
||
|
||
SIZE_T_DTYPE = f'{"<" if sys.byteorder == "little" else ">"}i{data_type_sizes["size_t"]}' | ||
VERSION_STRUCT = struct.Struct('iii') | ||
SIMPLE_VERSION_STRUCT = struct.Struct('I') | ||
VECTOR_SIZE_STRUCT = struct.Struct('=Q') | ||
FIELD_TYPE_STRUCT = struct.Struct('b') | ||
FLAG_STRUCT = struct.Struct('?') | ||
KMAGIC_STRUCT = struct.Struct('I') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import abc | ||
import numpy as np | ||
from struct import Struct | ||
from .common import SIZE_T_DTYPE, data_type_sizes, VECTOR_SIZE_STRUCT | ||
|
||
|
||
class DMatrixStreamParser(metaclass=abc.ABCMeta): | ||
"""Abstract base class for DMatrix stream parser.""" | ||
|
||
def __init__(self, buffer_reader, num_row, num_col): | ||
self._handle = buffer_reader | ||
self.num_row = num_row | ||
self.num_col = num_col | ||
self._offset_vector = [] | ||
self._data_vector = {} | ||
|
||
@abc.abstractmethod | ||
def parse(self) -> np.ndarray: | ||
"""Parse DMatrix to numpy 2d array | ||
Returns | ||
------- | ||
np.ndarray | ||
DMatrix values in numpy 2d array | ||
""" | ||
pass | ||
|
||
def _read_struct(self, s: Struct): | ||
return s.unpack(self._handle.read(s.size)) | ||
|
||
def _parse_offset_vector(self): | ||
offset_vector_size, = self._read_struct(VECTOR_SIZE_STRUCT) | ||
self._offset_vector = np.frombuffer(buffer=self._handle.read(offset_vector_size * data_type_sizes['size_t']), | ||
dtype=SIZE_T_DTYPE) | ||
|
||
def _parse_data_vector(self): | ||
data_vector_size, = self._read_struct(VECTOR_SIZE_STRUCT) | ||
data_vector_entry_size = data_type_sizes['uint32_t'] + data_type_sizes['float'] | ||
self._data_vector = np.frombuffer(buffer=self._handle.read(data_vector_size * data_vector_entry_size), | ||
dtype=np.dtype([('column_index', 'i4'), ('data', 'float32')])) | ||
|
||
def _get_nparray(self) -> np.ndarray: | ||
"""Generate 2d numpy array | ||
Returns | ||
------- | ||
np.ndarray | ||
dmatrix converted to 2d numpy array | ||
""" | ||
# When the matrix is flat, there are no values and matrix could be generated immediately | ||
if self.num_row == 0 or self.num_col == 0: | ||
return np.empty((self.num_row, self.num_col)) | ||
|
||
# Create flat matrix filled with nan values | ||
matrix = np.nan * np.empty((self.num_row * self.num_col)) | ||
|
||
# The offset vector contains the offsets of the values in the data vector for each row according to its index | ||
# We create size vector that contains the size of each row according to its index | ||
size_vector = self._offset_vector[1:] - self._offset_vector[:-1] | ||
|
||
# Since we work with flat matrix we want to convert 2d index (x, y) to 1d index (x*num_col + y) | ||
# The data vector only keep column index and data, | ||
# increase vector is the addition needed for the column index to be 1d index | ||
increase_vector = np.repeat(np.arange(0, size_vector.size * self.num_col, self.num_col), size_vector) | ||
flat_indexes = self._data_vector['column_index'] + increase_vector | ||
|
||
# Values assignment | ||
matrix[flat_indexes] = self._data_vector['data'] | ||
return matrix.reshape((self.num_row, self.num_col)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import tempfile | ||
import os | ||
import xgboost as xgb | ||
import numpy as np | ||
from .dmatrix_v_1_0_0_stream_parser import DMatrixStreamParserV1_0_0 | ||
from .dmatrix_v_0_80_stream_parser import DMatrixStreamParserV0_80 | ||
from .exceptions import InvalidInput | ||
from packaging import version | ||
from contextlib import suppress | ||
|
||
|
||
def dmatrix_to_numpy(dmatrix: xgb.DMatrix) -> np.ndarray: | ||
"""Convert DMatrix to 2d numpy array | ||
Parameters | ||
---------- | ||
dmatrix : xgb.DMatrix | ||
DMatrix to convert | ||
Returns | ||
------- | ||
np.ndarray | ||
2d numpy array with the corresponding DMatrix feature values | ||
Raises | ||
------ | ||
InvalidInput | ||
Input is not a valid DMatrix | ||
""" | ||
if not isinstance(dmatrix, xgb.DMatrix): | ||
raise InvalidInput("Type error: input parameter is not DMatrix") | ||
|
||
stream_parser = DMatrixStreamParserV0_80 if version.parse(xgb.__version__) < version.parse('1.0.0') \ | ||
else DMatrixStreamParserV1_0_0 | ||
|
||
# We set delete=False to avoid permissions error. This way, file can be accessed | ||
# by XGBoost without being deleted while handle is closed | ||
try: | ||
with tempfile.NamedTemporaryFile(delete=False) as fp: | ||
dmatrix.save_binary(fp.name) | ||
result = stream_parser(fp, dmatrix.num_row(), dmatrix.num_col()).parse() | ||
finally: | ||
# We can safely remove the temp file now, parsing process finished | ||
with suppress(OSError): | ||
os.remove(fp.name) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from .dmatrix_stream_parser import DMatrixStreamParser | ||
from .exceptions import InvalidStructure | ||
from .common import data_type_sizes, SIMPLE_VERSION_STRUCT, VECTOR_SIZE_STRUCT, KMAGIC_STRUCT | ||
from os import SEEK_CUR | ||
|
||
|
||
class DMatrixStreamParserV0_80(DMatrixStreamParser): | ||
|
||
NUM_OF_SCALAR_FIELDS = 3 | ||
kMagic = 0xffffab01 | ||
kVersion = 2 | ||
|
||
def __init__(self, buffer_reader, num_row, num_col): | ||
self._handle = buffer_reader | ||
self.num_row = num_row | ||
self.num_col = num_col | ||
|
||
def parse(self): | ||
self._handle.seek(0) | ||
self._parse_magic() | ||
self._parse_version() | ||
self._skip_fields() | ||
self._parse_offset_vector() | ||
self._parse_data_vector() | ||
return self._get_nparray() | ||
|
||
def _parse_magic(self): | ||
kMagic, = self._read_struct(KMAGIC_STRUCT) | ||
if kMagic != self.kMagic: | ||
raise InvalidStructure('Invalid magic') | ||
|
||
def _parse_version(self): | ||
version, = self._read_struct(SIMPLE_VERSION_STRUCT) | ||
if version != self.kVersion: | ||
raise InvalidStructure('Invalid version') | ||
|
||
def _skip_fields(self): | ||
# Skip num_row_, num_col_, num_nonzero_ (all uint64_t) | ||
self._handle.seek(self.NUM_OF_SCALAR_FIELDS * data_type_sizes['uint64_t'], SEEK_CUR) | ||
|
||
# skip info's vector fields (labels_, group_ptr_, qids_, weights_, root_index_, base_margin_) | ||
vectors_entry_sizes = [ | ||
data_type_sizes['float'], # labels_ | ||
data_type_sizes['uint32_t'], # group_ptr_ | ||
data_type_sizes['uint64_t'], # qids_ | ||
data_type_sizes['float'], # weights_ | ||
data_type_sizes['uint32_t'], # root_index_ | ||
data_type_sizes['float'], # base_margin_ | ||
] | ||
|
||
# Each vector field starts with uint64_t size indicator | ||
# followed by number of vector entries equals to indicated size | ||
for vector_entry_size in vectors_entry_sizes: | ||
vector_size, = self._read_struct(VECTOR_SIZE_STRUCT) | ||
self._handle.read(vector_size * vector_entry_size) |
Oops, something went wrong.