Skip to content

Commit

Permalink
CLANG=1 -> CPU=1 [pr]
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Feb 20, 2025
1 parent 3e22747 commit eebe660
Show file tree
Hide file tree
Showing 48 changed files with 138 additions and 136 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Test AMX tensor cores
run: |
DEBUG=2 CLANG=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
DEBUG=2 CPU=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
DEBUG=2 LLVM=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM (float)
run: DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
Expand Down
32 changes: 16 additions & 16 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
- name: Compile EfficientNet to C and test it
run: |
CLANG=1 PYTHONPATH="." python examples/compile_efficientnet.py > recognize.c
CPU=1 PYTHONPATH="." python examples/compile_efficientnet.py > recognize.c
clang -O2 recognize.c -lm -o recognize
cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
Expand Down Expand Up @@ -355,13 +355,13 @@ jobs:
llvm: 'true'
- name: Test ONNX (GPU)
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test ONNX (CLANG)
run: CLANG=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test ONNX (CPU)
run: CPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test ONNX (LLVM)
run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Run CLOUD=1 Test
run: |
CLOUDDEV=CLANG CLOUD=1 python3 test/test_tiny.py
CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py
CLOUDDEV=GPU CLOUD=1 python3 test/test_tiny.py
CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 test/test_tiny.py
- name: Test Optimization Helpers
Expand All @@ -378,7 +378,7 @@ jobs:
uses: ./.github/actions/process-replay

testmodels:
name: Models (llvm+clang+gpu)
name: Models (llvm+cpu+gpu)
runs-on: ubuntu-22.04
timeout-minutes: 10
steps:
Expand All @@ -395,8 +395,8 @@ jobs:
run: LLVM=1 python -m pytest -n=auto test/models --durations=20
- name: Test models (gpu)
run: GPU=1 python -m pytest -n=auto test/models --durations=20
- name: Test models (clang)
run: CLANG=1 python -m pytest -n=auto test/models --durations=20
- name: Test models (cpu)
run: CPU=1 python -m pytest -n=auto test/models --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay

Expand Down Expand Up @@ -431,8 +431,8 @@ jobs:
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
- name: Test LLVM=1 DEVECTORIZE=0
run: LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
- name: Test CLANG=1 DEVECTORIZE=0
run: CLANG=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
- name: Test CPU=1 DEVECTORIZE=0
run: CPU=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"

testwebgpu:
name: Linux (WebGPU)
Expand Down Expand Up @@ -464,7 +464,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [llvm, clang, gpu, ptx, amd, nv] #, triton]
backend: [llvm, cpu, gpu, ptx, amd, nv] #, triton]

name: Linux (${{ matrix.backend }})
runs-on: ubuntu-22.04
Expand All @@ -482,10 +482,10 @@ jobs:
amd: ${{ matrix.backend == 'amd' && 'true' }}
cuda: ${{ (matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv') && 'true' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nPTX=1\nMOCKGPU=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nMOCKGPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nPTX=1\nMOCKGPU=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nMOCKGPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','AMD','NV'], Device.DEFAULT"
PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CPU','CUDA','GPU','AMD','NV'], Device.DEFAULT"
DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
- name: Run pytest (not cuda or amd)
if: matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'amd' && matrix.backend != 'nv'
Expand Down Expand Up @@ -582,7 +582,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [metal, llvm, clang]
backend: [metal, llvm, cpu]
name: MacOS (${{ matrix.backend }})
runs-on: macos-15
timeout-minutes: 10
Expand All @@ -596,7 +596,7 @@ jobs:
deps: testing_minimal
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'metal' && 'METAL=1\nJIT=2'}}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'metal' && 'METAL=1\nJIT=2'}}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == '${{ matrix.backend }}'.upper(), Device.DEFAULT"
Expand All @@ -612,7 +612,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [llvm, clang]
backend: [llvm, cpu]

name: Windows (${{ matrix.backend }})
runs-on: windows-latest
Expand All @@ -627,7 +627,7 @@ jobs:
deps: testing_unit
- name: Set env
shell: bash
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1'}}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1'}}" >> $GITHUB_ENV
- name: Run unit tests
if: matrix.backend=='llvm'
run: python -m pytest -n=auto test/unit/ --ignore=test/unit/test_disk_tensor.py --ignore=test/unit/test_elf.py --ignore=test/unit/test_tar.py
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
tinygrad already supports numerous accelerators, including:

- [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
- [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
- [x] [CPU (C Code)](tinygrad/runtime/ops_cpu.py)
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
- [x] [METAL](tinygrad/runtime/ops_metal.py)
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
Expand Down
6 changes: 3 additions & 3 deletions docs/abstractions2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

print("******** first, the runtime ***********")

from tinygrad.runtime.ops_clang import ClangJITCompiler, MallocAllocator, CPUProgram
from tinygrad.runtime.ops_cpu import ClangJITCompiler, MallocAllocator, CPUProgram

# allocate some buffers
out = MallocAllocator.alloc(4)
Expand All @@ -34,7 +34,7 @@

print("******** second, the Device ***********")

DEVICE = "CLANG" # NOTE: you can change this!
DEVICE = "CPU" # NOTE: you can change this!

import struct
from tinygrad.dtype import dtypes
Expand Down Expand Up @@ -90,7 +90,7 @@

# schedule the computation as a list of kernels
sched, _, becomes_map = create_schedule_with_vars(out.sink())
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CPU
# NOTE: UOps are no longer mutable, the scheduler gives you a map to lookup which BUFFER the result was written to
out = becomes_map[out]

Expand Down
2 changes: 1 addition & 1 deletion docs/developer/runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The `Allocator` class is responsible for managing memory on the device. There is

The `Program` class is created for each loaded program. It is responsible for executing the program on the device. As an example, here is a `CPUProgram` implementation which loads program and runs it.

::: tinygrad.runtime.ops_clang.CPUProgram
::: tinygrad.runtime.ops_cpu.CPUProgram
options:
members: true

Expand Down
2 changes: 1 addition & 1 deletion docs/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ AMD | [1] | enable AMD backend
NV | [1] | enable NV backend
METAL | [1] | enable Metal backend (for Mac M1 and after)
METAL_XCODE | [1] | enable Metal using macOS Xcode SDK
CLANG | [1] | enable Clang backend
CPU | [1] | enable CPU Clang backend
LLVM | [1] | enable LLVM backend
BEAM | [#] | number of beams in kernel beam search
DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32
Expand Down
2 changes: 1 addition & 1 deletion docs/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from tinygrad import Device
print(Device.DEFAULT)
```

You will see `CUDA` here on a GPU instance, or `CLANG` here on a CPU instance.
You will see `CUDA` here on a GPU instance, or `CPU` here on a CPU instance.

## A simple model

Expand Down
4 changes: 2 additions & 2 deletions docs/runtime.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Runtimes

tinygrad supports various runtimes, enabling your code to scale across a wide range of devices. The default runtime can be automatically selected based on the available hardware, or you can force a specific runtime to be default using environment variables (e.g., `CLANG=1`).
tinygrad supports various runtimes, enabling your code to scale across a wide range of devices. The default runtime can be automatically selected based on the available hardware, or you can force a specific runtime to be default using environment variables (e.g., `CPU=1`).

| Runtime | Description | Requirements |
|---------|-------------|--------------|
Expand All @@ -10,6 +10,6 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra
| [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | M1+ Macs; Metal 3.0+ for `bfloat` support |
| [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | NVIDIA GPU with CUDA support |
| [GPU (OpenCL)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_gpu.py) | Accelerates computations using OpenCL on GPUs | OpenCL 2.0 compatible device |
| [CLANG (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_clang.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` |
| [CPU (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` |
| [LLVM (LLVM IR)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_llvm.py) | Runs on CPU using the LLVM compiler infrastructure | llvm libraries installed and findable |
| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.1.6). |
6 changes: 3 additions & 3 deletions examples/compile_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
if getenv("WEBGPU"):
safe_save(get_state_dict(model), (dirname / "net.safetensors").as_posix())
load_state_dict(model, safe_load(str(dirname / "net.safetensors")))
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
mode = "clang" if getenv("CPU", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
if getenv("CLANG", "") == "":
if getenv("CPU", "") == "":
ext = "js" if getenv("WEBGPU", "") != "" else "json"
with open(dirname / f"net.{ext}", "w") as text_file:
text_file.write(prg)
Expand Down Expand Up @@ -68,6 +68,6 @@
else printf("%s\\n", lbls[best_idx]);
}""")

# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
# CPU=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
# category : 281 (tabby, tabby cat) with 9.452788
print('\n'.join(cprog))
2 changes: 1 addition & 1 deletion examples/compile_tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# An example to compile a small Tensorflow model to extremely portable C code

import os, sys
os.environ["CLANG"] = '1'
os.environ["CPU"] = '1'
os.environ["JIT"] = '2'

import numpy as np
Expand Down
6 changes: 3 additions & 3 deletions examples/llm.c/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1"
from tinygrad import Device, nn, Tensor, dtypes, Variable
Device.DEFAULT = "CLANG"
Device.DEFAULT = "CPU"
from train_gpt2 import GPT, GPTConfig
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCounters, ansilen, to_function_name
from tinygrad.engine.realize import get_kernel, run_schedule
Expand Down Expand Up @@ -43,9 +43,9 @@
ast_dedup = dedup([si.ast for si in sched if si.ast.op is Ops.SINK])
srcs = {}
for ast in ast_dedup:
k = get_kernel(Device["CLANG"].renderer, ast)
k = get_kernel(Device["CPU"].renderer, ast)
k.linearize()
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
src = Device["CPU"].renderer.render(to_function_name(k.name), k.uops)
srcs[ast] = (k.name, src)
print("functions:", len(srcs))
used_buffers = dedup(flatten([si.bufs for si in sched]))
Expand Down
14 changes: 7 additions & 7 deletions examples/mlperf/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def receive_batch():

def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
return {
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"),
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CPU"),
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
}

def load_file(file: str):
Expand Down
14 changes: 7 additions & 7 deletions examples/mlperf/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ def get_mlperf_bert_model():

def get_fake_data_bert(BS:int):
return {
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"),
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"),
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CLANG"),
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CLANG"),
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CPU"),
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CPU"),
}
4 changes: 2 additions & 2 deletions extra/backends/clang_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.ops import Variable
from tinygrad.runtime.ops_clang import ClangProgram
from tinygrad.runtime.ops_cpu import ClangProgram
from tinygrad.renderer.cstyle import ClangRenderer
render_dtype = ClangRenderer().render_dtype

Expand All @@ -30,7 +30,7 @@ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], va
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
code.append("}")
if DEBUG >= 4: print("\n".join(code))
compiler = Device["CLANG"].compiler
compiler = Device["CPU"].compiler
assert compiler is not None
self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers

Expand Down
4 changes: 2 additions & 2 deletions extra/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tinygrad.dtype import dtypes
import json

EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"]

def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
Expand Down Expand Up @@ -191,7 +191,7 @@ def export_model_webgpu(functions, statements, bufs, weight_names, input_names,
"""

def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
state = get_state_dict(model)
Expand Down
2 changes: 1 addition & 1 deletion extra/gemm/tvm_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@
lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
#lin.hand_coded_optimizations()
lin.linearize()
from tinygrad.runtime.ops_clang import renderer
from tinygrad.runtime.ops_cpu import renderer
src = renderer("mmult", lin.uops)
print(src)
Loading

0 comments on commit eebe660

Please sign in to comment.