diff --git a/doc_comments_ai/treesitter/treesitter_hs.py b/doc_comments_ai/treesitter/treesitter_hs.py index a9644b7..b595d3b 100644 --- a/doc_comments_ai/treesitter/treesitter_hs.py +++ b/doc_comments_ai/treesitter/treesitter_hs.py @@ -1,6 +1,4 @@ import tree_sitter -from typing import List, Dict - from doc_comments_ai.constants import Language from doc_comments_ai.treesitter.treesitter import (Treesitter, @@ -10,9 +8,7 @@ class TreesitterHaskell(Treesitter): def __init__(self): - super().__init__( - Language.HASKELL, "function", "variable", "comment" - ) + super().__init__(Language.HASKELL, "function", "variable", "comment") def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]: self.tree = self.parser.parse(file_bytes) @@ -23,10 +19,15 @@ def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]: doc_comment = method["doc_comment"] source_code = None if method["method"].type == "signature": - sc = map(lambda x : "\n" + x.text.decode() if x.type == "function" else "", method["method"].children) + sc = map( + lambda x: "\n" + x.text.decode() if x.type == "function" else "", + method["method"].children, + ) source_code = method["method"].text.decode() + "".join(sc) result.append( - TreesitterMethodNode(method_name, doc_comment, source_code, method["method"]) + TreesitterMethodNode( + method_name, doc_comment, source_code, method["method"] + ) ) return result @@ -34,7 +35,7 @@ def _query_all_methods( self, node: tree_sitter.Node, ): - methods: List[Dict[tree_sitter.Node, tree_sitter.Node]] = [] + methods = [] if node.type == self.method_declaration_identifier: doc_comment_node = None if ( @@ -43,11 +44,15 @@ def _query_all_methods( ): doc_comment_node = node.prev_named_sibling.text.decode() else: - if node.prev_named_sibling.type == "signature": + if ( + node.prev_named_sibling + and node.prev_named_sibling.type == "signature" + ): prev_node = node.prev_named_sibling if ( prev_node.prev_named_sibling - and prev_node.prev_named_sibling.type == self.doc_comment_identifier + and prev_node.prev_named_sibling.type + == self.doc_comment_identifier ): doc_comment_node = prev_node.prev_named_sibling.text.decode() prev_node.children.append(node) @@ -58,8 +63,12 @@ def _query_all_methods( current = self._query_all_methods(child) if methods and current: previous = methods[-1] - if self._query_method_name(previous["method"]) == self._query_method_name(current[0]["method"]): - previous["method"].children.extend(map(lambda x: x["method"], current)) + if self._query_method_name( + previous["method"] + ) == self._query_method_name(current[0]["method"]): + previous["method"].children.extend( + map(lambda x: x["method"], current) + ) methods = methods[:-1] methods.append(previous) else: