OSDN Git Service

3678c639f2a3c14ec8c42df57e0437c8ab352a12
[pf3gnuchains/gcc-fork.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
30
31 #include "config.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>
35 #include "libgfortran.h"'
36 include(iparm.m4)dnl
37
38 `#if defined (HAVE_'rtype_name`)'
39
40 /* The order of loops is different in the case of plain matrix
41    multiplication C=MATMUL(A,B), and in the frequent special case where
42    the argument A is the temporary result of a TRANSPOSE intrinsic:
43    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
44    looking at their strides.
45
46    The equivalent Fortran pseudo-code is:
47
48    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
49    IF (.NOT.IS_TRANSPOSED(A)) THEN
50      C = 0
51      DO J=1,N
52        DO K=1,COUNT
53          DO I=1,M
54            C(I,J) = C(I,J)+A(I,K)*B(K,J)
55    ELSE
56      DO J=1,N
57        DO I=1,M
58          S = 0
59          DO K=1,COUNT
60            S = S+A(I,K)+B(K,J)
61          C(I,J) = S
62    ENDIF
63 */
64
65 extern void matmul_`'rtype_code (rtype * const restrict retarray, 
66         rtype * const restrict a, rtype * const restrict b);
67 export_proto(matmul_`'rtype_code);
68
69 void
70 matmul_`'rtype_code (rtype * const restrict retarray, 
71         rtype * const restrict a, rtype * const restrict b)
72 {
73   const rtype_name * restrict abase;
74   const rtype_name * restrict bbase;
75   rtype_name * restrict dest;
76
77   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
78   index_type x, y, n, count, xcount, ycount;
79
80   assert (GFC_DESCRIPTOR_RANK (a) == 2
81           || GFC_DESCRIPTOR_RANK (b) == 2);
82
83 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
84
85    Either A or B (but not both) can be rank 1:
86
87    o One-dimensional argument A is implicitly treated as a row matrix
88      dimensioned [1,count], so xcount=1.
89
90    o One-dimensional argument B is implicitly treated as a column matrix
91      dimensioned [count, 1], so ycount=1.
92   */
93
94   if (retarray->data == NULL)
95     {
96       if (GFC_DESCRIPTOR_RANK (a) == 1)
97         {
98           retarray->dim[0].lbound = 0;
99           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
100           retarray->dim[0].stride = 1;
101         }
102       else if (GFC_DESCRIPTOR_RANK (b) == 1)
103         {
104           retarray->dim[0].lbound = 0;
105           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
106           retarray->dim[0].stride = 1;
107         }
108       else
109         {
110           retarray->dim[0].lbound = 0;
111           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
112           retarray->dim[0].stride = 1;
113
114           retarray->dim[1].lbound = 0;
115           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
116           retarray->dim[1].stride = retarray->dim[0].ubound+1;
117         }
118
119       retarray->data
120         = internal_malloc_size (sizeof (rtype_name) * size0 ((array_t *) retarray));
121       retarray->offset = 0;
122     }
123
124 sinclude(`matmul_asm_'rtype_code`.m4')dnl
125
126   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
127     {
128       /* One-dimensional result may be addressed in the code below
129          either as a row or a column matrix. We want both cases to
130          work. */
131       rxstride = rystride = retarray->dim[0].stride;
132     }
133   else
134     {
135       rxstride = retarray->dim[0].stride;
136       rystride = retarray->dim[1].stride;
137     }
138
139
140   if (GFC_DESCRIPTOR_RANK (a) == 1)
141     {
142       /* Treat it as a a row matrix A[1,count]. */
143       axstride = a->dim[0].stride;
144       aystride = 1;
145
146       xcount = 1;
147       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
148     }
149   else
150     {
151       axstride = a->dim[0].stride;
152       aystride = a->dim[1].stride;
153
154       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
155       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
156     }
157
158   assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
159
160   if (GFC_DESCRIPTOR_RANK (b) == 1)
161     {
162       /* Treat it as a column matrix B[count,1] */
163       bxstride = b->dim[0].stride;
164
165       /* bystride should never be used for 1-dimensional b.
166          in case it is we want it to cause a segfault, rather than
167          an incorrect result. */
168       bystride = 0xDEADBEEF;
169       ycount = 1;
170     }
171   else
172     {
173       bxstride = b->dim[0].stride;
174       bystride = b->dim[1].stride;
175       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
176     }
177
178   abase = a->data;
179   bbase = b->data;
180   dest = retarray->data;
181
182   if (rxstride == 1 && axstride == 1 && bxstride == 1)
183     {
184       const rtype_name * restrict bbase_y;
185       rtype_name * restrict dest_y;
186       const rtype_name * restrict abase_n;
187       rtype_name bbase_yn;
188
189       if (rystride == xcount)
190         memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
191       else
192         {
193           for (y = 0; y < ycount; y++)
194             for (x = 0; x < xcount; x++)
195               dest[x + y*rystride] = (rtype_name)0;
196         }
197
198       for (y = 0; y < ycount; y++)
199         {
200           bbase_y = bbase + y*bystride;
201           dest_y = dest + y*rystride;
202           for (n = 0; n < count; n++)
203             {
204               abase_n = abase + n*aystride;
205               bbase_yn = bbase_y[n];
206               for (x = 0; x < xcount; x++)
207                 {
208                   dest_y[x] += abase_n[x] * bbase_yn;
209                 }
210             }
211         }
212     }
213   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
214     {
215       if (GFC_DESCRIPTOR_RANK (a) != 1)
216         {
217           const rtype_name *restrict abase_x;
218           const rtype_name *restrict bbase_y;
219           rtype_name *restrict dest_y;
220           rtype_name s;
221
222           for (y = 0; y < ycount; y++)
223             {
224               bbase_y = &bbase[y*bystride];
225               dest_y = &dest[y*rystride];
226               for (x = 0; x < xcount; x++)
227                 {
228                   abase_x = &abase[x*axstride];
229                   s = (rtype_name) 0;
230                   for (n = 0; n < count; n++)
231                     s += abase_x[n] * bbase_y[n];
232                   dest_y[x] = s;
233                 }
234             }
235         }
236       else
237         {
238           const rtype_name *restrict bbase_y;
239           rtype_name s;
240
241           for (y = 0; y < ycount; y++)
242             {
243               bbase_y = &bbase[y*bystride];
244               s = (rtype_name) 0;
245               for (n = 0; n < count; n++)
246                 s += abase[n*axstride] * bbase_y[n];
247               dest[y*rystride] = s;
248             }
249         }
250     }
251   else if (axstride < aystride)
252     {
253       for (y = 0; y < ycount; y++)
254         for (x = 0; x < xcount; x++)
255           dest[x*rxstride + y*rystride] = (rtype_name)0;
256
257       for (y = 0; y < ycount; y++)
258         for (n = 0; n < count; n++)
259           for (x = 0; x < xcount; x++)
260             /* dest[x,y] += a[x,n] * b[n,y] */
261             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
262     }
263   else if (GFC_DESCRIPTOR_RANK (a) == 1)
264     {
265       const rtype_name *restrict bbase_y;
266       rtype_name s;
267
268       for (y = 0; y < ycount; y++)
269         {
270           bbase_y = &bbase[y*bystride];
271           s = (rtype_name) 0;
272           for (n = 0; n < count; n++)
273             s += abase[n*axstride] * bbase_y[n*bxstride];
274           dest[y*rxstride] = s;
275         }
276     }
277   else
278     {
279       const rtype_name *restrict abase_x;
280       const rtype_name *restrict bbase_y;
281       rtype_name *restrict dest_y;
282       rtype_name s;
283
284       for (y = 0; y < ycount; y++)
285         {
286           bbase_y = &bbase[y*bystride];
287           dest_y = &dest[y*rystride];
288           for (x = 0; x < xcount; x++)
289             {
290               abase_x = &abase[x*axstride];
291               s = (rtype_name) 0;
292               for (n = 0; n < count; n++)
293                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
294               dest_y[x*rxstride] = s;
295             }
296         }
297     }
298 }
299
300 #endif