Skip to content

Commit

Permalink
Binary search through an array
Browse files Browse the repository at this point in the history
  • Loading branch information
nikswamy committed Sep 6, 2024
1 parent 90ce81a commit 2fb5240
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 1 deletion.
23 changes: 22 additions & 1 deletion lib/pulse/lib/Pulse.Lib.BoundedIntegers.fst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class bounded_int (t:eqtype) = {
op_Subtraction : (x:t -> y:t -> Pure t (requires fits (v x - v y)) (ensures fun z -> v z == v x - v y));
( < ) : (x:t -> y:t -> b:bool { b = (v x < v y)});
( <= ) : (x:t -> y:t -> b:bool { b = (v x <= v y)});
( % ) : (x:t -> y:t -> Pure t (requires v y > 0 /\ fits (v x % v y)) (ensures fun z -> v z == v x % v y));
( > ) : (x:t -> y:t -> b:bool { b = (v x > v y)});
( >= ) : (x:t -> y:t -> b:bool { b = (v x >= v y)});
( % ) : (x:t -> y:t -> Pure t (requires v y `Prims.(op_GreaterThan)` 0 /\ fits (v x % v y)) (ensures fun z -> v z == v x % v y));
( / ) : (x:t -> y:t -> Pure t (requires v y <> 0 /\ fits (v x / v y)) (ensures fun z -> v z == v x / v y));
[@@@TC.no_method]
properties: squash (
(forall (x:t). {:pattern v x} fits (v x))
Expand All @@ -47,7 +50,10 @@ instance bounded_int_int : bounded_int int = {
op_Subtraction = (fun x y -> Prims.op_Subtraction x y);
( < ) = (fun x y -> Prims.op_LessThan x y);
( <= ) = (fun x y -> Prims.op_LessThanOrEqual x y);
( > ) = (fun x y -> Prims.op_GreaterThan x y);
( >= ) = (fun x y -> Prims.op_GreaterThanOrEqual x y);
( % ) = (fun x y -> Prims.op_Modulus x y);
( / ) = (fun x y -> Prims.op_Division x y);
properties = ()
}

Expand Down Expand Up @@ -128,7 +134,10 @@ instance bounded_int_u32 : bounded_int FStar.UInt32.t = {
op_Subtraction = (fun x y -> FStar.UInt32.sub x y);
( < ) = FStar.UInt32.(fun x y -> x <^ y);
( <= ) = FStar.UInt32.(fun x y -> x <=^ y);
( > ) = FStar.UInt32.(fun x y -> x >^ y);
( >= ) = FStar.UInt32.(fun x y -> x >=^ y);
( % ) = FStar.UInt32.(fun x y -> x %^ y);
( / ) = FStar.UInt32.(fun x y -> x `div` y);
properties = ()
}

Expand All @@ -149,7 +158,10 @@ instance bounded_int_u64 : bounded_int FStar.UInt64.t = {
op_Subtraction = (fun x y -> FStar.UInt64.sub x y);
( < ) = FStar.UInt64.(fun x y -> x <^ y);
( <= ) = FStar.UInt64.(fun x y -> x <=^ y);
( > ) = FStar.UInt64.(fun x y -> x >^ y);
( >= ) = FStar.UInt64.(fun x y -> x >=^ y);
( % ) = FStar.UInt64.(fun x y -> x %^ y);
( / ) = FStar.UInt64.(fun x y -> x `div` y);
properties = ()
}

Expand Down Expand Up @@ -185,7 +197,10 @@ instance bounded_int_nat : bounded_int nat = {
op_Subtraction = (fun x y -> Prims.op_Subtraction x y); //can't write ( - ), it doesn't parse
( < ) = (fun x y -> Prims.op_LessThan x y);
( <= ) = (fun x y -> Prims.op_LessThanOrEqual x y);
( > ) = (fun x y -> Prims.op_GreaterThan x y);
( >= ) = (fun x y -> Prims.op_GreaterThanOrEqual x y);
( % ) = (fun x y -> Prims.op_Modulus x y);
( / ) = (fun x y -> Prims.op_Division x y);
properties = ()
}
//with an instance for nat this works
Expand All @@ -202,7 +217,10 @@ instance bounded_int_pos : bounded_int pos = {
op_Subtraction = (fun x y -> Prims.op_Subtraction x y); //can't write ( - ), it doesn't parse
( < ) = (fun x y -> Prims.op_LessThan x y);
( <= ) = (fun x y -> Prims.op_LessThanOrEqual x y);
( > ) = (fun x y -> Prims.op_GreaterThan x y);
( >= ) = (fun x y -> Prims.op_GreaterThanOrEqual x y);
( % ) = (fun x y -> Prims.op_Modulus x y);
( / ) = (fun x y -> Prims.op_Division x y);
properties = ()
}

Expand All @@ -218,7 +236,10 @@ instance bounded_int_size_t : bounded_int FStar.SizeT.t = {
op_Subtraction = (fun x y -> FStar.SizeT.sub x y);
( < ) = (fun x y -> FStar.SizeT.(x <^ y));
( <= ) = (fun x y -> FStar.SizeT.(x <=^ y));
( > ) = (fun x y -> FStar.SizeT.(x >^ y));
( >= ) = (fun x y -> FStar.SizeT.(x >=^ y));
( % ) = (fun x y -> FStar.SizeT.(x %^ y));
( / ) = (fun x y -> FStar.SizeT.(x `div` y));
properties = ();
}

Expand Down
113 changes: 113 additions & 0 deletions share/pulse/examples/PulseExample.BinarySearch.fst
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module PulseExample.BinarySearch
#lang-pulse
open Pulse.Lib.Pervasives
module A = Pulse.Lib.Array
module R = Pulse.Lib.Reference
module SZ = FStar.SizeT
module Seq = FStar.Seq
open FStar.Order
open Pulse.Lib.BoundedIntegers

let flip_order (o:order) : order =
match o with
| Lt -> Gt
| Eq -> Eq
| Gt -> Lt

class total_order (a:Type) = {
compare: a -> a -> order;
properties : squash (
(forall (x y:a). {:pattern compare x y} eq (compare x y) <==> x == y) /\
(forall (x y z:a). {:pattern compare x y; compare y z} lt (compare x y) /\ lt (compare y z) ==> lt (compare x z)) /\
(forall (x y:a). {:pattern compare x y} compare x y == flip_order (compare y x))
)
}

let ( <? ) (#t:Type) {| o:total_order t |} (x:t) (y:t) : bool = lt (o.compare x y)
let ( <=? ) (#t:Type) {| o:total_order t |} (x:t) (y:t) : bool = le (o.compare x y)

fn binary_search
(#t:Type)
{| total_order t |}
(a:A.array t)
(key:t)
(len:SZ.t)
(#s:erased (Seq.seq t) { Seq.length s == SZ.v len })
requires
A.pts_to a s **
pure ((forall (i j: SZ.t).
i <= j /\
j < len ==>
Seq.index s (SZ.v i) <=? Seq.index s (SZ.v j)) /\
(exists (k:SZ.t).
k < len /\
Seq.index s (SZ.v k) == key))
returns k:SZ.t
ensures
A.pts_to a s **
pure (k < len /\ eq (compare (Seq.index s (SZ.v k)) key))
{
let mut i1 : SZ.t = 0sz;
let mut i2 : SZ.t = len - 1sz;
while (
let v1 = !i1;
let v2 = !i2;
(v1 <> v2)
)
invariant b . (
exists* v1 v2.
pts_to i1 v1 **
pts_to i2 v2 **
A.pts_to a s **
pure (
(b == (v1 <> v2)) /\
v2 < len /\
(exists (i:SZ.t). {:pattern (Seq.index s (SZ.v i))} v1 <= i /\ i <= v2 /\ Seq.index s (SZ.v i) == key) /\
(forall (i j: SZ.t). {:pattern Seq.index s (SZ.v i); Seq.index s (SZ.v j)}
i <= j /\
j < len ==>
Seq.index s (SZ.v i) <=? Seq.index s (SZ.v j)))
)
{
let v1 = !i1;
let v2 = !i2;
let ix = v1 + (v2 - v1) / 2sz;
let a_ix = a.(ix);
if (a_ix <? key)
{
i1 := ix + 1sz;
}
else
{
i2 := ix;
}
};
!i1
}

instance total_order_int : total_order int = {
compare = compare_int;
properties = ()
}

fn binary_search_int
(a:A.array int)
(key:int)
(len:SZ.t)
(#s:erased (Seq.seq int) { Seq.length s == SZ.v len })
requires
A.pts_to a s **
pure ((forall (i j: SZ.t).
i <= j /\
j < len ==>
Seq.index s (SZ.v i) <= Seq.index s (SZ.v j)) /\
(exists (k:SZ.t).
k < len /\
Seq.index s (SZ.v k) == key))
returns k:SZ.t
ensures
A.pts_to a s **
pure (k < len /\ Seq.index s (SZ.v k) == key)
{
binary_search a key len
}

0 comments on commit 2fb5240

Please sign in to comment.