Skip to content

Commit

Permalink
remove unused memsize_to_str and minor cleanups [pr] (tinygrad#9211)
Browse files Browse the repository at this point in the history
* fix edge cases in memsize_to_str()

Inputs <= 1 now return "0.00 B" for 0 and "1.00 B" for 1, avoiding an
IndexError. Also, memsize_to_str(1000) now returns "1.00 KB" instead of
"1000.00 B".

Replaced the list comprehension with a next(...) generator for conciseness
and efficiency.

* simplify code using idiomatic python

- Remove the unused `memsize_to_str()` function in helpers.
- Use a tuple for checking multiple string prefixes/suffixes.
- Avoid unnecessary list construction by using iterables directly.
- Check None in @diskcache to ensure proper caching of falsy values.

* revert generators back to list comprehension

Sometimes building list first could be faster. Keep it as is.
  • Loading branch information
ShikChen authored Feb 23, 2025
1 parent 81a71ae commit 05e3202
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def convert(name) -> Tensor:
disk_tensors: List[Tensor] = [model[name] for model in models]
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
return disk_tensors[0].to(device=device)
axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
axis = 1 if name.endswith((".attention.wo.weight", ".feed_forward.w2.weight")) else 0
lazy_tensors = [data.to(device=device) for data in disk_tensors]
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
return {name: convert(name) for name in {name: None for model in models for name in model}}
Expand Down
8 changes: 3 additions & 5 deletions tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s:str): return len(ansistrip(s))
Expand Down Expand Up @@ -191,8 +190,7 @@ def diskcache_clear():
def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
if CACHELEVEL < 1: return None
if isinstance(key, (str,int)): key = {"key": key}
conn = db_connection()
cur = conn.cursor()
cur = db_connection().cursor()
try:
res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
except sqlite3.OperationalError:
Expand All @@ -211,15 +209,15 @@ def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=Fals
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
_db_tables.add(table)
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
conn.commit()
cur.close()
return val

def diskcache(func):
def wrapper(*args, **kwargs) -> bytes:
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
if (ret:=diskcache_get(table, key)) is not None: return ret
return diskcache_put(table, key, func(*args, **kwargs))
return wrapper

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/viz/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def do_GET(self):
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif (url:=urlparse(self.path)).path == "/profiler":
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
elif (self.path.startswith("/assets/") or self.path.startswith("/lib/")) and '/..' not in self.path:
elif self.path.startswith(("/assets/", "/lib/")) and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
if url.path.endswith(".js"): content_type = "application/javascript"
Expand Down

0 comments on commit 05e3202

Please sign in to comment.