OSDN Git Service

Licensing changes to GPLv3 resp. GPLv3 with GCC Runtime Exception.
[pf3gnuchains/gcc-fork.git] / libgfortran / generated / matmul_i4.c
1 /* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006, 2007, 2009 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
11
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 GNU General Public License for more details.
16
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
20
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
24 <http://www.gnu.org/licenses/>.  */
25
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>
30
31
32 #if defined (HAVE_GFC_INTEGER_4)
33
34 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
35    passed to us by the front-end, in which case we'll call it for large
36    matrices.  */
37
38 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
39                           const int *, const GFC_INTEGER_4 *, const GFC_INTEGER_4 *,
40                           const int *, const GFC_INTEGER_4 *, const int *,
41                           const GFC_INTEGER_4 *, GFC_INTEGER_4 *, const int *,
42                           int, int);
43
44 /* The order of loops is different in the case of plain matrix
45    multiplication C=MATMUL(A,B), and in the frequent special case where
46    the argument A is the temporary result of a TRANSPOSE intrinsic:
47    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
48    looking at their strides.
49
50    The equivalent Fortran pseudo-code is:
51
52    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
53    IF (.NOT.IS_TRANSPOSED(A)) THEN
54      C = 0
55      DO J=1,N
56        DO K=1,COUNT
57          DO I=1,M
58            C(I,J) = C(I,J)+A(I,K)*B(K,J)
59    ELSE
60      DO J=1,N
61        DO I=1,M
62          S = 0
63          DO K=1,COUNT
64            S = S+A(I,K)*B(K,J)
65          C(I,J) = S
66    ENDIF
67 */
68
69 /* If try_blas is set to a nonzero value, then the matmul function will
70    see if there is a way to perform the matrix multiplication by a call
71    to the BLAS gemm function.  */
72
73 extern void matmul_i4 (gfc_array_i4 * const restrict retarray, 
74         gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
75         int blas_limit, blas_call gemm);
76 export_proto(matmul_i4);
77
78 void
79 matmul_i4 (gfc_array_i4 * const restrict retarray, 
80         gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
81         int blas_limit, blas_call gemm)
82 {
83   const GFC_INTEGER_4 * restrict abase;
84   const GFC_INTEGER_4 * restrict bbase;
85   GFC_INTEGER_4 * restrict dest;
86
87   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
88   index_type x, y, n, count, xcount, ycount;
89
90   assert (GFC_DESCRIPTOR_RANK (a) == 2
91           || GFC_DESCRIPTOR_RANK (b) == 2);
92
93 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
94
95    Either A or B (but not both) can be rank 1:
96
97    o One-dimensional argument A is implicitly treated as a row matrix
98      dimensioned [1,count], so xcount=1.
99
100    o One-dimensional argument B is implicitly treated as a column matrix
101      dimensioned [count, 1], so ycount=1.
102   */
103
104   if (retarray->data == NULL)
105     {
106       if (GFC_DESCRIPTOR_RANK (a) == 1)
107         {
108           retarray->dim[0].lbound = 0;
109           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
110           retarray->dim[0].stride = 1;
111         }
112       else if (GFC_DESCRIPTOR_RANK (b) == 1)
113         {
114           retarray->dim[0].lbound = 0;
115           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
116           retarray->dim[0].stride = 1;
117         }
118       else
119         {
120           retarray->dim[0].lbound = 0;
121           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
122           retarray->dim[0].stride = 1;
123
124           retarray->dim[1].lbound = 0;
125           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
126           retarray->dim[1].stride = retarray->dim[0].ubound+1;
127         }
128
129       retarray->data
130         = internal_malloc_size (sizeof (GFC_INTEGER_4) * size0 ((array_t *) retarray));
131       retarray->offset = 0;
132     }
133     else if (unlikely (compile_options.bounds_check))
134       {
135         index_type ret_extent, arg_extent;
136
137         if (GFC_DESCRIPTOR_RANK (a) == 1)
138           {
139             arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
140             ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
141             if (arg_extent != ret_extent)
142               runtime_error ("Incorrect extent in return array in"
143                              " MATMUL intrinsic: is %ld, should be %ld",
144                              (long int) ret_extent, (long int) arg_extent);
145           }
146         else if (GFC_DESCRIPTOR_RANK (b) == 1)
147           {
148             arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
149             ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
150             if (arg_extent != ret_extent)
151               runtime_error ("Incorrect extent in return array in"
152                              " MATMUL intrinsic: is %ld, should be %ld",
153                              (long int) ret_extent, (long int) arg_extent);         
154           }
155         else
156           {
157             arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
158             ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
159             if (arg_extent != ret_extent)
160               runtime_error ("Incorrect extent in return array in"
161                              " MATMUL intrinsic for dimension 1:"
162                              " is %ld, should be %ld",
163                              (long int) ret_extent, (long int) arg_extent);
164
165             arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
166             ret_extent = retarray->dim[1].ubound + 1 - retarray->dim[1].lbound;
167             if (arg_extent != ret_extent)
168               runtime_error ("Incorrect extent in return array in"
169                              " MATMUL intrinsic for dimension 2:"
170                              " is %ld, should be %ld",
171                              (long int) ret_extent, (long int) arg_extent);
172           }
173       }
174
175
176   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
177     {
178       /* One-dimensional result may be addressed in the code below
179          either as a row or a column matrix. We want both cases to
180          work. */
181       rxstride = rystride = retarray->dim[0].stride;
182     }
183   else
184     {
185       rxstride = retarray->dim[0].stride;
186       rystride = retarray->dim[1].stride;
187     }
188
189
190   if (GFC_DESCRIPTOR_RANK (a) == 1)
191     {
192       /* Treat it as a a row matrix A[1,count]. */
193       axstride = a->dim[0].stride;
194       aystride = 1;
195
196       xcount = 1;
197       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
198     }
199   else
200     {
201       axstride = a->dim[0].stride;
202       aystride = a->dim[1].stride;
203
204       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
205       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
206     }
207
208   if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
209     {
210       if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
211         runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
212     }
213
214   if (GFC_DESCRIPTOR_RANK (b) == 1)
215     {
216       /* Treat it as a column matrix B[count,1] */
217       bxstride = b->dim[0].stride;
218
219       /* bystride should never be used for 1-dimensional b.
220          in case it is we want it to cause a segfault, rather than
221          an incorrect result. */
222       bystride = 0xDEADBEEF;
223       ycount = 1;
224     }
225   else
226     {
227       bxstride = b->dim[0].stride;
228       bystride = b->dim[1].stride;
229       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
230     }
231
232   abase = a->data;
233   bbase = b->data;
234   dest = retarray->data;
235
236
237   /* Now that everything is set up, we're performing the multiplication
238      itself.  */
239
240 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
241
242   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
243       && (bxstride == 1 || bystride == 1)
244       && (((float) xcount) * ((float) ycount) * ((float) count)
245           > POW3(blas_limit)))
246   {
247     const int m = xcount, n = ycount, k = count, ldc = rystride;
248     const GFC_INTEGER_4 one = 1, zero = 0;
249     const int lda = (axstride == 1) ? aystride : axstride,
250               ldb = (bxstride == 1) ? bystride : bxstride;
251
252     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
253       {
254         assert (gemm != NULL);
255         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
256               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
257         return;
258       }
259   }
260
261   if (rxstride == 1 && axstride == 1 && bxstride == 1)
262     {
263       const GFC_INTEGER_4 * restrict bbase_y;
264       GFC_INTEGER_4 * restrict dest_y;
265       const GFC_INTEGER_4 * restrict abase_n;
266       GFC_INTEGER_4 bbase_yn;
267
268       if (rystride == xcount)
269         memset (dest, 0, (sizeof (GFC_INTEGER_4) * xcount * ycount));
270       else
271         {
272           for (y = 0; y < ycount; y++)
273             for (x = 0; x < xcount; x++)
274               dest[x + y*rystride] = (GFC_INTEGER_4)0;
275         }
276
277       for (y = 0; y < ycount; y++)
278         {
279           bbase_y = bbase + y*bystride;
280           dest_y = dest + y*rystride;
281           for (n = 0; n < count; n++)
282             {
283               abase_n = abase + n*aystride;
284               bbase_yn = bbase_y[n];
285               for (x = 0; x < xcount; x++)
286                 {
287                   dest_y[x] += abase_n[x] * bbase_yn;
288                 }
289             }
290         }
291     }
292   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
293     {
294       if (GFC_DESCRIPTOR_RANK (a) != 1)
295         {
296           const GFC_INTEGER_4 *restrict abase_x;
297           const GFC_INTEGER_4 *restrict bbase_y;
298           GFC_INTEGER_4 *restrict dest_y;
299           GFC_INTEGER_4 s;
300
301           for (y = 0; y < ycount; y++)
302             {
303               bbase_y = &bbase[y*bystride];
304               dest_y = &dest[y*rystride];
305               for (x = 0; x < xcount; x++)
306                 {
307                   abase_x = &abase[x*axstride];
308                   s = (GFC_INTEGER_4) 0;
309                   for (n = 0; n < count; n++)
310                     s += abase_x[n] * bbase_y[n];
311                   dest_y[x] = s;
312                 }
313             }
314         }
315       else
316         {
317           const GFC_INTEGER_4 *restrict bbase_y;
318           GFC_INTEGER_4 s;
319
320           for (y = 0; y < ycount; y++)
321             {
322               bbase_y = &bbase[y*bystride];
323               s = (GFC_INTEGER_4) 0;
324               for (n = 0; n < count; n++)
325                 s += abase[n*axstride] * bbase_y[n];
326               dest[y*rystride] = s;
327             }
328         }
329     }
330   else if (axstride < aystride)
331     {
332       for (y = 0; y < ycount; y++)
333         for (x = 0; x < xcount; x++)
334           dest[x*rxstride + y*rystride] = (GFC_INTEGER_4)0;
335
336       for (y = 0; y < ycount; y++)
337         for (n = 0; n < count; n++)
338           for (x = 0; x < xcount; x++)
339             /* dest[x,y] += a[x,n] * b[n,y] */
340             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
341     }
342   else if (GFC_DESCRIPTOR_RANK (a) == 1)
343     {
344       const GFC_INTEGER_4 *restrict bbase_y;
345       GFC_INTEGER_4 s;
346
347       for (y = 0; y < ycount; y++)
348         {
349           bbase_y = &bbase[y*bystride];
350           s = (GFC_INTEGER_4) 0;
351           for (n = 0; n < count; n++)
352             s += abase[n*axstride] * bbase_y[n*bxstride];
353           dest[y*rxstride] = s;
354         }
355     }
356   else
357     {
358       const GFC_INTEGER_4 *restrict abase_x;
359       const GFC_INTEGER_4 *restrict bbase_y;
360       GFC_INTEGER_4 *restrict dest_y;
361       GFC_INTEGER_4 s;
362
363       for (y = 0; y < ycount; y++)
364         {
365           bbase_y = &bbase[y*bystride];
366           dest_y = &dest[y*rystride];
367           for (x = 0; x < xcount; x++)
368             {
369               abase_x = &abase[x*axstride];
370               s = (GFC_INTEGER_4) 0;
371               for (n = 0; n < count; n++)
372                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
373               dest_y[x*rxstride] = s;
374             }
375         }
376     }
377 }
378
379 #endif