I attempt to test how well ceres solver autodiff in fitting a simple logistic growth curve against actual observed data , to my surprise the solver seemed unable to provide solution . Using other c++ solvers , result can be easily obtained with ease , k=9643.61,c=84.61 and b=3.8121. I am not sure is the code having issue or just the ceres solver autodiff are not that well built? Any advice pls?
below is the sample
#include "ceres/ceres.h"
#include "glog/logging.h"
#include <cmath>
#include <iostream>
#include <stdio.h>
using ceres::AutoDiffCostFunction;
using ceres::CauchyLoss;
using ceres::CostFunction;
using ceres::Problem;
using ceres::Solve;
using ceres::Solver;
struct ExponentialResidual
{
ExponentialResidual(double x, double y)
: x_(x), y_(y) {}
template <typename T>
bool operator()(const T *const k,
const T *const c,
const T *const b,
//const T* const g,
T *residual) const
{
residual[0] = y_ - (k[0] / (1.0 + (pow((x_ / c[0]), b[0]))));
return true;
}
private:
const double x_;
const double y_;
};
const int kNumObservations = 247;
double data[] = {
0,3,
1,4,
2,4,
3,4,
4,7,
5,8,
6,8,
7,8,
8,8,
9,8,
10,10,
11,12,
12,12,
13,12,
14,16,
15,16,
16,18,
17,18,
18,18,
19,19,
20,19,
21,22,
22,22,
23,22,
24,22,
25,22,
26,22,
27,22,
28,22,
29,22,
30,22,
31,22,
32,22,
33,23,
34,23,
35,25,
36,29,
37,32,
38,36,
39,50,
40,55,
41,83,
42,93,
43,99,
44,117,
45,129,
46,149,
47,158,
48,197,
49,238,
50,428,
51,553,
52,673,
53,790,
54,900,
55,1030,
56,1183,
57,1306,
58,1518,
59,1624,
60,1796,
61,2031,
62,2161,
63,2320,
64,2470,
65,2626,
66,2766,
67,2908,
68,3116,
69,3333,
70,3483,
71,3662,
72,3793,
73,3963,
74,4119,
75,4228,
76,4346,
77,4530,
78,4683,
79,4817,
80,4987,
81,5072,
82,5182,
83,5251,
84,5305,
85,5389,
86,5425,
87,5482,
88,5532,
89,5603,
90,5691,
91,5742,
92,5780,
93,5820,
94,5851,
95,5945,
96,6002,
97,6071,
98,6176,
99,6298,
100,6353,
101,6383,
102,6428,
103,6467,
104,6535,
105,6589,
106,6656,
107,6726,
108,6742,
109,6779,
110,6819,
111,6855,
112,6872,
113,6894,
114,6941,
115,6978,
116,7009,
117,7059,
118,7137,
119,7185,
120,7245,
121,7417,
122,7604,
123,7619,
124,7629,
125,7732,
126,7762,
127,7819,
128,7857,
129,7877,
130,7970,
131,8247,
132,8266,
133,8303,
134,8322,
135,8329,
136,8336,
137,8338,
138,8369,
139,8402,
140,8445,
141,8453,
142,8494,
143,8505,
144,8515,
145,8529,
146,8535,
147,8556,
148,8572,
149,8587,
150,8590,
151,8596,
152,8600,
153,8606,
154,8616,
155,8634,
156,8637,
157,8639,
158,8640,
159,8643,
160,8648,
161,8658,
162,8663,
163,8668,
164,8674,
165,8677,
166,8683,
167,8696,
168,8704,
169,8718,
170,8725,
171,8729,
172,8734,
173,8737,
174,8755,
175,8764,
176,8779,
177,8800,
178,8815,
179,8831,
180,8840,
181,8861,
182,8884,
183,8897,
184,8904,
185,8943,
186,8956,
187,8964,
188,8976,
189,8985,
190,8999,
191,9001,
192,9002,
193,9023,
194,9038,
195,9063,
196,9070,
197,9083,
198,9094,
199,9103,
200,9114,
201,9129,
202,9149,
203,9175,
204,9200,
205,9212,
206,9219,
207,9235,
208,9240,
209,9249,
210,9257,
211,9267,
212,9274,
213,9285,
214,9291,
215,9296,
216,9306,
217,9317,
218,9334,
219,9340,
220,9354,
221,9360,
222,9374,
223,9385,
224,9391,
225,9397,
226,9459,
227,9559,
228,9583,
229,9628,
230,9810,
231,9868,
232,9915,
233,9946,
234,9969,
235,10031,
236,10052,
237,10147,
238,10167,
239,10219,
240,10276,
241,10358,
242,10505,
243,10576,
244,10687,
245,10769,
246,10919,
};
int main(int argc, char const *argv[])
{
google::InitGoogleLogging(argv[0]);
double k = 20000.0;
//double c=0.5;
double c = kNumObservations / 2.0;
double b = 0.5;
double g = 1.0;
Problem problem;
for (int i = 0; i < kNumObservations; i++)
{
problem.AddResidualBlock(
new AutoDiffCostFunction<ExponentialResidual, 1, 1, 1, 1>(
new ExponentialResidual(data[2 * i]*1.00, data[2 * i + 1]*1.00)),
new CauchyLoss(0.5), &k, &c, &b);
}
Solver::Options options;
options.max_num_iterations = 1000;
options.linear_solver_type = ceres::DENSE_QR;
//options.trust_region_strategy_type=ceres::DOGLEG;
//options.gradient_tolerance=1e-8;
//options.parameter_tolerance=1e-10;
//options.function_tolerance=1e-8;
options.minimizer_progress_to_stdout = true;
Solver::Summary summary;
Solve(options, &problem, &summary);
//std::cout<<summary.BriefReport()<<std::endl;
std::cout << summary.FullReport() << std::endl;
std::cout << "Final k: " << k << " c: " << c << " b: " << b << " g: " << g << "\n";
/* code */
return 0;
}