Skip to content

Commit

Permalink
fix codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDanielThangarajan committed Jan 22, 2025
1 parent bfd67ea commit 1d3e414
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 58 deletions.
7 changes: 5 additions & 2 deletions src/nasdaq_protocols/fix/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
@click.option('--prefix', type=click.STRING, default='')
@click.option('--op-dir', type=click.Path(exists=True, writable=True))
@click.option('--init-file/--no-init-file', show_default=True, default=True)
def generate(spec_file, app_name, op_dir, prefix, init_file):
@click.option('--fix-version',
type=click.Choice(['4.2', '4.4', '5.0', '5.0SP2']),
default='5.0SP2')
def generate(spec_file, app_name, op_dir, prefix, init_file, fix_version):

try:
generator = Generator(
parse(spec_file),
parse(spec_file, fix_version),
app_name,
op_dir,
prefix,
Expand Down
8 changes: 3 additions & 5 deletions src/nasdaq_protocols/fix/parser/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
'Definitions'
]

from nasdaq_protocols.fix.parser.version_types import Version


@attrs.define
class FieldDef:
Expand Down Expand Up @@ -127,7 +125,7 @@ def get_codegen_context(self, definitions):

@attrs.define
class Definitions:
version: Version
version: str
fields: dict[str, FieldDef] = attrs.field(init=False, factory=dict)
components: dict[str, Component] = attrs.field(kw_only=True, factory=dict)
header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
Expand Down Expand Up @@ -157,8 +155,8 @@ def get_codegen_context(self):
}

def _client_session(self):
if self.version == Version.FIX_4_4:
if self.version == '4.4':
return 'Fix44Session'
if self.version in (Version.FIX_5_0, Version.FIX_5_0_2):
if self.version in ('5.0', '5.0SP2'):
return 'Fix50Session'
raise ValueError(f'Version {self.version} is not supported')
14 changes: 1 addition & 13 deletions src/nasdaq_protocols/fix/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
)
from .version_types import (
get_supported_types,
Version,
SupportedTypes
)

Expand All @@ -24,24 +23,13 @@
LOG = logging.getLogger(__name__)


def parse(file: str) -> Definitions:
def parse(file: str, version: str) -> Definitions:
tree = e_tree.parse(file)
root = tree.getroot()

if root.tag != 'fix':
raise ValueError('root tag is not fix')

version_str = f'{root.get("major")}{root.get("minor")}'
servicepack = int(root.get('servicepack', '0'))
if servicepack > 0:
version_str += f'{servicepack}'
version = int(version_str)

try:
version = Version(version)
except ValueError as v_error:
raise ValueError(f'Version {version} is not supported') from v_error

handlers = {
'fields': partial(_handle_fields, get_supported_types(version)),
'components': _handle_components,
Expand Down
34 changes: 12 additions & 22 deletions src/nasdaq_protocols/fix/parser/version_types.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
import enum

from .. import types
from ...common import TypeDefinition


__all__ = [
'SupportedTypes',
'Version',
'get_supported_types'
]
SupportedTypes = dict[str, TypeDefinition]


class Version(enum.IntEnum):
FIX_4_2 = 42
FIX_4_4 = 44
FIX_5_0 = 50
FIX_5_0_2 = 502


def get_supported_types(version: Version) -> SupportedTypes:
def get_supported_types(version: str) -> SupportedTypes:
version_map = {
Version.FIX_4_2: fix_42_version_types,
Version.FIX_4_4: fix_44_version_types,
Version.FIX_5_0: fix_50_version_types,
Version.FIX_5_0_2: fix_502_version_types
'4.2': _fix_42_version_types,
'4.4': _fix_44_version_types,
'5.0': _fix_50_version_types,
'5.0SP2': _fix_502_version_types
}
try:
return version_map[version]()
except KeyError as k_error:
raise ValueError(f'Version {version} not supported') from k_error


def fix_42_version_types():
def _fix_42_version_types():
return {
'AMT': types.FixAmount,
'BOOLEAN': types.FixBool,
Expand Down Expand Up @@ -61,17 +51,17 @@ def fix_42_version_types():
}


def fix_44_version_types():
fix_44_types = fix_42_version_types()
def _fix_44_version_types():
fix_44_types = _fix_42_version_types()
fix_44_types.update({
'SEQNUM': types.FixInt,
'NUMINGROUP': types.FixInt,
})
return fix_44_types


def fix_50_version_types():
fix_50_types = fix_42_version_types()
def _fix_50_version_types():
fix_50_types = _fix_42_version_types()
fix_50_types.update({
'FIXSTRING': types.FixString,
'MULTIPLECHARVALUE': types.FixString,
Expand All @@ -81,8 +71,8 @@ def fix_50_version_types():
return fix_50_types


def fix_502_version_types():
fix_52_types = fix_50_version_types()
def _fix_502_version_types():
fix_52_types = _fix_50_version_types()
fix_52_types.update({
'LOCALMKTDATE': types.FixLocalMktDate,
'TZTIMEONLY': types.FixTzTimeonly,
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,23 @@ async def mock_server_session(unused_tcp_port):

@pytest.fixture(scope='function')
def codegen_invoker(capsys, tmp_path):
def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None):
def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None, extra_args=None):
runner = CliRunner()
with capsys.disabled(), runner.isolated_filesystem(temp_dir=tmp_path):
with open('spec.xml', 'w') as spec_file:
spec_file.write(xml_content)
output_dir = output_dir or 'output'
Path(output_dir).mkdir(parents=True, exist_ok=True)
extra_args = extra_args or []
result = runner.invoke(
codegen,
[
'--spec-file', 'spec.xml',
'--app-name', app_name,
'--op-dir', output_dir,
'--prefix', prefix,
'--init-file' if generate_init_file else '--no-init-file'
'--init-file' if generate_init_file else '--no-init-file',
*extra_args
]
)
assert result.exit_code == 0
Expand Down
8 changes: 5 additions & 3 deletions tests/test_fix_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.fixture(scope='function')
def fix_44_definitions(tmp_file_writer):
file = tmp_file_writer(TEST_FIX_44_XML)
definitions = parse(file)
definitions = parse(file, '4.4')
assert definitions is not None
yield definitions

Expand All @@ -26,7 +26,8 @@ def test__no_init_file__no_prefix__code_generated(codegen_invoker):
TEST_FIX_44_XML,
app_name,
generate_init_file=False,
prefix=prefix
prefix=prefix,
extra_args=['--fix-version', '4.4']
)

assert len(generated_files) == 5
Expand All @@ -42,7 +43,8 @@ def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invok
app_name,
generate_init_file=True,
prefix=prefix,
output_dir=output_dir
output_dir=output_dir,
extra_args=['--fix-version', '4.4']
)

assert len(generated_files) == 6
Expand Down
22 changes: 11 additions & 11 deletions tests/test_fix_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@pytest.fixture(scope='function')
def fix_44_definitions(tmp_file_writer):
file = tmp_file_writer(TEST_FIX_44_XML)
definitions = parse(file)
definitions = parse(file, '4.4')
assert definitions is not None
yield definitions

Expand All @@ -21,7 +21,7 @@ def test__fix_parser__parse__invalid_root_tag(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'root tag is not fix'

Expand All @@ -37,7 +37,7 @@ def test__fix_parser__parse__invalid_tag_in_fields(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'expected field tag, got invalid'

Expand All @@ -53,9 +53,9 @@ def test__fix_parser__parse__unsupported_version(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '2.1')

assert str(e.value) == 'Version 21 is not supported'
assert str(e.value) == 'Version 2.1 not supported'


def test__fix_parser__parse__component_not_found(tmp_file_writer):
Expand All @@ -73,7 +73,7 @@ def test__fix_parser__parse__component_not_found(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'Component definition for NotFoundComponent not found'

Expand All @@ -91,7 +91,7 @@ def test__fix_parser__parse__field_not_found(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'Field definition for NotFound not found'

Expand All @@ -103,8 +103,8 @@ def test__fix_parser__parse__xml_with_service_pack(tmp_file_writer):
'''
file = tmp_file_writer(fix_502)

definitions = parse(file)
assert definitions.version == 502
definitions = parse(file, '5.0SP2')
assert definitions.version == '5.0SP2'


def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_file_writer):
Expand All @@ -120,8 +120,8 @@ def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_fil
'''
file = tmp_file_writer(fix_502)

definitions = parse(file)
assert definitions.version == 502
definitions = parse(file, '5.0SP2')
assert definitions.version == '5.0SP2'
context = definitions.fields['MsgType'].get_codegen_context(None)
assert context['values'][0]['f_value'] == 'None_'
assert context['values'][1]['f_value'] == 'if_'
Expand Down

0 comments on commit 1d3e414

Please sign in to comment.