OSDN Git Service

gcc/ChangeLog
[pf3gnuchains/gcc-fork.git] / libgfortran / generated / matmul_c10.c
1 /* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006, 2007, 2009 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
11
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 GNU General Public License for more details.
16
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
20
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
24 <http://www.gnu.org/licenses/>.  */
25
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>
30
31
32 #if defined (HAVE_GFC_COMPLEX_10)
33
34 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
35    passed to us by the front-end, in which case we'll call it for large
36    matrices.  */
37
38 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
39                           const int *, const GFC_COMPLEX_10 *, const GFC_COMPLEX_10 *,
40                           const int *, const GFC_COMPLEX_10 *, const int *,
41                           const GFC_COMPLEX_10 *, GFC_COMPLEX_10 *, const int *,
42                           int, int);
43
44 /* The order of loops is different in the case of plain matrix
45    multiplication C=MATMUL(A,B), and in the frequent special case where
46    the argument A is the temporary result of a TRANSPOSE intrinsic:
47    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
48    looking at their strides.
49
50    The equivalent Fortran pseudo-code is:
51
52    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
53    IF (.NOT.IS_TRANSPOSED(A)) THEN
54      C = 0
55      DO J=1,N
56        DO K=1,COUNT
57          DO I=1,M
58            C(I,J) = C(I,J)+A(I,K)*B(K,J)
59    ELSE
60      DO J=1,N
61        DO I=1,M
62          S = 0
63          DO K=1,COUNT
64            S = S+A(I,K)*B(K,J)
65          C(I,J) = S
66    ENDIF
67 */
68
69 /* If try_blas is set to a nonzero value, then the matmul function will
70    see if there is a way to perform the matrix multiplication by a call
71    to the BLAS gemm function.  */
72
73 extern void matmul_c10 (gfc_array_c10 * const restrict retarray, 
74         gfc_array_c10 * const restrict a, gfc_array_c10 * const restrict b, int try_blas,
75         int blas_limit, blas_call gemm);
76 export_proto(matmul_c10);
77
78 void
79 matmul_c10 (gfc_array_c10 * const restrict retarray, 
80         gfc_array_c10 * const restrict a, gfc_array_c10 * const restrict b, int try_blas,
81         int blas_limit, blas_call gemm)
82 {
83   const GFC_COMPLEX_10 * restrict abase;
84   const GFC_COMPLEX_10 * restrict bbase;
85   GFC_COMPLEX_10 * restrict dest;
86
87   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
88   index_type x, y, n, count, xcount, ycount;
89
90   assert (GFC_DESCRIPTOR_RANK (a) == 2
91           || GFC_DESCRIPTOR_RANK (b) == 2);
92
93 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
94
95    Either A or B (but not both) can be rank 1:
96
97    o One-dimensional argument A is implicitly treated as a row matrix
98      dimensioned [1,count], so xcount=1.
99
100    o One-dimensional argument B is implicitly treated as a column matrix
101      dimensioned [count, 1], so ycount=1.
102   */
103
104   if (retarray->data == NULL)
105     {
106       if (GFC_DESCRIPTOR_RANK (a) == 1)
107         {
108           GFC_DIMENSION_SET(retarray->dim[0], 0,
109                             GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
110         }
111       else if (GFC_DESCRIPTOR_RANK (b) == 1)
112         {
113           GFC_DIMENSION_SET(retarray->dim[0], 0,
114                             GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
115         }
116       else
117         {
118           GFC_DIMENSION_SET(retarray->dim[0], 0,
119                             GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
120
121           GFC_DIMENSION_SET(retarray->dim[1], 0,
122                             GFC_DESCRIPTOR_EXTENT(b,1) - 1,
123                             GFC_DESCRIPTOR_EXTENT(retarray,0));
124         }
125
126       retarray->data
127         = internal_malloc_size (sizeof (GFC_COMPLEX_10) * size0 ((array_t *) retarray));
128       retarray->offset = 0;
129     }
130     else if (unlikely (compile_options.bounds_check))
131       {
132         index_type ret_extent, arg_extent;
133
134         if (GFC_DESCRIPTOR_RANK (a) == 1)
135           {
136             arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
137             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
138             if (arg_extent != ret_extent)
139               runtime_error ("Incorrect extent in return array in"
140                              " MATMUL intrinsic: is %ld, should be %ld",
141                              (long int) ret_extent, (long int) arg_extent);
142           }
143         else if (GFC_DESCRIPTOR_RANK (b) == 1)
144           {
145             arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
146             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
147             if (arg_extent != ret_extent)
148               runtime_error ("Incorrect extent in return array in"
149                              " MATMUL intrinsic: is %ld, should be %ld",
150                              (long int) ret_extent, (long int) arg_extent);         
151           }
152         else
153           {
154             arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
155             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
156             if (arg_extent != ret_extent)
157               runtime_error ("Incorrect extent in return array in"
158                              " MATMUL intrinsic for dimension 1:"
159                              " is %ld, should be %ld",
160                              (long int) ret_extent, (long int) arg_extent);
161
162             arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
163             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
164             if (arg_extent != ret_extent)
165               runtime_error ("Incorrect extent in return array in"
166                              " MATMUL intrinsic for dimension 2:"
167                              " is %ld, should be %ld",
168                              (long int) ret_extent, (long int) arg_extent);
169           }
170       }
171
172
173   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
174     {
175       /* One-dimensional result may be addressed in the code below
176          either as a row or a column matrix. We want both cases to
177          work. */
178       rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
179     }
180   else
181     {
182       rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
183       rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
184     }
185
186
187   if (GFC_DESCRIPTOR_RANK (a) == 1)
188     {
189       /* Treat it as a a row matrix A[1,count]. */
190       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
191       aystride = 1;
192
193       xcount = 1;
194       count = GFC_DESCRIPTOR_EXTENT(a,0);
195     }
196   else
197     {
198       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
199       aystride = GFC_DESCRIPTOR_STRIDE(a,1);
200
201       count = GFC_DESCRIPTOR_EXTENT(a,1);
202       xcount = GFC_DESCRIPTOR_EXTENT(a,0);
203     }
204
205   if (count != GFC_DESCRIPTOR_EXTENT(b,0))
206     {
207       if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
208         runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
209     }
210
211   if (GFC_DESCRIPTOR_RANK (b) == 1)
212     {
213       /* Treat it as a column matrix B[count,1] */
214       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
215
216       /* bystride should never be used for 1-dimensional b.
217          in case it is we want it to cause a segfault, rather than
218          an incorrect result. */
219       bystride = 0xDEADBEEF;
220       ycount = 1;
221     }
222   else
223     {
224       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
225       bystride = GFC_DESCRIPTOR_STRIDE(b,1);
226       ycount = GFC_DESCRIPTOR_EXTENT(b,1);
227     }
228
229   abase = a->data;
230   bbase = b->data;
231   dest = retarray->data;
232
233
234   /* Now that everything is set up, we're performing the multiplication
235      itself.  */
236
237 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
238
239   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
240       && (bxstride == 1 || bystride == 1)
241       && (((float) xcount) * ((float) ycount) * ((float) count)
242           > POW3(blas_limit)))
243   {
244     const int m = xcount, n = ycount, k = count, ldc = rystride;
245     const GFC_COMPLEX_10 one = 1, zero = 0;
246     const int lda = (axstride == 1) ? aystride : axstride,
247               ldb = (bxstride == 1) ? bystride : bxstride;
248
249     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
250       {
251         assert (gemm != NULL);
252         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
253               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
254         return;
255       }
256   }
257
258   if (rxstride == 1 && axstride == 1 && bxstride == 1)
259     {
260       const GFC_COMPLEX_10 * restrict bbase_y;
261       GFC_COMPLEX_10 * restrict dest_y;
262       const GFC_COMPLEX_10 * restrict abase_n;
263       GFC_COMPLEX_10 bbase_yn;
264
265       if (rystride == xcount)
266         memset (dest, 0, (sizeof (GFC_COMPLEX_10) * xcount * ycount));
267       else
268         {
269           for (y = 0; y < ycount; y++)
270             for (x = 0; x < xcount; x++)
271               dest[x + y*rystride] = (GFC_COMPLEX_10)0;
272         }
273
274       for (y = 0; y < ycount; y++)
275         {
276           bbase_y = bbase + y*bystride;
277           dest_y = dest + y*rystride;
278           for (n = 0; n < count; n++)
279             {
280               abase_n = abase + n*aystride;
281               bbase_yn = bbase_y[n];
282               for (x = 0; x < xcount; x++)
283                 {
284                   dest_y[x] += abase_n[x] * bbase_yn;
285                 }
286             }
287         }
288     }
289   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
290     {
291       if (GFC_DESCRIPTOR_RANK (a) != 1)
292         {
293           const GFC_COMPLEX_10 *restrict abase_x;
294           const GFC_COMPLEX_10 *restrict bbase_y;
295           GFC_COMPLEX_10 *restrict dest_y;
296           GFC_COMPLEX_10 s;
297
298           for (y = 0; y < ycount; y++)
299             {
300               bbase_y = &bbase[y*bystride];
301               dest_y = &dest[y*rystride];
302               for (x = 0; x < xcount; x++)
303                 {
304                   abase_x = &abase[x*axstride];
305                   s = (GFC_COMPLEX_10) 0;
306                   for (n = 0; n < count; n++)
307                     s += abase_x[n] * bbase_y[n];
308                   dest_y[x] = s;
309                 }
310             }
311         }
312       else
313         {
314           const GFC_COMPLEX_10 *restrict bbase_y;
315           GFC_COMPLEX_10 s;
316
317           for (y = 0; y < ycount; y++)
318             {
319               bbase_y = &bbase[y*bystride];
320               s = (GFC_COMPLEX_10) 0;
321               for (n = 0; n < count; n++)
322                 s += abase[n*axstride] * bbase_y[n];
323               dest[y*rystride] = s;
324             }
325         }
326     }
327   else if (axstride < aystride)
328     {
329       for (y = 0; y < ycount; y++)
330         for (x = 0; x < xcount; x++)
331           dest[x*rxstride + y*rystride] = (GFC_COMPLEX_10)0;
332
333       for (y = 0; y < ycount; y++)
334         for (n = 0; n < count; n++)
335           for (x = 0; x < xcount; x++)
336             /* dest[x,y] += a[x,n] * b[n,y] */
337             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
338     }
339   else if (GFC_DESCRIPTOR_RANK (a) == 1)
340     {
341       const GFC_COMPLEX_10 *restrict bbase_y;
342       GFC_COMPLEX_10 s;
343
344       for (y = 0; y < ycount; y++)
345         {
346           bbase_y = &bbase[y*bystride];
347           s = (GFC_COMPLEX_10) 0;
348           for (n = 0; n < count; n++)
349             s += abase[n*axstride] * bbase_y[n*bxstride];
350           dest[y*rxstride] = s;
351         }
352     }
353   else
354     {
355       const GFC_COMPLEX_10 *restrict abase_x;
356       const GFC_COMPLEX_10 *restrict bbase_y;
357       GFC_COMPLEX_10 *restrict dest_y;
358       GFC_COMPLEX_10 s;
359
360       for (y = 0; y < ycount; y++)
361         {
362           bbase_y = &bbase[y*bystride];
363           dest_y = &dest[y*rystride];
364           for (x = 0; x < xcount; x++)
365             {
366               abase_x = &abase[x*axstride];
367               s = (GFC_COMPLEX_10) 0;
368               for (n = 0; n < count; n++)
369                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
370               dest_y[x*rxstride] = s;
371             }
372         }
373     }
374 }
375
376 #endif