Skip to content

Commit

Permalink
A tiny fix in PyDataProvider2
Browse files Browse the repository at this point in the history
* hidden decorator kwargs in DataProvider.__init__
* also add unit test for this.
  • Loading branch information
reyoung committed Dec 21, 2016
1 parent 2965df5 commit 4d81b36
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion paddle/gserver/tests/test_PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddle.trainer.PyDataProvider2 import *


@provider(input_types=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
@provider(slots=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
def test_dense_no_seq(setting, filename):
for i in xrange(200):
yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)]
Expand Down
19 changes: 12 additions & 7 deletions python/paddle/trainer/PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def provider(input_types=None,
check=False,
check_fail_continue=False,
init_hook=None,
**kwargs):
**outter_kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
In this function, user only need to get each sample for some train/test
Expand Down Expand Up @@ -318,11 +318,6 @@ def __init__(self, file_list, **kwargs):
self.logger = logging.getLogger("")
self.logger.setLevel(logging.INFO)
self.input_types = None
if 'slots' in kwargs:
self.logger.warning('setting slots value is deprecated, '
'please use input_types instead.')
self.slots = kwargs['slots']
self.slots = input_types
self.should_shuffle = should_shuffle

true_table = [1, 't', 'true', 'on']
Expand Down Expand Up @@ -358,9 +353,19 @@ def __init__(self, file_list, **kwargs):
self.check = check
if init_hook is not None:
init_hook(self, file_list=file_list, **kwargs)

if 'slots' in outter_kwargs:
self.logger.warning('setting slots value is deprecated, '
'please use input_types instead.')
self.slots = outter_kwargs['slots']
if input_types is not None:
self.slots = input_types

if self.input_types is not None:
self.slots = self.input_types
assert self.slots is not None

assert self.slots is not None, \
"Data Provider's input_types must be set"
assert self.generator is not None

use_dynamic_order = False
Expand Down

0 comments on commit 4d81b36

Please sign in to comment.