OSDN Git Service

ef2f0fb88dc3ffc528e367e2864924380dfa69b8
[pf3gnuchains/gcc-fork.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
30
31 #include "config.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>
35 #include "libgfortran.h"'
36 include(iparm.m4)dnl
37
38 `#if defined (HAVE_'rtype_name`)'
39
40 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
41    passed to us by the front-end, in which case we'll call it for large
42    matrices.  */
43
44 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
45                           const int *, const rtype_name *, const rtype_name *,
46                           const int *, const rtype_name *, const int *,
47                           const rtype_name *, rtype_name *, const int *,
48                           int, int);
49
50 /* The order of loops is different in the case of plain matrix
51    multiplication C=MATMUL(A,B), and in the frequent special case where
52    the argument A is the temporary result of a TRANSPOSE intrinsic:
53    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
54    looking at their strides.
55
56    The equivalent Fortran pseudo-code is:
57
58    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
59    IF (.NOT.IS_TRANSPOSED(A)) THEN
60      C = 0
61      DO J=1,N
62        DO K=1,COUNT
63          DO I=1,M
64            C(I,J) = C(I,J)+A(I,K)*B(K,J)
65    ELSE
66      DO J=1,N
67        DO I=1,M
68          S = 0
69          DO K=1,COUNT
70            S = S+A(I,K)*B(K,J)
71          C(I,J) = S
72    ENDIF
73 */
74
75 /* If try_blas is set to a nonzero value, then the matmul function will
76    see if there is a way to perform the matrix multiplication by a call
77    to the BLAS gemm function.  */
78
79 extern void matmul_`'rtype_code (rtype * const restrict retarray, 
80         rtype * const restrict a, rtype * const restrict b, int try_blas,
81         int blas_limit, blas_call gemm);
82 export_proto(matmul_`'rtype_code);
83
84 void
85 matmul_`'rtype_code (rtype * const restrict retarray, 
86         rtype * const restrict a, rtype * const restrict b, int try_blas,
87         int blas_limit, blas_call gemm)
88 {
89   const rtype_name * restrict abase;
90   const rtype_name * restrict bbase;
91   rtype_name * restrict dest;
92
93   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
94   index_type x, y, n, count, xcount, ycount;
95
96   assert (GFC_DESCRIPTOR_RANK (a) == 2
97           || GFC_DESCRIPTOR_RANK (b) == 2);
98
99 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
100
101    Either A or B (but not both) can be rank 1:
102
103    o One-dimensional argument A is implicitly treated as a row matrix
104      dimensioned [1,count], so xcount=1.
105
106    o One-dimensional argument B is implicitly treated as a column matrix
107      dimensioned [count, 1], so ycount=1.
108   */
109
110   if (retarray->data == NULL)
111     {
112       if (GFC_DESCRIPTOR_RANK (a) == 1)
113         {
114           retarray->dim[0].lbound = 0;
115           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
116           retarray->dim[0].stride = 1;
117         }
118       else if (GFC_DESCRIPTOR_RANK (b) == 1)
119         {
120           retarray->dim[0].lbound = 0;
121           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
122           retarray->dim[0].stride = 1;
123         }
124       else
125         {
126           retarray->dim[0].lbound = 0;
127           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
128           retarray->dim[0].stride = 1;
129
130           retarray->dim[1].lbound = 0;
131           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
132           retarray->dim[1].stride = retarray->dim[0].ubound+1;
133         }
134
135       retarray->data
136         = internal_malloc_size (sizeof (rtype_name) * size0 ((array_t *) retarray));
137       retarray->offset = 0;
138     }
139
140 sinclude(`matmul_asm_'rtype_code`.m4')dnl
141
142   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
143     {
144       /* One-dimensional result may be addressed in the code below
145          either as a row or a column matrix. We want both cases to
146          work. */
147       rxstride = rystride = retarray->dim[0].stride;
148     }
149   else
150     {
151       rxstride = retarray->dim[0].stride;
152       rystride = retarray->dim[1].stride;
153     }
154
155
156   if (GFC_DESCRIPTOR_RANK (a) == 1)
157     {
158       /* Treat it as a a row matrix A[1,count]. */
159       axstride = a->dim[0].stride;
160       aystride = 1;
161
162       xcount = 1;
163       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
164     }
165   else
166     {
167       axstride = a->dim[0].stride;
168       aystride = a->dim[1].stride;
169
170       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
171       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
172     }
173
174   assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
175
176   if (GFC_DESCRIPTOR_RANK (b) == 1)
177     {
178       /* Treat it as a column matrix B[count,1] */
179       bxstride = b->dim[0].stride;
180
181       /* bystride should never be used for 1-dimensional b.
182          in case it is we want it to cause a segfault, rather than
183          an incorrect result. */
184       bystride = 0xDEADBEEF;
185       ycount = 1;
186     }
187   else
188     {
189       bxstride = b->dim[0].stride;
190       bystride = b->dim[1].stride;
191       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
192     }
193
194   abase = a->data;
195   bbase = b->data;
196   dest = retarray->data;
197
198
199   /* Now that everything is set up, we're performing the multiplication
200      itself.  */
201
202 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
203
204   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
205       && (bxstride == 1 || bystride == 1)
206       && (((float) xcount) * ((float) ycount) * ((float) count)
207           > POW3(blas_limit)))
208   {
209     const int m = xcount, n = ycount, k = count, ldc = rystride;
210     const rtype_name one = 1, zero = 0;
211     const int lda = (axstride == 1) ? aystride : axstride,
212               ldb = (bxstride == 1) ? bystride : bxstride;
213
214     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
215       {
216         assert (gemm != NULL);
217         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
218               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
219         return;
220       }
221   }
222
223   if (rxstride == 1 && axstride == 1 && bxstride == 1)
224     {
225       const rtype_name * restrict bbase_y;
226       rtype_name * restrict dest_y;
227       const rtype_name * restrict abase_n;
228       rtype_name bbase_yn;
229
230       if (rystride == xcount)
231         memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
232       else
233         {
234           for (y = 0; y < ycount; y++)
235             for (x = 0; x < xcount; x++)
236               dest[x + y*rystride] = (rtype_name)0;
237         }
238
239       for (y = 0; y < ycount; y++)
240         {
241           bbase_y = bbase + y*bystride;
242           dest_y = dest + y*rystride;
243           for (n = 0; n < count; n++)
244             {
245               abase_n = abase + n*aystride;
246               bbase_yn = bbase_y[n];
247               for (x = 0; x < xcount; x++)
248                 {
249                   dest_y[x] += abase_n[x] * bbase_yn;
250                 }
251             }
252         }
253     }
254   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
255     {
256       if (GFC_DESCRIPTOR_RANK (a) != 1)
257         {
258           const rtype_name *restrict abase_x;
259           const rtype_name *restrict bbase_y;
260           rtype_name *restrict dest_y;
261           rtype_name s;
262
263           for (y = 0; y < ycount; y++)
264             {
265               bbase_y = &bbase[y*bystride];
266               dest_y = &dest[y*rystride];
267               for (x = 0; x < xcount; x++)
268                 {
269                   abase_x = &abase[x*axstride];
270                   s = (rtype_name) 0;
271                   for (n = 0; n < count; n++)
272                     s += abase_x[n] * bbase_y[n];
273                   dest_y[x] = s;
274                 }
275             }
276         }
277       else
278         {
279           const rtype_name *restrict bbase_y;
280           rtype_name s;
281
282           for (y = 0; y < ycount; y++)
283             {
284               bbase_y = &bbase[y*bystride];
285               s = (rtype_name) 0;
286               for (n = 0; n < count; n++)
287                 s += abase[n*axstride] * bbase_y[n];
288               dest[y*rystride] = s;
289             }
290         }
291     }
292   else if (axstride < aystride)
293     {
294       for (y = 0; y < ycount; y++)
295         for (x = 0; x < xcount; x++)
296           dest[x*rxstride + y*rystride] = (rtype_name)0;
297
298       for (y = 0; y < ycount; y++)
299         for (n = 0; n < count; n++)
300           for (x = 0; x < xcount; x++)
301             /* dest[x,y] += a[x,n] * b[n,y] */
302             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
303     }
304   else if (GFC_DESCRIPTOR_RANK (a) == 1)
305     {
306       const rtype_name *restrict bbase_y;
307       rtype_name s;
308
309       for (y = 0; y < ycount; y++)
310         {
311           bbase_y = &bbase[y*bystride];
312           s = (rtype_name) 0;
313           for (n = 0; n < count; n++)
314             s += abase[n*axstride] * bbase_y[n*bxstride];
315           dest[y*rxstride] = s;
316         }
317     }
318   else
319     {
320       const rtype_name *restrict abase_x;
321       const rtype_name *restrict bbase_y;
322       rtype_name *restrict dest_y;
323       rtype_name s;
324
325       for (y = 0; y < ycount; y++)
326         {
327           bbase_y = &bbase[y*bystride];
328           dest_y = &dest[y*rystride];
329           for (x = 0; x < xcount; x++)
330             {
331               abase_x = &abase[x*axstride];
332               s = (rtype_name) 0;
333               for (n = 0; n < count; n++)
334                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
335               dest_y[x*rxstride] = s;
336             }
337         }
338     }
339 }
340
341 #endif