Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

ctreebatchlearning.h

Go to the documentation of this file.
00001 #ifndef C_TREEBATCHLEARNING__H
00002 #define C_TREEBATCHLEARNING__H
00003 
00004 #include "cparameters.h"
00005 #include "csupervisedlearner.h"
00006 #include "cqfunction.h"
00007 #include "cstatemodifier.h"
00008 
00009 class CDataSet;
00010 class CDataSet1D;
00011 class CDataSubset;
00012 class CDataPreprocessor;
00013 class CRegressionForest;
00014 
00015 class CRegressionTreeFunction;
00016 
00017 class CStateProperties;
00018 class CFeatureVFunction;
00019 
00020 class ActionSet;
00021 class CAction;
00022 class CActionData;
00023 
00024 class CEpisodeHistory;
00025 
00026 class CStateProperties;
00027 class CState;
00028 class CStateCollection;
00029 
00030 class CBatchQDataGenerator;
00031 class CKDTree;
00032 class CKNearestNeighbors;
00033 
00034 class CRegressionTreeVFunction;
00035 
00036 
00037 class CExtraRegressionForestTrainer : virtual public CParameterObject
00038 {
00039         public:
00040                 CExtraRegressionForestTrainer(int numTrees, int K, int n_min, double treshold);
00041                 virtual ~CExtraRegressionForestTrainer();
00042 
00043                 virtual CRegressionForest * getNewTree(CDataSet *input, CDataSet1D *output, CDataSet1D *weightData);
00044 };
00045 
00046 
00047 class CExtraRegressionForestLearner : public CExtraRegressionForestTrainer, public CSupervisedLearner, public CSupervisedWeightedLearner
00048 {
00049         protected:
00050                 CRegressionTreeFunction *treeFunction;
00051 
00052         public:
00053                 CExtraRegressionForestLearner(CRegressionTreeFunction *treeFunction, int numTrees, int K, int n_min, double treshold);
00054                 virtual ~CExtraRegressionForestLearner();
00055 
00056                 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00057                 virtual void learnWeightedFA(CDataSet *input, CDataSet1D *output, CDataSet1D *weightData);
00058 
00059                 virtual void resetLearner();
00060 };
00061 
00062 class CExtraRegressionForestFeatureLearner : public CNewFeatureCalculator, public CExtraRegressionForestTrainer
00063 {
00064         protected:
00065                 CStateProperties *originalState;
00066 
00067         public:
00068                 CExtraRegressionForestFeatureLearner(CStateProperties *originalState, int numTrees, int K, int n_min, double treshold);
00069                 virtual ~CExtraRegressionForestFeatureLearner();
00070 
00071                 virtual CFeatureCalculator * getFeatureCalculator(CFeatureVFunction *vFunction, CDataSet *inputData, CDataSet1D *outputData);
00072 };
00073 
00074 
00075 class CExtraLinearRegressionModelForestLearner : public CSupervisedLearner
00076 {
00077         protected:
00078                 CRegressionTreeFunction *treeFunction;
00079 
00080         public:
00081                 CExtraLinearRegressionModelForestLearner(CRegressionTreeFunction *treeFunction, int numTrees, int K, int n_min, double treshold, int t1, int t2, int t3);
00082                 virtual ~CExtraLinearRegressionModelForestLearner();
00083 
00084                 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00085 };
00086 
00087 class CRBFForestLearner : public CSupervisedLearner
00088 {
00089         protected:
00090                 CRegressionTreeFunction *treeFunction;
00091 
00092         public:
00093                 CRBFForestLearner(CRegressionTreeFunction *treeFunction, int numTrees, int kNN, int K, int n_min, double treshold, double varMult, double minVar);
00094                 virtual ~CRBFForestLearner();
00095 
00096                 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00097 };
00098 
00099 
00100 class CLocalLinearLearner : public CSupervisedLearner
00101 {
00102         protected:
00103                 CRegressionTreeFunction *treeFunction;
00104 
00105                 CDataSet *inputData;
00106                 CDataSet1D *outputData;
00107 
00108                 CDataPreprocessor *preprocessor;
00109         public:
00110                 CLocalLinearLearner(CRegressionTreeFunction *treeFunction, int kNN, int degree);
00111                 virtual ~CLocalLinearLearner();
00112 
00113                 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00114 };
00115 
00116 
00117 class CLocalRBFLearner : public CSupervisedLearner
00118 {
00119         protected:
00120                 CRegressionTreeFunction *treeFunction;
00121 
00122                 CDataSet *inputData;
00123                 CDataSet1D *outputData;
00124 
00125                 CDataPreprocessor *preprocessor;
00126         public:
00127                 CLocalRBFLearner(CRegressionTreeFunction *treeFunction, int kNN, double varMult);
00128                 virtual ~CLocalRBFLearner();
00129 
00130                 virtual void learnFA(CDataSet *input, CDataSet1D *output);
00131 };
00132 
00133 class CUnknownDataQFunction : public CAbstractQFunction
00134 {
00135 protected:
00136         CStateProperties *properties;
00137         std::map<CAction *, CKDTree *> *treeMap;
00138         std::map<CAction *, CKNearestNeighbors *> *nnMap;
00139 
00140 //      std::map<CAction *, ColumnVector *> *bufferMap;
00141         std::map<CAction *, CDataPreprocessor *> *preMap;
00142 
00143         CEpisodeHistory *logger;
00144 
00145         CBatchQDataGenerator *dataGenerator;
00146         ColumnVector *distVector;
00147 
00148         void clearMaps();
00149 public:
00150 
00151         CUnknownDataQFunction(CActionSet *actions, CEpisodeHistory *logger, CStateProperties *properties, double factor);
00152 
00153         virtual ~CUnknownDataQFunction();
00154 
00155         virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00156 
00157         void recalculateTrees();
00158 
00159         virtual double getUnknownDataValue(ColumnVector *distances);
00160 
00161         virtual void onParametersChanged();
00162 
00163         virtual void resetData();
00164 };
00165 
00166 
00167 class CUnknownDataQFunctionFromLocalRBFRegression : public CAbstractQFunction
00168 {
00169 protected:
00170         std::map<CAction *, CRegressionTreeVFunction *> *regressionMap;
00171 public:
00172         bool recalculateFactors;
00173 
00174         CUnknownDataQFunctionFromLocalRBFRegression(CActionSet *actions, std::map<CAction *, CRegressionTreeVFunction *> *regressionMap, double factor);
00175 
00176         virtual ~CUnknownDataQFunctionFromLocalRBFRegression();
00177 
00178         virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00179 };
00180 
00181 
00182 
00183 /*
00184 class CTreeBatchPolicyEvaluation : public CPolicyEvaluation
00185 {
00186 protected:
00187        CAgentController *estimationPolicy;     
00188        
00189        CTreeTrainer *treeTrainer;
00190 
00191        CEpisodeHistory *episodeHistory;
00192        CRewardHistory *rewardLogger;
00193 
00194        CDataCollector *dataCollector;
00195 
00196        virtual double getValue(CStateCollection *state, CAction *action) = 0;
00197        virtual void addInput(CStateCollection *state, CAction *action, double output) = 0;
00198 
00199        virtual void trainTree() = 0;
00200        virtual void resetPolicyEvaluation() = 0;
00201 public:
00202        CTreeBatchPolicyEvaluation(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CTreeTrainer *treeTrainer);
00203 
00204        virtual ~CTreeBatchPolicyEvaluation();
00205 
00206        virtual void doEvaluationTrial();
00207        virtual void evaluatePolicy(int trials);
00208 
00209        virtual void setDataCollector(CDataCollector *dataCollector);
00210 };
00211 
00212 class CCAQTreeBatchPolicyEvaluation : public CTreeBatchPolicyEvaluation
00213 {
00214 protected:
00215        CDataSet *inputData;
00216        CDataSet1D *outputData;
00217 
00218        ColumnVector *buffVector;
00219 
00220        CRegressionTreeQFunction *qFunction;
00221 
00222        virtual double getValue(CStateCollection *state, CAction *action);
00223        virtual void addInput(CStateCollection *state, CAction *action, double output);
00224 
00225        virtual void trainTree();
00226        virtual void resetPolicyEvaluation();
00227 public:
00228        CCAQTreeBatchPolicyEvaluation(CRegressionTreeQFunction *qFunction, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CTreeTrainer *treeTrainer);
00229 
00230        virtual ~CCAQTreeBatchPolicyEvaluation();
00231 };
00232 
00233 
00234 class CQTreeBatchPolicyEvaluation : public CTreeBatchPolicyEvaluation
00235 {
00236 protected:
00237        ColumnVector *buffVector;
00238 
00239        CQFunction *qFunction;
00240 
00241        std::map<CAction *, CRegressionTreeFunction *> *functionMap;
00242        std::map<CAction *, CDataSet *> *inputMap;
00243        std::map<CAction *, CDataSet1D *> *outputMap;
00244 
00245        virtual double getValue(CStateCollection *state, CAction *action);
00246        virtual void addInput(CStateCollection *state, CAction *action, double output);
00247 
00248        virtual void trainTree();
00249        virtual void resetPolicyEvaluation();
00250 public:
00251        CQTreeBatchPolicyEvaluation(CQFunction *qFunction, std::map<CAction *, CRegressionTreeFunction *> *functionMap, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CTreeTrainer *treeTrainer);
00252 
00253        virtual ~CQTreeBatchPolicyEvaluation();
00254 };
00255 */
00256 #endif
00257