Skip to content


Repository files navigation

AWFUTILS: AWF's Utilities

A few utils for general python programming


This is very much code-in-progress. When I use it, I typically just put it as a submodule under whatever else I'm building, so I can easily bugfix awfutils as I do other work:

$ git submodule add


I love run time type checkers, particularly with JAX but by default they (OK, beartype doesn't) don't check statement-level annotations like these:

def foo(x : int, y : float):
  z : int = x * y # This should error, but doesn't
  w : float = z * 3.2
  return w

foo(3, 1.3)

With the awfutils typecheck decorator, they can...

def foo(x : int, y : float):
  z : int = x * y # Now it raises TypeError: z not of type int
  w : float = z * 3.2
  return w

foo(3, 1.3) # Error comes from this call

This works by AST transformation, replacing the function foo above with the function

def foo_typecheck_wrap(x: int, y: float):
    assert isinstance(x, int), 'x not of type int'
    assert isinstance(y, float), 'y not of type float'
    z: int = x * y
    assert isinstance(z, int), 'z not of type int'
    w: float = z * 3.2
    assert isinstance(w, float), 'w not of type float'
    return w

Because it is AST transformation, it is basically literally the above code, which you can see with the optional argument show_src=True

@functools.partial(typecheck, show_src=True)
def foo(x : int, y : float):
  z : int = x * y # Now it does
  w : float = z * 3.2
  return w

Tensor pretty-printing

It's handy to print a one-line summary of the contents of a tensor.

a = np.random.rand(2, 1, 3) - 0.2
print(au.ndarray_str(a**6 * 1e7))


f64[2x1x3] 10^5 x [0.181 5.555 1.721 1.462 0.001 0.000]

Easy to read, even with only 3 significant figures (see the leading 10^5x).

For larger tensors, show percentiles:

f32[22x11x33] 10^-7 x Percentiles{0.002|0.493|2.470|4.958|7.490|9.434|9.996}
              ^scale            0 (min)|   5%|  25%|  50%|  75%|  95%|100% (max)

MkSweep: Simple sweeps via makefile

What is a sweep? Generally it boils down to a list of commands, perhaps the same core program with different command-line arguments:

python --lr=0.00003
python --lr=0.0001
python --lr=0.0003
python --lr=0.001

In managing such a list, a few properties we might like are:

  • Interruptable: if a job fails, or if a machine fails, we can easily resume the sweep without re-running already-finished jobs. Similarly, if we change the sweep definition slightly, we don't want to rerun jobs that were in previous sweeps.

  • Flexible: we can define complex combinations of configurations, rather than simple grids.

  • Parallel: we can easily run jobs in parallel up to the resource limits of available hardware

  • Portable: we can easily set up a sweep without installing a lot of infrastructure

These properties are reminiscent of those one might want in a large software build system, so MkSweep simply puts the series of commands into a classic Makefile, which can be called in order to run them. Each command is given an output directory which is a hash of its command line, so that re-running the same command will re-use the outputs. The output's "done" marker is not updated until the command is successfully completed, so an interrupted command will re-run when the sweep is restarted until it completes successfully.

To specify which commands to run, we use a python script rather than any sort of YAML file, for reasons which, if not apparent immediately, should become reasonable when we see more complex examples. For the above simple learning rate sweep, the definition is:

from awfutils import MkSweep
with MkSweep("mytmp") as ms: # Sweep will write into directory 'mytmp'
  # The sweep definition begins here:
  for lr in (0.00003, 0.0001, 0.0003, 0.001):
      ms.add(f"python --lr={lr}")

which, when run, creates a Makefile in folder mytmp so that

make -f mytmp/Makefile

will execute any undone commands, leaving outputs and logs in the subfolders of mytmp.

If I want to run multiple commmands in parallel, up to a maximum of N concurrent jobs, I can just use make -j<N>

make -f mytmp/Makefile -j4 # Run up to 4 jobs in parallel

If we edit the sweep definition to include an extra lr, and a baseline run:

  ms.add(f"python --no-lr") # "No LR" is some baseline config
  for lr in (0.00003, 0.0001, 0.0003, 0.001, 0.003):
      ms.add(f"python --lr={lr}")

then re-running make -f mytmp/Makefile will just run the parts that have not been marked as done, which in this case would mean the new command for lr=0.003, and the new "No LR" command.

If we want to re-run everything (for example the code changed), then we can just remove the mytmp folders, or just make a new sweep folder e.g. sweeps/run2

Why python, not YAML?

Suppose you have a program with parameters "alpha" and "beta", and you're testing the idea that you should set beta to 1-alpha, while existing work sets it either to .99 or .999.

With a traditional sweeping infrastructure, configured via YAML files, it might be hard to encode the special rule that you don't need to run 1-alpha if it equals beta. In python you can just write

  for alpha in [1e-4, 3e-4, 1e-3]:
    for beta in set([0.99, 0.999, 1-alpha]): # deduplicate betas using python `set`
      ms.add(f"python --alpha={alpha} --beta={beta}")

Now in this case, the command hashing would not have run both commands anyway, but other constraints, e.g. $\alpha \le \beta \le \alpha^2$ are easily handled because the specification is all in Python:

  for alpha in np.arange(1.5, 2, 0.1):
    for beta in np.linspace(alpha, alpha** 2, 5):
      ms.add(f"python --alpha={alpha} --beta={beta}")

I know makefiles, tell me more?

We use make, rather than any more sophisticated build system because the complex logic of command assembly lives in Python, meaning the makefile can be very simple. Here's a snippet of the makefile that's generated.

mytmp/926fa3d0/done.txt: # if done.txt doesn't exist
	python --lr=0.00003 >& mytmp/926fa3d0/log.txt
	touch $@ # Create the "done" file

I'll probably switch it to Ninja if I find a use case that I can't hack in make.


A distributed argument parser, like absl flags, but a little more convenient and less stringy. If you want a config value anywhere in your program, just declare an Arg nearby (at top level), and use it:

from jaxutils.Arg import Arg

tau = Arg("mqh-tau", 0.01, "Scale factor for softmax in my_quick_hack")

def my_quick_hack(xs, qs):
    softmax(tau() * xs @ qs.T)

Now, even if my_quick_hack is far down the call tree from main, you can quickly try some values of tau by just running

$ python -mqh-tau 0.0001

More conventionally, Arg is also useful in main:

def main():
    param1 = Arg("p1", default=34, help="Set first parameter")
    switch = Arg("s", False, "Turn on the switch")

    if switch(): # This is where arg parsing first happens. See notes.

    # To see which args have been set:

It's all a thin wrapper around argparse, so inherits a bunch of goodness from there, but avoids a lot of long-chain plumbing when my_quick_hack is far down the call tree from main.

And yes, these are global variables. This is absolutely reasonable, because the command line is a global resource. You'll see that a little bit of namespacing has been illustrated above, where tau was given the flag mqh-tau. Feel free to formalize that as much as you like.

Parsing happens the first time any of the Args is read, and is then cached.

There is a potential gotcha if you want to act on an arg during module load time, e.g. at the top level:

jit_foo = Arg("jit-foo", False, "Run JIT on foo")
if jit_foo():   # Prefer jit_foo.peek() for load-time checks
    foo = jit(foo)

The call to jit_foo() will know only about arguments declared before that point, so a call to --help will produce too short a list. This is remedied later but is better avoided:

jit_foo = Arg("jit-foo", False, "Run JIT on foo")
if jit_foo.peek(): # Just check for this arg in sys.argv
    foo = jit(foo)

PyTree Utils

Various smoothers for torch.utils._pytree.

Given a nest of lists and tuples, perform various useful funcions.

For example, given the object val as follows:

val = ( # tuple
        [ # list
          (np.random.rand(2, 3), "b", np.random.rand(12, 13)),

Then, val) will make these six calls to foo:

foo(np.random.rand(2, 3))
foo(np.random.rand(12, 13))

And given a numeric-only pytree, e.g.

val = ( # tuple
        [ # list
          (np.random.rand(2, 3), np.random.rand(12, 13))

Then arithmetic can be performed on PyTrees, e.g.

(PyTree(val) + val) * val == 2 * PyTree(val) * val


AWF's Utilities







No releases published


No packages published
