-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Tobi Dance #22
base: master
Are you sure you want to change the base?
Changes from 7 commits
20df0c1
9e7d4fa
8c60986
4c0bd92
1681fdd
4776ed4
9e72b37
3c1895f
0c6ea2a
8d5eddd
1ddc244
d343f31
91a611e
a6fda56
7f7659d
4d34dde
76d455d
f13330c
5155fe9
f63ab96
ab87fec
3dd5102
89f4ead
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,7 @@ class ComputationLibrary: | |
gather: Callable[[TensorType, TensorType, int], TensorType] = None | ||
gather_last: Callable[[TensorType, TensorType], TensorType] = None | ||
arange: Callable[[Optional[NumericType], NumericType, Optional[NumericType]], TensorType] = None | ||
zeros: Callable[["tuple[int]"], TensorType] = None | ||
zeros: Callable[["tuple[int,...]"], TensorType] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice! I had wrong assumptions about how to type subscript tuples. Learned something again. |
||
zeros_like: Callable[[TensorType], TensorType] = None | ||
ones: Callable[["tuple[int]"], TensorType] = None | ||
sign: Callable[[TensorType], TensorType] = None | ||
|
@@ -88,6 +88,9 @@ class ComputationLibrary: | |
dot: Callable[[TensorType, TensorType], TensorType] = None | ||
stop_gradient: Callable[[TensorType], TensorType] = None | ||
assign: Callable[[Union[TensorType, tf.Variable], TensorType], Union[TensorType, tf.Variable]] = None | ||
nan:TensorType=None | ||
isnan:Callable[[TensorType],bool]=None | ||
string = None | ||
|
||
|
||
class NumpyLibrary(ComputationLibrary): | ||
|
@@ -158,14 +161,16 @@ class NumpyLibrary(ComputationLibrary): | |
dot = np.dot | ||
stop_gradient = lambda x: x | ||
assign = LibraryHelperFunctions.set_to_value | ||
|
||
nan = np.nan | ||
isnan=np.isnan | ||
string=str | ||
|
||
class TensorFlowLibrary(ComputationLibrary): | ||
lib = 'TF' | ||
reshape = tf.reshape | ||
permute = tf.transpose | ||
newaxis = tf.newaxis | ||
shape = lambda x: x.get_shape() # .as_list() | ||
shape = tf.shape # tobi does not understand reason for this previous definition: # lambda x: x.get_shape() # .as_list() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No difference: If Using |
||
to_numpy = lambda x: x.numpy() | ||
to_variable = lambda x, dtype: tf.Variable(x, dtype=dtype) | ||
to_tensor = lambda x, dtype: tf.convert_to_tensor(x, dtype=dtype) | ||
|
@@ -228,6 +233,9 @@ class TensorFlowLibrary(ComputationLibrary): | |
dot = lambda a, b: tf.tensordot(a, b, 1) | ||
stop_gradient = tf.stop_gradient | ||
assign = LibraryHelperFunctions.set_to_variable | ||
nan=tf.constant(np.nan) | ||
isnan=tf.math.is_nan | ||
string=tf.string | ||
|
||
|
||
class PyTorchLibrary(ComputationLibrary): | ||
|
@@ -307,3 +315,6 @@ def gather_last_pytorch(a, index_vector): | |
dot = torch.dot | ||
stop_gradient = tf.stop_gradient # FIXME: How to imlement this in torch? | ||
assign = LibraryHelperFunctions.set_to_value | ||
nan=torch.nan | ||
isnan=torch.isnan | ||
string=lambda x: torch.ByteTensor(bytes(x,'utf8')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.