OSDN Git Service

2009-09-23 Johannes Singler <singler@ira.uka.de>
[pf3gnuchains/gcc-fork.git] / libstdc++-v3 / include / parallel / multiway_mergesort.h
1 // -*- C++ -*-
2
3 // Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library.  This library is free
6 // software; you can redistribute it and/or modify it under the terms
7 // of the GNU General Public License as published by the Free Software
8 // Foundation; either version 3, or (at your option) any later
9 // version.
10
11 // This library is distributed in the hope that it will be useful, but
12 // WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 // General Public License for more details.
15
16 // Under Section 7 of GPL version 3, you are granted additional
17 // permissions described in the GCC Runtime Library Exception, version
18 // 3.1, as published by the Free Software Foundation.
19
20 // You should have received a copy of the GNU General Public License and
21 // a copy of the GCC Runtime Library Exception along with this program;
22 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
23 // <http://www.gnu.org/licenses/>.
24
25 /** @file parallel/multiway_mergesort.h
26  *  @brief Parallel multiway merge sort.
27  *  This file is a GNU parallel extension to the Standard C++ Library.
28  */
29
30 // Written by Johannes Singler.
31
32 #ifndef _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H
33 #define _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H 1
34
35 #include <vector>
36
37 #include <parallel/basic_iterator.h>
38 #include <bits/stl_algo.h>
39 #include <parallel/parallel.h>
40 #include <parallel/multiway_merge.h>
41
42 namespace __gnu_parallel
43 {
44
45 /** @brief Subsequence description. */
46 template<typename _DifferenceTp>
47   struct _Piece
48   {
49     typedef _DifferenceTp _DifferenceType;
50
51     /** @brief Begin of subsequence. */
52     _DifferenceType _M_begin;
53
54     /** @brief End of subsequence. */
55     _DifferenceType _M_end;
56   };
57
58 /** @brief Data accessed by all threads.
59   *
60   *  PMWMS = parallel multiway mergesort */
61 template<typename _RAIter>
62   struct _PMWMSSortingData
63   {
64     typedef std::iterator_traits<_RAIter> _TraitsType;
65     typedef typename _TraitsType::value_type _ValueType;
66     typedef typename _TraitsType::difference_type _DifferenceType;
67
68     /** @brief Number of threads involved. */
69     _ThreadIndex _M_num_threads;
70
71     /** @brief Input __begin. */
72     _RAIter _M_source;
73
74     /** @brief Start indices, per thread. */
75     _DifferenceType* _M_starts;
76
77     /** @brief Storage in which to sort. */
78     _ValueType** _M_temporary;
79
80     /** @brief Samples. */
81     _ValueType* _M_samples;
82
83     /** @brief Offsets to add to the found positions. */
84     _DifferenceType* _M_offsets;
85
86     /** @brief Pieces of data to merge @__c [thread][__sequence] */
87     std::vector<_Piece<_DifferenceType> >* _M_pieces;
88 };
89
90 /**
91   *  @brief Select _M_samples from a sequence.
92   *  @param __sd Pointer to algorithm data. _Result will be placed in
93   *  @__c __sd->_M_samples.
94   *  @param __num_samples Number of _M_samples to select.
95   */
96 template<typename _RAIter, typename _DifferenceTp>
97   void 
98   __determine_samples(_PMWMSSortingData<_RAIter>* __sd,
99                     _DifferenceTp __num_samples)
100   {
101     typedef std::iterator_traits<_RAIter> _TraitsType;
102     typedef typename _TraitsType::value_type _ValueType;
103     typedef _DifferenceTp _DifferenceType;
104
105     _ThreadIndex __iam = omp_get_thread_num();
106
107     _DifferenceType* __es = new _DifferenceType[__num_samples + 2];
108
109     equally_split(__sd->_M_starts[__iam + 1] - __sd->_M_starts[__iam], 
110                   __num_samples + 1, __es);
111
112     for (_DifferenceType __i = 0; __i < __num_samples; ++__i)
113       ::new(&(__sd->_M_samples[__iam * __num_samples + __i]))
114           _ValueType(__sd->_M_source[__sd->_M_starts[__iam] + __es[__i + 1]]);
115
116     delete[] __es;
117   }
118
119 /** @brief Split consistently. */
120 template<bool __exact, typename _RAIter,
121           typename _Compare, typename _SortingPlacesIterator>
122   struct _SplitConsistently
123   {
124   };
125
126 /** @brief Split by exact splitting. */
127 template<typename _RAIter, typename _Compare,
128           typename _SortingPlacesIterator>
129   struct _SplitConsistently
130     <true, _RAIter, _Compare, _SortingPlacesIterator>
131   {
132     void operator()(
133       const _ThreadIndex __iam,
134       _PMWMSSortingData<_RAIter>* __sd,
135       _Compare& __comp,
136       const typename
137         std::iterator_traits<_RAIter>::difference_type
138           __num_samples)
139       const
140   {
141 #   pragma omp barrier
142
143     std::vector<std::pair<_SortingPlacesIterator, _SortingPlacesIterator> >
144         seqs(__sd->_M_num_threads);
145     for (_ThreadIndex __s = 0; __s < __sd->_M_num_threads; __s++)
146       seqs[__s] = std::make_pair(__sd->_M_temporary[__s],
147                                  __sd->_M_temporary[__s]
148                                  + (__sd->_M_starts[__s + 1]
149                                  - __sd->_M_starts[__s]));
150
151     std::vector<_SortingPlacesIterator> _M_offsets(__sd->_M_num_threads);
152
153     // if not last thread
154     if (__iam < __sd->_M_num_threads - 1)
155       multiseq_partition(seqs.begin(), seqs.end(),
156                          __sd->_M_starts[__iam + 1], _M_offsets.begin(),
157                          __comp);
158
159     for (int __seq = 0; __seq < __sd->_M_num_threads; __seq++)
160       {
161         // for each sequence
162         if (__iam < (__sd->_M_num_threads - 1))
163           __sd->_M_pieces[__iam][__seq]._M_end
164             = _M_offsets[__seq] - seqs[__seq].first;
165         else
166           // very end of this sequence
167           __sd->_M_pieces[__iam][__seq]._M_end =
168             __sd->_M_starts[__seq + 1] - __sd->_M_starts[__seq];
169       }
170
171 #   pragma omp barrier
172
173     for (_ThreadIndex __seq = 0; __seq < __sd->_M_num_threads; __seq++)
174       {
175         // For each sequence.
176         if (__iam > 0)
177           __sd->_M_pieces[__iam][__seq]._M_begin =
178             __sd->_M_pieces[__iam - 1][__seq]._M_end;
179         else
180           // Absolute beginning.
181           __sd->_M_pieces[__iam][__seq]._M_begin = 0;
182       }
183   }   
184   };
185
186 /** @brief Split by sampling. */ 
187 template<typename _RAIter, typename _Compare,
188           typename _SortingPlacesIterator>
189   struct _SplitConsistently<false, _RAIter, _Compare,
190                              _SortingPlacesIterator>
191   {
192     void operator()(
193         const _ThreadIndex __iam,
194         _PMWMSSortingData<_RAIter>* __sd,
195         _Compare& __comp,
196         const typename
197           std::iterator_traits<_RAIter>::difference_type
198             __num_samples)
199         const
200     {
201       typedef std::iterator_traits<_RAIter> _TraitsType;
202       typedef typename _TraitsType::value_type _ValueType;
203       typedef typename _TraitsType::difference_type _DifferenceType;
204
205       __determine_samples(__sd, __num_samples);
206
207 #     pragma omp barrier
208
209 #     pragma omp single
210       __gnu_sequential::sort(__sd->_M_samples,
211                              __sd->_M_samples
212                                 + (__num_samples * __sd->_M_num_threads),
213                              __comp);
214
215 #     pragma omp barrier
216
217       for (_ThreadIndex __s = 0; __s < __sd->_M_num_threads; ++__s)
218         {
219           // For each sequence.
220           if (__num_samples * __iam > 0)
221             __sd->_M_pieces[__iam][__s]._M_begin =
222                 std::lower_bound(__sd->_M_temporary[__s],
223                     __sd->_M_temporary[__s]
224                         + (__sd->_M_starts[__s + 1] - __sd->_M_starts[__s]),
225                     __sd->_M_samples[__num_samples * __iam],
226                     __comp)
227                 - __sd->_M_temporary[__s];
228           else
229             // Absolute beginning.
230             __sd->_M_pieces[__iam][__s]._M_begin = 0;
231
232           if ((__num_samples * (__iam + 1)) <
233                          (__num_samples * __sd->_M_num_threads))
234             __sd->_M_pieces[__iam][__s]._M_end =
235                 std::lower_bound(__sd->_M_temporary[__s],
236                         __sd->_M_temporary[__s]
237                           + (__sd->_M_starts[__s + 1] - __sd->_M_starts[__s]),
238                         __sd->_M_samples[__num_samples * (__iam + 1)],
239                         __comp)
240                 - __sd->_M_temporary[__s];
241           else
242             // Absolute end.
243             __sd->_M_pieces[__iam][__s]._M_end = __sd->_M_starts[__s + 1]
244                                                  - __sd->_M_starts[__s];
245         }
246     }
247   };
248   
249 template<bool __stable, typename _RAIter, typename _Compare>
250   struct __possibly_stable_sort
251   {
252   };
253
254 template<typename _RAIter, typename _Compare>
255   struct __possibly_stable_sort<true, _RAIter, _Compare>
256   {
257     void operator()(const _RAIter& __begin,
258                      const _RAIter& __end, _Compare& __comp) const
259     {
260       __gnu_sequential::stable_sort(__begin, __end, __comp); 
261     }
262   };
263
264 template<typename _RAIter, typename _Compare>
265   struct __possibly_stable_sort<false, _RAIter, _Compare>
266   {
267     void operator()(const _RAIter __begin,
268                      const _RAIter __end, _Compare& __comp) const
269     {
270       __gnu_sequential::sort(__begin, __end, __comp); 
271     }
272   };
273
274 template<bool __stable, typename Seq_RAIter,
275           typename _RAIter, typename _Compare,
276           typename DiffType>
277   struct __possibly_stable_multiway_merge
278   {
279   };
280
281 template<typename Seq_RAIter, typename _RAIter,
282           typename _Compare, typename DiffType>
283   struct __possibly_stable_multiway_merge
284     <true, Seq_RAIter, _RAIter, _Compare,
285     DiffType>
286   {
287     void operator()(const Seq_RAIter& __seqs_begin,
288                       const Seq_RAIter& __seqs_end,
289                       const _RAIter& __target,
290                       _Compare& __comp,
291                       DiffType __length_am) const
292     {
293       stable_multiway_merge(__seqs_begin, __seqs_end, __target, __length_am,
294                             __comp, sequential_tag());
295     }
296   };
297
298 template<typename Seq_RAIter, typename _RAIter,
299           typename _Compare, typename DiffType>
300   struct __possibly_stable_multiway_merge
301     <false, Seq_RAIter, _RAIter, _Compare,
302     DiffType>
303   {
304     void operator()(const Seq_RAIter& __seqs_begin,
305                       const Seq_RAIter& __seqs_end,
306                       const _RAIter& __target,
307                       _Compare& __comp,
308                       DiffType __length_am) const
309     {
310       multiway_merge(__seqs_begin, __seqs_end, __target, __length_am, __comp,
311                        sequential_tag());
312     }
313   };
314
315 /** @brief PMWMS code executed by each thread.
316   *  @param __sd Pointer to algorithm data.
317   *  @param __comp Comparator.
318   */
319 template<bool __stable, bool __exact, typename _RAIter,
320           typename _Compare>
321   void 
322   parallel_sort_mwms_pu(_PMWMSSortingData<_RAIter>* __sd,
323                         _Compare& __comp)
324   {
325     typedef std::iterator_traits<_RAIter> _TraitsType;
326     typedef typename _TraitsType::value_type _ValueType;
327     typedef typename _TraitsType::difference_type _DifferenceType;
328
329     _ThreadIndex __iam = omp_get_thread_num();
330
331     // Length of this thread's chunk, before merging.
332     _DifferenceType __length_local
333                         = __sd->_M_starts[__iam + 1] - __sd->_M_starts[__iam];
334
335     // Sort in temporary storage, leave space for sentinel.
336
337     typedef _ValueType* _SortingPlacesIterator;
338
339     __sd->_M_temporary[__iam] =
340         static_cast<_ValueType*>(
341         ::operator new(sizeof(_ValueType) * (__length_local + 1)));
342
343     // Copy there.
344     std::uninitialized_copy(
345                 __sd->_M_source + __sd->_M_starts[__iam],
346                 __sd->_M_source + __sd->_M_starts[__iam] + __length_local,
347                 __sd->_M_temporary[__iam]);
348
349     __possibly_stable_sort<__stable, _SortingPlacesIterator, _Compare>()
350         (__sd->_M_temporary[__iam],
351          __sd->_M_temporary[__iam] + __length_local,
352          __comp);
353
354     // Invariant: locally sorted subsequence in sd->_M_temporary[__iam],
355     // __sd->_M_temporary[__iam] + __length_local.
356
357     // No barrier here: Synchronization is done by the splitting routine.
358
359     _DifferenceType __num_samples =
360         _Settings::get().sort_mwms_oversampling * __sd->_M_num_threads - 1;
361     _SplitConsistently
362       <__exact, _RAIter, _Compare, _SortingPlacesIterator>()
363         (__iam, __sd, __comp, __num_samples);
364
365     // Offset from __target __begin, __length after merging.
366     _DifferenceType __offset = 0, __length_am = 0;
367     for (_ThreadIndex __s = 0; __s < __sd->_M_num_threads; __s++)
368       {
369         __length_am += __sd->_M_pieces[__iam][__s]._M_end
370                        - __sd->_M_pieces[__iam][__s]._M_begin;
371         __offset += __sd->_M_pieces[__iam][__s]._M_begin;
372       }
373
374     typedef std::vector<
375       std::pair<_SortingPlacesIterator, _SortingPlacesIterator> >
376         _SeqVector;
377     _SeqVector seqs(__sd->_M_num_threads);
378
379     for (int __s = 0; __s < __sd->_M_num_threads; ++__s)
380       {
381         seqs[__s] =
382           std::make_pair(
383             __sd->_M_temporary[__s] + __sd->_M_pieces[__iam][__s]._M_begin,
384             __sd->_M_temporary[__s] + __sd->_M_pieces[__iam][__s]._M_end);
385       }
386
387     __possibly_stable_multiway_merge<
388         __stable,
389         typename _SeqVector::iterator,
390         _RAIter,
391         _Compare, _DifferenceType>()
392           (seqs.begin(), seqs.end(),
393            __sd->_M_source + __offset, __comp,
394            __length_am);
395
396 #   pragma omp barrier
397
398     ::operator delete(__sd->_M_temporary[__iam]);
399   }
400
401 /** @brief PMWMS main call.
402   *  @param __begin Begin iterator of sequence.
403   *  @param __end End iterator of sequence.
404   *  @param __comp Comparator.
405   *  @param __n Length of sequence.
406   *  @param __num_threads Number of threads to use.
407   */
408 template<bool __stable, bool __exact, typename _RAIter,
409            typename _Compare>
410   void
411   parallel_sort_mwms(_RAIter __begin, _RAIter __end,
412                      _Compare __comp,
413                      _ThreadIndex __num_threads)
414   {
415     _GLIBCXX_CALL(__end - __begin)
416
417     typedef std::iterator_traits<_RAIter> _TraitsType;
418     typedef typename _TraitsType::value_type _ValueType;
419     typedef typename _TraitsType::difference_type _DifferenceType;
420
421     _DifferenceType __n = __end - __begin;
422
423     if (__n <= 1)
424       return;
425
426     // at least one element per thread
427     if (__num_threads > __n)
428       __num_threads = static_cast<_ThreadIndex>(__n);
429
430     // shared variables
431     _PMWMSSortingData<_RAIter> __sd;
432     _DifferenceType* _M_starts;
433
434 #   pragma omp parallel num_threads(__num_threads)
435       {
436         __num_threads = omp_get_num_threads(); //no more threads than requested
437
438 #       pragma omp single
439           {
440             __sd._M_num_threads = __num_threads;
441             __sd._M_source = __begin;
442
443             __sd._M_temporary = new _ValueType*[__num_threads];
444
445             if (!__exact)
446               {
447                 _DifferenceType __size =
448                   (_Settings::get().sort_mwms_oversampling * __num_threads - 1)
449                         * __num_threads;
450                 __sd._M_samples = static_cast<_ValueType*>(
451                               ::operator new(__size * sizeof(_ValueType)));
452               }
453             else
454               __sd._M_samples = NULL;
455
456             __sd._M_offsets = new _DifferenceType[__num_threads - 1];
457             __sd._M_pieces
458                 = new std::vector<_Piece<_DifferenceType> >[__num_threads];
459             for (int __s = 0; __s < __num_threads; ++__s)
460               __sd._M_pieces[__s].resize(__num_threads);
461             _M_starts = __sd._M_starts
462                 = new _DifferenceType[__num_threads + 1];
463
464             _DifferenceType __chunk_length = __n / __num_threads;
465             _DifferenceType __split = __n % __num_threads;
466             _DifferenceType __pos = 0;
467             for (int __i = 0; __i < __num_threads; ++__i)
468               {
469                 _M_starts[__i] = __pos;
470                 __pos += (__i < __split)
471                          ? (__chunk_length + 1) : __chunk_length;
472               }
473             _M_starts[__num_threads] = __pos;
474           } //single
475
476         // Now sort in parallel.
477         parallel_sort_mwms_pu<__stable, __exact>(&__sd, __comp);
478       } //parallel
479
480     delete[] _M_starts;
481     delete[] __sd._M_temporary;
482
483     if (!__exact)
484       ::operator delete(__sd._M_samples);
485
486     delete[] __sd._M_offsets;
487     delete[] __sd._M_pieces;
488   }
489 } //namespace __gnu_parallel
490
491 #endif /* _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H */