diff --git a/sorts-r-us/sorts.py b/sorts-r-us/sorts.py index 7f0c001..7e208a2 100644 --- a/sorts-r-us/sorts.py +++ b/sorts-r-us/sorts.py @@ -19,7 +19,7 @@ def insertion_sort(arr, inplace=True): return a -def sedgewick_shell_gaps(num_terms=float('inf'), upper_bound=float('inf')): +def sedgewick_gaps(num_terms=float('inf'), upper_bound=float('inf')): """Generates Sedgewick's increments for Shellsort. The first 20 terms of the sequence are: @@ -72,7 +72,38 @@ def tokuda_gap(i): """ return 1 if i==0 else ((9**(i+1)>>(i<<1))-4)//5 + 1 -def shellsort(arr, inplace=True): +def tokuda_gaps(num_terms=float('inf'), upper_bound=float('inf')): + """Generates the sequence of Tokuda gaps for Shellsort. + + The first 20 terms of the sequence are: + [1, 4, 9, 20, 46, 103, 233, 525, 1182, 2660, 5985, 13467, 30301, 68178, 153401, 345152, 776591, 1747331, 3931496, 8845866, ...] + + h_i = ceil( (9*(9/4)**i-4)/5 ) for i>=0. + + If 9*(9/4)**i-4)/5 is not an integer, I believe this is the same as + h_i = ((9**(i+1)>>(i<<1))-4)//5 + 1, + and I believe the above should be non-integer valued for all i>0. + (We have to explicitly return 1 when i=0, as the above formula would return 2.) + """ + if num_terms>=1 and upper_bound>=1: + yield 1 #first term + + i=1 #i+1 = number of terms + term = 4 #2nd term + + while i+1<= num_terms and term <= upper_bound: + yield term + i += 1 + term = ((9**(i+1)>>(i<<1))-4)//5 + 1 + +def ciura_gaps(num_terms=float('inf'), upper_bound=float('inf')): + """Generates the Ciura gap sequence for Shellsort + + The sequence is [1, 4, 10, 23, 57, 132, 301, 701, 1750, h_k, ...] + """ + yield 1 + +def shellsort(arr, inplace=True, gap_sequence='Sedgewick'): """Sorts an array using Shellsort with Sedgewick's increments. With these increments, Shellsort operates in O(n^4/3) time worst case, @@ -82,7 +113,12 @@ def shellsort(arr, inplace=True): a = arr if inplace else arr.copy() #Storing the gaps in an array requires O(log(n)) extra space. #Order the gaps in descending order. - gaps = list(sedgewick_shell_gaps(upper_bound=len(a)-1))[::-1] + if gap_sequence == 'Sedgewick': + gaps = list(sedgewick_gaps(upper_bound=len(a)-1))[::-1] + elif gap_sequence == 'Tokuda': + gaps = list(tokuda_gaps(upper_bound=len(a)-1))[::-1] + else: + raise ValueError(f'Unknown gap sequence: {gap_sequence}') #Iterate through gaps h in descending order. for h in gaps: