diff --git a/sqlelf/cli.py b/sqlelf/cli.py index e7d0f3c..97e0b9a 100644 --- a/sqlelf/cli.py +++ b/sqlelf/cli.py @@ -2,19 +2,19 @@ import os import os.path import sys +from dataclasses import dataclass, field from functools import reduce +from typing import TextIO import lief from sqlelf import sql as api_sql -from typing import TextIO -from dataclasses import dataclass @dataclass class ProgramArguments: - filenames: list[str] - sql: list[str] + filenames: list[str] = field(default_factory=list) + sql: list[str] = field(default_factory=list) recursive: bool = False @@ -46,7 +46,9 @@ def start(args: list[str] = sys.argv[1:], stdin: TextIO = sys.stdin): help="Load all shared libraries needed by each file using ldd", ) - args = parser.parse_args(args) + program_args: ProgramArguments = parser.parse_args( + args, namespace=ProgramArguments() + ) # Iterate through our arguments and if one of them is a directory explode it out filenames: list[str] = reduce( @@ -55,7 +57,7 @@ def start(args: list[str] = sys.argv[1:], stdin: TextIO = sys.stdin): lambda dir: [os.path.join(dir, f) for f in os.listdir(dir)] if os.path.isdir(dir) else [dir], - args.filenames, + program_args.filenames, ), ) # Filter the list of filenames to those that are ELF files only @@ -67,11 +69,11 @@ def start(args: list[str] = sys.argv[1:], stdin: TextIO = sys.stdin): binaries: list[lief.Binary] = [lief.parse(filename) for filename in filenames] - sql_engine = api_sql.make_sql_engine(binaries, recursive=args.recursive) + sql_engine = api_sql.make_sql_engine(binaries, recursive=program_args.recursive) shell = sql_engine.shell(stdin=stdin) - if args.sql: - for sql in args.sql: + if program_args.sql and len(program_args.filenames) > 0: + for sql in program_args.sql: shell.process_complete_line(sql) else: shell.cmdloop() diff --git a/sqlelf/elf.py b/sqlelf/elf.py index 3dda67c..7478a9d 100644 --- a/sqlelf/elf.py +++ b/sqlelf/elf.py @@ -244,7 +244,6 @@ def register_virtual_tables( (make_symbols_generator, "raw_elf_symbols"), ] for factory, name in factory_and_names: - print(name) generator = factory(binaries) # setup columns and access by providing an example of the first entry returned generator.columns, generator.column_access = apsw.ext.get_column_names(