OSDN Git Service

Licensing changes to GPLv3 resp. GPLv3 with GCC Runtime Exception.
[pf3gnuchains/gcc-fork.git] / libstdc++-v3 / include / parallel / balanced_quicksort.h
1 // -*- C++ -*-
2
3 // Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library.  This library is free
6 // software; you can redistribute it and/or modify it under the terms
7 // of the GNU General Public License as published by the Free Software
8 // Foundation; either version 3, or (at your option) any later
9 // version.
10
11 // This library is distributed in the hope that it will be useful, but
12 // WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 // General Public License for more details.
15
16 // Under Section 7 of GPL version 3, you are granted additional
17 // permissions described in the GCC Runtime Library Exception, version
18 // 3.1, as published by the Free Software Foundation.
19
20 // You should have received a copy of the GNU General Public License and
21 // a copy of the GCC Runtime Library Exception along with this program;
22 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
23 // <http://www.gnu.org/licenses/>.
24
25 /** @file parallel/balanced_quicksort.h
26  *  @brief Implementation of a dynamically load-balanced parallel quicksort.
27  *
28  *  It works in-place and needs only logarithmic extra memory.
29  *  The algorithm is similar to the one proposed in
30  *
31  *  P. Tsigas and Y. Zhang.
32  *  A simple, fast parallel implementation of quicksort and
33  *  its performance evaluation on SUN enterprise 10000.
34  *  In 11th Euromicro Conference on Parallel, Distributed and
35  *  Network-Based Processing, page 372, 2003.
36  *
37  *  This file is a GNU parallel extension to the Standard C++ Library.
38  */
39
40 // Written by Johannes Singler.
41
42 #ifndef _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H
43 #define _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H 1
44
45 #include <parallel/basic_iterator.h>
46 #include <bits/stl_algo.h>
47
48 #include <parallel/settings.h>
49 #include <parallel/partition.h>
50 #include <parallel/random_number.h>
51 #include <parallel/queue.h>
52 #include <functional>
53
54 #if _GLIBCXX_ASSERTIONS
55 #include <parallel/checkers.h>
56 #endif
57
58 namespace __gnu_parallel
59 {
60 /** @brief Information local to one thread in the parallel quicksort run. */
61 template<typename RandomAccessIterator>
62   struct QSBThreadLocal
63   {
64     typedef std::iterator_traits<RandomAccessIterator> traits_type;
65     typedef typename traits_type::difference_type difference_type;
66
67     /** @brief Continuous part of the sequence, described by an
68     iterator pair. */
69     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
70
71     /** @brief Initial piece to work on. */
72     Piece initial;
73
74     /** @brief Work-stealing queue. */
75     RestrictedBoundedConcurrentQueue<Piece> leftover_parts;
76
77     /** @brief Number of threads involved in this algorithm. */
78     thread_index_t num_threads;
79
80     /** @brief Pointer to a counter of elements left over to sort. */
81     volatile difference_type* elements_leftover;
82
83     /** @brief The complete sequence to sort. */
84     Piece global;
85
86     /** @brief Constructor.
87      *  @param queue_size Size of the work-stealing queue. */
88     QSBThreadLocal(int queue_size) : leftover_parts(queue_size) { }
89   };
90
91 /** @brief Balanced quicksort divide step.
92   *  @param begin Begin iterator of subsequence.
93   *  @param end End iterator of subsequence.
94   *  @param comp Comparator.
95   *  @param num_threads Number of threads that are allowed to work on
96   *  this part.
97   *  @pre @c (end-begin)>=1 */
98 template<typename RandomAccessIterator, typename Comparator>
99   typename std::iterator_traits<RandomAccessIterator>::difference_type
100   qsb_divide(RandomAccessIterator begin, RandomAccessIterator end,
101              Comparator comp, thread_index_t num_threads)
102   {
103     _GLIBCXX_PARALLEL_ASSERT(num_threads > 0);
104
105     typedef std::iterator_traits<RandomAccessIterator> traits_type;
106     typedef typename traits_type::value_type value_type;
107     typedef typename traits_type::difference_type difference_type;
108
109     RandomAccessIterator pivot_pos =
110       median_of_three_iterators(begin, begin + (end - begin) / 2,
111                                 end  - 1, comp);
112
113 #if defined(_GLIBCXX_ASSERTIONS)
114     // Must be in between somewhere.
115     difference_type n = end - begin;
116
117     _GLIBCXX_PARALLEL_ASSERT(
118            (!comp(*pivot_pos, *begin) && !comp(*(begin + n / 2), *pivot_pos))
119         || (!comp(*pivot_pos, *begin) && !comp(*(end - 1), *pivot_pos))
120         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*begin, *pivot_pos))
121         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*(end - 1), *pivot_pos))
122         || (!comp(*pivot_pos, *(end - 1)) && !comp(*begin, *pivot_pos))
123         || (!comp(*pivot_pos, *(end - 1)) && !comp(*(begin + n / 2), *pivot_pos)));
124 #endif
125
126     // Swap pivot value to end.
127     if (pivot_pos != (end - 1))
128       std::swap(*pivot_pos, *(end - 1));
129     pivot_pos = end - 1;
130
131     __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool>
132         pred(comp, *pivot_pos);
133
134     // Divide, returning end - begin - 1 in the worst case.
135     difference_type split_pos = parallel_partition(
136         begin, end - 1, pred, num_threads);
137
138     // Swap back pivot to middle.
139     std::swap(*(begin + split_pos), *pivot_pos);
140     pivot_pos = begin + split_pos;
141
142 #if _GLIBCXX_ASSERTIONS
143     RandomAccessIterator r;
144     for (r = begin; r != pivot_pos; ++r)
145       _GLIBCXX_PARALLEL_ASSERT(comp(*r, *pivot_pos));
146     for (; r != end; ++r)
147       _GLIBCXX_PARALLEL_ASSERT(!comp(*r, *pivot_pos));
148 #endif
149
150     return split_pos;
151   }
152
153 /** @brief Quicksort conquer step.
154   *  @param tls Array of thread-local storages.
155   *  @param begin Begin iterator of subsequence.
156   *  @param end End iterator of subsequence.
157   *  @param comp Comparator.
158   *  @param iam Number of the thread processing this function.
159   *  @param num_threads
160   *          Number of threads that are allowed to work on this part. */
161 template<typename RandomAccessIterator, typename Comparator>
162   void
163   qsb_conquer(QSBThreadLocal<RandomAccessIterator>** tls,
164               RandomAccessIterator begin, RandomAccessIterator end,
165               Comparator comp,
166               thread_index_t iam, thread_index_t num_threads,
167               bool parent_wait)
168   {
169     typedef std::iterator_traits<RandomAccessIterator> traits_type;
170     typedef typename traits_type::value_type value_type;
171     typedef typename traits_type::difference_type difference_type;
172
173     difference_type n = end - begin;
174
175     if (num_threads <= 1 || n <= 1)
176       {
177         tls[iam]->initial.first  = begin;
178         tls[iam]->initial.second = end;
179
180         qsb_local_sort_with_helping(tls, comp, iam, parent_wait);
181
182         return;
183       }
184
185     // Divide step.
186     difference_type split_pos = qsb_divide(begin, end, comp, num_threads);
187
188 #if _GLIBCXX_ASSERTIONS
189     _GLIBCXX_PARALLEL_ASSERT(0 <= split_pos && split_pos < (end - begin));
190 #endif
191
192     thread_index_t num_threads_leftside =
193         std::max<thread_index_t>(1, std::min<thread_index_t>(
194                           num_threads - 1, split_pos * num_threads / n));
195
196 #   pragma omp atomic
197     *tls[iam]->elements_leftover -= (difference_type)1;
198
199     // Conquer step.
200 #   pragma omp parallel num_threads(2)
201     {
202       bool wait;
203       if(omp_get_num_threads() < 2)
204         wait = false;
205       else
206         wait = parent_wait;
207
208 #     pragma omp sections
209         {
210 #         pragma omp section
211             {
212               qsb_conquer(tls, begin, begin + split_pos, comp,
213                           iam,
214                           num_threads_leftside,
215                           wait);
216               wait = parent_wait;
217             }
218           // The pivot_pos is left in place, to ensure termination.
219 #         pragma omp section
220             {
221               qsb_conquer(tls, begin + split_pos + 1, end, comp,
222                           iam + num_threads_leftside,
223                           num_threads - num_threads_leftside,
224                           wait);
225               wait = parent_wait;
226             }
227         }
228     }
229   }
230
231 /**
232   *  @brief Quicksort step doing load-balanced local sort.
233   *  @param tls Array of thread-local storages.
234   *  @param comp Comparator.
235   *  @param iam Number of the thread processing this function.
236   */
237 template<typename RandomAccessIterator, typename Comparator>
238   void
239   qsb_local_sort_with_helping(QSBThreadLocal<RandomAccessIterator>** tls,
240                               Comparator& comp, int iam, bool wait)
241   {
242     typedef std::iterator_traits<RandomAccessIterator> traits_type;
243     typedef typename traits_type::value_type value_type;
244     typedef typename traits_type::difference_type difference_type;
245     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
246
247     QSBThreadLocal<RandomAccessIterator>& tl = *tls[iam];
248
249     difference_type base_case_n =
250         _Settings::get().sort_qsb_base_case_maximal_n;
251     if (base_case_n < 2)
252       base_case_n = 2;
253     thread_index_t num_threads = tl.num_threads;
254
255     // Every thread has its own random number generator.
256     random_number rng(iam + 1);
257
258     Piece current = tl.initial;
259
260     difference_type elements_done = 0;
261 #if _GLIBCXX_ASSERTIONS
262     difference_type total_elements_done = 0;
263 #endif
264
265     for (;;)
266       {
267         // Invariant: current must be a valid (maybe empty) range.
268         RandomAccessIterator begin = current.first, end = current.second;
269         difference_type n = end - begin;
270
271         if (n > base_case_n)
272           {
273             // Divide.
274             RandomAccessIterator pivot_pos = begin +  rng(n);
275
276             // Swap pivot_pos value to end.
277             if (pivot_pos != (end - 1))
278               std::swap(*pivot_pos, *(end - 1));
279             pivot_pos = end - 1;
280
281             __gnu_parallel::binder2nd
282                 <Comparator, value_type, value_type, bool>
283                 pred(comp, *pivot_pos);
284
285             // Divide, leave pivot unchanged in last place.
286             RandomAccessIterator split_pos1, split_pos2;
287             split_pos1 = __gnu_sequential::partition(begin, end - 1, pred);
288
289             // Left side: < pivot_pos; right side: >= pivot_pos.
290 #if _GLIBCXX_ASSERTIONS
291             _GLIBCXX_PARALLEL_ASSERT(begin <= split_pos1 && split_pos1 < end);
292 #endif
293             // Swap pivot back to middle.
294             if (split_pos1 != pivot_pos)
295               std::swap(*split_pos1, *pivot_pos);
296             pivot_pos = split_pos1;
297
298             // In case all elements are equal, split_pos1 == 0.
299             if ((split_pos1 + 1 - begin) < (n >> 7)
300             || (end - split_pos1) < (n >> 7))
301               {
302                 // Very unequal split, one part smaller than one 128th
303                 // elements not strictly larger than the pivot.
304                 __gnu_parallel::unary_negate<__gnu_parallel::binder1st
305                   <Comparator, value_type, value_type, bool>, value_type>
306                   pred(__gnu_parallel::binder1st
307                        <Comparator, value_type, value_type, bool>(comp,
308                                                                   *pivot_pos));
309
310                 // Find other end of pivot-equal range.
311                 split_pos2 = __gnu_sequential::partition(split_pos1 + 1,
312                                                          end, pred);
313               }
314             else
315               // Only skip the pivot.
316               split_pos2 = split_pos1 + 1;
317
318             // Elements equal to pivot are done.
319             elements_done += (split_pos2 - split_pos1);
320 #if _GLIBCXX_ASSERTIONS
321             total_elements_done += (split_pos2 - split_pos1);
322 #endif
323             // Always push larger part onto stack.
324             if (((split_pos1 + 1) - begin) < (end - (split_pos2)))
325               {
326                 // Right side larger.
327                 if ((split_pos2) != end)
328                   tl.leftover_parts.push_front(std::make_pair(split_pos2,
329                                                               end));
330
331                 //current.first = begin;        //already set anyway
332                 current.second = split_pos1;
333                 continue;
334               }
335             else
336               {
337                 // Left side larger.
338                 if (begin != split_pos1)
339                   tl.leftover_parts.push_front(std::make_pair(begin,
340                                                               split_pos1));
341
342                 current.first = split_pos2;
343                 //current.second = end; //already set anyway
344                 continue;
345               }
346           }
347         else
348           {
349             __gnu_sequential::sort(begin, end, comp);
350             elements_done += n;
351 #if _GLIBCXX_ASSERTIONS
352             total_elements_done += n;
353 #endif
354
355             // Prefer own stack, small pieces.
356             if (tl.leftover_parts.pop_front(current))
357               continue;
358
359 #           pragma omp atomic
360             *tl.elements_leftover -= elements_done;
361
362             elements_done = 0;
363
364 #if _GLIBCXX_ASSERTIONS
365             double search_start = omp_get_wtime();
366 #endif
367
368             // Look for new work.
369             bool successfully_stolen = false;
370             while (wait && *tl.elements_leftover > 0 && !successfully_stolen
371 #if _GLIBCXX_ASSERTIONS
372               // Possible dead-lock.
373               && (omp_get_wtime() < (search_start + 1.0))
374 #endif
375               )
376               {
377                 thread_index_t victim;
378                 victim = rng(num_threads);
379
380                 // Large pieces.
381                 successfully_stolen = (victim != iam)
382                     && tls[victim]->leftover_parts.pop_back(current);
383                 if (!successfully_stolen)
384                   yield();
385 #if !defined(__ICC) && !defined(__ECC)
386 #               pragma omp flush
387 #endif
388               }
389
390 #if _GLIBCXX_ASSERTIONS
391             if (omp_get_wtime() >= (search_start + 1.0))
392               {
393                 sleep(1);
394                 _GLIBCXX_PARALLEL_ASSERT(omp_get_wtime()
395                                          < (search_start + 1.0));
396               }
397 #endif
398             if (!successfully_stolen)
399               {
400 #if _GLIBCXX_ASSERTIONS
401                 _GLIBCXX_PARALLEL_ASSERT(*tl.elements_leftover == 0);
402 #endif
403                 return;
404               }
405           }
406       }
407   }
408
409 /** @brief Top-level quicksort routine.
410   *  @param begin Begin iterator of sequence.
411   *  @param end End iterator of sequence.
412   *  @param comp Comparator.
413   *  @param num_threads Number of threads that are allowed to work on
414   *  this part.
415   */
416 template<typename RandomAccessIterator, typename Comparator>
417   void
418   parallel_sort_qsb(RandomAccessIterator begin, RandomAccessIterator end,
419                     Comparator comp,
420                     thread_index_t num_threads)
421   {
422     _GLIBCXX_CALL(end - begin)
423
424     typedef std::iterator_traits<RandomAccessIterator> traits_type;
425     typedef typename traits_type::value_type value_type;
426     typedef typename traits_type::difference_type difference_type;
427     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
428
429     typedef QSBThreadLocal<RandomAccessIterator> tls_type;
430
431     difference_type n = end - begin;
432
433     if (n <= 1)
434       return;
435
436     // At least one element per processor.
437     if (num_threads > n)
438       num_threads = static_cast<thread_index_t>(n);
439
440     // Initialize thread local storage
441     tls_type** tls = new tls_type*[num_threads];
442     difference_type queue_size = num_threads * (thread_index_t)(log2(n) + 1);
443     for (thread_index_t t = 0; t < num_threads; ++t)
444       tls[t] = new QSBThreadLocal<RandomAccessIterator>(queue_size);
445
446     // There can never be more than ceil(log2(n)) ranges on the stack, because
447     // 1. Only one processor pushes onto the stack
448     // 2. The largest range has at most length n
449     // 3. Each range is larger than half of the range remaining
450     volatile difference_type elements_leftover = n;
451     for (int i = 0; i < num_threads; ++i)
452       {
453         tls[i]->elements_leftover = &elements_leftover;
454         tls[i]->num_threads = num_threads;
455         tls[i]->global = std::make_pair(begin, end);
456
457         // Just in case nothing is left to assign.
458         tls[i]->initial = std::make_pair(end, end);
459       }
460
461     // Main recursion call.
462     qsb_conquer(tls, begin, begin + n, comp, 0, num_threads, true);
463
464 #if _GLIBCXX_ASSERTIONS
465     // All stack must be empty.
466     Piece dummy;
467     for (int i = 1; i < num_threads; ++i)
468       _GLIBCXX_PARALLEL_ASSERT(!tls[i]->leftover_parts.pop_back(dummy));
469 #endif
470
471     for (int i = 0; i < num_threads; ++i)
472       delete tls[i];
473     delete[] tls;
474   }
475 } // namespace __gnu_parallel
476
477 #endif /* _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H */