Skip to content

Commit

Permalink
Don't use float equality for AST comparisons (#1238)
Browse files Browse the repository at this point in the history
## Summary
Float equality lacks the substitution nor reflexivity properties usually
expected from an equality operator, so it's not correct to use float
equality in AST comparisons. This PR changes it so that they are
compared for bit equality.

## Details
*   `0.0`  and  `-0.0`  are not being considered equal by the compiler
anymore, this affects:
- static parameters in generic types and procedures, see
`tests/statictypes/tstatictypes.nim`
- default arguments for forward declarations, see
`tests/errmsgs/tforwarddecl_defaultparam.nim`
- ```macros.`==`(a, b: NimNode)```, see
`tests/lang_callable/macros/tmacros_various.nim`
- term-rewriting macros, see
`tests/lang_experimental/trmacros/trmacros_various2.nim`

*  `trees.exprStructuralEquivalentStrictSym`  and it's only usage in 
`sem/semfoldnim`  have been removed

---------

Co-authored-by: zerbina <[email protected]>
  • Loading branch information
Clyybber and zerbina authored Feb 8, 2025
1 parent 6766abe commit f313cfc
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 52 deletions.
28 changes: 9 additions & 19 deletions compiler/ast/trees.nim
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ proc cyclicTree*(n: PNode): bool =
var visited: seq[PNode] = @[]
cyclicTreeAux(n, visited)

proc sameFloatIgnoreNan(a, b: BiggestFloat): bool {.inline.} =
## ignores NaN semantics, but ensures 0.0 == -0.0, see #13730
cast[uint64](a) == cast[uint64](b) or a == b
template cmpFloatRep*(a, b: BiggestFloat): bool =
## Compares the bit-representation of floats `a` and `b`
# Special handling for floats, so that floats that have the same
# value but different bit representations are treated as different constants
# Compared to float equality, this does not lack the substitution and
# reflexivity property, which the compiler relies on for correctness.
cast[uint64](a) == cast[uint64](b)

template makeTreeEquivalenceProc*(
name, relaxedKindCheck, symCheck, floatCheck, typeCheck, commentCheck) {.dirty.} =
name, relaxedKindCheck, symCheck, typeCheck, commentCheck) {.dirty.} =
## Defines a tree equivalence checking procedure.
## This skeleton is shared between all recursive
## `PNode` equivalence checks in the compiler code base
Expand All @@ -61,10 +65,7 @@ template makeTreeEquivalenceProc*(
of nkSym: result = symCheck
of nkIdent: result = a.ident.id == b.ident.id
of nkIntLiterals: result = a.intVal == b.intVal
of nkFloatLiterals: result = floatCheck
# XXX: Using float equality, even if partially tamed through
# sameFloatIgnoreNan, causes inconsistencies due to it
# lacking the substition and reflexivity property.
of nkFloatLiterals: result = cmpFloatRep(a.floatVal, b.floatVal)
of nkStrLiterals: result = a.strVal == b.strVal
of nkType: result = typeCheck
of nkCommentStmt: result = commentCheck
Expand All @@ -78,25 +79,14 @@ template makeTreeEquivalenceProc*(
makeTreeEquivalenceProc(exprStructuralEquivalent,
relaxedKindCheck = false,
symCheck = a.sym.name.id == b.sym.name.id, # same symbol as string is enough
floatCheck = sameFloatIgnoreNan(a.floatVal, b.floatVal),
typeCheck = true,
commentCheck = true
)
export exprStructuralEquivalent

makeTreeEquivalenceProc(exprStructuralEquivalentStrictSym,
relaxedKindCheck = false,
symCheck = a.sym == b.sym,
floatCheck = sameFloatIgnoreNan(a.floatVal, b.floatVal),
typeCheck = true,
commentCheck = true
)
export exprStructuralEquivalentStrictSym

makeTreeEquivalenceProc(exprStructuralEquivalentStrictSymAndComm,
relaxedKindCheck = false,
symCheck = a.sym == b.sym,
floatCheck = sameFloatIgnoreNan(a.floatVal, b.floatVal),
typeCheck = a.typ == b.typ,
commentCheck = a.comment == b.comment
)
Expand Down
3 changes: 2 additions & 1 deletion compiler/sem/guards.nim
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,11 @@ proc sameOpr(a, b: PSym): bool =
else: result = a == b

makeTreeEquivalenceProc(sameTree,
# XXX: This completely ignores that expressions might
# not be pure/deterministic.
relaxedKindCheck = false,
symCheck = sameOpr(a.sym, b.sym) or
(a.sym.magic != mNone and a.sym.magic == b.sym.magic),
floatCheck = a.floatVal == b.floatVal,
typeCheck = a.typ == b.typ,
commentCheck = true # ignore comments
)
Expand Down
4 changes: 1 addition & 3 deletions compiler/sem/patterns.nim
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ proc sameKinds(a, b: PNode): bool {.inline.} =
makeTreeEquivalenceProc(sameTrees,
relaxedKindCheck = sameKinds(a, b),
symCheck = a.sym == b.sym,
floatCheck = a.floatVal == b.floatVal,
typeCheck = sameTypeOrNil(a.typ, b.typ),
commentCheck = true # Ignore comments
)
export sameTrees

proc inSymChoice(sc, x: PNode): bool =
if sc.kind == nkClosedSymChoice:
Expand Down Expand Up @@ -177,7 +175,7 @@ proc matches(c: PPatternContext, p, n: PNode): bool =
of nkSym: result = p.sym == n.sym
of nkIdent: result = p.ident.id == n.ident.id
of nkIntLiterals: result = p.intVal == n.intVal
of nkFloatLiterals: result = p.floatVal == n.floatVal
of nkFloatLiterals: result = cmpFloatRep(p.floatVal, n.floatVal)
of nkStrLiterals: result = p.strVal == n.strVal
of nkEmpty, nkNilLit, nkType, nkCommentStmt:
result = true # Ignore comments
Expand Down
11 changes: 9 additions & 2 deletions compiler/sem/semfold.nim
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import
],
compiler/front/[
options,
msgs,
],
compiler/utils/[
platform,
Expand Down Expand Up @@ -378,8 +379,14 @@ proc evalOp*(m: TMagic, n, a, b, c: PNode; idgen: IdGenerator; g: ModuleGraph):
result = copyTree(a)
result.typ = n.typ
of mEqProc:
result = newIntNodeT(toInt128(ord(
exprStructuralEquivalentStrictSym(a, b))), n, idgen, g)
g.config.internalAssert(a.kind in {nkSym, nkNilLit} and
b.kind in {nkSym, nkNilLit},
n.info, "mEqProc: invalid AST")
let isEqual =
if a.kind != b.kind: false
elif a.kind == nkSym: a.sym == b.sym # b.kind == nkSym
else: true # a.kind == b.kind == nkNilLit
result = newIntNodeT(toInt128(ord(isEqual)), n, idgen, g)
else: discard

proc getConstIfExpr(c: PSym, n: PNode; idgen: IdGenerator; g: ModuleGraph): PNode =
Expand Down
23 changes: 1 addition & 22 deletions compiler/vm/vmgen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -722,27 +722,6 @@ proc rawGenLiteral(c: var TCtx, val: sink VmConstant): int =
internalAssert c.config, result < regBxMax, "Too many constants used"


template cmpFloatRep(a, b: BiggestFloat): bool =
## Compares the bit-representation of floats `a` and `b`
# Special handling for floats, so that floats that have the same
# value but different bit representations are treated as different constants
cast[uint64](a) == cast[uint64](b)
# refs bug #16469
# if we wanted to only distinguish 0.0 vs -0.0:
# if a.floatVal == 0.0: result = cast[uint64](a.floatVal) == cast[uint64](b.floatVal)
# else: result = a.floatVal == b.floatVal

# Compares two trees for structural equality, also taking the type of
# ``nkType`` nodes into account. This procedure is used to prevent the same
# AST from being added as a node constant more than once
makeTreeEquivalenceProc(cmpNodeCnst,
relaxedKindCheck = false,
symCheck = a.sym == b.sym,
floatCheck = cmpFloatRep(a.floatVal, b.floatVal),
typeCheck = a.typ == b.typ,
commentCheck = a.comment == b.comment
)

template makeCnstFunc(name, vType, aKind, valName, cmp) {.dirty.} =
proc name(c: var TCtx, val: vType): int =
for (i, cnst) in c.constants.pairs():
Expand All @@ -752,7 +731,7 @@ template makeCnstFunc(name, vType, aKind, valName, cmp) {.dirty.} =
c.rawGenLiteral: VmConstant(kind: aKind, valName: val)


makeCnstFunc(toNodeCnst, PNode, cnstNode, node, cmpNodeCnst)
makeCnstFunc(toNodeCnst, PNode, cnstNode, node, exprStructuralEquivalentStrictSymAndComm)

makeCnstFunc(toIntCnst, BiggestInt, cnstInt, intVal, `==`)

Expand Down
9 changes: 9 additions & 0 deletions tests/errmsgs/tforwarddecl_defaultparam.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
discard """
errormsg: "overloaded 'reciprocal' leads to ambiguous calls"
line: 9
"""

# Differing float literal default args must prevent forward declaration
# and the compiler must not compare them via float equality
proc reciprocal(f: float = 0.0): float
proc reciprocal(f: float = -0.0): float = 1 / f
2 changes: 1 addition & 1 deletion tests/lang_callable/generics/tgenerics_issues.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1084,4 +1084,4 @@ block typed_macro_in_generic_object_when:
var o1 = Object[0]()
doAssert not compiles(o1.val)
var o2 = Object[1](val: 2)
doAssert o2.val == 2
doAssert o2.val == 2
13 changes: 11 additions & 2 deletions tests/lang_callable/macros/tmacros_various.nim
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ CommentStmt "comment 1"
CommentStmt "comment 2"
false
false
false
true
'''
output: '''
Expand Down Expand Up @@ -334,8 +336,8 @@ block: # bug #15118
flop("b")

block:
# Ensure nkCommentStmt equality is not ignored when vmgen.cmpNodeCnst
# is used to deduplicate NimNode constants, so that `CommentStmt "comment 2"`
# Ensure nkCommentStmt equality is not ignored when vmgen.toNodeCnst
# deduplicates NimNode constants, so that `CommentStmt "comment 2"`
# is not counted as a duplicate of `CommentStmt "comment 1"` and
# incorrectly optimized to point at the `Comment "comment 1"` node

Expand Down Expand Up @@ -386,3 +388,10 @@ block:
except E:
discard
)

block:
# Ensure float equality semantics are not used when comparing AST for equality

static:
echo newLit(0.0) == newLit(-0.0) # false
echo newLit(NaN) == newLit(NaN) # true
13 changes: 13 additions & 0 deletions tests/lang_experimental/trmacros/trmacros_various2.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ lo
my awesome concat
1
TRM
10000000000.0
-10000000000.0
'''
"""

Expand Down Expand Up @@ -99,3 +101,14 @@ echo u * 3'u # 1
template dontAppendE{`&`(s, 'E')}(s: string): string = s
var s = "T"
echo s & 'E' & 'R' & 'M'

# Floats must not be matched with float equality semantics
template capDivPos0{`/`(f, 0.0)}(f: float): float =
10000000000.float

template capDivNeg0{`/`(f, -0.0)}(f: float): float =
-10000000000.float

let f = 1.0
echo f / 0.0 # 10000000000.0
echo f / -0.0 # -10000000000.0
14 changes: 12 additions & 2 deletions tests/statictypes/tstatictypes.nim
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ when true:

block: # issue #13730
type Foo[T: static[float]] = object
doAssert Foo[0.0] is Foo[-0.0]
# It should not actually be considered the same type as
# float equality does not have the substition property,
# For example: 1 / 0.0 = Inf != -Inf = 1 / -0.0
# even though 0.0 == -0.0 according to float semantics
doAssert Foo[0.0] isnot Foo[-0.0]

when true:
type
Expand Down Expand Up @@ -411,4 +415,10 @@ block coercion_to_static_type:
result = 2.1

# the call must be fully evaluated at compile-time
doAssert static[int](get()) == 1
doAssert static[int](get()) == 1

proc reciprocal(f: static float): float =
1 / f

doAssert reciprocal(-0.0) == -Inf
doAssert reciprocal(0.0) == Inf

0 comments on commit f313cfc

Please sign in to comment.