Skip to content

Commit

Permalink
unit test for new iteration API by section
Browse files Browse the repository at this point in the history
code cleaning in iter_record_objects_by_section function
GitHub issue - #12
  • Loading branch information
mugdhadhole1 committed Jul 18, 2024
1 parent e916063 commit 4ac056a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
40 changes: 40 additions & 0 deletions tests-unit/test_ast_bysection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
from unittest.mock import patch, MagicMock
from trlc.ast import Symbol_Table

class TestRecordObject:
def __init__(self, location, section):
self.location = location
self.section = section

class TestSection:
def __init__(self, name):
self.name = name

class TestIterRecordObjectsBySection(unittest.TestCase):

@patch("trlc.ast.Symbol_Table.iter_record_objects")
def test_iter_record_objects_by_section(self, mock_iter_record_objects):
mock_location1 = MagicMock(file_name = 'file1')
mock_section1 = TestSection('section1')
mock_section2 = TestSection('section2')
mock_location2 = MagicMock(file_name = 'file2')
record1 = TestRecordObject(mock_location1, [mock_section1, mock_section2])
record2 = TestRecordObject(mock_location2, [])
mock_iter_record_objects.return_value = [record1, record2]

results = list(Symbol_Table().iter_record_objects_by_section())

expected_results = [
'file1',
('section1', 0),
('section2', 1),
(record1, 1),
'file2',
(record2, 0)
]

self.assertEqual(results, expected_results)

if __name__ == '__main__':
unittest.main()
4 changes: 1 addition & 3 deletions trlc/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3096,15 +3096,13 @@ def iter_record_objects_by_section(self):
yield location
if record_object.section:
object_level = len(record_object.section) - 1
else:
object_level = 0
if record_object.section:
for level, section in enumerate(record_object.section):
if section not in self.section_names:
self.section_names.append(section)
yield section.name, level
yield record_object, object_level
else:
object_level = 0
yield record_object, object_level

def iter_record_objects(self):
Expand Down

0 comments on commit 4ac056a

Please sign in to comment.