diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c008f93..d078c5b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ - improve ELF strtab and needed parsing @mr-tz - better handle exceptional cases when parsing ELF files [#1458](https://github.com/mandiant/capa/issues/1458) [@Aayush-Goel-04](https://github.com/aayush-goel-04) - Improved testing coverage for Binary Ninja Backend [#1446](https://github.com/mandiant/capa/issues/1446) [@Aayush-Goel-04](https://github.com/aayush-goel-04) +- Add redirect print to tqdm for capa main [#749](https://github.com/mandiant/capa/issues/749) [@Aayush-Goel-04](https://github.com/aayush-goel-04) - extractor: fix binja installation path detection does not work with Python 3.11 ### capa explorer IDA Pro plugin diff --git a/capa/helpers.py b/capa/helpers.py index 25d9efe0..c03e0553 100644 --- a/capa/helpers.py +++ b/capa/helpers.py @@ -6,9 +6,13 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and limitations under the License. import os +import inspect import logging +import contextlib from typing import NoReturn +import tqdm + from capa.exceptions import UnsupportedFormatError from capa.features.common import FORMAT_PE, FORMAT_SC32, FORMAT_SC64, FORMAT_DOTNET, FORMAT_UNKNOWN, Format @@ -85,6 +89,39 @@ def get_format(sample: str) -> str: return FORMAT_UNKNOWN +@contextlib.contextmanager +def redirecting_print_to_tqdm(disable_progress): + """ + tqdm (progress bar) expects to have fairly tight control over console output. + so calls to `print()` will break the progress bar and make things look bad. + so, this context manager temporarily replaces the `print` implementation + with one that is compatible with tqdm. + via: https://stackoverflow.com/a/42424890/87207 + """ + old_print = print + + def new_print(*args, **kwargs): + # If tqdm.tqdm.write raises error, use builtin print + if disable_progress: + old_print(*args, **kwargs) + else: + try: + tqdm.tqdm.write(*args, **kwargs) + except: + old_print(*args, **kwargs) + + try: + # Globally replace print with new_print. + # Verified this works manually on Python 3.11: + # >>> import inspect + # >>> inspect.builtins + # + inspect.builtins.print = new_print # type: ignore + yield + finally: + inspect.builtins.print = old_print # type: ignore + + def log_unsupported_format_error(): logger.error("-" * 80) logger.error(" Input file does not appear to be a PE or ELF file.") diff --git a/capa/main.py b/capa/main.py index 14411b0e..fc71a89c 100644 --- a/capa/main.py +++ b/capa/main.py @@ -53,6 +53,7 @@ from capa.helpers import ( get_file_taste, get_auto_format, log_unsupported_os_error, + redirecting_print_to_tqdm, log_unsupported_arch_error, log_unsupported_format_error, ) @@ -251,38 +252,39 @@ def find_capabilities(ruleset: RuleSet, extractor: FeatureExtractor, disable_pro "library_functions": {}, } # type: Dict[str, Any] - pbar = tqdm.tqdm - if disable_progress: - # do not use tqdm to avoid unnecessary side effects when caller intends - # to disable progress completely - def pbar(s, *args, **kwargs): - return s + with redirecting_print_to_tqdm(disable_progress): + pbar = tqdm.tqdm + if disable_progress: + # do not use tqdm to avoid unnecessary side effects when caller intends + # to disable progress completely + def pbar(s, *args, **kwargs): + return s - functions = list(extractor.get_functions()) - n_funcs = len(functions) + functions = list(extractor.get_functions()) + n_funcs = len(functions) - pb = pbar(functions, desc="matching", unit=" functions", postfix="skipped 0 library functions") - for f in pb: - if extractor.is_library_function(f.address): - function_name = extractor.get_function_name(f.address) - logger.debug("skipping library function 0x%x (%s)", f.address, function_name) - meta["library_functions"][f.address] = function_name - n_libs = len(meta["library_functions"]) - percentage = round(100 * (n_libs / n_funcs)) - if isinstance(pb, tqdm.tqdm): - pb.set_postfix_str(f"skipped {n_libs} library functions ({percentage}%)") - continue + pb = pbar(functions, desc="matching", unit=" functions", postfix="skipped 0 library functions") + for f in pb: + if extractor.is_library_function(f.address): + function_name = extractor.get_function_name(f.address) + logger.debug("skipping library function 0x%x (%s)", f.address, function_name) + meta["library_functions"][f.address] = function_name + n_libs = len(meta["library_functions"]) + percentage = round(100 * (n_libs / n_funcs)) + if isinstance(pb, tqdm.tqdm): + pb.set_postfix_str(f"skipped {n_libs} library functions ({percentage}%)") + continue - function_matches, bb_matches, insn_matches, feature_count = find_code_capabilities(ruleset, extractor, f) - meta["feature_counts"]["functions"][f.address] = feature_count - logger.debug("analyzed function 0x%x and extracted %d features", f.address, feature_count) + function_matches, bb_matches, insn_matches, feature_count = find_code_capabilities(ruleset, extractor, f) + meta["feature_counts"]["functions"][f.address] = feature_count + logger.debug("analyzed function 0x%x and extracted %d features", f.address, feature_count) - for rule_name, res in function_matches.items(): - all_function_matches[rule_name].extend(res) - for rule_name, res in bb_matches.items(): - all_bb_matches[rule_name].extend(res) - for rule_name, res in insn_matches.items(): - all_insn_matches[rule_name].extend(res) + for rule_name, res in function_matches.items(): + all_function_matches[rule_name].extend(res) + for rule_name, res in bb_matches.items(): + all_bb_matches[rule_name].extend(res) + for rule_name, res in insn_matches.items(): + all_insn_matches[rule_name].extend(res) # collection of features that captures the rule matches within function, BB, and instruction scopes. # mapping from feature (matched rule) to set of addresses at which it matched. diff --git a/scripts/lint.py b/scripts/lint.py index 92c7fbcf..a80d3e12 100644 --- a/scripts/lint.py +++ b/scripts/lint.py @@ -22,13 +22,11 @@ import time import string import difflib import hashlib -import inspect import logging import pathlib import argparse import itertools import posixpath -import contextlib from typing import Set, Dict, List from pathlib import Path from dataclasses import field, dataclass @@ -866,37 +864,6 @@ def width(s, count): return s.ljust(count) -@contextlib.contextmanager -def redirecting_print_to_tqdm(): - """ - tqdm (progress bar) expects to have fairly tight control over console output. - so calls to `print()` will break the progress bar and make things look bad. - so, this context manager temporarily replaces the `print` implementation - with one that is compatible with tqdm. - - via: https://stackoverflow.com/a/42424890/87207 - """ - old_print = print - - def new_print(*args, **kwargs): - # If tqdm.tqdm.write raises error, use builtin print - try: - tqdm.tqdm.write(*args, **kwargs) - except: - old_print(*args, **kwargs) - - try: - # Globally replace print with new_print. - # Verified this works manually on Python 3.11: - # >>> import inspect - # >>> inspect.builtins - # - inspect.builtins.print = new_print # type: ignore - yield - finally: - inspect.builtins.print = old_print # type: ignore - - def lint(ctx: Context): """ Returns: Dict[string, Tuple(int, int)] @@ -907,7 +874,7 @@ def lint(ctx: Context): source_rules = [rule for rule in ctx.rules.rules.values() if not rule.is_subscope_rule()] with tqdm.contrib.logging.tqdm_logging_redirect(source_rules, unit="rule") as pbar: - with redirecting_print_to_tqdm(): + with capa.helpers.redirecting_print_to_tqdm(False): for rule in pbar: name = rule.name pbar.set_description(width(f"linting rule: {name}", 48))