00001 #ifndef C_TREEBATCHLEARNING__H
00002 #define C_TREEBATCHLEARNING__H
00003
00004 #include "cparameters.h"
00005 #include "csupervisedlearner.h"
00006 #include "cqfunction.h"
00007 #include "cstatemodifier.h"
00008
00009 class CDataSet;
00010 class CDataSet1D;
00011 class CDataSubset;
00012 class CDataPreprocessor;
00013 class CRegressionForest;
00014
00015 class CRegressionTreeFunction;
00016
00017 class CStateProperties;
00018 class CFeatureVFunction;
00019
00020 class ActionSet;
00021 class CAction;
00022 class CActionData;
00023
00024 class CEpisodeHistory;
00025
00026 class CStateProperties;
00027 class CState;
00028 class CStateCollection;
00029
00030 class CBatchQDataGenerator;
00031 class CKDTree;
00032 class CKNearestNeighbors;
00033
00034 class CRegressionTreeVFunction;
00035
00036
00037 class CExtraRegressionForestTrainer : virtual public CParameterObject
00038 {
00039 public:
00040 CExtraRegressionForestTrainer(int numTrees, int K, int n_min, double treshold);
00041 virtual ~CExtraRegressionForestTrainer();
00042
00043 virtual CRegressionForest * getNewTree(CDataSet *input, CDataSet1D *output, CDataSet1D *weightData);
00044 };
00045
00046
00047 class CExtraRegressionForestLearner : public CExtraRegressionForestTrainer, public CSupervisedLearner, public CSupervisedWeightedLearner
00048 {
00049 protected:
00050 CRegressionTreeFunction *treeFunction;
00051
00052 public:
00053 CExtraRegressionForestLearner(CRegressionTreeFunction *treeFunction, int numTrees, int K, int n_min, double treshold);
00054 virtual ~CExtraRegressionForestLearner();
00055
00056 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00057 virtual void learnWeightedFA(CDataSet *input, CDataSet1D *output, CDataSet1D *weightData);
00058
00059 virtual void resetLearner();
00060 };
00061
00062 class CExtraRegressionForestFeatureLearner : public CNewFeatureCalculator, public CExtraRegressionForestTrainer
00063 {
00064 protected:
00065 CStateProperties *originalState;
00066
00067 public:
00068 CExtraRegressionForestFeatureLearner(CStateProperties *originalState, int numTrees, int K, int n_min, double treshold);
00069 virtual ~CExtraRegressionForestFeatureLearner();
00070
00071 virtual CFeatureCalculator * getFeatureCalculator(CFeatureVFunction *vFunction, CDataSet *inputData, CDataSet1D *outputData);
00072 };
00073
00074
00075 class CExtraLinearRegressionModelForestLearner : public CSupervisedLearner
00076 {
00077 protected:
00078 CRegressionTreeFunction *treeFunction;
00079
00080 public:
00081 CExtraLinearRegressionModelForestLearner(CRegressionTreeFunction *treeFunction, int numTrees, int K, int n_min, double treshold, int t1, int t2, int t3);
00082 virtual ~CExtraLinearRegressionModelForestLearner();
00083
00084 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00085 };
00086
00087 class CRBFForestLearner : public CSupervisedLearner
00088 {
00089 protected:
00090 CRegressionTreeFunction *treeFunction;
00091
00092 public:
00093 CRBFForestLearner(CRegressionTreeFunction *treeFunction, int numTrees, int kNN, int K, int n_min, double treshold, double varMult, double minVar);
00094 virtual ~CRBFForestLearner();
00095
00096 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00097 };
00098
00099
00100 class CLocalLinearLearner : public CSupervisedLearner
00101 {
00102 protected:
00103 CRegressionTreeFunction *treeFunction;
00104
00105 CDataSet *inputData;
00106 CDataSet1D *outputData;
00107
00108 CDataPreprocessor *preprocessor;
00109 public:
00110 CLocalLinearLearner(CRegressionTreeFunction *treeFunction, int kNN, int degree);
00111 virtual ~CLocalLinearLearner();
00112
00113 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00114 };
00115
00116
00117 class CLocalRBFLearner : public CSupervisedLearner
00118 {
00119 protected:
00120 CRegressionTreeFunction *treeFunction;
00121
00122 CDataSet *inputData;
00123 CDataSet1D *outputData;
00124
00125 CDataPreprocessor *preprocessor;
00126 public:
00127 CLocalRBFLearner(CRegressionTreeFunction *treeFunction, int kNN, double varMult);
00128 virtual ~CLocalRBFLearner();
00129
00130 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00131 };
00132
00133 class CUnknownDataQFunction : public CAbstractQFunction
00134 {
00135 protected:
00136 CStateProperties *properties;
00137 std::map<CAction *, CKDTree *> *treeMap;
00138 std::map<CAction *, CKNearestNeighbors *> *nnMap;
00139
00140
00141 std::map<CAction *, CDataPreprocessor *> *preMap;
00142
00143 CEpisodeHistory *logger;
00144
00145 CBatchQDataGenerator *dataGenerator;
00146 ColumnVector *distVector;
00147
00148 void clearMaps();
00149 public:
00150
00151 CUnknownDataQFunction(CActionSet *actions, CEpisodeHistory *logger, CStateProperties *properties, double factor);
00152
00153 virtual ~CUnknownDataQFunction();
00154
00155 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00156
00157 void recalculateTrees();
00158
00159 virtual double getUnknownDataValue(ColumnVector *distances);
00160
00161 virtual void onParametersChanged();
00162
00163 virtual void resetData();
00164 };
00165
00166
00167 class CUnknownDataQFunctionFromLocalRBFRegression : public CAbstractQFunction
00168 {
00169 protected:
00170 std::map<CAction *, CRegressionTreeVFunction *> *regressionMap;
00171 public:
00172 bool recalculateFactors;
00173
00174 CUnknownDataQFunctionFromLocalRBFRegression(CActionSet *actions, std::map<CAction *, CRegressionTreeVFunction *> *regressionMap, double factor);
00175
00176 virtual ~CUnknownDataQFunctionFromLocalRBFRegression();
00177
00178 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00179 };
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256 #endif
00257