2 * Copyright 2008-2013 NVIDIA Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <thrust/iterator/iterator_traits.h>
18 #include <thrust/detail/temporary_array.h>
19 #include <thrust/system/tbb/detail/execution_policy.h>
20 #include <thrust/system/detail/internal/scalar/merge.h>
21 #include <thrust/system/detail/internal/scalar/binary_search.h>
22 #include <tbb/parallel_for.h>
32 namespace merge_detail
35 template<typename InputIterator1,
36 typename InputIterator2,
37 typename OutputIterator,
38 typename StrictWeakOrdering>
41 InputIterator1 first1, last1;
42 InputIterator2 first2, last2;
43 OutputIterator result;
44 StrictWeakOrdering comp;
47 range(InputIterator1 first1, InputIterator1 last1,
48 InputIterator2 first2, InputIterator2 last2,
49 OutputIterator result,
50 StrictWeakOrdering comp,
51 size_t grain_size = 1024)
52 : first1(first1), last1(last1),
53 first2(first2), last2(last2),
54 result(result), comp(comp), grain_size(grain_size)
57 range(range& r, ::tbb::split)
58 : first1(r.first1), last1(r.last1),
59 first2(r.first2), last2(r.last2),
60 result(r.result), comp(r.comp), grain_size(r.grain_size)
62 // we can assume n1 and n2 are not both 0
63 size_t n1 = thrust::distance(first1, last1);
64 size_t n2 = thrust::distance(first2, last2);
66 InputIterator1 mid1 = first1;
67 InputIterator2 mid2 = first2;
72 mid2 = thrust::system::detail::internal::scalar::lower_bound(first2, last2, raw_reference_cast(*mid1), comp);
77 mid1 = thrust::system::detail::internal::scalar::upper_bound(first1, last1, raw_reference_cast(*mid2), comp);
80 // set first range to [first1, mid1), [first2, mid2), result
84 // set second range to [mid1, last1), [mid2, last2), result + (mid1 - first1) + (mid2 - first2)
87 result += thrust::distance(r.first1, mid1) + thrust::distance(r.first2, mid2);
90 bool empty(void) const
92 return (first1 == last1) && (first2 == last2);
95 bool is_divisible(void) const
97 return static_cast<size_t>(thrust::distance(first1, last1) + thrust::distance(first2, last2)) > grain_size;
103 template <typename Range>
104 void operator()(Range& r) const
106 thrust::system::detail::internal::scalar::merge
114 } // end namespace merge_detail
116 namespace merge_by_key_detail
119 template<typename InputIterator1,
120 typename InputIterator2,
121 typename InputIterator3,
122 typename InputIterator4,
123 typename OutputIterator1,
124 typename OutputIterator2,
125 typename StrictWeakOrdering>
128 InputIterator1 keys_first1, keys_last1;
129 InputIterator2 keys_first2, keys_last2;
130 InputIterator3 values_first1;
131 InputIterator4 values_first2;
132 OutputIterator1 keys_result;
133 OutputIterator2 values_result;
134 StrictWeakOrdering comp;
137 range(InputIterator1 keys_first1, InputIterator1 keys_last1,
138 InputIterator2 keys_first2, InputIterator2 keys_last2,
139 InputIterator3 values_first1,
140 InputIterator4 values_first2,
141 OutputIterator1 keys_result,
142 OutputIterator2 values_result,
143 StrictWeakOrdering comp,
144 size_t grain_size = 1024)
145 : keys_first1(keys_first1), keys_last1(keys_last1),
146 keys_first2(keys_first2), keys_last2(keys_last2),
147 values_first1(values_first1),
148 values_first2(values_first2),
149 keys_result(keys_result), values_result(values_result),
150 comp(comp), grain_size(grain_size)
153 range(range& r, ::tbb::split)
154 : keys_first1(r.keys_first1), keys_last1(r.keys_last1),
155 keys_first2(r.keys_first2), keys_last2(r.keys_last2),
156 values_first1(r.values_first1),
157 values_first2(r.values_first2),
158 keys_result(r.keys_result), values_result(r.values_result),
159 comp(r.comp), grain_size(r.grain_size)
161 // we can assume n1 and n2 are not both 0
162 size_t n1 = thrust::distance(keys_first1, keys_last1);
163 size_t n2 = thrust::distance(keys_first2, keys_last2);
165 InputIterator1 mid1 = keys_first1;
166 InputIterator2 mid2 = keys_first2;
171 mid2 = thrust::system::detail::internal::scalar::lower_bound(keys_first2, keys_last2, raw_reference_cast(*mid1), comp);
176 mid1 = thrust::system::detail::internal::scalar::upper_bound(keys_first1, keys_last1, raw_reference_cast(*mid2), comp);
179 // set first range to [keys_first1, mid1), [keys_first2, mid2), keys_result, values_result
183 // set second range to [mid1, keys_last1), [mid2, keys_last2), keys_result + (mid1 - keys_first1) + (mid2 - keys_first2), values_result + (mid1 - keys_first1) + (mid2 - keys_first2)
186 values_first1 += thrust::distance(r.keys_first1, mid1);
187 values_first2 += thrust::distance(r.keys_first2, mid2);
188 keys_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2);
189 values_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2);
192 bool empty(void) const
194 return (keys_first1 == keys_last1) && (keys_first2 == keys_last2);
197 bool is_divisible(void) const
199 return static_cast<size_t>(thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)) > grain_size;
205 template <typename Range>
206 void operator()(Range& r) const
208 thrust::system::detail::internal::scalar::merge_by_key
209 (r.keys_first1, r.keys_last1,
210 r.keys_first2, r.keys_last2,
219 } // end namespace merge_by_key_detail
222 template<typename DerivedPolicy,
223 typename InputIterator1,
224 typename InputIterator2,
225 typename OutputIterator,
226 typename StrictWeakOrdering>
227 OutputIterator merge(execution_policy<DerivedPolicy> &exec,
228 InputIterator1 first1,
229 InputIterator1 last1,
230 InputIterator2 first2,
231 InputIterator2 last2,
232 OutputIterator result,
233 StrictWeakOrdering comp)
235 typedef typename merge_detail::range<InputIterator1,InputIterator2,OutputIterator,StrictWeakOrdering> Range;
236 typedef merge_detail::body Body;
237 Range range(first1, last1, first2, last2, result, comp);
240 ::tbb::parallel_for(range, body);
242 thrust::advance(result, thrust::distance(first1, last1) + thrust::distance(first2, last2));
247 template <typename DerivedPolicy,
248 typename InputIterator1,
249 typename InputIterator2,
250 typename InputIterator3,
251 typename InputIterator4,
252 typename OutputIterator1,
253 typename OutputIterator2,
254 typename StrictWeakOrdering>
255 thrust::pair<OutputIterator1,OutputIterator2>
256 merge_by_key(execution_policy<DerivedPolicy> &exec,
257 InputIterator1 keys_first1,
258 InputIterator1 keys_last1,
259 InputIterator2 keys_first2,
260 InputIterator2 keys_last2,
261 InputIterator3 values_first3,
262 InputIterator4 values_first4,
263 OutputIterator1 keys_result,
264 OutputIterator2 values_result,
265 StrictWeakOrdering comp)
267 typedef typename merge_by_key_detail::range<InputIterator1,InputIterator2,InputIterator3,InputIterator4,OutputIterator1,OutputIterator2,StrictWeakOrdering> Range;
268 typedef merge_by_key_detail::body Body;
270 Range range(keys_first1, keys_last1, keys_first2, keys_last2, values_first3, values_first4, keys_result, values_result, comp);
273 ::tbb::parallel_for(range, body);
275 thrust::advance(keys_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2));
276 thrust::advance(values_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2));
278 return thrust::make_pair(keys_result,values_result);
281 } // end namespace detail
282 } // end namespace tbb
283 } // end namespace system
284 } // end namespace thrust