-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from PySport/datasets
Add datasets loader and bump version
- Loading branch information
Showing
17 changed files
with
203 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from kloppy import datasets, to_pandas | ||
|
||
|
||
def main(): | ||
""" | ||
This example shows the use of Metrica datasets, and how we can pass argument | ||
to the dataset loader. | ||
""" | ||
|
||
# The metrica dataset loader loads by default the 'game1' dataset | ||
data_set = datasets.load("metrica_tracking", options={'sample_rate': 1./12, 'limit': 10}) | ||
print(len(data_set.frames)) | ||
|
||
# We can pass additional keyword arguments to the loaders to specify a different dataset | ||
data_set = datasets.load("metrica_tracking", options={'limit': 1000}, game='game2') | ||
|
||
data_frame = to_pandas(data_set) | ||
print(data_frame) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .infra.serializers import * | ||
from .helpers import * | ||
from .infra import datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# import for registration | ||
from . import tracking | ||
|
||
from .core.loading import load | ||
|
||
__all__ = [ | ||
'load' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .builder import DatasetBuilder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from abc import abstractmethod | ||
from typing import Dict, Type, Union | ||
|
||
from ...serializers.tracking import TrackingDataSerializer | ||
from .registered import RegisteredDataset | ||
|
||
|
||
class DatasetBuilder(metaclass=RegisteredDataset): | ||
@abstractmethod | ||
def get_data_set_files(self, **kwargs) -> Dict[str, Dict[str, str]]: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_serializer_cls(self) -> Union[Type[TrackingDataSerializer]]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
|
||
import requests | ||
|
||
from typing import Dict, Union | ||
|
||
from kloppy.domain import DataSet, TrackingDataSet | ||
|
||
from .registered import _DATASET_REGISTRY | ||
|
||
|
||
def download_file(url, local_filename): | ||
with requests.get(url, stream=True) as r: | ||
r.raise_for_status() | ||
with open(local_filename, 'wb') as f: | ||
for chunk in r.iter_content(chunk_size=8192): | ||
f.write(chunk) | ||
|
||
|
||
def get_local_files(data_set_name: str, files: Dict[str, str]) -> Dict[str, str]: | ||
datasets_base_dir = os.environ.get('KLOPPY_BASE_DIR', None) | ||
if not datasets_base_dir: | ||
datasets_base_dir = os.path.expanduser('~/kloppy_datasets') | ||
|
||
dataset_base_dir = f'{datasets_base_dir}/{data_set_name}' | ||
if not os.path.exists(dataset_base_dir): | ||
os.makedirs(dataset_base_dir) | ||
|
||
local_files = {} | ||
for file_key, file_url in files.items(): | ||
filename = file_url.split('/')[-1] | ||
local_filename = f'{dataset_base_dir}/{filename}' | ||
if not os.path.exists(local_filename): | ||
print(f'Downloading {filename}...') | ||
download_file(file_url, local_filename) | ||
print('Done') | ||
local_files[file_key] = local_filename | ||
return local_files | ||
|
||
|
||
def load(data_set_name: str, options=None, **dataset_kwargs) -> Union[TrackingDataSet]: | ||
if data_set_name not in _DATASET_REGISTRY: | ||
raise ValueError(f"Dataset {data_set_name} not found") | ||
|
||
builder_cls = _DATASET_REGISTRY[data_set_name] | ||
builder = builder_cls() | ||
|
||
dataset_remote_files = builder.get_data_set_files(**dataset_kwargs) | ||
dataset_local_files = get_local_files(data_set_name, dataset_remote_files) | ||
|
||
file_handlers = { | ||
local_file_key: open(local_file_name, 'rb') | ||
for local_file_key, local_file_name | ||
in dataset_local_files.items() | ||
} | ||
|
||
try: | ||
serializer_cls = builder.get_serializer_cls() | ||
serializer = serializer_cls() | ||
data_set = serializer.deserialize( | ||
inputs=file_handlers, | ||
options=options | ||
) | ||
finally: | ||
for fp in file_handlers.values(): | ||
fp.close() | ||
return data_set |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import inspect | ||
import re | ||
import abc | ||
from typing import Type, Dict | ||
|
||
|
||
_first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)") | ||
_all_cap_re = re.compile("([a-z0-9])([A-Z])") | ||
|
||
# from .builder import DatasetBuilder | ||
_DATASET_REGISTRY: Dict[str, Type['DatasetBuilder']] = {} | ||
|
||
|
||
def camelcase_to_snakecase(name): | ||
"""Convert camel-case string to snake-case.""" | ||
s1 = _first_cap_re.sub(r"\1_\2", name) | ||
return _all_cap_re.sub(r"\1_\2", s1).lower() | ||
|
||
|
||
class RegisteredDataset(abc.ABCMeta): | ||
def __new__(mcs, cls_name, bases, class_dict): | ||
name = camelcase_to_snakecase(cls_name) | ||
class_dict["name"] = name | ||
builder_cls = super(RegisteredDataset, mcs).__new__(mcs, cls_name, bases, class_dict) | ||
if not inspect.isabstract(builder_cls): | ||
_DATASET_REGISTRY[name] = builder_cls | ||
return builder_cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .metrica import MetricaTracking |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Dict, Type | ||
|
||
from ..core.builder import DatasetBuilder | ||
from ...serializers.tracking import TrackingDataSerializer, MetricaTrackingSerializer | ||
|
||
|
||
_DATASET_URLS = { | ||
'game1': { | ||
'raw_data_home': 'https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_1/Sample_Game_1_RawTrackingData_Home_Team.csv', | ||
'raw_data_away': 'https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_1/Sample_Game_1_RawTrackingData_Away_Team.csv' | ||
}, | ||
'game2': { | ||
'raw_data_home': 'https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_2/Sample_Game_2_RawTrackingData_Home_Team.csv', | ||
'raw_data_away': 'https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_2/Sample_Game_2_RawTrackingData_Away_Team.csv' | ||
} | ||
} | ||
|
||
|
||
class MetricaTracking(DatasetBuilder): | ||
def get_data_set_files(self,**kwargs) -> Dict[str, str]: | ||
game = kwargs.get('game', 'game1') | ||
return _DATASET_URLS[game] | ||
|
||
def get_serializer_cls(self) -> Type[TrackingDataSerializer]: | ||
return MetricaTrackingSerializer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
|
||
setup( | ||
name='kloppy', | ||
version='0.2.1', | ||
version='0.3.0', | ||
author='Koen Vossen', | ||
author_email='[email protected]', | ||
url="https://github.com/PySport/kloppy", | ||
|
@@ -26,7 +26,8 @@ | |
"Topic :: Scientific/Engineering" | ||
], | ||
install_requires=[ | ||
'lxml>=4.5.0' | ||
'lxml>=4.5.0', | ||
'requests>=2.0.0' | ||
], | ||
extras_require={ | ||
'test': [ | ||
|