/ani/mrses

To get this branch, use:
bzr branch http://darksoft.org/webbzr/ani/mrses
1 by Suren A. Chilingaryan
Initial import
1
#include <stdio.h>
2
#include <math.h>
3
4
#include <cblas.h>
5
#include "mrses.h"
6
7
int atlas_spotrf2(const int N, MRSESDataType *A, const int lda)
8
{
9
   int j;
10
   MRSESDataType Ajj, *Ac=A, *An=A+lda;
11
12
   for (j=0; j != N; j++)
13
   {
14
      Ajj = Ac[j] - cblas_sdot(j, Ac, 1, Ac, 1);
15
      if (Ajj > 0)
16
      {
17
         Ac[j] = Ajj = sqrt(Ajj);
18
         if (j != N-1)
19
         {
20
            cblas_sgemv(CblasColMajor, CblasTrans, j, N-j-1, -1,
21
                       An, lda, Ac, 1, 1, An+j, lda);
22
            cblas_sscal(N-j-1, 1/Ajj, An+j, lda);
23
            Ac = An;
24
            An += lda;
25
         }
26
      }
27
      else
28
      {
29
         Ac[j] = Ajj;
30
         return(j+1);
31
      }
32
   }
33
   return(0);
34
}
35
36
37
#define TYPE MRSESDataType
38
#define ATL_rzero 0
39
#define ATL_rone 1
40
#define ATL_rnone -1
41
#define ONE ATL_rone
42
43
inline int ATL_potrfU_4(TYPE *A, const short int lda)
44
{
45
   TYPE *pA1=A+lda, *pA2=pA1+lda, *pA3=pA2+lda;
46
   TYPE L11 = *A, L21 = *pA1, L31 = *pA2, L41 = *pA3;
47
   TYPE L22 = pA1[1], L32 = pA2[1], L42 = pA3[1];
48
   TYPE L33 = pA2[2], L43 = pA3[2];
49
   TYPE L44 = pA3[3];
50
   int iret=0;
51
52
   if (L11 > ATL_rzero)
53
   {
54
      *A = L11 = sqrt(L11);
55
      L11 = ATL_rone / L11;
56
      L21 *= L11;
57
      L31 *= L11;
58
      L41 *= L11;
59
      *pA1 = L21; *pA2 = L31; *pA3 = L41;
60
      L22 -= L21*L21;
61
      if (L22 > ATL_rzero)
62
      {
63
         pA1[1] = L22 = sqrt(L22);
64
         L22 = ATL_rone / L22;
65
         L32 = (L32 - L31*L21) * L22;
66
         L42 = (L42 - L41*L21) * L22;
67
         L33 -= L31*L31 + L32*L32;
68
         pA2[1] = L32; pA3[1] = L42;
69
         if (L33 > ATL_rzero)
70
         {
71
            pA2[2] = L33 = sqrt(L33);
72
            L43 = (L43 - L41*L31 - L42*L32) / L33;
73
            L44 -= L41*L41 + L42*L42 + L43*L43;
74
            pA3[2] = L43;
75
            if (L44 > ATL_rzero)
76
            {
77
               pA3[3] = sqrt(L44);
78
               return(0);
79
            }
80
            else iret=4;
81
         }
82
         else iret=3;
83
      }
84
      else iret=2;
85
   }
86
   else iret=1;
87
   return(iret);
88
}
89
90
inline int ATL_potrfU_3(TYPE *A, const short int lda)
91
{
92
   TYPE *pA1=A+lda, *pA2=pA1+lda;
93
   register TYPE L11 = *A, L21 = *pA1, L31 = *pA2;
94
   register TYPE L22=pA1[1], L32=pA2[1];
95
   register TYPE L33=pA2[2];
96
   int iret=0;
97
98
   if (L11 > ATL_rzero)
99
   {
100
      *A = L11 = sqrt(L11);
101
      L11 = ATL_rone / L11;
102
      L21 *= L11;
103
      L31 *= L11;
104
      *pA1 = L21; *pA2 = L31;
105
      L22 -= L21*L21;
106
      if (L22 > ATL_rzero)
107
      {
108
         L22 = sqrt(L22);
109
         L32 = (L32 - L31*L21) / L22;
110
         L33 -= L31*L31 + L32*L32;
111
         pA1[1] = L22; pA2[1] = L32;
112
         if (L33 > ATL_rzero)
113
         {
114
            pA2[2] = sqrt(L33);
115
            return(0);
116
         }
117
         else iret=3;
118
      }
119
      else iret=2;
120
   }
121
   else iret=1;
122
   return(iret);
123
}
124
125
inline int ATL_potrfU_2(TYPE *A, const short int lda)
126
{
127
   TYPE *pA1 = A+lda;
128
   register TYPE L11=*A, L21=*pA1, L22 = pA1[1];
129
130
   if (L11 > ATL_rzero)
131
   {
132
      *A = L11 = sqrt(L11);
133
      *pA1 = L21 = L21 / L11;
134
      L22 -= L21*L21;
135
      if (L22 > ATL_rzero)
136
      {
137
         pA1[1] = sqrt(L22);
138
         return(0);
139
      }
140
      else return(2);
141
   }
142
   else return(1);
143
}
144
145
#define BASE float
146
#define INDEX short int
147
inline void
148
gsl_cblas_strsm_clut_1 (const short int M, const short int N, const float *A, float *B, const short int lda)
149
{
150
    register short int i, j, k;
151
    float Ajj, Bij;
152
153
    for (i = 0; i < N; i++) {
154
      for (j = 0; j < M; j++) {
155
          Ajj = A[lda * j + j];
156
          Bij = B[lda * i + j] / Ajj;
157
          
158
	  B[lda * i + j] = Bij;
159
          for (k = j + 1; k < M; k++) {
160
            B[lda * i + k] -= A[k * lda + j] * Bij;
161
          }
162
      }
163
    }
164
}
165
166
inline void
167
gsl_cblas_ssyrk_cut_m11 (const short int N, const short int M, const float *A, float *C, const short int lda) {
168
  register short int i, j, k;
169
170
    for (i = 0; i < N; i++) {
171
      for (j = 0; j <= i; j++) {
172
        float temp = 0.0;
173
        for (k = 0; k < M; k++) {
174
          temp += A[i * lda + k] * A[j * lda + k];
175
        }
176
        C[i * lda + j] -= temp;
177
      }
178
    }
179
}
180
181
inline void gsl_spotrf_step(const short int M, const short int N, float *A, const short int lda) {
182
  register short int i, j, k;
183
184
  float *B = A + M*lda;
185
  float *C = B + M;
186
187
  for (i = 0; i < N; i++) {
188
      for (j = 0; j < M; j++) {
189
	  float sum = 0;
190
	  
191
	  for (k = 0; k < j; k++) {
192
            sum += A[j * lda + k] *  B[i * lda + k];
193
          }
194
	  
195
	  B[i * lda + j] =  (B[i * lda + j] - sum) / A[j * lda + j];
196
      }
197
198
      for (j = 0; j <= i; j++) {
199
        float temp = 0.0;
200
201
        for (k = 0; k < M; k++) {
202
          temp += B[i * lda + k] * B[j * lda + k];
203
	}
204
205
        C[i * lda + j] -= temp;
206
      }
207
  }
208
}
209
210
211
int atlas_spotrf_u(const short int N, TYPE *A, const short int lda)
212
{
213
   TYPE *Ac, *An;
214
   short int Nleft, Nright, ierr;
215
216
  if (N > 4)
217
  {
218
      Nleft = N >> 1;
219
      Nright = N - Nleft;
220
      ierr = atlas_spotrf_u(Nleft, A, lda);
221
      if (!ierr)
222
      {
223
        Ac = A + lda * Nleft;
224
        An = Ac + Nleft; //SHIFT;
225
        gsl_spotrf_step(Nleft, Nright, A, lda);
226
227
/*
228
//       gsl_cblas_strsm_clut_1(Nleft, Nright, A, Ac, lda);
229
//	 gsl_cblas_ssyrk_cut_m11(Nright, Nleft, Ac, An, lda);
230
231
232
//         cblas_strsm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans,
233
//                    CblasNonUnit, Nleft, Nright, ONE, A, lda, Ac, lda);
234
235
236
237
//         cblas_ssyrk(CblasColMajor, CblasUpper, CblasTrans, Nright, Nleft,
238
//                  ATL_rnone, Ac, lda, ATL_rone, An, lda);
239
240
*/
241
         ierr = atlas_spotrf_u(Nright, An, lda);
242
         if (ierr) return(ierr+Nleft);
243
      }
244
      else return(ierr);
245
   }
246
      else if (N==4) return(ATL_potrfU_4(A, lda));
247
      else if (N==3) return(ATL_potrfU_3(A, lda));
248
      else if (N==2) return(ATL_potrfU_2(A, lda));
249
      else if (N==1)
250
      {
251
         if (*A > ATL_rzero) *A = sqrt(*A);
252
         else return(1);
253
      }
254
   return(0);
255
}