Free Electron
BCGThreaded.h
Go to the documentation of this file.
1 /* Copyright (C) 2003-2021 Free Electron Organization
2  Any use of this software requires a license. If a valid license
3  was not distributed with this file, visit freeelectron.org. */
4 
5 /** @file */
6 
7 #ifndef __solve_BCGThreaded_h__
8 #define __solve_BCGThreaded_h__
9 
10 #define BCGT_DEBUG FALSE
11 #define BCGT_TRACE FALSE
12 #define BCGT_THREAD_DEBUG FALSE
13 #define BCGT_VERIFY (FE_CODEGEN<=FE_DEBUG)
14 
15 namespace fe
16 {
17 namespace ext
18 {
19 
20 /**************************************************************************//**
21  @brief solve Ax=b for x
22 
23  @ingroup solve
24 
25  Uses Biconjugate-Gradient. The matrix must be positive-definite,
26  but not necessarily symmetric. For symmetric matrices, a regular
27  Conjugate-Gradient can be about twice as fast.
28 
29  The arguments are templated, so any argument types should work,
30  given that they have the appropriate methods and operators.
31 *//***************************************************************************/
32 template <typename MATRIX, typename VECTOR>
34 {
35  public:
36 
37  class PointerCounted: public Counted
38  {
39  public:
40  PointerCounted(BCGThreaded* pointer)
41  { m_pointer=pointer; }
42  BCGThreaded* m_pointer;
43  };
44 
45  class Worker
46  {
47  public:
48  Worker(U32 id,sp< JobQueue<I32> > spJobQueue):
49  m_id(id),
50  m_spJobQueue(spJobQueue) {}
51 
52  void operator()(void);
53 
54  void setObject(sp<Counted> spObject)
55  {
56  m_spPointerCounted=spObject;
57  FEASSERT(m_spPointerCounted.isValid());
58  m_pBCGThreaded=m_spPointerCounted->m_pointer;
59  }
60 
61  private:
62  U32 m_id;
63  sp< JobQueue<I32> > m_spJobQueue;
64  sp<PointerCounted> m_spPointerCounted;
65  BCGThreaded* m_pBCGThreaded;
66  };
67 
68  BCGThreaded(void):
69  m_A(NULL),
70  m_x(NULL),
71  m_b(NULL),
72  m_threshold(1e-6f)
73  {
74  sp<PointerCounted> spPointerCounted(
75  new PointerCounted(this));
76  m_spGang=new Gang<Worker,I32>();
77  m_spGang->start(2,spPointerCounted);
78  }
79 virtual ~BCGThreaded(void)
80  {
81  m_spGang->post(-1);
82  m_spGang->post(-1);
83  m_spGang->finish();
84  }
85 
86  void solve(VECTOR& x, const MATRIX& A, const VECTOR& b);
87  void setThreshold(F64 threshold) { m_threshold=threshold; }
88 
89  private:
90  void solve(U32 thread);
91  void solve(U32 thread, VECTOR& x, const MATRIX& A,const VECTOR& b);
92 
93  sp< Gang<Worker,I32> > m_spGang;
94 
95  const MATRIX* m_A;
96  VECTOR* m_x;
97  const VECTOR* m_b;
98 
99  VECTOR r,r_1,r_2; //* residual (at k, k-1, k-2)
100  VECTOR rb,rb_1,rb_2; //* second residual (r bar)
101  VECTOR p,p_1; //* direction
102  VECTOR pb,pb_1; //* second direction
103 
104  VECTOR temp[2]; //* persistent temporary
105 
106  VECTOR Ap;
107  F64 m_threshold;
108  F64 m_dot_r_1;
109  F64 m_alpha;
110  F64 m_beta;
111  U32 m_N;
112  BWORD m_break;
113 };
114 
115 template <typename MATRIX, typename VECTOR>
117 {
118  while(TRUE)
119  {
120  I32 job;
121 #if BCGT_THREAD_DEBUG
122  feLog("BCGThreaded<>::Worker::operator %p thread %d wait\n",
123  m_spJobQueue.raw(),m_id);
124 #endif
125  m_spJobQueue->waitForJob(job);
126  if(job<0)
127  {
128 #if BCGT_THREAD_DEBUG
129  feLog("BCGThreaded<>::Worker::operator %p thread %d break\n",
130  m_spJobQueue.raw(),m_id);
131 #endif
132  break;
133  }
134 #if BCGT_THREAD_DEBUG
135  feLog("BCGThreaded<>::Worker::operator %p thread %d solve %d\n",
136  m_spJobQueue.raw(),m_id,job);
137 #endif
138  m_pBCGThreaded->solve(job);
139 #if BCGT_THREAD_DEBUG
140  feLog("BCGThreaded<>::Worker::operator %p thread %d deliver %d\n",
141  m_spJobQueue.raw(),m_id,job);
142 #endif
143  m_spJobQueue->deliver(job);
144  }
145 }
146 
147 template <typename MATRIX, typename VECTOR>
148 inline void BCGThreaded<MATRIX,VECTOR>::solve(VECTOR& x,
149  const MATRIX& A, const VECTOR& b)
150 {
151  m_A=&A;
152  m_x=&x;
153  m_b=&b;
154 
155 #if BCGT_THREAD_DEBUG
156  feLog("BCGThreaded<>::solve %p post\n",m_spGang.raw());
157 #endif
158  m_spGang->post(0,1);
159  U32 jobs=2;
160  I32 job;
161  while(jobs)
162  {
163  if(m_spGang->acceptDelivery(job))
164  {
165 #if BCGT_THREAD_DEBUG
166  feLog("BCGThreaded<>::solve %p acceptDelivery %d\n",
167  m_spGang.raw(),job);
168 #endif
169  jobs--;
170  }
171  }
172 
173  m_A=NULL;
174  m_x=NULL;
175  m_b=NULL;
176 
177 #if BCGT_THREAD_DEBUG
178  feLog("BCGThreaded<>::solve %p complete\n",m_spGang.raw());
179 #endif
180 }
181 
182 template <typename MATRIX, typename VECTOR>
183 inline void BCGThreaded<MATRIX,VECTOR>::solve(U32 thread)
184 {
185  solve(thread,*m_x,*m_A,*m_b);
186 }
187 
188 template <typename MATRIX, typename VECTOR>
189 inline void BCGThreaded<MATRIX,VECTOR>::solve(U32 thread,VECTOR& x,
190  const MATRIX& A, const VECTOR& b)
191 {
192  if(!thread)
193  {
194  m_N=size(b);
195  if(size(x)!=m_N)
196  {
197  x=b; // adopt size
198  }
199  if(size(Ap)!=m_N)
200  {
201  Ap=b; // adopt size
202  }
203  set(x);
204  temp[1]=x;
205  m_break=FALSE;
206 
207 #if BCGT_DEBUG
208  feLog("\nA\n%s\nb=<%s>\n",c_print(A),c_print(b));
209 #endif
210 
211  if(magnitudeSquared(b)<m_threshold)
212  {
213 #if BCGT_DEBUG
214  feLog("BCGThreaded::solve has trivial solution\n");
215 #endif
216  m_break=TRUE;
217  }
218  }
219 
220  m_spGang->synchronize();
221  if(m_break)
222  {
223  return;
224  }
225 
226  for(U32 k=1;k<=m_N;k++)
227  {
228  if(!thread)
229  {
230  if(k==1)
231  {
232  p=b;
233  p_1=p;
234 
235  set(r);
236  r_1=p;
237  r_2=r_1;
238 
239  pb=b;
240  rb_1=pb;
241  set(rb_2);
242  set(pb_1);
243 
244  m_dot_r_1=dot(rb_1,r_1);
245  m_beta=0.0f;
246  }
247  else
248  {
249  r_2=r_1;
250  r_1=r;
251  p_1=p;
252 
253  rb_2=rb_1;
254  rb_1=rb;
255  pb_1=pb;
256 
257  m_dot_r_1=dot(rb_1,r_1);
258  m_beta=m_dot_r_1/dot(rb_2,r_2);
259 
260 // p=r_1+m_beta*p_1;
261  temp[0]=p_1;
262  temp[0]*=m_beta;
263  p=r_1;
264  p+=temp[0];
265 
266 // pb=rb_1+m_beta*pb_1;
267  temp[0]=pb_1;
268  temp[0]*=m_beta;
269  pb=rb_1;
270  pb+=temp[0];
271  }
272  }
273 
274  m_spGang->synchronize();
275  if(!thread)
276  {
277  transformVector(A,p,Ap);
278  m_alpha=m_dot_r_1/dot(pb,Ap);
279 
280  BWORD zero_mag=(magnitudeSquared(p)==0.0f);
281 
282 #if BCGT_TRACE==FALSE
283  if(zero_mag)
284 #endif
285  {
286  feLog("\n%d m_alpha=%.6G beta=%.6G\n",k,m_alpha,m_beta);
287  feLog("x<%s>\n",c_print(x));
288  feLog("r<%s>\n",c_print(r));
289  feLog("r_1<%s>\n",c_print(r_1));
290  feLog("r_2<%s>\n",c_print(r_2));
291  feLog("rb_1<%s>\n",c_print(rb_1));
292  feLog("rb_2<%s>\n",c_print(rb_2));
293  feLog("p<%s>\n",c_print(p));
294  feLog("p_1<%s>\n",c_print(p_1));
295  feLog("A*p<%s>\n",c_print(A*p));
296  feLog("pb<%s>\n",c_print(pb));
297  feLog("pb_1<%s>\n",c_print(pb_1));
298  }
299 
300  if(zero_mag)
301  {
302  feX("BCGThreaded::solve","direction lost its magnitude");
303  }
304 
305 // x+=m_alpha*p;
306  temp[0]=p;
307  temp[0]*=m_alpha;
308  x+=temp[0];
309 
310 // r=r_1-m_alpha*(Ap);
311  temp[0]=Ap;
312  temp[0]*=m_alpha;
313  r=r_1;
314  r-=temp[0];
315 
316  if(magnitudeSquared(r)<m_threshold)
317  {
318 #if BCGT_DEBUG
319  feLog("BCGThreaded::solve early solve %d/%d\n",k,m_N);
320 #endif
321  m_break=TRUE;
322  m_spGang->synchronize();
323  break;
324  }
325 
326  if(k==m_N)
327  {
328  feLog("BCGThreaded::solve ran %d/%d\n",k,m_N);
329  }
330 
331  m_spGang->synchronize();
332 
333  temp[1]*=m_alpha;
334  rb=rb_1;
335  rb-=temp[1];
336 
337 #if BCGT_TRACE
338  feLog("temp[1]<%s>\n",c_print(temp[1]));
339  feLog("rb<%s>\n",c_print(rb));
340 #endif
341  }
342  else
343  {
344 // rb=rb_1-m_alpha*transposeMultiply(A,pb);
345  transposeTransformVector(A,pb,temp[1]);
346 
347  m_spGang->synchronize();
348 
349  if(m_break)
350  {
351  break;
352  }
353  }
354  }
355 
356  if(!thread)
357  {
358 #if BCGT_DEBUG
359  feLog("\nx=<%s>\nA*x=<%s>\n",c_print(x),c_print(A*x));
360 #endif
361 
362 #if BCGT_VERIFY
363  BWORD invalid=FALSE;
364  for(U32 k=0;k<m_N;k++)
365  {
366  if(FE_INVALID_SCALAR(x[k]))
367  {
368  invalid=TRUE;
369  }
370  }
371  VECTOR Ax=A*x;
372  F64 distance=magnitude(Ax-b);
373  if(invalid || distance>1.0f)
374  {
375  feLog("BCGThreaded::solve failed to converge (dist=%.6G)\n",
376  distance);
377  if(size(x)<100)
378  {
379  feLog(" collecting state ...\n");
380  feLog("A=\n%s\nx=<%s>\nA*x=<%s>\nb=<%s>\n",
381  c_print(A),c_print(x),
382  c_print(Ax),c_print(b));
383  }
384  // feX("BCGThreaded::solve","failed to converge");
385  }
386 #endif
387  }
388 }
389 
390 } /* namespace ext */
391 } /* namespace fe */
392 
393 #endif /* __solve_BCGThreaded_h__ */
Heap-based support for classes participating in fe::ptr <>
Definition: Counted.h:35
kernel
Definition: namespace.dox:3
Intrusive Smart Pointer.
Definition: src/core/ptr.h:53
solve Ax=b for x
Definition: BCGThreaded.h:33