00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef C_SUPERVISEDTRAINER__H
00034 #define C_SUPERVISEDTRAINER__H
00035
00036 #include "cparameters.h"
00037
00038 class Matrix;
00039 class ColumnVector;
00040
00041 class CGradientFunction;
00042 class CGradientUpdateFunction;
00043 class CDataSet;
00044 class CDataSet1D;
00045 class CAction;
00046 class CTorchGradientFunction;
00047 class CFeatureList;
00048 class CFeatureFunction;
00049
00050 class CSupervisedLearner : virtual public CParameterObject
00051 {
00052 protected:
00053 public:
00054 CSupervisedLearner() {};
00055 virtual ~CSupervisedLearner() {};
00056
00057
00058 virtual void learnFA(CDataSet *inputData, CDataSet1D *outputData) = 0;
00059
00060 virtual void resetLearner() {};
00061 };
00062
00063 class CSupervisedWeightedLearner : virtual public CParameterObject
00064 {
00065 protected:
00066
00067 public:
00068 CSupervisedWeightedLearner() {};
00069 virtual ~CSupervisedWeightedLearner() {};
00070
00071 virtual void learnWeightedFA(CDataSet *inputData, CDataSet1D *outputData, CDataSet1D *weighting) = 0;
00072
00073 virtual void resetLearner() {};
00074
00075 };
00076
00077 class CSupervisedQFunctionLearner : virtual public CParameterObject
00078 {
00079 protected:
00080 public:
00081 CSupervisedQFunctionLearner() {};
00082 virtual ~CSupervisedQFunctionLearner() {};
00083
00084 virtual void learnQFunction(CAction *action, CDataSet *inputData, CDataSet1D *outputData) = 0;
00085
00086 virtual void resetLearner() {};
00087 };
00088
00089
00090 class CSupervisedQFunctionWeightedLearner : virtual public CParameterObject
00091 {
00092 protected:
00093 public:
00094 CSupervisedQFunctionWeightedLearner() {};
00095 virtual ~CSupervisedQFunctionWeightedLearner() {};
00096
00097 virtual void learnQFunction(CAction *action, CDataSet *inputData, CDataSet1D *outputData, CDataSet1D *weightingData) = 0;
00098
00099 virtual void resetLearner() {};
00100 };
00101
00102
00103 class CLeastSquaresLearner : virtual public CParameterObject
00104 {
00105 protected:
00106 Matrix *A;
00107 Matrix *A_pinv;
00108 ColumnVector *b;
00109
00110 CGradientUpdateFunction *featureFunc;
00111 public:
00112 CLeastSquaresLearner(CGradientUpdateFunction *featureFunc, int numData);
00113 virtual ~CLeastSquaresLearner();
00114
00115 virtual double doOptimization();
00116 static double doOptimization(Matrix *A, Matrix *A_pinv, ColumnVector *b, ColumnVector *w, double lambda);
00117
00118 };
00119
00120
00121 class CGradientCalculator : virtual public CParameterObject
00122 {
00123 protected:
00124 public:
00125 virtual ~CGradientCalculator() {};
00126
00127 virtual void getGradient(CFeatureList *gradient) = 0;
00128 virtual double getFunctionValue() = 0;
00129
00130 virtual void resetGradientCalculator() {};
00131 };
00132
00133
00134 class CSupervisedGradientCalculator : public CGradientCalculator
00135 {
00136 protected:
00137 CDataSet *inputData;
00138
00139 CDataSet1D *outputData1D;
00140 CDataSet *outputData;
00141
00142 CGradientFunction *gradientFunction;
00143 public:
00144 CSupervisedGradientCalculator(CGradientFunction *gradientFunction, CDataSet *inputData, CDataSet *outputData);
00145 virtual ~CSupervisedGradientCalculator();
00146
00147 virtual void getGradient(CFeatureList *gradient);
00148 virtual double getFunctionValue();
00149
00150 virtual void setData(CDataSet *inputData, CDataSet1D *outputData1D);
00151 virtual void setData(CDataSet *inputData, CDataSet *outputData);
00152 };
00153
00154
00155 class CSupervisedFeatureGradientCalculator : public CSupervisedGradientCalculator
00156 {
00157 protected:
00158 CFeatureFunction *featureFunction;
00159 CFeatureList *featureList;
00160
00161 public:
00162 CSupervisedFeatureGradientCalculator(CFeatureFunction *featureFunction);
00163 virtual ~CSupervisedFeatureGradientCalculator();
00164
00165 CFeatureList *getFeatureList(ColumnVector *input);
00166
00167 virtual void getGradient(CFeatureList *gradient);
00168 virtual double getFunctionValue();
00169
00170 };
00171
00172 class CGradientFunctionUpdater : virtual public CParameterObject
00173 {
00174 protected:
00175 CGradientUpdateFunction *updateFunction;
00176 public:
00177 CGradientFunctionUpdater(CGradientUpdateFunction *updateFunction);
00178 virtual ~CGradientFunctionUpdater() {};
00179
00180 virtual void updateWeights(CFeatureList *gradient) = 0;
00181 void addRandomParams(double randSize);
00182
00183 CGradientUpdateFunction *getUpdateFunction() {return updateFunction;};
00184 };
00185
00186 class CConstantGradientFunctionUpdater : public CGradientFunctionUpdater
00187 {
00188 public:
00189 CConstantGradientFunctionUpdater(CGradientUpdateFunction *updateFunction, double learningRate);
00190 virtual ~CConstantGradientFunctionUpdater() {};
00191
00192 virtual void updateWeights(CFeatureList *gradient);
00193 };
00194
00195 class CLineSearchGradientFunctionUpdater : public CGradientFunctionUpdater
00196 {
00197 protected:
00198 double *startParameters;
00199 double *workParameters;
00200
00201 int maxSteps;
00202
00203 CGradientCalculator *gradientCalculator;
00204
00205 double precision_treshold;
00206
00207 void setWorkingParamters(CFeatureList *gradient, double stepSize, double *startParameters, double *workParameters);
00208 virtual double getFunctionValue(double *startParameters, CFeatureList *gradient, double stepSize);
00209
00210 void bracketMinimum(double *startParameters, CFeatureList *gradient, double fa, double &a, double &b, double &c);
00211 public:
00212
00213 CLineSearchGradientFunctionUpdater(CGradientCalculator *gradientCalculator, CGradientUpdateFunction *updateFunction, int maxSteps);
00214 virtual ~CLineSearchGradientFunctionUpdater();
00215
00216 virtual void updateWeights(CFeatureList *gradient);
00217 virtual double updateWeights(CFeatureList *gradient, double fold, double &lmin);
00218 };
00219
00220 class CGradientLearner : virtual public CParameterObject
00221 {
00222 protected:
00223 CGradientCalculator *gradientCalculator;
00224 public:
00225 CGradientLearner(CGradientCalculator *gradientCalculator);
00226 virtual ~CGradientLearner() {};
00227
00228 virtual double doOptimization(int maxSteps) = 0;
00229
00230 virtual void resetOptimization() {gradientCalculator->resetGradientCalculator();};
00231 };
00232
00233 class CSupervisedGradientLearner : public CSupervisedLearner
00234 {
00235 protected:
00236 CGradientLearner *gradientLearner;
00237 CSupervisedGradientCalculator *gradientCalculator;
00238 public:
00239 CSupervisedGradientLearner(CGradientLearner *gradientLearner, CSupervisedGradientCalculator *gradientCalculator, int episodes);
00240 virtual ~CSupervisedGradientLearner();
00241
00242
00243 virtual void learnFA(CDataSet *inputData, CDataSet1D *outputData);
00244
00245 virtual void resetLearner();
00246 };
00247
00248 class CSupervisedQFunctionLearnerFromLearners : public CSupervisedQFunctionLearner
00249 {
00250 protected:
00251 std::map<CAction *, CSupervisedLearner *> *learnerMap;
00252 public:
00253 CSupervisedQFunctionLearnerFromLearners(std::map<CAction *, CSupervisedLearner *> *learnerMap);
00254
00255 virtual ~CSupervisedQFunctionLearnerFromLearners();
00256
00257 virtual void learnQFunction(CAction *action, CDataSet *inputData, CDataSet1D *outputData);
00258
00259 virtual void resetLearner();
00260 };
00261
00262 class CSupervisedQFunctionWeightedLearnerFromLearners : public CSupervisedQFunctionLearner
00263 {
00264 protected:
00265 std::map<CAction *, CSupervisedWeightedLearner *> *learnerMap;
00266 public:
00267 CSupervisedQFunctionWeightedLearnerFromLearners(std::map<CAction *, CSupervisedWeightedLearner *> *learnerMap);
00268
00269 virtual ~CSupervisedQFunctionWeightedLearnerFromLearners();
00270
00271 virtual void learnQFunction(CAction *action, CDataSet *inputData, CDataSet1D *outputData, CDataSet1D *weightData);
00272
00273 virtual void resetLearner();
00274 };
00275
00276
00277 class CBatchGradientLearner : public CGradientLearner
00278 {
00279 protected:
00280 CGradientFunctionUpdater *updater;
00281 CFeatureList *gradient;
00282
00283 double treshold_f;
00284 public:
00285 CBatchGradientLearner(CGradientCalculator *gradientCalculator, CGradientFunctionUpdater *updater);
00286 ~CBatchGradientLearner();
00287
00288 virtual double doOptimization(int maxSteps);
00289 };
00290
00291 class CConjugateGradientLearner : public CGradientLearner
00292 {
00293 protected:
00294 CLineSearchGradientFunctionUpdater *gradientUpdater;
00295
00296 CFeatureList *gradnew;
00297 CFeatureList *gradold;
00298 CFeatureList *d;
00299
00300 double treshold_x;
00301 double treshold_f;
00302
00303 double fnew;
00304
00305 int exiting;
00306
00307 public:
00308 CConjugateGradientLearner(CGradientCalculator *gradientCalculator, CLineSearchGradientFunctionUpdater *updater);
00309 virtual ~CConjugateGradientLearner();
00310
00311 virtual double doOptimization(int maxGradientUpdates);
00312
00313 virtual void resetOptimization();
00314 };
00315
00316
00317 class CSupervisedNeuralNetworkMatlabLearner : public CSupervisedLearner
00318 {
00319 protected:
00320 CTorchGradientFunction *mlpFunction;
00321
00322 public:
00323 CSupervisedNeuralNetworkMatlabLearner(CTorchGradientFunction *mlpFunction, int numHidden);
00324 virtual ~CSupervisedNeuralNetworkMatlabLearner();
00325
00326
00327 virtual void learnFA(CDataSet *inputData, CDataSet1D *outputData);
00328
00329 virtual void resetLearner();
00330 };
00331
00332 class CSupervisedNeuralNetworkTorchLearner : public CSupervisedLearner, public CSupervisedWeightedLearner
00333 {
00334 protected:
00335 CTorchGradientFunction *mlpFunction;
00336
00337 public:
00338 CSupervisedNeuralNetworkTorchLearner(CTorchGradientFunction *mlpFunction);
00339 virtual ~CSupervisedNeuralNetworkTorchLearner();
00340
00341
00342 virtual void learnFA(CDataSet *inputData, CDataSet1D *outputData);
00343 virtual void learnWeightedFA(CDataSet *inputData, CDataSet1D *outputData, CDataSet1D *weighting);
00344
00345 virtual void resetLearner();
00346 };
00347
00348
00349
00350 #endif
00351