OSDN Git Service

2008-02-10 Benjamin Kosnik <bkoz@redhat.com>
[pf3gnuchains/gcc-fork.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006, 2007 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
30
31 #include "libgfortran.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>'
35
36 include(iparm.m4)dnl
37
38 `#if defined (HAVE_'rtype_name`)
39
40 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
41    passed to us by the front-end, in which case we''`ll call it for large
42    matrices.  */
43
44 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
45                           const int *, const 'rtype_name` *, const 'rtype_name` *,
46                           const int *, const 'rtype_name` *, const int *,
47                           const 'rtype_name` *, 'rtype_name` *, const int *,
48                           int, int);
49
50 /* The order of loops is different in the case of plain matrix
51    multiplication C=MATMUL(A,B), and in the frequent special case where
52    the argument A is the temporary result of a TRANSPOSE intrinsic:
53    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
54    looking at their strides.
55
56    The equivalent Fortran pseudo-code is:
57
58    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
59    IF (.NOT.IS_TRANSPOSED(A)) THEN
60      C = 0
61      DO J=1,N
62        DO K=1,COUNT
63          DO I=1,M
64            C(I,J) = C(I,J)+A(I,K)*B(K,J)
65    ELSE
66      DO J=1,N
67        DO I=1,M
68          S = 0
69          DO K=1,COUNT
70            S = S+A(I,K)*B(K,J)
71          C(I,J) = S
72    ENDIF
73 */
74
75 /* If try_blas is set to a nonzero value, then the matmul function will
76    see if there is a way to perform the matrix multiplication by a call
77    to the BLAS gemm function.  */
78
79 extern void matmul_'rtype_code` ('rtype` * const restrict retarray, 
80         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
81         int blas_limit, blas_call gemm);
82 export_proto(matmul_'rtype_code`);
83
84 void
85 matmul_'rtype_code` ('rtype` * const restrict retarray, 
86         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
87         int blas_limit, blas_call gemm)
88 {
89   const 'rtype_name` * restrict abase;
90   const 'rtype_name` * restrict bbase;
91   'rtype_name` * restrict dest;
92
93   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
94   index_type x, y, n, count, xcount, ycount;
95
96   assert (GFC_DESCRIPTOR_RANK (a) == 2
97           || GFC_DESCRIPTOR_RANK (b) == 2);
98
99 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
100
101    Either A or B (but not both) can be rank 1:
102
103    o One-dimensional argument A is implicitly treated as a row matrix
104      dimensioned [1,count], so xcount=1.
105
106    o One-dimensional argument B is implicitly treated as a column matrix
107      dimensioned [count, 1], so ycount=1.
108   */
109
110   if (retarray->data == NULL)
111     {
112       if (GFC_DESCRIPTOR_RANK (a) == 1)
113         {
114           retarray->dim[0].lbound = 0;
115           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
116           retarray->dim[0].stride = 1;
117         }
118       else if (GFC_DESCRIPTOR_RANK (b) == 1)
119         {
120           retarray->dim[0].lbound = 0;
121           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
122           retarray->dim[0].stride = 1;
123         }
124       else
125         {
126           retarray->dim[0].lbound = 0;
127           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
128           retarray->dim[0].stride = 1;
129
130           retarray->dim[1].lbound = 0;
131           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
132           retarray->dim[1].stride = retarray->dim[0].ubound+1;
133         }
134
135       retarray->data
136         = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
137       retarray->offset = 0;
138     }
139 '
140 sinclude(`matmul_asm_'rtype_code`.m4')dnl
141 `
142   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
143     {
144       /* One-dimensional result may be addressed in the code below
145          either as a row or a column matrix. We want both cases to
146          work. */
147       rxstride = rystride = retarray->dim[0].stride;
148     }
149   else
150     {
151       rxstride = retarray->dim[0].stride;
152       rystride = retarray->dim[1].stride;
153     }
154
155
156   if (GFC_DESCRIPTOR_RANK (a) == 1)
157     {
158       /* Treat it as a a row matrix A[1,count]. */
159       axstride = a->dim[0].stride;
160       aystride = 1;
161
162       xcount = 1;
163       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
164     }
165   else
166     {
167       axstride = a->dim[0].stride;
168       aystride = a->dim[1].stride;
169
170       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
171       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
172     }
173
174   if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
175     runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
176
177   if (GFC_DESCRIPTOR_RANK (b) == 1)
178     {
179       /* Treat it as a column matrix B[count,1] */
180       bxstride = b->dim[0].stride;
181
182       /* bystride should never be used for 1-dimensional b.
183          in case it is we want it to cause a segfault, rather than
184          an incorrect result. */
185       bystride = 0xDEADBEEF;
186       ycount = 1;
187     }
188   else
189     {
190       bxstride = b->dim[0].stride;
191       bystride = b->dim[1].stride;
192       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
193     }
194
195   abase = a->data;
196   bbase = b->data;
197   dest = retarray->data;
198
199
200   /* Now that everything is set up, we''`re performing the multiplication
201      itself.  */
202
203 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
204
205   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
206       && (bxstride == 1 || bystride == 1)
207       && (((float) xcount) * ((float) ycount) * ((float) count)
208           > POW3(blas_limit)))
209   {
210     const int m = xcount, n = ycount, k = count, ldc = rystride;
211     const 'rtype_name` one = 1, zero = 0;
212     const int lda = (axstride == 1) ? aystride : axstride,
213               ldb = (bxstride == 1) ? bystride : bxstride;
214
215     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
216       {
217         assert (gemm != NULL);
218         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
219               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
220         return;
221       }
222   }
223
224   if (rxstride == 1 && axstride == 1 && bxstride == 1)
225     {
226       const 'rtype_name` * restrict bbase_y;
227       'rtype_name` * restrict dest_y;
228       const 'rtype_name` * restrict abase_n;
229       'rtype_name` bbase_yn;
230
231       if (rystride == xcount)
232         memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
233       else
234         {
235           for (y = 0; y < ycount; y++)
236             for (x = 0; x < xcount; x++)
237               dest[x + y*rystride] = ('rtype_name`)0;
238         }
239
240       for (y = 0; y < ycount; y++)
241         {
242           bbase_y = bbase + y*bystride;
243           dest_y = dest + y*rystride;
244           for (n = 0; n < count; n++)
245             {
246               abase_n = abase + n*aystride;
247               bbase_yn = bbase_y[n];
248               for (x = 0; x < xcount; x++)
249                 {
250                   dest_y[x] += abase_n[x] * bbase_yn;
251                 }
252             }
253         }
254     }
255   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
256     {
257       if (GFC_DESCRIPTOR_RANK (a) != 1)
258         {
259           const 'rtype_name` *restrict abase_x;
260           const 'rtype_name` *restrict bbase_y;
261           'rtype_name` *restrict dest_y;
262           'rtype_name` s;
263
264           for (y = 0; y < ycount; y++)
265             {
266               bbase_y = &bbase[y*bystride];
267               dest_y = &dest[y*rystride];
268               for (x = 0; x < xcount; x++)
269                 {
270                   abase_x = &abase[x*axstride];
271                   s = ('rtype_name`) 0;
272                   for (n = 0; n < count; n++)
273                     s += abase_x[n] * bbase_y[n];
274                   dest_y[x] = s;
275                 }
276             }
277         }
278       else
279         {
280           const 'rtype_name` *restrict bbase_y;
281           'rtype_name` s;
282
283           for (y = 0; y < ycount; y++)
284             {
285               bbase_y = &bbase[y*bystride];
286               s = ('rtype_name`) 0;
287               for (n = 0; n < count; n++)
288                 s += abase[n*axstride] * bbase_y[n];
289               dest[y*rystride] = s;
290             }
291         }
292     }
293   else if (axstride < aystride)
294     {
295       for (y = 0; y < ycount; y++)
296         for (x = 0; x < xcount; x++)
297           dest[x*rxstride + y*rystride] = ('rtype_name`)0;
298
299       for (y = 0; y < ycount; y++)
300         for (n = 0; n < count; n++)
301           for (x = 0; x < xcount; x++)
302             /* dest[x,y] += a[x,n] * b[n,y] */
303             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
304     }
305   else if (GFC_DESCRIPTOR_RANK (a) == 1)
306     {
307       const 'rtype_name` *restrict bbase_y;
308       'rtype_name` s;
309
310       for (y = 0; y < ycount; y++)
311         {
312           bbase_y = &bbase[y*bystride];
313           s = ('rtype_name`) 0;
314           for (n = 0; n < count; n++)
315             s += abase[n*axstride] * bbase_y[n*bxstride];
316           dest[y*rxstride] = s;
317         }
318     }
319   else
320     {
321       const 'rtype_name` *restrict abase_x;
322       const 'rtype_name` *restrict bbase_y;
323       'rtype_name` *restrict dest_y;
324       'rtype_name` s;
325
326       for (y = 0; y < ycount; y++)
327         {
328           bbase_y = &bbase[y*bystride];
329           dest_y = &dest[y*rystride];
330           for (x = 0; x < xcount; x++)
331             {
332               abase_x = &abase[x*axstride];
333               s = ('rtype_name`) 0;
334               for (n = 0; n < count; n++)
335                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
336               dest_y[x*rxstride] = s;
337             }
338         }
339     }
340 }
341
342 #endif'