1 module radixsort;
2 
3 import std.functional;
4 import std.traits;
5 import std.algorithm.mutation : swap;
6 
7 
8 /**
9  * Sort array using radix sort.
10  *
11  * Radix sort is very fast algorithm for sorting small values using an integral
12  * or floating point key. Sizes of 4 or 8 bytes are ideal. Bigger ones get
13  * progressively slower as radix sort have to move every value same number of
14  * times as number of bytes in the sorting key.
15  * 
16  * The result is guaranteed to be written in the array passed as the first argument.
17  */
18 T[] radixsort(alias keyFun = "a", T)(T[] arr, T[] _tmparr) {
19   assert(_tmparr.length >= arr.length, "The temporary array must be at least as long as the source array.");
20 
21   auto sourcePtr = arr.ptr;
22   auto tmparr = _tmparr[0 .. arr.length];
23 
24   alias key = unaryFun!keyFun;
25   alias keyType = ReturnType!((T t) => key(t));
26   enum byteLen = keyType.sizeof;
27 
28   static assert (isIntegral!keyType || isFloatingPoint!keyType, "radixsort can sort only integral and floating point types.");
29 
30   union FloatBits { float f; uint i; }
31   union DoubleBits { double f; ulong i; }
32 
33   static if (is(keyType == float)) {
34     alias keyAsInt = (x) => FloatBits(key(x)).i;
35   } else static if (is(keyType == double)) {
36     alias keyAsInt = (x) => DoubleBits(key(x)).i;
37   } else {
38     alias keyAsInt = key;
39   }
40 
41   int[256][byteLen] counts;
42   int[256] offsets = void;
43 
44   // count byte histograms
45   foreach (x; arr) {
46     static foreach (b; 0 .. byteLen) {{
47       uint c = (keyAsInt(x) >> (b * 8)) & 0xff;
48       counts[b][c] += 1;
49     }}
50   }
51 
52 
53   static if (isIntegral!keyType) {
54 
55     foreach (b; 0 .. byteLen) {
56       if (canSkip(counts[b], arr.length)) continue;
57 
58       // this fixes offsets for negative integral keys
59       int shift = (isSigned!keyType && b == byteLen-1) ? 128 : 0;
60       alias wrap = (i) => ((i+256) % 256);
61 
62       offsets[shift] = 0;
63       foreach (i; 1 .. 256) {
64         i = wrap(i+shift);
65         offsets[i] = counts[b][wrap(i-1)] + offsets[wrap(i-1)];
66       }
67 
68       foreach (x; arr) {
69         uint c = (keyAsInt(x) >> (b * 8)) & 0xff;
70         tmparr.ptr[offsets[c]] = x;
71         offsets[c]++;
72       }
73 
74       swap(arr, tmparr);
75     }
76 
77 
78   } else static if (isFloatingPoint!keyType) {
79 
80     // all iterations but the last one
81     foreach (b; 0 .. byteLen-1) {
82       if (canSkip(counts[b], arr.length)) continue;
83 
84       offsets[0] = 0;
85       foreach (i; 1 .. 256) {
86         offsets[i] = counts[b][i-1] + offsets[i-1];
87       }
88 
89       foreach (x; arr) {
90         uint c = (keyAsInt(x) >> (b * 8)) & 0xff;
91         tmparr.ptr[offsets[c]] = x;
92         offsets[c]++;
93       }
94 
95       swap(arr, tmparr);
96     }
97 
98     // the last iteration needs to handle negative values
99     foreach (b; byteLen-1 .. byteLen) {
100       if (canSkip(counts[b], arr.length)) continue;
101 
102       offsets[255] = 0;
103       foreach_reverse (i; 128 .. 255) {
104         offsets[i] = counts[b][i+1] + offsets[i+1];
105       }
106       offsets[0] = counts[b][128] + offsets[128];
107       foreach (i; 1 .. 128) {
108         offsets[i] = counts[b][i-1] + offsets[i-1];
109       }
110       foreach_reverse (i; 128 .. 256) {
111         offsets[i] += counts[b][i] - 1;
112       }
113 
114       foreach (x; arr) {
115         uint c = (keyAsInt(x) >> (b * 8)) & 0xff;
116         tmparr.ptr[offsets[c]] = x;
117         offsets[c] += (c >= 128) ? -1 : 1;
118       }
119 
120       swap(arr, tmparr);
121     }
122 
123   } else assert(0);
124 
125   // Sorted result is now in the memory originally pointed to by the tmparr
126   // argument. Must be copied to the source array.
127   if (arr.ptr != sourcePtr) {
128     tmparr[] = arr[];
129     return tmparr;
130   }
131 
132   return arr;
133 }
134 
135 private bool canSkip(ref int[256] cnts, ulong len) {
136   foreach (c; cnts) {
137     if (c == len) return true;
138     if (c != 0) return false;
139   }
140   return false;
141 }
142 
143 
144 
145 unittest {
146   import std.algorithm;
147   import std.random;
148   import std.range;
149   import std.conv;
150   import std.meta;
151 
152   auto rnd = Random(1337);
153 
154   static foreach (T; AliasSeq!(ulong, long, uint, int, ushort, short, ubyte, byte, float, double)) {{
155     foreach (i; 0 .. 10) {
156       T[] arr = rnd.take(100+i).map!(x => cast(T)x).array;
157       auto sorted = radixsort(arr, new T[arr.length]);
158       assert(isSorted(sorted), sorted.to!string);
159       assert(sorted.ptr == arr.ptr);
160     }
161   }}
162 
163   static foreach (T; AliasSeq!(float, double)) {{
164     foreach (i; 0 .. 10) {
165       T[] arr = iota(100+i).map!(x => rnd.uniform01!T - T(0.5)).array;
166       auto sorted = radixsort(arr, new T[arr.length]);
167       assert(isSorted(sorted), T.stringof ~ " " ~ sorted.to!string);
168       assert(sorted.ptr == arr.ptr);
169     }
170   }}
171 
172   struct S2 { short key; short val; }
173   struct D2 { double key; double val; }
174 
175   static foreach (T; AliasSeq!(S2, D2)) {{
176     foreach (i; 0 .. 10) {
177       T[] arr = rnd.take(100+i).map!(x => T(cast(typeof(T.key))x, 0)).array;
178       auto sorted = radixsort!"a.key"(arr, new T[arr.length]);
179       assert(isSorted!"a.key < b.key"(sorted), sorted.to!string);
180       assert(sorted.ptr == arr.ptr);
181     }
182   }}
183 }