OSDN Git Service

2011-11-21 Arnaud Charlet <charlet@adacore.com>
[pf3gnuchains/gcc-fork.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006, 2007, 2009 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
11
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 GNU General Public License for more details.
16
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
20
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
24 <http://www.gnu.org/licenses/>.  */
25
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>'
30
31 include(iparm.m4)dnl
32
33 `#if defined (HAVE_'rtype_name`)
34
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36    passed to us by the front-end, in which case we''`ll call it for large
37    matrices.  */
38
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40                           const int *, const 'rtype_name` *, const 'rtype_name` *,
41                           const int *, const 'rtype_name` *, const int *,
42                           const 'rtype_name` *, 'rtype_name` *, const int *,
43                           int, int);
44
45 /* The order of loops is different in the case of plain matrix
46    multiplication C=MATMUL(A,B), and in the frequent special case where
47    the argument A is the temporary result of a TRANSPOSE intrinsic:
48    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
49    looking at their strides.
50
51    The equivalent Fortran pseudo-code is:
52
53    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
54    IF (.NOT.IS_TRANSPOSED(A)) THEN
55      C = 0
56      DO J=1,N
57        DO K=1,COUNT
58          DO I=1,M
59            C(I,J) = C(I,J)+A(I,K)*B(K,J)
60    ELSE
61      DO J=1,N
62        DO I=1,M
63          S = 0
64          DO K=1,COUNT
65            S = S+A(I,K)*B(K,J)
66          C(I,J) = S
67    ENDIF
68 */
69
70 /* If try_blas is set to a nonzero value, then the matmul function will
71    see if there is a way to perform the matrix multiplication by a call
72    to the BLAS gemm function.  */
73
74 extern void matmul_'rtype_code` ('rtype` * const restrict retarray, 
75         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
76         int blas_limit, blas_call gemm);
77 export_proto(matmul_'rtype_code`);
78
79 void
80 matmul_'rtype_code` ('rtype` * const restrict retarray, 
81         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
82         int blas_limit, blas_call gemm)
83 {
84   const 'rtype_name` * restrict abase;
85   const 'rtype_name` * restrict bbase;
86   'rtype_name` * restrict dest;
87
88   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
89   index_type x, y, n, count, xcount, ycount;
90
91   assert (GFC_DESCRIPTOR_RANK (a) == 2
92           || GFC_DESCRIPTOR_RANK (b) == 2);
93
94 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
95
96    Either A or B (but not both) can be rank 1:
97
98    o One-dimensional argument A is implicitly treated as a row matrix
99      dimensioned [1,count], so xcount=1.
100
101    o One-dimensional argument B is implicitly treated as a column matrix
102      dimensioned [count, 1], so ycount=1.
103   */
104
105   if (retarray->data == NULL)
106     {
107       if (GFC_DESCRIPTOR_RANK (a) == 1)
108         {
109           GFC_DIMENSION_SET(retarray->dim[0], 0,
110                             GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
111         }
112       else if (GFC_DESCRIPTOR_RANK (b) == 1)
113         {
114           GFC_DIMENSION_SET(retarray->dim[0], 0,
115                             GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
116         }
117       else
118         {
119           GFC_DIMENSION_SET(retarray->dim[0], 0,
120                             GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
121
122           GFC_DIMENSION_SET(retarray->dim[1], 0,
123                             GFC_DESCRIPTOR_EXTENT(b,1) - 1,
124                             GFC_DESCRIPTOR_EXTENT(retarray,0));
125         }
126
127       retarray->data
128         = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
129       retarray->offset = 0;
130     }
131     else if (unlikely (compile_options.bounds_check))
132       {
133         index_type ret_extent, arg_extent;
134
135         if (GFC_DESCRIPTOR_RANK (a) == 1)
136           {
137             arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
138             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
139             if (arg_extent != ret_extent)
140               runtime_error ("Incorrect extent in return array in"
141                              " MATMUL intrinsic: is %ld, should be %ld",
142                              (long int) ret_extent, (long int) arg_extent);
143           }
144         else if (GFC_DESCRIPTOR_RANK (b) == 1)
145           {
146             arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
147             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
148             if (arg_extent != ret_extent)
149               runtime_error ("Incorrect extent in return array in"
150                              " MATMUL intrinsic: is %ld, should be %ld",
151                              (long int) ret_extent, (long int) arg_extent);         
152           }
153         else
154           {
155             arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
156             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
157             if (arg_extent != ret_extent)
158               runtime_error ("Incorrect extent in return array in"
159                              " MATMUL intrinsic for dimension 1:"
160                              " is %ld, should be %ld",
161                              (long int) ret_extent, (long int) arg_extent);
162
163             arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
164             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
165             if (arg_extent != ret_extent)
166               runtime_error ("Incorrect extent in return array in"
167                              " MATMUL intrinsic for dimension 2:"
168                              " is %ld, should be %ld",
169                              (long int) ret_extent, (long int) arg_extent);
170           }
171       }
172 '
173 sinclude(`matmul_asm_'rtype_code`.m4')dnl
174 `
175   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
176     {
177       /* One-dimensional result may be addressed in the code below
178          either as a row or a column matrix. We want both cases to
179          work. */
180       rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
181     }
182   else
183     {
184       rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
185       rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
186     }
187
188
189   if (GFC_DESCRIPTOR_RANK (a) == 1)
190     {
191       /* Treat it as a a row matrix A[1,count]. */
192       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
193       aystride = 1;
194
195       xcount = 1;
196       count = GFC_DESCRIPTOR_EXTENT(a,0);
197     }
198   else
199     {
200       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
201       aystride = GFC_DESCRIPTOR_STRIDE(a,1);
202
203       count = GFC_DESCRIPTOR_EXTENT(a,1);
204       xcount = GFC_DESCRIPTOR_EXTENT(a,0);
205     }
206
207   if (count != GFC_DESCRIPTOR_EXTENT(b,0))
208     {
209       if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
210         runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
211     }
212
213   if (GFC_DESCRIPTOR_RANK (b) == 1)
214     {
215       /* Treat it as a column matrix B[count,1] */
216       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
217
218       /* bystride should never be used for 1-dimensional b.
219          in case it is we want it to cause a segfault, rather than
220          an incorrect result. */
221       bystride = 0xDEADBEEF;
222       ycount = 1;
223     }
224   else
225     {
226       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
227       bystride = GFC_DESCRIPTOR_STRIDE(b,1);
228       ycount = GFC_DESCRIPTOR_EXTENT(b,1);
229     }
230
231   abase = a->data;
232   bbase = b->data;
233   dest = retarray->data;
234
235
236   /* Now that everything is set up, we''`re performing the multiplication
237      itself.  */
238
239 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
240
241   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
242       && (bxstride == 1 || bystride == 1)
243       && (((float) xcount) * ((float) ycount) * ((float) count)
244           > POW3(blas_limit)))
245   {
246     const int m = xcount, n = ycount, k = count, ldc = rystride;
247     const 'rtype_name` one = 1, zero = 0;
248     const int lda = (axstride == 1) ? aystride : axstride,
249               ldb = (bxstride == 1) ? bystride : bxstride;
250
251     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
252       {
253         assert (gemm != NULL);
254         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
255               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
256         return;
257       }
258   }
259
260   if (rxstride == 1 && axstride == 1 && bxstride == 1)
261     {
262       const 'rtype_name` * restrict bbase_y;
263       'rtype_name` * restrict dest_y;
264       const 'rtype_name` * restrict abase_n;
265       'rtype_name` bbase_yn;
266
267       if (rystride == xcount)
268         memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
269       else
270         {
271           for (y = 0; y < ycount; y++)
272             for (x = 0; x < xcount; x++)
273               dest[x + y*rystride] = ('rtype_name`)0;
274         }
275
276       for (y = 0; y < ycount; y++)
277         {
278           bbase_y = bbase + y*bystride;
279           dest_y = dest + y*rystride;
280           for (n = 0; n < count; n++)
281             {
282               abase_n = abase + n*aystride;
283               bbase_yn = bbase_y[n];
284               for (x = 0; x < xcount; x++)
285                 {
286                   dest_y[x] += abase_n[x] * bbase_yn;
287                 }
288             }
289         }
290     }
291   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
292     {
293       if (GFC_DESCRIPTOR_RANK (a) != 1)
294         {
295           const 'rtype_name` *restrict abase_x;
296           const 'rtype_name` *restrict bbase_y;
297           'rtype_name` *restrict dest_y;
298           'rtype_name` s;
299
300           for (y = 0; y < ycount; y++)
301             {
302               bbase_y = &bbase[y*bystride];
303               dest_y = &dest[y*rystride];
304               for (x = 0; x < xcount; x++)
305                 {
306                   abase_x = &abase[x*axstride];
307                   s = ('rtype_name`) 0;
308                   for (n = 0; n < count; n++)
309                     s += abase_x[n] * bbase_y[n];
310                   dest_y[x] = s;
311                 }
312             }
313         }
314       else
315         {
316           const 'rtype_name` *restrict bbase_y;
317           'rtype_name` s;
318
319           for (y = 0; y < ycount; y++)
320             {
321               bbase_y = &bbase[y*bystride];
322               s = ('rtype_name`) 0;
323               for (n = 0; n < count; n++)
324                 s += abase[n*axstride] * bbase_y[n];
325               dest[y*rystride] = s;
326             }
327         }
328     }
329   else if (axstride < aystride)
330     {
331       for (y = 0; y < ycount; y++)
332         for (x = 0; x < xcount; x++)
333           dest[x*rxstride + y*rystride] = ('rtype_name`)0;
334
335       for (y = 0; y < ycount; y++)
336         for (n = 0; n < count; n++)
337           for (x = 0; x < xcount; x++)
338             /* dest[x,y] += a[x,n] * b[n,y] */
339             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
340     }
341   else if (GFC_DESCRIPTOR_RANK (a) == 1)
342     {
343       const 'rtype_name` *restrict bbase_y;
344       'rtype_name` s;
345
346       for (y = 0; y < ycount; y++)
347         {
348           bbase_y = &bbase[y*bystride];
349           s = ('rtype_name`) 0;
350           for (n = 0; n < count; n++)
351             s += abase[n*axstride] * bbase_y[n*bxstride];
352           dest[y*rxstride] = s;
353         }
354     }
355   else
356     {
357       const 'rtype_name` *restrict abase_x;
358       const 'rtype_name` *restrict bbase_y;
359       'rtype_name` *restrict dest_y;
360       'rtype_name` s;
361
362       for (y = 0; y < ycount; y++)
363         {
364           bbase_y = &bbase[y*bystride];
365           dest_y = &dest[y*rystride];
366           for (x = 0; x < xcount; x++)
367             {
368               abase_x = &abase[x*axstride];
369               s = ('rtype_name`) 0;
370               for (n = 0; n < count; n++)
371                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
372               dest_y[x*rxstride] = s;
373             }
374         }
375     }
376 }
377
378 #endif'