diff --git a/tools/.gitignore b/tools/.gitignore index 90474d8..affe7c9 100644 --- a/tools/.gitignore +++ b/tools/.gitignore @@ -1,2 +1,3 @@ lib capture +strace-log-watch diff --git a/tools/Makefile b/tools/Makefile index 1cc34ac..b2281bd 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -1,3 +1,5 @@ +CC = gcc + CFLAGS = -Wall -Werror -O2 LDFLAGS = LOG_LEVEL ?= WARN @@ -5,7 +7,7 @@ EXTRA_CFLAGS ?= -DLOG_LEVEL=LOG_$(LOG_LEVEL) EXTRA_LDFLAGS ?= .PHONY: all -all: bpf lib capture +all: bpf lib capture strace-log-watch .PHONY: bpf bpf: @@ -14,10 +16,10 @@ bpf: lib: lib.c $(CC) $(CFLAGS) $(EXTRA_CFLAGS) -o "$@" "$<" $(LDFLAGS) -ldl $(EXTRA_LDFLAGS) -capture: capture.c +%: %.c $(CC) $(CFLAGS) $(EXTRA_CFLAGS) -o "$@" "$<" $(LDFLAGS) $(EXTRA_LDFLAGS) .PHONY: clean clean: - rm -f lib capture + rm -f lib capture strace-log-watch $(MAKE) -C bpf clean diff --git a/tools/capture.c b/tools/capture.c index 3f4e351..6ce729d 100644 --- a/tools/capture.c +++ b/tools/capture.c @@ -11,6 +11,8 @@ struct options { const char* stdout_fn; int silence_stderr; const char* stderr_fn; + + const char* returncode_fn; }; static void print_usage(int fd, const char* prog) @@ -22,6 +24,7 @@ static void print_usage(int fd, const char* prog) dprintf(fd, " -o FILE write stdout to FILE\n"); dprintf(fd, " -E do not output stderr\n"); dprintf(fd, " -e FILE write stderr to FILE\n"); + dprintf(fd, " -r FILE write returncode to FILE (-SIG if signaled)\n"); dprintf(fd, " -h print this message\n"); } @@ -33,7 +36,7 @@ static int parse_options(struct options* o, int argc, char* argv[]) o->stderr_fn = "/dev/null"; int res; - while((res = getopt(argc, argv, "Oo:Ee:h-")) != -1) { + while((res = getopt(argc, argv, "Oo:Ee:r:h-")) != -1) { switch(res) { case 'O': o->silence_stdout = 1; @@ -49,6 +52,10 @@ static int parse_options(struct options* o, int argc, char* argv[]) o->stderr_fn = strdup(optarg); CHECK_NOT(o->stderr_fn, NULL, "strdup(%s)", optarg); break; + case 'r': + o->returncode_fn = strdup(optarg); + CHECK_NOT(o->returncode_fn, NULL, "strdup(%s)", optarg); + break; case '-': goto opt_end; case 'h': @@ -257,5 +264,27 @@ int main(int argc, char* argv[]) } info("child returncode: %d", state.returncode); + + if(o.returncode_fn) { + debug("writing returncode to: %s", o.returncode_fn); + int fd = open(o.returncode_fn, O_WRONLY|O_EXCL|O_CREAT); + CHECK(fd, "open(%s)", o.returncode_fn); + + char buf[48]; + ssize_t s = snprintf(LIT(buf), "%d", state.returncode); + if(s >= sizeof(buf)) { + failwith("buffer too small"); + } + + size_t i = 0; + while(i != s) { + int r = write(fd, &buf[i], s-i); + CHECK(r, "write(%s)", o.returncode_fn); + i += r; + } + + int r = close(fd); CHECK(r, "close"); + } + return state.returncode != 0; } diff --git a/tools/strace-log-watch.c b/tools/strace-log-watch.c new file mode 100644 index 0000000..833de8b --- /dev/null +++ b/tools/strace-log-watch.c @@ -0,0 +1,641 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LIBR_IMPLEMENTATION +#include "r.h" + +// TODO: move to libr +inline int is_digit(char c) +{ + return '0' <= c && c <= '9'; +} + +int devnull(int flags) +{ + flags |= O_CLOEXEC; + int r = open("/dev/null", flags); + CHECK(r, "open(/dev/null, %d)", flags); + return r; +} + +void devnull_to(int fd) +{ + int n = devnull(fd); + int r = dup2(fd, n); CHECK(r, "dup2(.., %d)", n); +} + +#ifndef pidfd_open +int pidfd_open(pid_t pid, unsigned int flags) +{ + return syscall(__NR_pidfd_open, pid, flags); +} +#endif + +struct buffer { + uint8_t* bs; + size_t L; + size_t i; +}; + +void buffer_init(struct buffer* b, size_t L) +{ + b->bs = calloc(L, 1); CHECK_MALLOC(b->bs); + b->L = L; + b->i = 0; +} + +void buffer_extend(struct buffer* b) +{ + size_t L = b->L << 1; + debug("extending buffer: %zu -> %zu", b->L, L); + uint8_t* bs = calloc(L, 1); CHECK_MALLOC(bs); + memcpy(bs, b->bs, b->i); + free(b->bs); + b->bs = bs; + b->L = L; +} + +void buffer_free(struct buffer* b) +{ + free(b->bs); + b->bs = NULL; + b->L = 0; + b->i = 0; +} + +int buffer_full(const struct buffer* b) +{ + return b->i == b->L; +} + +size_t buffer_available(const struct buffer* b) +{ + return b->L - b->i; +} + +int buffer_empty(const struct buffer* b) +{ + return b->i == 0; +} + +ssize_t buffer_find(const struct buffer* b, uint8_t c) +{ + for(size_t i = 0; i < b->i; i++) { + if(b->bs[i] == c) { + return i; + } + } + return -1; +} + +void buffer_read(struct buffer* b, int fd) +{ + while(1) { + if(buffer_full(b)) { + buffer_extend(b); + } + + const size_t l = b->L - b->i; + ssize_t s = read(fd, &b->bs[b->i], l); + if(s == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + return; + } + CHECK(s, "read(%d, .., %zu)", fd, l); + + debug("read(%d) = %zd", fd, s); + if(s == 0) { + break; + } + + b->i += s; + } +} + +void buffer_write(struct buffer* b, int fd) +{ + while(1) { + if(buffer_empty(b)) { + return; + } + + const size_t l = b->i; + ssize_t s = write(fd, b->bs, l); + if(s == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + return; + } + CHECK(s, "write(%d, .., %zu)", fd, l); + + debug("write(%d) = %zd", fd, s); + + size_t j = s - l; + memmove(b->bs, &b->bs[s], j); + b->i = j; + } +} + +void buffer_append(struct buffer* b, void* p, size_t len) +{ + while(buffer_available(b) < len) { + buffer_extend(b); + } + + memcpy(&b->bs[b->i], p, len); + b->i += len; +} + +void buffer_move(struct buffer* in, struct buffer* out, size_t len) +{ + if(in->i < len) { + failwith("asked to move more than available: %zu < %zu", in->i, len); + } + + while(buffer_available(out) < len) { + buffer_extend(out); + } + + memcpy(&out->bs[out->i], in->bs, len); + + out->i += len; + + size_t j = in->i - len; + memmove(in->bs, &in->bs[len], j); + in->i = j; +} + +static void split_path(const char* path, char** dir, char** base) +{ + char buf0[PATH_MAX]; + strncpy(buf0, path, sizeof(buf0)-1); + buf0[sizeof(buf0)-1] = '\0'; + *dir = strdup(dirname(buf0)); CHECK_MALLOC(*dir); + + strncpy(buf0, path, sizeof(buf0)-1); + buf0[sizeof(buf0)-1] = '\0'; + *base = strdup(basename(buf0)); CHECK_MALLOC(*base); +} + +struct options { + pid_t pid; + const char* pattern; +}; + +static void print_usage(int fd, const char* prog) +{ + dprintf(fd, "usage: %s [OPTION]... PATTERN\n", prog); + dprintf(fd, "\n"); + dprintf(fd, "options:\n"); + dprintf(fd, " -p PID terminate after PID dies\n"); + dprintf(fd, " -h print this message\n"); +} + +static void parse_options(struct options* o, int argc, char* argv[]) +{ + memset(o, 0, sizeof(*o)); + o->pid = -1; + + int res; + while((res = getopt(argc, argv, "p:h-")) != -1) { + switch(res) { + case 'p': { + int r = sscanf(optarg, "%d", &o->pid); + if(r != 1) { + dprintf(2, "error: unable to parse pid: %s", optarg); + exit(1); + } + break; + } + case '-': + goto opt_end; + case 'h': + default: + print_usage(res == 'h' ? 1 : 2, argv[0]); + exit(res == 'h' ? 0 : 1); + } + } +opt_end: + + if(optind < argc) { + o->pattern = argv[optind]; + debug("pattern: %s", o->pattern); + } else { + dprintf(2, "error: no input file specified\n"); + print_usage(2, argv[0]); + exit(1); + } +} + +struct trace { + char* path; + pid_t pid, tail; + int fd; + struct buffer buf; + struct trace* next; +}; + +struct { + sigset_t sm; + + char* dir; + char* prefix; + + int running; + + int wd; + + struct buffer out; + + struct pollfd* fds; + size_t fds_n, fds_N; + + struct trace* traces; +} state; + +static size_t trace_n(void) +{ + size_t i = 0; + struct trace* t = state.traces; + while(t != NULL) { + i += 1; + t = t->next; + } + return i; +} + +static struct trace* trace_add(void) +{ + struct trace** t = &state.traces; + while(*t != NULL) t = &(*t)->next; + + struct trace* s = calloc(1, sizeof(struct trace)); + CHECK_MALLOC(s); + *t = s; + return s; +} + +static void trace_init(struct trace* t, const char* dir, const char* name) +{ + debug("dir=%s name=%s", dir, name); + + char path[PATH_MAX]; + int r = snprintf(LIT(path), "%s/%s", dir, name); + if(r >= sizeof(path)) { + failwith("truncated path: %s/%s", dir, name); + } + + debug("path=%s", path); + t->path = strdup(path); CHECK_MALLOC(t->path); + + size_t l = strlen(name); + if(l == 0) { + failwith("unable to parse pid: empty filename"); + } + size_t i = l - 1; + while(1) { + char c = name[i]; + if(is_digit(c)) { + if(i == 0) { + failwith("unable to parse pid: %s (missing '.')", name); + } else { + i -= 1; + } + } else if(c == '.') { + break; + } else { + failwith("unable to parse pid: %s (expected digit or '.')", name); + } + } + r = sscanf(&name[i], ".%u", &t->pid); + if(r != 1) { + failwith("unable to parse pid: %s (unexpected format)", name); + } + debug("pid=%d", t->pid); + + buffer_init(&t->buf, 1<<8); + + int pipefd[2]; + r = pipe2(pipefd, O_CLOEXEC | O_NONBLOCK); CHECK(r, "pipe2"); + + t->tail = fork(); CHECK(t->tail, "fork"); + if(t->tail == 0) { + r = sigprocmask(SIG_UNBLOCK, &state.sm, NULL); + CHECK(r, "sigprocmask"); + + devnull_to(0); + r = dup2(pipefd[1], 1); CHECK(r, "dup2(.., 1)"); + + r = execlp("tail", "tail", "--lines=+1", "--follow", path, NULL); + CHECK(r, "execlp(tail -n+1 -f %s)", path); + } + + info("spawned %d: tail -n+1 -f %s", t->tail, path); + + r = close(pipefd[1]); CHECK(r, "close"); + t->fd = pipefd[0]; +} + +static void trace_close(struct trace* t) +{ + debug("pid=%d tail=%d", t->pid, t->tail); + + int r = close(t->fd); CHECK(r, "close"); + buffer_free(&t->buf); +} + +static void handle_inotifyfd(int ifd) +{ + char buf[sizeof(struct inotify_event) + NAME_MAX + 1]; + ssize_t s = read(ifd, &buf, sizeof(buf)); + if(s == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + return; + } + CHECK(s, "read"); + if(s < sizeof(struct inotify_event)) { + failwith("unexpected partial read"); + } + struct inotify_event* e = (struct inotify_event*)buf; + const size_t l = sizeof(struct inotify_event) + e->len; + if(s < l) { + failwith("unexpected read length: %zd != %zu", s, l); + } + + if(e->wd != state.wd) { + failwith("unexpected watch descriptor: %d != %d", e->wd, state.wd); + } + + if(e->len > 0) { + size_t L = strlen(state.prefix); + if(strncmp(state.prefix, e->name, L) == 0) { + info("matched file: %s/%s", state.dir, e->name); + + struct trace* t = trace_add(); + trace_init(t, state.dir, e->name); + } else { + debug("ignoring new file: %s/%s", state.dir, e->name); + } + } + + handle_inotifyfd(ifd); +} + +static void graceful_shutdown(const char* reason, int sig) +{ + debug("initiating graceful shutdown: %s", reason); + + struct trace* t = state.traces; + while(t) { + int r = kill(t->tail, sig); + CHECK(r, "kill(%d, %s)", t->tail, strsignal(sig)); + debug("kill(%d, %s)", t->tail, strsignal(sig)); + t = t->next; + } + + state.running = 0; +} + +static void handle_signalfd(int sfd) +{ + struct signalfd_siginfo si; + ssize_t s = read(sfd, &si, sizeof(si)); + if(s == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + return; + } + CHECK(s, "read"); + if(s != sizeof(si)) { + failwith("unexpected partial read"); + } + + if(si.ssi_signo == SIGCHLD) { + while(1) { + int ws; + pid_t r = waitpid(-1, &ws, WNOHANG); + if(r == -1 && (errno == ECHILD)) { + break; + } + CHECK(r, "wait"); + if(r == 0) { + break; + } + + if(WIFEXITED(ws)) { + info("child (%d) exited: %d", r, WEXITSTATUS(ws)); + } else if(WIFSIGNALED(ws)) { + info("child (%d) signaled: %s", r, strsignal(WTERMSIG(ws))); + } else { + failwith("unexpected waitpid (%d) status: %d", r, ws); + } + + struct trace** t = &state.traces; + int n = 0; + while(*t) { + if((*t)->tail == r) { + trace_close(*t); + n += 1; + *t = (*t)->next; + break; + } + + t = &(*t)->next; + } + + if(n == 0) { + failwith("unexpected pid: %d", r); + } + } + } else if(si.ssi_signo == SIGINT || si.ssi_signo == SIGTERM) { + int sig = si.ssi_signo; + const char* sigstr = strsignal(sig); + info("%s", sigstr); + graceful_shutdown(sigstr, sig); + } else { + failwith("unexpected signal: %s", strsignal(si.ssi_signo)); + } + + handle_signalfd(sfd); +} + +static struct trace* find_trace_from_fd(int fd) +{ + struct trace* p = state.traces; + while(p) { + if(p->fd == fd) { + return p; + } + p = p->next; + } + failwith("unable to resolve fd: %d", fd); +} + +static void handle_tail(int fd) +{ + struct trace* t = find_trace_from_fd(fd); + buffer_read(&t->buf, fd); + + ssize_t i = buffer_find(&t->buf, '\n'); + if(i >= 0) { + char buf[48]; + int r = snprintf(LIT(buf), "%d ", t->pid); + if(r >= sizeof(buf)) { + failwith("truncated"); + } + buffer_append(&state.out, buf, r); + + buffer_move(&t->buf, &state.out, i + 1); + } +} + +static void expand_fds(void) +{ + size_t N = state.fds_N << 1; + debug("expanding fds: %zu -> %zu", state.fds_N, N); + + struct pollfd* p = calloc(N, sizeof(struct pollfd)); + CHECK_MALLOC(p); + + memcpy(p, state.fds, sizeof(struct pollfd)*state.fds_n); + + free(state.fds); + + state.fds = p; + state.fds_N = N; +} + +int main(int argc, char* argv[]) +{ + struct options o; + parse_options(&o, argc, argv); + + memset(&state, 0, sizeof(state)); + + int wfd = -1; + if(o.pid >= 0) { + info("waiting for %d", o.pid); + wfd = pidfd_open(o.pid, O_NONBLOCK); // TODO: PIDFD_NONBLOCK + CHECK(wfd, "pidfd_open(%d)", o.pid); + } + + debug("pattern: %s", o.pattern); + split_path(o.pattern, &state.dir, &state.prefix); + debug("dir: %s", state.dir); + debug("prefix: %s", state.prefix); + + buffer_init(&state.out, 1024); + + int ifd = inotify_init1(IN_NONBLOCK | IN_CLOEXEC); + CHECK(ifd, "inotify_init1"); + + state.wd = inotify_add_watch(ifd, state.dir, IN_CREATE); + CHECK(state.wd, "inotify_add_watch(%s, IN_CREATE)", state.dir); + + sigemptyset(&state.sm); + sigaddset(&state.sm, SIGINT); + sigaddset(&state.sm, SIGTERM); + sigaddset(&state.sm, SIGCHLD); + int sfd = signalfd(-1, &state.sm, SFD_NONBLOCK | SFD_CLOEXEC); + CHECK(sfd, "signalfd"); + + int r = sigprocmask(SIG_BLOCK, &state.sm, NULL); + CHECK(r, "sigprocmask"); + + state.running = 1; + state.fds_N = 1<<4; + size_t M = 3 + (wfd >= 0); + assert(state.fds_N >= M); + state.fds = calloc(state.fds_N, sizeof(struct pollfd)); + CHECK_MALLOC(state.fds); + state.fds_n = M; + + size_t sfd_i = 0; + assert(sfd_i < M); + state.fds[sfd_i].fd = sfd; + state.fds[sfd_i].events = POLLIN; + + size_t ifd_i = 1; + assert(ifd_i < M); + state.fds[ifd_i].fd = ifd; + state.fds[ifd_i].events = POLLIN; + + size_t ofd_i = 2; + assert(ofd_i < M); + state.fds[ofd_i].fd = 1; + state.fds[ofd_i].events = 0; + + ssize_t wfd_i = wfd >= 0 ? 3 : -1; + assert((wfd < 0) || wfd_i < M); + if(wfd_i) { + state.fds[wfd_i].fd = wfd; + state.fds[wfd_i].events = POLLIN; + } + + while(state.running || state.traces != NULL) { + if(buffer_empty(&state.out)) { + state.fds[ofd_i].events = 0; + } else { + state.fds[ofd_i].events = POLLOUT; + } + + const size_t n = trace_n(); + while(n + M > state.fds_N) { + expand_fds(); + } + + struct trace* t = state.traces; + size_t i = M; + while(t != NULL) { + state.fds[i].fd = t->fd; + state.fds[i].events = POLLIN; + + i += 1; + t = t->next; + } + + debug("polling: %zu", i); + int r = poll(state.fds, i, -1); + CHECK(r, "poll"); + debug("poll: %d", r); + + for(size_t j = M; j < i; j++) { + handle_tail(state.fds[j].fd); + state.fds[j].revents ^= POLLIN; + } + + if(state.fds[sfd_i].revents & POLLIN) { + handle_signalfd(state.fds[sfd_i].fd); + state.fds[sfd_i].revents ^= POLLIN; + } + + if(state.fds[ifd_i].revents & POLLIN) { + handle_inotifyfd(state.fds[ifd_i].fd); + state.fds[ifd_i].revents ^= POLLIN; + } + + if(state.fds[ofd_i].revents & POLLOUT) { + buffer_write(&state.out, state.fds[ofd_i].fd); + state.fds[ofd_i].revents ^= POLLOUT; + } + + if(wfd_i && state.fds[wfd_i].revents & POLLIN) { + graceful_shutdown("waited for pid died", SIGINT); + state.fds[wfd_i].revents ^= POLLIN; + } + + for(size_t i = 0; i < state.fds_n; i++) { + if(state.fds[i].revents != 0) { + failwith("unhandled events: fds[%zu].revents = %d", + i, state.fds[i].revents); + } + } + } + + debug("bye"); + + return 0; +} diff --git a/tools/test-runner b/tools/test-runner index 366078a..d7772be 100755 --- a/tools/test-runner +++ b/tools/test-runner @@ -10,6 +10,10 @@ import subprocess import sys import tempfile +TOOLS = os.environ.get("TOOLS", os.path.dirname(os.path.realpath(__file__))) +CAPTURE = os.environ.get("CAPTURE", os.path.join(TOOLS, "capture")) +WATCH = os.environ.get("WATCH", os.path.join(TOOLS, "strace-log-watch")) + def parse_args(): parser = argparse.ArgumentParser(description="Yet another test runner") @@ -90,60 +94,67 @@ if __name__ == "__main__": res["input"] = input_ logger.debug(f"test input: {input_}") - cmdline = [] - if args.trace: - trace = tempfile.NamedTemporaryFile() - cmdline += ["strace", "--follow-forks", "--output-append-mode", "--output", trace.name] - else: - trace = None - cmdline.append(sut) - cmdline += t.get("args", []) - cmdline.append(input_) - res["cmdline"] = cmdline - - res["when"] = datetime.datetime.now().astimezone().isoformat(timespec="seconds") - - res["result"] = True - res["messages"] = [] + with tempfile.TemporaryDirectory() as tmp: + stdout = os.path.join(tmp, "stdout") + stderr = os.path.join(tmp, "stderr") + returncode = os.path.join(tmp, "returncode") - try: + cmdline = [CAPTURE, "-o", stdout, "-e", stderr, "-r", returncode] if args.silent: - p = subprocess.run(cmdline, capture_output=True, cwd=root, timeout=args.timeout) - res["stdout"] = p.stdout.decode("UTF-8") - res["stderr"] = p.stderr.decode("UTF-8") + cmdline += ["-O", "-E"] + cmdline += ["--"] + + if args.trace: + trace = os.path.join(tmp, "trace") + cmdline += [ + "strace", + "--absolute-timestamps=precision:us", + "--follow-forks", "--output-separately", + "--output", trace, + ] else: - with tempfile.TemporaryDirectory() as tmp: - run = os.path.join(tmp, "run.sh") - with open(run, "w") as f: - f.write(""" -exec 1> >(tee "$1") -(tail -F "$3" 1>&2 2>/dev/null) & -exec 2> >(tee "$2" 1>&2) -shift 3 -exec "$@" -""" -) - stdout = os.path.join(tmp, "stdout") - stderr = os.path.join(tmp, "stderr") - p = subprocess.run(["bash", run, stdout, stderr, trace.name if trace else "/dev/null"] + cmdline, cwd=root, timeout=args.timeout) - with open(stdout, "rb") as f: - res["stdout"] = f.read().decode("UTF-8") - with open(stderr, "rb") as f: - res["stderr"] = f.read().decode("UTF-8") - - res["returncode"] = p.returncode - expected_exit = t.get("exit", 0) - if isinstance(expected_exit, str): - expected_exit = -int(signal.__dict__[expected_exit]) - res["expected_returncode"] = expected_exit - except subprocess.TimeoutExpired: - res["timeout"] = args.timeout - res["result"] = False - res["messages"].append("process timedout") + trace = None + + cmdline += [sut] + cmdline += t.get("args", []) + cmdline += [input_] - if trace: - trace.seek(0) - res["trace"] = trace.read().decode("UTF-8") + res["cmdline"] = cmdline + res["when"] = datetime.datetime.now().astimezone().isoformat(timespec="seconds") + p = subprocess.Popen(cmdline, cwd=root) + + if trace and not args.silent: + w = subprocess.Popen([WATCH, "-p", str(p.pid), trace]) + else: + w = None + + res["result"] = True + res["messages"] = [] + try: + p.wait(timeout=args.timeout) + + with open(returncode, "r") as f: + res["returncode"] = int(f.read()) + expected_exit = t.get("exit", 0) + if isinstance(expected_exit, str): + expected_exit = -int(signal.__dict__[expected_exit]) + res["expected_returncode"] = expected_exit + except subprocess.TimeoutExpired: + res["timeout"] = args.timeout + res["result"] = False + res["messages"].append("process timedout") + + if w is not None: + w.wait() + + with open(stdout, "rb") as f: + res["stdout"] = f.read().decode("UTF-8") + with open(stderr, "rb") as f: + res["stderr"] = f.read().decode("UTF-8") + + if trace: + q = subprocess.run(["strace-log-merge", trace], capture_output=True, check=True) + res["trace"] = q.stdout.decode("UTF-8") expected_stdout_fn = t.get("stdout") if expected_stdout_fn is None: