OSDN Git Service

PR libfortran/26985
[pf3gnuchains/gcc-fork.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
30
31 #include "config.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>
35 #include "libgfortran.h"'
36 include(iparm.m4)dnl
37
38 `#if defined (HAVE_'rtype_name`)'
39
40 /* The order of loops is different in the case of plain matrix
41    multiplication C=MATMUL(A,B), and in the frequent special case where
42    the argument A is the temporary result of a TRANSPOSE intrinsic:
43    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
44    looking at their strides.
45
46    The equivalent Fortran pseudo-code is:
47
48    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
49    IF (.NOT.IS_TRANSPOSED(A)) THEN
50      C = 0
51      DO J=1,N
52        DO K=1,COUNT
53          DO I=1,M
54            C(I,J) = C(I,J)+A(I,K)*B(K,J)
55    ELSE
56      DO J=1,N
57        DO I=1,M
58          S = 0
59          DO K=1,COUNT
60            S = S+A(I,K)+B(K,J)
61          C(I,J) = S
62    ENDIF
63 */
64
65 extern void matmul_`'rtype_code (rtype * const restrict retarray, 
66         rtype * const restrict a, rtype * const restrict b);
67 export_proto(matmul_`'rtype_code);
68
69 void
70 matmul_`'rtype_code (rtype * const restrict retarray, 
71         rtype * const restrict a, rtype * const restrict b)
72 {
73   const rtype_name * restrict abase;
74   const rtype_name * restrict bbase;
75   rtype_name * restrict dest;
76
77   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
78   index_type x, y, n, count, xcount, ycount;
79
80   assert (GFC_DESCRIPTOR_RANK (a) == 2
81           || GFC_DESCRIPTOR_RANK (b) == 2);
82
83 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
84
85    Either A or B (but not both) can be rank 1:
86
87    o One-dimensional argument A is implicitly treated as a row matrix
88      dimensioned [1,count], so xcount=1.
89
90    o One-dimensional argument B is implicitly treated as a column matrix
91      dimensioned [count, 1], so ycount=1.
92   */
93
94   if (retarray->data == NULL)
95     {
96       if (GFC_DESCRIPTOR_RANK (a) == 1)
97         {
98           retarray->dim[0].lbound = 0;
99           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
100           retarray->dim[0].stride = 1;
101         }
102       else if (GFC_DESCRIPTOR_RANK (b) == 1)
103         {
104           retarray->dim[0].lbound = 0;
105           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
106           retarray->dim[0].stride = 1;
107         }
108       else
109         {
110           retarray->dim[0].lbound = 0;
111           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
112           retarray->dim[0].stride = 1;
113
114           retarray->dim[1].lbound = 0;
115           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
116           retarray->dim[1].stride = retarray->dim[0].ubound+1;
117         }
118
119       retarray->data
120         = internal_malloc_size (sizeof (rtype_name) * size0 ((array_t *) retarray));
121       retarray->offset = 0;
122     }
123
124   if (retarray->dim[0].stride == 0)
125     retarray->dim[0].stride = 1;
126
127   /* This prevents constifying the input arguments.  */
128   if (a->dim[0].stride == 0)
129     a->dim[0].stride = 1;
130   if (b->dim[0].stride == 0)
131     b->dim[0].stride = 1;
132
133 sinclude(`matmul_asm_'rtype_code`.m4')dnl
134
135   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
136     {
137       /* One-dimensional result may be addressed in the code below
138          either as a row or a column matrix. We want both cases to
139          work. */
140       rxstride = rystride = retarray->dim[0].stride;
141     }
142   else
143     {
144       rxstride = retarray->dim[0].stride;
145       rystride = retarray->dim[1].stride;
146     }
147
148
149   if (GFC_DESCRIPTOR_RANK (a) == 1)
150     {
151       /* Treat it as a a row matrix A[1,count]. */
152       axstride = a->dim[0].stride;
153       aystride = 1;
154
155       xcount = 1;
156       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
157     }
158   else
159     {
160       axstride = a->dim[0].stride;
161       aystride = a->dim[1].stride;
162
163       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
164       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
165     }
166
167   assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
168
169   if (GFC_DESCRIPTOR_RANK (b) == 1)
170     {
171       /* Treat it as a column matrix B[count,1] */
172       bxstride = b->dim[0].stride;
173
174       /* bystride should never be used for 1-dimensional b.
175          in case it is we want it to cause a segfault, rather than
176          an incorrect result. */
177       bystride = 0xDEADBEEF;
178       ycount = 1;
179     }
180   else
181     {
182       bxstride = b->dim[0].stride;
183       bystride = b->dim[1].stride;
184       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
185     }
186
187   abase = a->data;
188   bbase = b->data;
189   dest = retarray->data;
190
191   if (rxstride == 1 && axstride == 1 && bxstride == 1)
192     {
193       const rtype_name * restrict bbase_y;
194       rtype_name * restrict dest_y;
195       const rtype_name * restrict abase_n;
196       rtype_name bbase_yn;
197
198       if (rystride == xcount)
199         memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
200       else
201         {
202           for (y = 0; y < ycount; y++)
203             for (x = 0; x < xcount; x++)
204               dest[x + y*rystride] = (rtype_name)0;
205         }
206
207       for (y = 0; y < ycount; y++)
208         {
209           bbase_y = bbase + y*bystride;
210           dest_y = dest + y*rystride;
211           for (n = 0; n < count; n++)
212             {
213               abase_n = abase + n*aystride;
214               bbase_yn = bbase_y[n];
215               for (x = 0; x < xcount; x++)
216                 {
217                   dest_y[x] += abase_n[x] * bbase_yn;
218                 }
219             }
220         }
221     }
222   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
223     {
224       const rtype_name *restrict abase_x;
225       const rtype_name *restrict bbase_y;
226       rtype_name *restrict dest_y;
227       rtype_name s;
228
229       for (y = 0; y < ycount; y++)
230         {
231           bbase_y = &bbase[y*bystride];
232           dest_y = &dest[y*rystride];
233           for (x = 0; x < xcount; x++)
234             {
235               abase_x = &abase[x*axstride];
236               s = (rtype_name) 0;
237               for (n = 0; n < count; n++)
238                 s += abase_x[n] * bbase_y[n];
239               dest_y[x] = s;
240             }
241         }
242     }
243   else if (axstride < aystride)
244     {
245       for (y = 0; y < ycount; y++)
246         for (x = 0; x < xcount; x++)
247           dest[x*rxstride + y*rystride] = (rtype_name)0;
248
249       for (y = 0; y < ycount; y++)
250         for (n = 0; n < count; n++)
251           for (x = 0; x < xcount; x++)
252             /* dest[x,y] += a[x,n] * b[n,y] */
253             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
254     }
255   else
256     {
257       const rtype_name *restrict abase_x;
258       const rtype_name *restrict bbase_y;
259       rtype_name *restrict dest_y;
260       rtype_name s;
261
262       for (y = 0; y < ycount; y++)
263         {
264           bbase_y = &bbase[y*bystride];
265           dest_y = &dest[y*rystride];
266           for (x = 0; x < xcount; x++)
267             {
268               abase_x = &abase[x*axstride];
269               s = (rtype_name) 0;
270               for (n = 0; n < count; n++)
271                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
272               dest_y[x*rxstride] = s;
273             }
274         }
275     }
276 }
277
278 #endif