OSDN Git Service

339c9c03554ee026b0dbc3f8cdc07039a06a08c9
[pf3gnuchains/gcc-fork.git] / libgfortran / generated / matmul_c4.c
1 /* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
30
31 #include "config.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>
35 #include "libgfortran.h"
36
37 #if defined (HAVE_GFC_COMPLEX_4)
38
39 /* The order of loops is different in the case of plain matrix
40    multiplication C=MATMUL(A,B), and in the frequent special case where
41    the argument A is the temporary result of a TRANSPOSE intrinsic:
42    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
43    looking at their strides.
44
45    The equivalent Fortran pseudo-code is:
46
47    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
48    IF (.NOT.IS_TRANSPOSED(A)) THEN
49      C = 0
50      DO J=1,N
51        DO K=1,COUNT
52          DO I=1,M
53            C(I,J) = C(I,J)+A(I,K)*B(K,J)
54    ELSE
55      DO J=1,N
56        DO I=1,M
57          S = 0
58          DO K=1,COUNT
59            S = S+A(I,K)+B(K,J)
60          C(I,J) = S
61    ENDIF
62 */
63
64 extern void matmul_c4 (gfc_array_c4 * const restrict retarray, 
65         gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b);
66 export_proto(matmul_c4);
67
68 void
69 matmul_c4 (gfc_array_c4 * const restrict retarray, 
70         gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b)
71 {
72   const GFC_COMPLEX_4 * restrict abase;
73   const GFC_COMPLEX_4 * restrict bbase;
74   GFC_COMPLEX_4 * restrict dest;
75
76   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
77   index_type x, y, n, count, xcount, ycount;
78
79   assert (GFC_DESCRIPTOR_RANK (a) == 2
80           || GFC_DESCRIPTOR_RANK (b) == 2);
81
82 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
83
84    Either A or B (but not both) can be rank 1:
85
86    o One-dimensional argument A is implicitly treated as a row matrix
87      dimensioned [1,count], so xcount=1.
88
89    o One-dimensional argument B is implicitly treated as a column matrix
90      dimensioned [count, 1], so ycount=1.
91   */
92
93   if (retarray->data == NULL)
94     {
95       if (GFC_DESCRIPTOR_RANK (a) == 1)
96         {
97           retarray->dim[0].lbound = 0;
98           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
99           retarray->dim[0].stride = 1;
100         }
101       else if (GFC_DESCRIPTOR_RANK (b) == 1)
102         {
103           retarray->dim[0].lbound = 0;
104           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
105           retarray->dim[0].stride = 1;
106         }
107       else
108         {
109           retarray->dim[0].lbound = 0;
110           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
111           retarray->dim[0].stride = 1;
112
113           retarray->dim[1].lbound = 0;
114           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
115           retarray->dim[1].stride = retarray->dim[0].ubound+1;
116         }
117
118       retarray->data
119         = internal_malloc_size (sizeof (GFC_COMPLEX_4) * size0 ((array_t *) retarray));
120       retarray->offset = 0;
121     }
122
123
124   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
125     {
126       /* One-dimensional result may be addressed in the code below
127          either as a row or a column matrix. We want both cases to
128          work. */
129       rxstride = rystride = retarray->dim[0].stride;
130     }
131   else
132     {
133       rxstride = retarray->dim[0].stride;
134       rystride = retarray->dim[1].stride;
135     }
136
137
138   if (GFC_DESCRIPTOR_RANK (a) == 1)
139     {
140       /* Treat it as a a row matrix A[1,count]. */
141       axstride = a->dim[0].stride;
142       aystride = 1;
143
144       xcount = 1;
145       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
146     }
147   else
148     {
149       axstride = a->dim[0].stride;
150       aystride = a->dim[1].stride;
151
152       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
153       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
154     }
155
156   assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
157
158   if (GFC_DESCRIPTOR_RANK (b) == 1)
159     {
160       /* Treat it as a column matrix B[count,1] */
161       bxstride = b->dim[0].stride;
162
163       /* bystride should never be used for 1-dimensional b.
164          in case it is we want it to cause a segfault, rather than
165          an incorrect result. */
166       bystride = 0xDEADBEEF;
167       ycount = 1;
168     }
169   else
170     {
171       bxstride = b->dim[0].stride;
172       bystride = b->dim[1].stride;
173       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
174     }
175
176   abase = a->data;
177   bbase = b->data;
178   dest = retarray->data;
179
180   if (rxstride == 1 && axstride == 1 && bxstride == 1)
181     {
182       const GFC_COMPLEX_4 * restrict bbase_y;
183       GFC_COMPLEX_4 * restrict dest_y;
184       const GFC_COMPLEX_4 * restrict abase_n;
185       GFC_COMPLEX_4 bbase_yn;
186
187       if (rystride == xcount)
188         memset (dest, 0, (sizeof (GFC_COMPLEX_4) * xcount * ycount));
189       else
190         {
191           for (y = 0; y < ycount; y++)
192             for (x = 0; x < xcount; x++)
193               dest[x + y*rystride] = (GFC_COMPLEX_4)0;
194         }
195
196       for (y = 0; y < ycount; y++)
197         {
198           bbase_y = bbase + y*bystride;
199           dest_y = dest + y*rystride;
200           for (n = 0; n < count; n++)
201             {
202               abase_n = abase + n*aystride;
203               bbase_yn = bbase_y[n];
204               for (x = 0; x < xcount; x++)
205                 {
206                   dest_y[x] += abase_n[x] * bbase_yn;
207                 }
208             }
209         }
210     }
211   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
212     {
213       const GFC_COMPLEX_4 *restrict abase_x;
214       const GFC_COMPLEX_4 *restrict bbase_y;
215       GFC_COMPLEX_4 *restrict dest_y;
216       GFC_COMPLEX_4 s;
217
218       for (y = 0; y < ycount; y++)
219         {
220           bbase_y = &bbase[y*bystride];
221           dest_y = &dest[y*rystride];
222           for (x = 0; x < xcount; x++)
223             {
224               abase_x = &abase[x*axstride];
225               s = (GFC_COMPLEX_4) 0;
226               for (n = 0; n < count; n++)
227                 s += abase_x[n] * bbase_y[n];
228               dest_y[x] = s;
229             }
230         }
231     }
232   else if (axstride < aystride)
233     {
234       for (y = 0; y < ycount; y++)
235         for (x = 0; x < xcount; x++)
236           dest[x*rxstride + y*rystride] = (GFC_COMPLEX_4)0;
237
238       for (y = 0; y < ycount; y++)
239         for (n = 0; n < count; n++)
240           for (x = 0; x < xcount; x++)
241             /* dest[x,y] += a[x,n] * b[n,y] */
242             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
243     }
244   else
245     {
246       const GFC_COMPLEX_4 *restrict abase_x;
247       const GFC_COMPLEX_4 *restrict bbase_y;
248       GFC_COMPLEX_4 *restrict dest_y;
249       GFC_COMPLEX_4 s;
250
251       for (y = 0; y < ycount; y++)
252         {
253           bbase_y = &bbase[y*bystride];
254           dest_y = &dest[y*rystride];
255           for (x = 0; x < xcount; x++)
256             {
257               abase_x = &abase[x*axstride];
258               s = (GFC_COMPLEX_4) 0;
259               for (n = 0; n < count; n++)
260                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
261               dest_y[x*rxstride] = s;
262             }
263         }
264     }
265 }
266
267 #endif