Skip to content

Commit

Permalink
Huge bugfix: return only desired base dtypes in data ops -> all ops
Browse files Browse the repository at this point in the history
  • Loading branch information
tostenzel committed Jan 4, 2024
1 parent cf08086 commit 7c3785b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
13 changes: 6 additions & 7 deletions applications/learn_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ def parse(file):

# parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
BASE = os.path.dirname(__file__) + "/datasets"

X_train = parse(BASE + "/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28 * 28)).astype(np.float32)
Y_train = parse(BASE + "/mnist/train-labels-idx1-ubyte.gz")[8:]
Y_train = parse(BASE + "/mnist/train-labels-idx1-ubyte.gz")[8:].astype(np.int32)
X_test = parse(BASE + "/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28 * 28)).astype(np.float32)
Y_test = parse(BASE + "/mnist/t10k-labels-idx1-ubyte.gz")[8:]
Y_test = parse(BASE + "/mnist/t10k-labels-idx1-ubyte.gz")[8:].astype(np.int32)
if for_convolution:
X_train = X_train.reshape(-1, 1, 28, 28)
X_test = X_test.reshape(-1, 1, 28, 28)
Expand Down Expand Up @@ -47,9 +47,9 @@ def train_and_evaluate_mnist(num_steps=100, batch_size=128, learning_rate=0.001)

with Tensor.train():
for step in range(num_steps):
samp = np.random.randint(0, X_train.shape[0], size=(batch_size))
samp = np.random.randint(0, X_train.shape[0], size=(batch_size)).astype(np.int32)
xb, yb = Tensor(X_train[samp], requires_grad=False), Tensor(Y_train[samp])

out = model(xb)
loss = out.sparse_categorical_crossentropy(yb)
opt.zero_grad()
Expand All @@ -72,8 +72,7 @@ def train_and_evaluate_mnist(num_steps=100, batch_size=128, learning_rate=0.001)
return test_accuracy



if __name__ == "__main__":
# Only execute if this script is run directly
test_accuracy = train_and_evaluate_mnist()
print(f"Test Acc: {test_accuracy:.3f}")
print(f"Test Acc: {test_accuracy:.3f}")
27 changes: 14 additions & 13 deletions edugrad/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def shape(self) -> Tuple[int, ...]:
def __repr__(self) -> str:
"""Return a string representation of the TensorData object."""
return f"<TensorData shape={self.shape} dtype={self.dtype}>"

def __call__(self):
return self.data

Expand Down Expand Up @@ -80,22 +80,22 @@ def cast(self, dtype: DType, bitcast: bool = False) -> "TensorData":
def elementwise(self, op, *srcs: "TensorData"):
"""Perform a unary, binary, or ternary elementwise operation on the data."""
unary_ops = {
UnaryOps.NEG: np.negative,
UnaryOps.EXP2: np.exp2,
UnaryOps.LOG2: np.log2,
UnaryOps.SIN: np.sin,
UnaryOps.SQRT: np.sqrt,
UnaryOps.NEG: lambda x: np.negative(x).astype(np.float32),
UnaryOps.EXP2: lambda x: np.exp2(x).astype(np.float32),
UnaryOps.LOG2: lambda x: np.log2(x).astype(np.float32),
UnaryOps.SIN: lambda x: np.sin(x).astype(np.float32),
UnaryOps.SQRT: lambda x: np.sqrt(x).astype(np.float32),
}
binary_ops = {
BinaryOps.ADD: np.add,
BinaryOps.SUB: np.subtract,
BinaryOps.MUL: np.multiply,
BinaryOps.DIV: np.divide,
BinaryOps.MAX: np.maximum,
BinaryOps.CMPLT: np.less,
BinaryOps.ADD: lambda x, y: np.add(x, y).astype(np.float32),
BinaryOps.SUB: lambda x, y: np.subtract(x, y).astype(np.float32),
BinaryOps.MUL: lambda x, y: np.multiply(x, y).astype(np.float32),
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(np.float32),
BinaryOps.MAX: lambda x, y: np.maximum(x, y).astype(np.float32),
BinaryOps.CMPLT: lambda x, y: np.less(x, y).astype(np.bool_),
}
ternary_ops = {
TernaryOps.WHERE: np.where,
TernaryOps.WHERE: lambda x, y, z: np.where(x, y, z).astype(np.float32),
}

if op in unary_ops:
Expand All @@ -107,6 +107,7 @@ def elementwise(self, op, *srcs: "TensorData"):
else:
raise NotImplementedError(f"Operation {op} not implemented or wrong number of sources")


def reduce(self, op, new_shape):
"""Perform reduction operations on the data."""
if DEBUG >= 1:
Expand Down

0 comments on commit 7c3785b

Please sign in to comment.