2 * Copyright 2008-2013 NVIDIA Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include <thrust/detail/config.h>
20 #include <thrust/detail/raw_pointer_cast.h>
21 #include <thrust/detail/allocator/allocator_traits.h>
22 #include <thrust/detail/allocator/temporary_allocator.h>
23 #include <thrust/pair.h>
38 template<typename DerivedPolicy, template<typename> class BasePolicy>
39 class cached_temporary_allocator
40 : public BasePolicy<cached_temporary_allocator<DerivedPolicy,BasePolicy> >
43 typedef thrust::detail::temporary_allocator<char,DerivedPolicy> base_allocator_type;
44 typedef thrust::detail::allocator_traits<base_allocator_type> traits;
45 typedef typename traits::pointer allocator_pointer;
46 typedef std::multimap<std::ptrdiff_t, void*> free_blocks_type;
47 typedef std::map<void *, std::ptrdiff_t> allocated_blocks_type;
49 base_allocator_type m_base_allocator;
50 free_blocks_type free_blocks;
51 allocated_blocks_type allocated_blocks;
55 // deallocate all outstanding blocks in both lists
56 for(free_blocks_type::iterator i = free_blocks.begin();
57 i != free_blocks.end();
60 // transform the pointer to allocator_pointer before calling deallocate
61 traits::deallocate(m_base_allocator, allocator_pointer(reinterpret_cast<char*>(i->second)), i->first);
64 for(allocated_blocks_type::iterator i = allocated_blocks.begin();
65 i != allocated_blocks.end();
68 // transform the pointer to allocator_pointer before calling deallocate
69 traits::deallocate(m_base_allocator, allocator_pointer(reinterpret_cast<char*>(i->first)), i->second);
74 cached_temporary_allocator(thrust::execution_policy<DerivedPolicy> &system)
75 : m_base_allocator(system)
78 ~cached_temporary_allocator()
80 // free all allocations when cached_allocator goes out of scope
84 void *allocate(std::ptrdiff_t num_bytes)
88 // search the cache for a free block
89 free_blocks_type::iterator free_block = free_blocks.find(num_bytes);
91 if(free_block != free_blocks.end())
94 result = free_block->second;
96 // erase from the free_blocks map
97 free_blocks.erase(free_block);
101 // no allocation of the right size exists
102 // create a new one with m_base_allocator
103 // allocate memory and convert to raw pointer
104 result = thrust::raw_pointer_cast(traits::allocate(m_base_allocator, num_bytes));
107 // insert the allocated pointer into the allocated_blocks map
108 allocated_blocks.insert(std::make_pair(result, num_bytes));
113 void deallocate(void *ptr)
115 // erase the allocated block from the allocated blocks map
116 allocated_blocks_type::iterator iter = allocated_blocks.find(ptr);
117 std::ptrdiff_t num_bytes = iter->second;
118 allocated_blocks.erase(iter);
120 // insert the block into the free blocks map
121 free_blocks.insert(std::make_pair(num_bytes, ptr));
126 // overload get_temporary_buffer on cached_temporary_allocator
127 // note that we take a reference to cached_temporary_allocator
128 template<typename T, typename DerivedPolicy, template<typename> class BasePolicy>
129 thrust::pair<T*, std::ptrdiff_t>
130 get_temporary_buffer(cached_temporary_allocator<DerivedPolicy,BasePolicy> &alloc, std::ptrdiff_t n)
132 // ask the allocator for sizeof(T) * n bytes
133 T* result = reinterpret_cast<T*>(alloc.allocate(sizeof(T) * n));
135 // return the pointer and the number of elements allocated
136 return thrust::make_pair(result,n);
140 // overload return_temporary_buffer on cached_temporary_allocator
141 // an overloaded return_temporary_buffer should always accompany
142 // an overloaded get_temporary_buffer
143 template<typename Pointer, typename DerivedPolicy, template<typename> class BasePolicy>
144 void return_temporary_buffer(cached_temporary_allocator<DerivedPolicy,BasePolicy> &alloc, Pointer p)
146 // return the pointer to the allocator
147 alloc.deallocate(thrust::raw_pointer_cast(p));