OSDN Git Service

gcc/fortran/
[pf3gnuchains/gcc-fork.git] / libgfortran / generated / matmul_r4.c
1 /* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
30
31 #include "config.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>
35 #include "libgfortran.h"
36
37 #if defined (HAVE_GFC_REAL_4)
38
39 /* The order of loops is different in the case of plain matrix
40    multiplication C=MATMUL(A,B), and in the frequent special case where
41    the argument A is the temporary result of a TRANSPOSE intrinsic:
42    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
43    looking at their strides.
44
45    The equivalent Fortran pseudo-code is:
46
47    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
48    IF (.NOT.IS_TRANSPOSED(A)) THEN
49      C = 0
50      DO J=1,N
51        DO K=1,COUNT
52          DO I=1,M
53            C(I,J) = C(I,J)+A(I,K)*B(K,J)
54    ELSE
55      DO J=1,N
56        DO I=1,M
57          S = 0
58          DO K=1,COUNT
59            S = S+A(I,K)+B(K,J)
60          C(I,J) = S
61    ENDIF
62 */
63
64 extern void matmul_r4 (gfc_array_r4 * const restrict retarray, 
65         gfc_array_r4 * const restrict a, gfc_array_r4 * const restrict b);
66 export_proto(matmul_r4);
67
68 void
69 matmul_r4 (gfc_array_r4 * const restrict retarray, 
70         gfc_array_r4 * const restrict a, gfc_array_r4 * const restrict b)
71 {
72   const GFC_REAL_4 * restrict abase;
73   const GFC_REAL_4 * restrict bbase;
74   GFC_REAL_4 * restrict dest;
75
76   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
77   index_type x, y, n, count, xcount, ycount;
78
79   assert (GFC_DESCRIPTOR_RANK (a) == 2
80           || GFC_DESCRIPTOR_RANK (b) == 2);
81
82 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
83
84    Either A or B (but not both) can be rank 1:
85
86    o One-dimensional argument A is implicitly treated as a row matrix
87      dimensioned [1,count], so xcount=1.
88
89    o One-dimensional argument B is implicitly treated as a column matrix
90      dimensioned [count, 1], so ycount=1.
91   */
92
93   if (retarray->data == NULL)
94     {
95       if (GFC_DESCRIPTOR_RANK (a) == 1)
96         {
97           retarray->dim[0].lbound = 0;
98           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
99           retarray->dim[0].stride = 1;
100         }
101       else if (GFC_DESCRIPTOR_RANK (b) == 1)
102         {
103           retarray->dim[0].lbound = 0;
104           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
105           retarray->dim[0].stride = 1;
106         }
107       else
108         {
109           retarray->dim[0].lbound = 0;
110           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
111           retarray->dim[0].stride = 1;
112
113           retarray->dim[1].lbound = 0;
114           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
115           retarray->dim[1].stride = retarray->dim[0].ubound+1;
116         }
117
118       retarray->data
119         = internal_malloc_size (sizeof (GFC_REAL_4) * size0 ((array_t *) retarray));
120       retarray->offset = 0;
121     }
122
123   if (retarray->dim[0].stride == 0)
124     retarray->dim[0].stride = 1;
125
126   /* This prevents constifying the input arguments.  */
127   if (a->dim[0].stride == 0)
128     a->dim[0].stride = 1;
129   if (b->dim[0].stride == 0)
130     b->dim[0].stride = 1;
131
132
133   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
134     {
135       /* One-dimensional result may be addressed in the code below
136          either as a row or a column matrix. We want both cases to
137          work. */
138       rxstride = rystride = retarray->dim[0].stride;
139     }
140   else
141     {
142       rxstride = retarray->dim[0].stride;
143       rystride = retarray->dim[1].stride;
144     }
145
146
147   if (GFC_DESCRIPTOR_RANK (a) == 1)
148     {
149       /* Treat it as a a row matrix A[1,count]. */
150       axstride = a->dim[0].stride;
151       aystride = 1;
152
153       xcount = 1;
154       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
155     }
156   else
157     {
158       axstride = a->dim[0].stride;
159       aystride = a->dim[1].stride;
160
161       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
162       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
163     }
164
165   assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
166
167   if (GFC_DESCRIPTOR_RANK (b) == 1)
168     {
169       /* Treat it as a column matrix B[count,1] */
170       bxstride = b->dim[0].stride;
171
172       /* bystride should never be used for 1-dimensional b.
173          in case it is we want it to cause a segfault, rather than
174          an incorrect result. */
175       bystride = 0xDEADBEEF;
176       ycount = 1;
177     }
178   else
179     {
180       bxstride = b->dim[0].stride;
181       bystride = b->dim[1].stride;
182       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
183     }
184
185   abase = a->data;
186   bbase = b->data;
187   dest = retarray->data;
188
189   if (rxstride == 1 && axstride == 1 && bxstride == 1)
190     {
191       const GFC_REAL_4 * restrict bbase_y;
192       GFC_REAL_4 * restrict dest_y;
193       const GFC_REAL_4 * restrict abase_n;
194       GFC_REAL_4 bbase_yn;
195
196       if (rystride == ycount)
197         memset (dest, 0, (sizeof (GFC_REAL_4) * size0((array_t *) retarray)));
198       else
199         {
200           for (y = 0; y < ycount; y++)
201             for (x = 0; x < xcount; x++)
202               dest[x + y*rystride] = (GFC_REAL_4)0;
203         }
204
205       for (y = 0; y < ycount; y++)
206         {
207           bbase_y = bbase + y*bystride;
208           dest_y = dest + y*rystride;
209           for (n = 0; n < count; n++)
210             {
211               abase_n = abase + n*aystride;
212               bbase_yn = bbase_y[n];
213               for (x = 0; x < xcount; x++)
214                 {
215                   dest_y[x] += abase_n[x] * bbase_yn;
216                 }
217             }
218         }
219     }
220   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
221     {
222       const GFC_REAL_4 *restrict abase_x;
223       const GFC_REAL_4 *restrict bbase_y;
224       GFC_REAL_4 *restrict dest_y;
225       GFC_REAL_4 s;
226
227       for (y = 0; y < ycount; y++)
228         {
229           bbase_y = &bbase[y*bystride];
230           dest_y = &dest[y*rystride];
231           for (x = 0; x < xcount; x++)
232             {
233               abase_x = &abase[x*axstride];
234               s = (GFC_REAL_4) 0;
235               for (n = 0; n < count; n++)
236                 s += abase_x[n] * bbase_y[n];
237               dest_y[x] = s;
238             }
239         }
240     }
241   else if (axstride < aystride)
242     {
243       for (y = 0; y < ycount; y++)
244         for (x = 0; x < xcount; x++)
245           dest[x*rxstride + y*rystride] = (GFC_REAL_4)0;
246
247       for (y = 0; y < ycount; y++)
248         for (n = 0; n < count; n++)
249           for (x = 0; x < xcount; x++)
250             /* dest[x,y] += a[x,n] * b[n,y] */
251             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
252     }
253   else
254     {
255       const GFC_REAL_4 *restrict abase_x;
256       const GFC_REAL_4 *restrict bbase_y;
257       GFC_REAL_4 *restrict dest_y;
258       GFC_REAL_4 s;
259
260       for (y = 0; y < ycount; y++)
261         {
262           bbase_y = &bbase[y*bystride];
263           dest_y = &dest[y*rystride];
264           for (x = 0; x < xcount; x++)
265             {
266               abase_x = &abase[x*axstride];
267               s = (GFC_REAL_4) 0;
268               for (n = 0; n < count; n++)
269                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
270               dest_y[x*rxstride] = s;
271             }
272         }
273     }
274 }
275
276 #endif