Skip to content
This repository has been archived by the owner on Apr 16, 2024. It is now read-only.

Commit

Permalink
Apply black to the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kmmbvnr committed Nov 4, 2021
1 parent 5a4dba4 commit 8e40cf7
Show file tree
Hide file tree
Showing 23 changed files with 356 additions and 338 deletions.
121 changes: 71 additions & 50 deletions django_fsm/management/commands/graph_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

try:
from django.db.models import get_apps, get_app, get_models, get_model

NEW_META_API = False
except ImportError:
from django.apps import apps

NEW_META_API = True

from django import VERSION
Expand All @@ -22,20 +24,18 @@

def all_fsm_fields_data(model):
if NEW_META_API:
return [(field, model) for field in model._meta.get_fields()
if isinstance(field, FSMFieldMixin)]
return [(field, model) for field in model._meta.get_fields() if isinstance(field, FSMFieldMixin)]
else:
return [(field, model) for field in model._meta.fields
if isinstance(field, FSMFieldMixin)]
return [(field, model) for field in model._meta.fields if isinstance(field, FSMFieldMixin)]


def node_name(field, state):
opts = field.model._meta
return "%s.%s.%s.%s" % (opts.app_label, opts.verbose_name.replace(' ', '_'), field.name, state)
return "%s.%s.%s.%s" % (opts.app_label, opts.verbose_name.replace(" ", "_"), field.name, state)


def node_label(field, state):
if type(state) == int or (type(state) == bool and hasattr(field, 'choices')):
if type(state) == int or (type(state) == bool and hasattr(field, "choices")):
return force_text(dict(field.choices).get(state))
else:
return state
Expand All @@ -49,62 +49,63 @@ def generate_dot(fields_data):

# dump nodes and edges
for transition in field.get_all_transitions(model):
if transition.source == '*':
if transition.source == "*":
any_targets.add((transition.target, transition.name))
elif transition.source == '+':
elif transition.source == "+":
any_except_targets.add((transition.target, transition.name))
else:
_targets =\
(state for state in transition.target.allowed_states)\
if isinstance(transition.target, (GET_STATE, RETURN_VALUE))\
_targets = (
(state for state in transition.target.allowed_states)
if isinstance(transition.target, (GET_STATE, RETURN_VALUE))
else (transition.target,)
source_name_pair =\
((state, node_name(field, state)) for state in transition.source.allowed_states)\
if isinstance(transition.source, (GET_STATE, RETURN_VALUE))\
)
source_name_pair = (
((state, node_name(field, state)) for state in transition.source.allowed_states)
if isinstance(transition.source, (GET_STATE, RETURN_VALUE))
else ((transition.source, node_name(field, transition.source)),)
)
for source, source_name in source_name_pair:
if transition.on_error:
on_error_name = node_name(field, transition.on_error)
targets.add(
(on_error_name, node_label(field, transition.on_error))
)
edges.add((source_name, on_error_name, (('style', 'dotted'),)))
targets.add((on_error_name, node_label(field, transition.on_error)))
edges.add((source_name, on_error_name, (("style", "dotted"),)))
for target in _targets:
add_transition(source, target, transition.name,
source_name, field, sources, targets, edges)
add_transition(source, target, transition.name, source_name, field, sources, targets, edges)

targets.update(set((node_name(field, target), node_label(field, target))
for target, _ in chain(any_targets, any_except_targets)))
targets.update(
set((node_name(field, target), node_label(field, target)) for target, _ in chain(any_targets, any_except_targets))
)
for target, name in any_targets:
target_name = node_name(field, target)
all_nodes = sources | targets
for source_name, label in all_nodes:
sources.add((source_name, label))
edges.add((source_name, target_name, (('label', name),)))
edges.add((source_name, target_name, (("label", name),)))

for target, name in any_except_targets:
target_name = node_name(field, target)
all_nodes = sources | targets
all_nodes.remove(((target_name, node_label(field, target))))
for source_name, label in all_nodes:
sources.add((source_name, label))
edges.add((source_name, target_name, (('label', name),)))
edges.add((source_name, target_name, (("label", name),)))

# construct subgraph
opts = field.model._meta
subgraph = graphviz.Digraph(
name="cluster_%s_%s_%s" % (opts.app_label, opts.object_name, field.name),
graph_attr={'label': "%s.%s.%s" % (opts.app_label, opts.object_name, field.name)})
graph_attr={"label": "%s.%s.%s" % (opts.app_label, opts.object_name, field.name)},
)

final_states = targets - sources
for name, label in final_states:
subgraph.node(name, label=label, shape='doublecircle')
subgraph.node(name, label=label, shape="doublecircle")
for name, label in (sources | targets) - final_states:
subgraph.node(name, label=label, shape='circle')
subgraph.node(name, label=label, shape="circle")
if field.default: # Adding initial state notation
if label == field.default:
initial_name = node_name(field, '_initial')
subgraph.node(name=initial_name, label='', shape='point')
initial_name = node_name(field, "_initial")
subgraph.node(name=initial_name, label="", shape="point")
subgraph.edge(initial_name, name)
for source_name, target_name, attrs in edges:
subgraph.edge(source_name, target_name, **dict(attrs))
Expand All @@ -118,7 +119,7 @@ def add_transition(transition_source, transition_target, transition_name, source
target_name = node_name(field, transition_target)
sources.add((source_name, node_label(field, transition_source)))
targets.add((target_name, node_label(field, transition_target)))
edges.add((source_name, target_name, (('label', transition_name),)))
edges.add((source_name, target_name, (("label", transition_name),)))


def get_graphviz_layouts():
Expand All @@ -127,50 +128,70 @@ def get_graphviz_layouts():

return graphviz.backend.ENGINES
except Exception:
return {'sfdp', 'circo', 'twopi', 'dot', 'neato', 'fdp', 'osage', 'patchwork'}
return {"sfdp", "circo", "twopi", "dot", "neato", "fdp", "osage", "patchwork"}


class Command(BaseCommand):
requires_system_checks = True

if not HAS_ARGPARSE:
option_list = BaseCommand.option_list + (
make_option('--output', '-o', action='store', dest='outputfile',
help=('Render output file. Type of output dependent on file extensions. '
'Use png or jpg to render graph to image.')),
make_option(
"--output",
"-o",
action="store",
dest="outputfile",
help=(
"Render output file. Type of output dependent on file extensions. " "Use png or jpg to render graph to image."
),
),
# NOQA
make_option('--layout', '-l', action='store', dest='layout', default='dot',
help=('Layout to be used by GraphViz for visualization. '
'Layouts: %s.' % ' '.join(get_graphviz_layouts()))),
make_option(
"--layout",
"-l",
action="store",
dest="layout",
default="dot",
help=("Layout to be used by GraphViz for visualization. " "Layouts: %s." % " ".join(get_graphviz_layouts())),
),
)
args = "[appname[.model[.field]]]"
else:

def add_arguments(self, parser):
parser.add_argument(
'--output', '-o', action='store', dest='outputfile',
help=('Render output file. Type of output dependent on file extensions. '
'Use png or jpg to render graph to image.'))
"--output",
"-o",
action="store",
dest="outputfile",
help=(
"Render output file. Type of output dependent on file extensions. " "Use png or jpg to render graph to image."
),
)
parser.add_argument(
'--layout', '-l', action='store', dest='layout', default='dot',
help=('Layout to be used by GraphViz for visualization. '
'Layouts: %s.' % ' '.join(get_graphviz_layouts())))
parser.add_argument('args', nargs='*',
help=('[appname[.model[.field]]]'))
"--layout",
"-l",
action="store",
dest="layout",
default="dot",
help=("Layout to be used by GraphViz for visualization. " "Layouts: %s." % " ".join(get_graphviz_layouts())),
)
parser.add_argument("args", nargs="*", help=("[appname[.model[.field]]]"))

help = ("Creates a GraphViz dot file with transitions for selected fields")
help = "Creates a GraphViz dot file with transitions for selected fields"

def render_output(self, graph, **options):
filename, format = options['outputfile'].rsplit('.', 1)
filename, format = options["outputfile"].rsplit(".", 1)

graph.engine = options['layout']
graph.engine = options["layout"]
graph.format = format
graph.render(filename)

def handle(self, *args, **options):
fields_data = []
if len(args) != 0:
for arg in args:
field_spec = arg.split('.')
field_spec = arg.split(".")

if len(field_spec) == 1:
if NEW_META_API:
Expand Down Expand Up @@ -204,7 +225,7 @@ def handle(self, *args, **options):

dotdata = generate_dot(fields_data)

if options['outputfile']:
if options["outputfile"]:
self.render_output(dotdata, **options)
else:
print(dotdata)
20 changes: 10 additions & 10 deletions django_fsm/tests/test_abstract_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@


class BaseAbstractModel(models.Model):
state = FSMField(default='new')
state = FSMField(default="new")

class Meta:
abstract = True

@transition(field=state, source='new', target='published')
@transition(field=state, source="new", target="published")
def publish(self):
pass


class InheritedFromAbstractModel(BaseAbstractModel):
@transition(field='state', source='published', target='sticked')
@transition(field="state", source="published", target="sticked")
def stick(self):
pass

Expand All @@ -28,20 +28,20 @@ def setUp(self):
def test_known_transition_should_succeed(self):
self.assertTrue(can_proceed(self.model.publish))
self.model.publish()
self.assertEqual(self.model.state, 'published')
self.assertEqual(self.model.state, "published")

self.assertTrue(can_proceed(self.model.stick))
self.model.stick()
self.assertEqual(self.model.state, 'sticked')
self.assertEqual(self.model.state, "sticked")

def test_field_available_transitions_works(self):
self.model.publish()
self.assertEqual(self.model.state, 'published')
self.assertEqual(self.model.state, "published")
transitions = self.model.get_available_state_transitions()
self.assertEqual(['sticked'], [data.target for data in transitions])
self.assertEqual(["sticked"], [data.target for data in transitions])

def test_field_all_transitions_works(self):
transitions = self.model.get_all_state_transitions()
self.assertEqual(set([('new', 'published'),
('published', 'sticked')]),
set((data.source, data.target) for data in transitions))
self.assertEqual(
set([("new", "published"), ("published", "sticked")]), set((data.source, data.target) for data in transitions)
)
Loading

0 comments on commit 8e40cf7

Please sign in to comment.