1 /* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file. (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU General Public License for more details.
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING. If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA. */
35 #include "libgfortran.h"
37 #if defined (HAVE_GFC_COMPLEX_4)
39 /* The order of loops is different in the case of plain matrix
40 multiplication C=MATMUL(A,B), and in the frequent special case where
41 the argument A is the temporary result of a TRANSPOSE intrinsic:
42 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
43 looking at their strides.
45 The equivalent Fortran pseudo-code is:
47 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
48 IF (.NOT.IS_TRANSPOSED(A)) THEN
53 C(I,J) = C(I,J)+A(I,K)*B(K,J)
64 extern void matmul_c4 (gfc_array_c4 * const restrict retarray,
65 gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b);
66 export_proto(matmul_c4);
69 matmul_c4 (gfc_array_c4 * const restrict retarray,
70 gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b)
72 const GFC_COMPLEX_4 * restrict abase;
73 const GFC_COMPLEX_4 * restrict bbase;
74 GFC_COMPLEX_4 * restrict dest;
76 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
77 index_type x, y, n, count, xcount, ycount;
79 assert (GFC_DESCRIPTOR_RANK (a) == 2
80 || GFC_DESCRIPTOR_RANK (b) == 2);
82 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
84 Either A or B (but not both) can be rank 1:
86 o One-dimensional argument A is implicitly treated as a row matrix
87 dimensioned [1,count], so xcount=1.
89 o One-dimensional argument B is implicitly treated as a column matrix
90 dimensioned [count, 1], so ycount=1.
93 if (retarray->data == NULL)
95 if (GFC_DESCRIPTOR_RANK (a) == 1)
97 retarray->dim[0].lbound = 0;
98 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
99 retarray->dim[0].stride = 1;
101 else if (GFC_DESCRIPTOR_RANK (b) == 1)
103 retarray->dim[0].lbound = 0;
104 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
105 retarray->dim[0].stride = 1;
109 retarray->dim[0].lbound = 0;
110 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
111 retarray->dim[0].stride = 1;
113 retarray->dim[1].lbound = 0;
114 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
115 retarray->dim[1].stride = retarray->dim[0].ubound+1;
119 = internal_malloc_size (sizeof (GFC_COMPLEX_4) * size0 ((array_t *) retarray));
120 retarray->offset = 0;
124 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
126 /* One-dimensional result may be addressed in the code below
127 either as a row or a column matrix. We want both cases to
129 rxstride = rystride = retarray->dim[0].stride;
133 rxstride = retarray->dim[0].stride;
134 rystride = retarray->dim[1].stride;
138 if (GFC_DESCRIPTOR_RANK (a) == 1)
140 /* Treat it as a a row matrix A[1,count]. */
141 axstride = a->dim[0].stride;
145 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
149 axstride = a->dim[0].stride;
150 aystride = a->dim[1].stride;
152 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
153 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
156 assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
158 if (GFC_DESCRIPTOR_RANK (b) == 1)
160 /* Treat it as a column matrix B[count,1] */
161 bxstride = b->dim[0].stride;
163 /* bystride should never be used for 1-dimensional b.
164 in case it is we want it to cause a segfault, rather than
165 an incorrect result. */
166 bystride = 0xDEADBEEF;
171 bxstride = b->dim[0].stride;
172 bystride = b->dim[1].stride;
173 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
178 dest = retarray->data;
180 if (rxstride == 1 && axstride == 1 && bxstride == 1)
182 const GFC_COMPLEX_4 * restrict bbase_y;
183 GFC_COMPLEX_4 * restrict dest_y;
184 const GFC_COMPLEX_4 * restrict abase_n;
185 GFC_COMPLEX_4 bbase_yn;
187 if (rystride == xcount)
188 memset (dest, 0, (sizeof (GFC_COMPLEX_4) * xcount * ycount));
191 for (y = 0; y < ycount; y++)
192 for (x = 0; x < xcount; x++)
193 dest[x + y*rystride] = (GFC_COMPLEX_4)0;
196 for (y = 0; y < ycount; y++)
198 bbase_y = bbase + y*bystride;
199 dest_y = dest + y*rystride;
200 for (n = 0; n < count; n++)
202 abase_n = abase + n*aystride;
203 bbase_yn = bbase_y[n];
204 for (x = 0; x < xcount; x++)
206 dest_y[x] += abase_n[x] * bbase_yn;
211 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
213 const GFC_COMPLEX_4 *restrict abase_x;
214 const GFC_COMPLEX_4 *restrict bbase_y;
215 GFC_COMPLEX_4 *restrict dest_y;
218 for (y = 0; y < ycount; y++)
220 bbase_y = &bbase[y*bystride];
221 dest_y = &dest[y*rystride];
222 for (x = 0; x < xcount; x++)
224 abase_x = &abase[x*axstride];
225 s = (GFC_COMPLEX_4) 0;
226 for (n = 0; n < count; n++)
227 s += abase_x[n] * bbase_y[n];
232 else if (axstride < aystride)
234 for (y = 0; y < ycount; y++)
235 for (x = 0; x < xcount; x++)
236 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_4)0;
238 for (y = 0; y < ycount; y++)
239 for (n = 0; n < count; n++)
240 for (x = 0; x < xcount; x++)
241 /* dest[x,y] += a[x,n] * b[n,y] */
242 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
246 const GFC_COMPLEX_4 *restrict abase_x;
247 const GFC_COMPLEX_4 *restrict bbase_y;
248 GFC_COMPLEX_4 *restrict dest_y;
251 for (y = 0; y < ycount; y++)
253 bbase_y = &bbase[y*bystride];
254 dest_y = &dest[y*rystride];
255 for (x = 0; x < xcount; x++)
257 abase_x = &abase[x*axstride];
258 s = (GFC_COMPLEX_4) 0;
259 for (n = 0; n < count; n++)
260 s += abase_x[n*aystride] * bbase_y[n*bxstride];
261 dest_y[x*rxstride] = s;