diff --git a/src/Conjure/Language/RepresentationOf.hs b/src/Conjure/Language/RepresentationOf.hs index 01c5ea1f11..c42feae252 100644 --- a/src/Conjure/Language/RepresentationOf.hs +++ b/src/Conjure/Language/RepresentationOf.hs @@ -1,39 +1,42 @@ module Conjure.Language.RepresentationOf where -- conjure -import Conjure.Prelude -import Conjure.Language.Domain -import Conjure.Language.Type ( TypeCheckerMode ) +import Conjure.Language.Domain +import Conjure.Language.Type (TypeCheckerMode) +import Conjure.Prelude class RepresentationOf a where - representationTreeOf - :: (MonadFailDoc m, ?typeCheckerMode :: TypeCheckerMode) - => a -> m (Tree (Maybe HasRepresentation)) + representationTreeOf :: + (MonadFailDoc m, ?typeCheckerMode :: TypeCheckerMode) => + a -> + m (Tree (Maybe HasRepresentation)) representationOf :: (RepresentationOf a, MonadFailDoc m, ?typeCheckerMode :: TypeCheckerMode) => a -> m HasRepresentation -representationOf a = do - tree <- representationTreeOf a - case rootLabel tree of - Nothing -> failDoc "doesn't seem to have a representation" +representationOf a = + case representationTreeOf a of + Nothing -> failDoc "doesn't seem to have a representation tree" + Just tree -> + case rootLabel tree of + Nothing -> failDoc "doesn't seem to have a representation" Just NoRepresentation -> failDoc "doesn't seem to have a representation" Just r -> return r hasRepresentation :: (RepresentationOf a, MonadFailDoc m, ?typeCheckerMode :: TypeCheckerMode) => a -> m () hasRepresentation x = - case representationOf x of - Nothing -> failDoc "doesn't seem to have a representation" - Just _ -> return () + case representationTreeOf x of + Nothing -> failDoc "doesn't seem to have a representation" + Just _ -> return () sameRepresentation :: (RepresentationOf a, MonadFailDoc m, ?typeCheckerMode :: TypeCheckerMode) => a -> a -> m () sameRepresentation x y = - case (representationOf x, representationOf y) of - (Just rx, Just ry) | rx == ry -> return () - _ -> failDoc "doesn't seem to have the same representation" + case (representationOf x, representationOf y) of + (Just rx, Just ry) | rx == ry -> return () + _ -> failDoc "doesn't seem to have the same representation" sameRepresentationTree :: (RepresentationOf a, MonadFailDoc m, ?typeCheckerMode :: TypeCheckerMode) => a -> a -> m () sameRepresentationTree x y = do - xTree <- representationTreeOf x - yTree <- representationTreeOf y - unless (xTree == yTree) $ - failDoc "doesn't seem to have the same representation tree" + xTree <- representationTreeOf x + yTree <- representationTreeOf y + unless (xTree == yTree) + $ failDoc "doesn't seem to have the same representation tree" diff --git a/src/Conjure/Rules/Horizontal/Sequence.hs b/src/Conjure/Rules/Horizontal/Sequence.hs index 3175631510..b8fc6b1040 100644 --- a/src/Conjure/Rules/Horizontal/Sequence.hs +++ b/src/Conjure/Rules/Horizontal/Sequence.hs @@ -379,16 +379,10 @@ rule_Restrict_Comprehension = "sequence-restrict-comprehension" `namedRule` theR theRule _ = na "rule_Restrict_Comprehension" --- | image(f,x) can be nasty for non-total sequences. --- 1. if f is a total sequence, it can readily be replaced by a set expression. --- 2.1. if f isn't total, and if the return type is right, it will always end up as a generator for a comprehension. --- a vertical rule is needed for such cases. --- 2.2. if the return type is not "right", i.e. it is a bool or an int, i.e. sth we cannot quantify over, --- the vertical rule is harder. rule_Image_Bool :: Rule rule_Image_Bool = "sequence-image-bool" `namedRule` theRule where - theRule Reference{} = na "rule_Image_Int" + theRule Reference{} = na "rule_Image_Bool" theRule p = do let onChildren @@ -402,6 +396,9 @@ rule_Image_Bool = "sequence-image-bool" `namedRule` theRule where case match opRestrict func of Nothing -> return () Just{} -> na "rule_Image_Bool" -- do not use this rule for restricted sequences + case match opTransform func of + Nothing -> na "rule_Image_Bool" -- only use this rule for transformed sequences + Just{} -> return () TypeSequence TypeBool <- typeOf func return (func, arg) case try of @@ -448,6 +445,9 @@ rule_Image_Int = "sequence-image-int" `namedRule` theRule where case match opRestrict func of Nothing -> return () Just{} -> na "rule_Image_Int" -- do not use this rule for restricted sequences + case match opTransform func of + Nothing -> na "rule_Image_Int" -- only use this rule for transformed sequences + Just{} -> return () TypeSequence (TypeInt _) <- typeOf func return (func, arg) case try of diff --git a/tests/exhaustive/basic/sequence_subseq_dups/expected/model.eprime b/tests/exhaustive/basic/sequence_subseq_dups/expected/model.eprime index 7a0286b781..4f16192b1b 100644 --- a/tests/exhaustive/basic/sequence_subseq_dups/expected/model.eprime +++ b/tests/exhaustive/basic/sequence_subseq_dups/expected/model.eprime @@ -15,10 +15,11 @@ such that b_ExplicitBounded_Length = conjure_aux1_ExplicitBounded_Length, and([q6 <= b_ExplicitBounded_Length -> and([b_ExplicitBounded_Values[q6] = - sum([toInt(q8 = conjure_aux1_ExplicitBounded_Values[q6]) * catchUndef([1, 1, 2; int(1..3)][q8], 0) - | q8 : int(1..3)]), - or([q10 = conjure_aux1_ExplicitBounded_Values[q6] | q10 : int(1..3), q10 <= 3]), - q6 <= conjure_aux1_ExplicitBounded_Length; + sum([toInt(1 = conjure_aux1_ExplicitBounded_Values[q6]), + toInt(2 = conjure_aux1_ExplicitBounded_Values[q6]), + toInt(3 = conjure_aux1_ExplicitBounded_Values[q6]) * 2; + int(1..3)]), + conjure_aux1_ExplicitBounded_Values[q6] <= 3, q6 <= conjure_aux1_ExplicitBounded_Length; int(1..3)]) | q6 : int(1..2)]), and([q1 > b_ExplicitBounded_Length -> b_ExplicitBounded_Values[q1] = 1 | q1 : int(1..2)]), diff --git a/tests/exhaustive/basic/sequence_subseq_nodups/expected/model.eprime b/tests/exhaustive/basic/sequence_subseq_nodups/expected/model.eprime index 96f6872240..f2dda32fe8 100644 --- a/tests/exhaustive/basic/sequence_subseq_nodups/expected/model.eprime +++ b/tests/exhaustive/basic/sequence_subseq_nodups/expected/model.eprime @@ -15,10 +15,11 @@ such that b_ExplicitBounded_Length = conjure_aux1_ExplicitBounded_Length, and([q6 <= b_ExplicitBounded_Length -> and([b_ExplicitBounded_Values[q6] = - sum([toInt(q8 = conjure_aux1_ExplicitBounded_Values[q6]) * catchUndef([3, 1, 2; int(1..3)][q8], 0) - | q8 : int(1..3)]), - or([q10 = conjure_aux1_ExplicitBounded_Values[q6] | q10 : int(1..3), q10 <= 3]), - q6 <= conjure_aux1_ExplicitBounded_Length; + sum([toInt(1 = conjure_aux1_ExplicitBounded_Values[q6]) * 3, + toInt(2 = conjure_aux1_ExplicitBounded_Values[q6]), + toInt(3 = conjure_aux1_ExplicitBounded_Values[q6]) * 2; + int(1..3)]), + conjure_aux1_ExplicitBounded_Values[q6] <= 3, q6 <= conjure_aux1_ExplicitBounded_Length; int(1..3)]) | q6 : int(1..2)]), and([q1 > b_ExplicitBounded_Length -> b_ExplicitBounded_Values[q1] = 1 | q1 : int(1..2)]),