Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

cbatchlearning.h

Go to the documentation of this file.
00001 // Copyright (C) 2003
00002 // Gerhard Neumann (gneumann@gmx.net)
00003 // Stephan Neumann (sneumann@gmx.net) 
00004 //                
00005 // This file is part of RL Toolbox.
00006 // http://www.igi.tugraz.at/ril_toolbox
00007 //
00008 // All rights reserved.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions
00012 // are met:
00013 // 1. Redistributions of source code must retain the above copyright
00014 //    notice, this list of conditions and the following disclaimer.
00015 // 2. Redistributions in binary form must reproduce the above copyright
00016 //    notice, this list of conditions and the following disclaimer in the
00017 //    documentation and/or other materials provided with the distribution.
00018 // 3. The name of the author may not be used to endorse or promote products
00019 //    derived from this software without specific prior written permission.
00020 // 
00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00024 // IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00025 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00026 // NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00027 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00028 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00030 // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031 
00032 
00033 #ifndef C_BATCHLEARNING__
00034 #define C_BATCHLEARNING__
00035 
00036 
00037 #include "cagentcontroller.h"
00038 #include "cparameters.h"
00039 #include "cagent.h"
00040 
00041 #include "newmat/newmat.h"
00042 
00043 class CRewardHistory;
00044 class CEpisodeHistory;
00045 class CRewardLogger;
00046 class CAgentLogger;
00047 
00048 class CContinuousActionQFunction;
00049 class CPolicySameStateEvaluator;
00050 class CGradientUpdateFunction;
00051 class CSemiMDPListener;
00052 class CLSTDLambda;
00053 class CSemiMDPRewardListener;
00054 
00055 class CFeatureCalculator;
00056 class CFeatureVFunction;
00057 class CDataSet;
00058 class CDataSet1D;
00059 class CDataSubset;
00060 
00061 class CAbstractVFunction;
00062 class CGradientVFunction;
00063 class CQFunction; 
00064         
00065 class CSupervisedLearner;
00066 class CSupervisedWeightedLearner;
00067 class CSupervisedQFunctionLearner;
00068 class CSupervisedQFunctionWeightedLearner;
00069 
00070 class CStochasticPolicy;
00071 
00072 class CLearnDataObject;
00073 
00074 class CContinuousActionQFunction;
00075 class CStateProperties;
00076 
00077 class CKDTree;
00078 class CKNearestNeighbors;
00079 class CDataPreprocessor;
00080 
00081 class CSamplingBasedTransitionModel;
00082 class CSamplingBasedGraph;
00083 class CActionDistribution;
00084 
00085 class CAction;
00086 class CActionSet;
00087 
00088 class CState;
00089 class CStateCollection;
00090 class CStateProperties;
00091 
00092 class CStep;
00093 class CNewFeatureCalculator;
00094 class CAbstractQFunction;
00095 
00096 
00097 class CBatchLearningPolicy : public CDeterministicController
00098 {
00099         protected:
00100                 CActionDataSet *actionDataSet;
00101                 CAction *nextAction;
00102         public:
00103                 CBatchLearningPolicy(CActionSet *actions);
00104                 ~CBatchLearningPolicy();
00105                 
00106                 virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *actionData = NULL);
00107                 
00108                 virtual void setAction(CAction *action, CActionData *data);
00109 };
00110 
00111 
00112 class CPolicyEvaluation : public CParameterObject
00113 {
00114         protected:
00115                 
00116         public:
00117                 CPolicyEvaluation(int maxEvaluations = 100);
00118                 virtual ~CPolicyEvaluation();
00119                 
00120                 virtual void evaluatePolicy();
00121                 virtual void evaluatePolicy(int numEvaluations) = 0;
00122 
00123                 virtual void resetLearnData() {};
00124 };
00125 
00126 class CPolicyEvaluationGradientFunction : public CPolicyEvaluation
00127 {
00128 protected:
00129                 CGradientUpdateFunction *learnData;
00130                 double *oldWeights;
00131                 
00132                 
00133                 virtual double getWeightDifference(double *oldWeights);
00134         public:
00135                 bool resetData;
00136         
00137                 CPolicyEvaluationGradientFunction(CGradientUpdateFunction *learnData, double treshold = 0.1, int maxEvaluations = 100);
00138                 virtual ~CPolicyEvaluationGradientFunction();
00139                 
00140                 
00141                 virtual void evaluatePolicy(int numEvaluations) = 0;
00142                 
00143                 virtual void resetLearnData();
00144 };
00145 
00146 class COnlinePolicyEvaluation : public CPolicyEvaluationGradientFunction
00147 {
00148         protected:
00149                 CAgent *agent;
00150                 CSemiMDPListener *learner;
00151 
00152                 CSemiMDPSender *semiMDPSender;
00153         public:
00154         
00155                 COnlinePolicyEvaluation(CAgent *agent, CSemiMDPListener *learner, CGradientUpdateFunction *learnData, int maxEvaluationEpisodes, int numSteps, int checkWeightsPerEpisode);
00156                 virtual ~COnlinePolicyEvaluation();
00157                 
00158                 virtual void evaluatePolicy(int numEvaluations);
00159                 
00160                 void setSemiMDPSender(CSemiMDPSender *semiMDPSender);
00161 };
00162 
00163 class CLSTDOnlinePolicyEvaluation : public COnlinePolicyEvaluation
00164 {
00165         protected:
00166                 CLSTDLambda *lstdLearner;
00167                 
00168                 virtual double getWeightDifference(double *oldWeights);
00169         public:
00170 
00171                 CLSTDOnlinePolicyEvaluation(CAgent *agent, CLSTDLambda *learner, CGradientUpdateFunction *learnData, int maxEvaluationSteps, int nSteps);
00172                 virtual ~CLSTDOnlinePolicyEvaluation();
00173 
00174                 virtual void resetLearnData();
00175                 
00176 };
00177 
00178 /*
00179 class COfflinePolicyEvaluation : public CPolicyEvaluation
00180 {
00181        protected:
00182                CSemiMDPListener *learner;
00183                CStepHistory *stepHistory;
00184                
00185                std::list<CStateModifier *> *modifiers;
00186        public:
00187 
00188                COfflinePolicyEvaluation(CStepHistory *stepHistory, CSemiMDPListener *learner, CGradientUpdateFunction *learnData, std::list<CStateModifier *> *l_modifiers, int maxEvaluationEpisodes);
00189                virtual ~COfflinePolicyEvaluation();
00190                
00191                
00192                virtual void evaluatePolicy(int numEvaluations);
00193 };
00194 
00195 class CLSTDOfflinePolicyEvaluation : public COfflinePolicyEvaluation
00196 {
00197        protected:
00198                CLSTDLambda *lstdLearner;
00199                
00200                virtual double getWeightDifference(double *oldWeights);
00201        public:
00202 
00203                CLSTDOfflinePolicyEvaluation(CStepHistory *stepHistory, CLSTDLambda *learner, CGradientVFunction *learnData, std::list<CStateModifier *> *l_modifiers);
00204                virtual ~CLSTDOfflinePolicyEvaluation();
00205 
00206                virtual void resetLearnData();
00207 };*/
00208 
00209 class COfflineEpisodePolicyEvaluation : public CPolicyEvaluationGradientFunction, public CSemiMDPSender
00210 {
00211         protected:
00212                 CSemiMDPRewardListener *learner;
00213                 CEpisodeHistory *episodeHistory;
00214                 CRewardHistory *rewardLogger;
00215                 
00216                 std::list<CStateModifier *> *modifiers;
00217                 
00218                 CBatchLearningPolicy *policy;
00219         public:
00220 
00221 
00222                 COfflineEpisodePolicyEvaluation(CEpisodeHistory *episodeHistory, CSemiMDPRewardListener *learner, CGradientUpdateFunction *learnData, std::list<CStateModifier *> *l_modifiers, int maxEvaluationEpisodes);
00223                 COfflineEpisodePolicyEvaluation(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSemiMDPRewardListener *learner, CGradientUpdateFunction *learnData, std::list<CStateModifier *> *l_modifiers, int maxEvaluationEpisodes);
00224                 virtual ~COfflineEpisodePolicyEvaluation();
00225 
00226                 void setBatchLearningPolicy(CBatchLearningPolicy *l_policy);
00227                 
00228                 virtual void evaluatePolicy(int numEvaluations);
00229 };
00230 
00231 
00232 class CLSTDOfflineEpisodePolicyEvaluation : public COfflineEpisodePolicyEvaluation
00233 {
00234         protected:
00235                 CLSTDLambda *lstdLearner;
00236                 
00237                 virtual double getWeightDifference(double *oldWeights);
00238         public:
00239 
00240                 CLSTDOfflineEpisodePolicyEvaluation(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CLSTDLambda *learner, CGradientVFunction *learnData, std::list<CStateModifier *> *l_modifiers, int episodes);
00241                 virtual ~CLSTDOfflineEpisodePolicyEvaluation();
00242 
00243                 virtual void resetLearnData();
00244 };
00245 
00246 
00247 class CDataCollector : public CParameterObject
00248 {
00249         protected:
00250         
00251         public:
00252                 CDataCollector();
00253                 virtual ~CDataCollector();
00254                 
00255                 virtual void collectData() = 0;
00256 
00257 };
00258 
00259 class CUnknownDataQFunction;
00260 
00261 class CDataCollectorFromAgentLogger : public CDataCollector
00262 {
00263 protected:
00264         CAgentLogger *logger;
00265         CRewardLogger *rewardLogger;
00266         CAgent *agent;
00267 
00268         int numCollections;
00269         CSemiMarkovDecisionProcess *sender;
00270 
00271         std::list<CUnknownDataQFunction *> *unknownDataQFunctions;
00272         CAgentController *controller;
00273 public:
00274         CDataCollectorFromAgentLogger(CAgent *agent, CAgentLogger *logger, CRewardLogger *rewardLogger, int numEpisodes, int numSteps);
00275         virtual ~CDataCollectorFromAgentLogger();
00276         
00277         virtual void collectData();
00278 
00279         virtual void setController(CAgentController *controller);
00280         virtual void addUnknownDataFunction(CUnknownDataQFunction *unknownDataQFunctions);      
00281 
00282         void setSemiMDPSender(CSemiMarkovDecisionProcess *sender);
00283 };
00284 
00285 
00286 
00287 class CPolicyIteration : virtual public CParameterObject
00288 {
00289 protected:
00290         CLearnDataObject *policyFunction;
00291         CLearnDataObject *evaluationFunction;
00292         
00293         CPolicyEvaluation *evaluation;
00294         CDataCollector *collector;
00295 public:
00296         
00297         CPolicyIteration(CLearnDataObject *policyFunction, CLearnDataObject *evaluationFunction, CPolicyEvaluation *evaluation, CDataCollector *collector = NULL);
00298         virtual ~CPolicyIteration();
00299 
00300         virtual void doPolicyIterationStep();
00301         virtual void initPolicyIteration();
00302 };
00303 
00304 
00305 
00306 class CNewFeatureCalculatorDataGenerator 
00307 {
00308 protected:
00309 
00310 public:
00311         virtual ~CNewFeatureCalculatorDataGenerator() {}; 
00312         
00313         virtual void initFeatures() {calculateNewFeatures();};
00314         virtual void calculateNewFeatures() = 0;
00315 
00316         virtual void swapValueFunctions() = 0;
00317 
00318         virtual void resetData() {};
00319 };
00320 
00321 class CPolicyIterationNewFeatures : public CPolicyIteration
00322 {
00323 protected:
00324         CNewFeatureCalculatorDataGenerator *newFeatureCalculator;
00325 public:
00326 
00327         CPolicyIterationNewFeatures(CLearnDataObject *policyFunction, CLearnDataObject *evaluationFunction, CPolicyEvaluation *evaluation, CNewFeatureCalculatorDataGenerator *newFeatureCalculator, CDataCollector *collector = NULL);
00328         virtual ~CPolicyIterationNewFeatures();
00329 
00330         virtual void doPolicyIterationStep();
00331         virtual void initPolicyIteration();
00332 };
00333 
00334 /*
00335 class CResidualFunction;
00336 class CResidualGradientFunction;
00337 
00338 class CValueGradientCalculator : public CGradientCalculator
00339 {
00340        protected:
00341                CResidualFunction *residual;
00342                CResidualGradientFunction *residualGradientFunction;
00343                
00344                CEpisodeHistory *episodeHistory;
00345                CRewardHistory *rewardLogger;
00346                CRewardFunction *rewardFunction;
00347                
00348                virtual double getValue(CStateCollection *state, CAction *action) = 0;
00349                virtual void getValueGradient(CStateCollection *state, CAction *action, CFeatureList *gradient) = 0;
00350                        
00351                CBatchLearningPolicy *policy;
00352                CAgentController *estimationPolicy;
00353        public:
00354                CValueGradientCalculator(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CResidualFunction *residual, CResidualGradientFunction *gradient);
00355                CValueGradientCalculator(CEpisodeHistory *episodeHistory, CRewardFunction *rewardFunction, CResidualFunction *residual, CResidualGradientFunction *gradient);
00356                virtual ~CValueGradientCalculator();
00357 
00358                virtual void getGradient(CFeatureList *gradient);
00359                virtual double getFunctionValue();
00360        
00361                virtual void resetGradientCalculator() {};
00362                
00363                virtual void setBatchPolicy(CBatchLearningPolicy *policy);
00364                virtual void setEstimationPolicy(CAgentController *estimationPolicy);
00365 };
00366 
00367 
00368 class CVResidualGradientCalculator : public CValueGradientCalculator
00369 {
00370        protected:
00371                CGradientVFunction *vFunction;
00372                
00373                virtual double getValue(CStateCollection *state, CAction *action);
00374                virtual void getValueGradient(CStateCollection *state, CAction *action, CFeatureList *gradient);
00375                
00376                
00377        public:
00378                CVResidualGradientCalculator(CGradientVFunction *vFunction, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CResidualFunction *residual, CResidualGradientFunction *gradient);
00379                CVResidualGradientCalculator(CGradientVFunction *vFunction, CEpisodeHistory *episodeHistory, CRewardFunction *rewardFunction, CResidualFunction *residual, CResidualGradientFunction *gradient);
00380                virtual ~CVResidualGradientCalculator();
00381 };
00382 
00383 
00384 class CQResidualGradientCalculator : public CValueGradientCalculator
00385 {
00386        protected:
00387                CGradientQFunction *qFunction;
00388                
00389                
00390                virtual double getValue(CStateCollection *state, CAction *action);
00391                virtual void getValueGradient(CStateCollection *state, CAction *action, CFeatureList *gradient);
00392                
00393                
00394        public:
00395                CQResidualGradientCalculator(CGradientQFunction *qFunction, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CResidualFunction *residual, CResidualGradientFunction *gradient);
00396                CQResidualGradientCalculator(CGradientQFunction *qFunction, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardFunction *rewardFunction, CResidualFunction *residual, CResidualGradientFunction *gradient);
00397                virtual ~CQResidualGradientCalculator();
00398 };
00399 
00400 class CResidualGradientBatchLearner : public CPolicyEvaluationGradientFunction
00401 {
00402        protected:
00403                CGradientLearner *gradientLearner;
00404        public:
00405                CResidualGradientBatchLearner(CGradientLearner *gradientLearner, CGradientUpdateFunction *learnData, double treshold = 0.1, int maxEvaluations = 10);
00406                
00407                virtual ~CResidualGradientBatchLearner();
00408                
00409                virtual void evaluatePolicy(int numEvaluations);
00410                virtual void resetLearnData();
00411 };
00412 */
00413 class CBatchDataGenerator : public CParameterObject
00414 {
00415         protected:
00416         public:
00417                 CBatchDataGenerator() {};
00418                 virtual ~CBatchDataGenerator() {};
00419         
00420                 virtual void addInput(CStateCollection *state, CAction *action, double output, double weighting = 1.0) = 0;
00421 
00422                 virtual void trainFA() = 0;
00423                 virtual void resetPolicyEvaluation() = 0;
00424 
00425                 virtual double getValue(CStateCollection *state, CAction *action) = 0;
00426 
00427                 void generateInputData(CEpisodeHistory *logger);
00428 };
00429 
00430 class CBatchVDataGenerator : public CBatchDataGenerator
00431 {
00432 protected:
00433         CDataSet *inputData;
00434         CDataSet1D *outputData;
00435         CDataSet1D *weightingData;
00436         ColumnVector *buffVector;
00437         
00438         CAbstractVFunction *vFunction;
00439         
00440         CSupervisedLearner *learner;
00441         CSupervisedWeightedLearner *weightedLearner;
00442 
00443         CBatchVDataGenerator(CSupervisedLearner *learner, int inputDim);
00444         
00445         
00446 public:
00447         CBatchVDataGenerator(CAbstractVFunction *vFunction, CSupervisedLearner *learner);
00448         CBatchVDataGenerator(CAbstractVFunction *vFunction, CSupervisedWeightedLearner *learner);
00449         virtual ~CBatchVDataGenerator();
00450 
00451         virtual void init(int numDim);  
00452 
00453         virtual void addInput(CStateCollection *state, CAction *action, double output, double weighting = 1.0);
00454 
00455         virtual void trainFA();
00456         virtual void resetPolicyEvaluation();   
00457 
00458         virtual double getValue(CStateCollection *state, CAction *action);
00459 
00460         virtual CDataSet *getInputData();
00461         virtual CDataSet1D *getOutputData();
00462         virtual CDataSet1D *getWeighting();
00463 };
00464 
00465 
00466 
00467 class CBatchCAQDataGenerator : public CBatchVDataGenerator
00468 {
00469 protected:
00470         CContinuousActionQFunction *qFunction;
00471         CStateProperties *properties;
00472 
00473         
00474 public:
00475         CBatchCAQDataGenerator(CStateProperties *properties, CContinuousActionQFunction *qFunction, CSupervisedLearner *learner);
00476         virtual ~CBatchCAQDataGenerator();
00477         
00478         virtual void addInput(CStateCollection *state, CAction *action, double output);
00479         
00480         virtual double getValue(CStateCollection *state, CAction *action);
00481 };
00482 
00483 
00484 class CBatchQDataGenerator : public CBatchDataGenerator
00485 {
00486 protected:
00487         std::map<CAction *, CDataSet *> *inputMap;
00488         std::map<CAction *, CDataSet1D *> *outputMap;
00489         std::map<CAction *, ColumnVector *> *buffVectorMap;
00490         std::map<CAction *, CDataSet1D *> *weightedMap;
00491                 
00492 
00493         CQFunction *qFunction;
00494         CStateProperties *properties;
00495         CActionSet *actions;    
00496 
00497         CSupervisedQFunctionLearner *learner;
00498         CSupervisedQFunctionWeightedLearner *weightedLearner;
00499 public:
00500         CBatchQDataGenerator(CQFunction *qFunction, CSupervisedQFunctionLearner *learner, CStateProperties *inputState = NULL);
00501         CBatchQDataGenerator(CQFunction *qFunction, CSupervisedQFunctionWeightedLearner *learner, CStateProperties *inputState = NULL);
00502         CBatchQDataGenerator(CActionSet *actions, CStateProperties *properties);
00503 
00504         void init(CQFunction *qFunction, CActionSet *actions, CStateProperties *properties);
00505 
00506         virtual ~CBatchQDataGenerator();
00507         
00508         virtual void addInput(CStateCollection *state, CAction *action, double output, double weighting = 1.0);
00509 
00510         virtual void trainFA();
00511         virtual void resetPolicyEvaluation();   
00512 
00513         virtual double getValue(CStateCollection *state, CAction *action);
00514 
00515         CDataSet *getInputData(CAction *action);
00516         CDataSet1D *getOutputData(CAction *action);
00517 
00518         CStateProperties *getStateProperties(CAction *action);
00519 };
00520 
00521 
00522 
00523 class CFittedIteration : public CPolicyEvaluation
00524 {
00525 protected:
00526         CAgentController *estimationPolicy;
00527         CBatchDataGenerator *dataGenerator;
00528         
00529         CEpisodeHistory *episodeHistory;
00530         CRewardHistory *rewardLogger;
00531 
00532         CDataCollector *dataCollector;
00533         CPolicyEvaluation *actorLearner;
00534 
00535         virtual void addResidualInput(CStep *step, CAction *action, double oldV, double newV, double nearestNeighborDistance, CAction *nextHistoryActon = NULL, double nextReward = 0.0);
00536         
00537         CPolicyEvaluation *initialPolicyEvaluation;
00538 
00539         virtual double getWeighting(CStateCollection *state, CAction *action);
00540 
00541         virtual double getValue(CStateCollection *state, CAction *action);
00542 
00543         int useResidualAlgorithm;
00544 
00545         virtual void onParametersChanged();
00546 public:
00547         
00548 
00549         CFittedIteration(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CBatchDataGenerator *dataGenerator);
00550 
00551         virtual ~CFittedIteration();
00552 
00553         
00554 
00555         virtual void doEvaluationTrial();
00556         virtual void evaluatePolicy(int trials);
00557 
00558         virtual CBatchDataGenerator *createTrainingsData();
00559 
00560         virtual void setDataCollector(CDataCollector *dataCollector);
00561 
00562         virtual void setInitialPolicyEvaluation(CPolicyEvaluation *initialPolicyEvaluation);
00563 
00564         virtual void resetLearnData();
00565 
00566         void setActorLearner(CPolicyEvaluation *actorLearner);
00567         virtual void evaluatePolicy();
00568 };
00569 
00570 class CFittedCAQIteration : public CFittedIteration
00571 {
00572 protected:
00573         
00574         
00575 public:
00576         CFittedCAQIteration(CContinuousActionQFunction *qFunction, CStateProperties *properties, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedLearner *learner);
00577 
00578         virtual ~CFittedCAQIteration();
00579 };
00580 
00581 
00582 class CFittedVIteration : public CFittedIteration
00583 {
00584 protected:
00585         CStochasticPolicy *estimationPolicy;
00586 
00587         virtual double getWeighting(CStateCollection *state, CAction *action);
00588 
00589         double *actionProbabilities;
00590         CActionSet *availableActions;
00591 public:
00592         CFittedVIteration(CAbstractVFunction *vFunction, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedLearner *learner);
00593 
00594 
00595         CFittedVIteration(CAbstractVFunction *vFunction, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedWeightedLearner *learner, CStochasticPolicy *estimationPolicy);
00596 
00597         virtual ~CFittedVIteration();
00598 
00599         
00600 };
00601 
00602 class CFittedQIteration : public CFittedIteration
00603 {
00604 protected:
00605         std::map<CAction *, CDataSet *> *inputDatas;
00606         std::map<CAction *, CDataSet1D *> *outputDatas;
00607         std::map<CAction *, CKDTree *> *kdTrees;
00608         std::map<CAction *, CKNearestNeighbors *> *nearestNeighbors;
00609         std::map<CAction *, CDataPreprocessor *> *dataPreProc;
00610 
00611         std::list<int> *neighborsList;
00612 
00613         CStateProperties *residualProperties;
00614 
00615         int kNN;
00616 
00617         CState *buffState;
00618 
00619         virtual void addResidualInput(CStep *step, CAction *action, double oldV, double newV, double nearestNeighborDistance, CAction *nextHistoryActon = NULL, double nextReward = 0.0);
00620         
00621 public:
00622         CFittedQIteration(CQFunction *qFunction, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedQFunctionLearner *learner, CStateProperties *residualProperties = NULL);
00623 
00624         CFittedQIteration(CQFunction *qFunction, CStateProperties *inputState, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedQFunctionLearner *learner, CStateProperties *residualProperties = NULL);
00625 
00626 
00627         virtual ~CFittedQIteration();
00628 
00629         virtual void doEvaluationTrial();
00630 };
00631 
00632 
00633 
00634 class CFittedQNewFeatureCalculator : public CFittedQIteration, public CNewFeatureCalculatorDataGenerator
00635 {
00636 protected:
00637         CQFunction *qFunction;
00638         CQFunction *qFunctionPolicy;
00639 
00640         std::map<CAction *, CFeatureCalculator *> *policyCalculator;
00641         std::map<CAction *, CFeatureCalculator *> *estimationCalculator;
00642 
00643         CNewFeatureCalculator *newFeatureCalc;
00644         CStateModifiersObject *agent;
00645 public:
00646         CFittedQNewFeatureCalculator(CQFunction *qFunction, CQFunction *qFunctionPolicy, CStateProperties *inputState, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CNewFeatureCalculator *newFeatCalc);
00647         
00648         virtual ~CFittedQNewFeatureCalculator();
00649 
00650         virtual void calculateNewFeatures();
00651 
00652         virtual void swapValueFunctions();
00653 
00654         void clearCalculators();
00655         virtual void resetData();
00656 
00657         void setStateModifiersObject(CStateModifiersObject *agent);
00658 
00659 };
00660 
00661 class CActionDistribution;
00662 
00663 class CContinuousDynamicProgramming : public CPolicyEvaluation
00664 {
00665 protected:
00666         CSamplingBasedTransitionModel *transModel;
00667         
00668         CActionSet *availableActions;
00669         
00670 
00671         double *actionValues;
00672         double *actionProbabilities;
00673 
00674         int numIteration;
00675 public:
00676         CContinuousDynamicProgramming(CActionSet *allActions, CSamplingBasedTransitionModel *transModel);
00677 
00678         virtual ~CContinuousDynamicProgramming();
00679 
00680         virtual void evaluatePolicy(int numEvaluations);
00681 
00682         virtual double getValueFromDistribution(CActionSet *availableActions, double *actionValues, CActionDistribution *distribution); 
00683 
00684         virtual double getValue(CState *state, CActionSet *availableActions) = 0;
00685         virtual void updateOutputs(int index, CActionSet *availableActions, double *actionValues) = 0;
00686 
00687         virtual void learn() = 0;
00688 
00689         virtual void resetLearnData();
00690         virtual void resetDynamicProgramming() = 0;
00691 };
00692 
00693 class CContinuousDynamicVProgramming : public CContinuousDynamicProgramming
00694 {
00695 protected:
00696         CDataSet1D *outputValues;
00697         CActionDistribution *distribution;
00698 
00699         CSupervisedLearner *learner;
00700         CAbstractVFunction *vFunction;
00701 public:
00702         CContinuousDynamicVProgramming(CActionSet *allActions, CActionDistribution *distribution, CSamplingBasedTransitionModel *transModel, CAbstractVFunction *vFunction, CSupervisedLearner *learner);
00703 
00704         virtual ~CContinuousDynamicVProgramming();
00705 
00706         virtual double getValue(CState *state, CActionSet *availableActions);
00707         virtual void updateOutputs(int index, CActionSet *availableActions, double *actionValues);
00708 
00709         virtual void learn();
00710 
00711         virtual void resetDynamicProgramming();
00712 };
00713 
00714 class CContinuousDynamicQProgramming : public CContinuousDynamicProgramming
00715 {
00716 protected:
00717         std::map<CAction *, CDataSet1D *> *outputValues;
00718         std::map<CAction *, CDataSet *> *inputValues;
00719         
00720         CSupervisedQFunctionLearner *learner;
00721         CAbstractQFunction *qFunction;
00722 
00723         CActionDistribution *distribution;
00724         double *actionValues2;
00725 public:
00726         CContinuousDynamicQProgramming(CActionSet *allActions, CActionDistribution *distribution, CSamplingBasedTransitionModel *transModel, CAbstractQFunction *vFunction, CSupervisedQFunctionLearner *learner);
00727 
00728         virtual ~CContinuousDynamicQProgramming();
00729 
00730         virtual double getValue(CState *state, CActionSet *availableActions);
00731         virtual void updateOutputs(int index, CActionSet *availableActions, double *actionValues);
00732 
00733         virtual void learn();
00734 
00735         virtual void resetDynamicProgramming();
00736 };
00737 
00738 class CContinuousMCQEvaluation : public CContinuousDynamicQProgramming
00739 {
00740 protected:
00741         CPolicySameStateEvaluator *evaluator;
00742 public:
00743         CContinuousMCQEvaluation(CActionSet *allActions, CActionDistribution *distribution, CSamplingBasedTransitionModel *transModel, CPolicySameStateEvaluator *evaluator, CSupervisedQFunctionLearner *learner);
00744 
00745         virtual ~CContinuousMCQEvaluation();
00746 
00747         virtual double getValue(CState *state, CActionSet *availableActions);
00748 };
00749 
00750 
00751 class CGraphTransition;
00752 class DataSubset;
00753 class CKDRectangle;
00754 
00755 class CGraphDynamicProgramming : public CPolicyEvaluation
00756 {
00757 protected:
00758         CSamplingBasedGraph *transModel;
00759         CDataSet1D *outputValues;
00760         
00761         
00762 public:
00763         bool resetGraph;
00764 
00765         CGraphDynamicProgramming(CSamplingBasedGraph *transModel);
00766 
00767         virtual ~CGraphDynamicProgramming();
00768 
00769         virtual void evaluatePolicy(int numEvaluations);
00770 
00771         virtual void resetLearnData();
00772 
00773         virtual double getValue(int node);
00774         virtual double getValue(ColumnVector *input);
00775 
00776         virtual CGraphTransition *getMaxTransition(int index, double &maxValue, CKDRectangle *range = NULL); 
00777 
00778         virtual void getNearestNode(ColumnVector *input, int &node, double &distance);
00779 
00780         CSamplingBasedGraph *getGraph();
00781 
00782         CDataSet1D *getOutputValues();
00783 
00784         virtual void saveCSV(string filename, DataSubset *nodeSubset);
00785 };
00786 
00787 class CGraphTarget;
00788 class CAdaptiveTargetGraph;
00789 
00790 class CGraphAdaptiveTargetDynamicProgramming : public CGraphDynamicProgramming
00791 {
00792 protected:
00793         std::map<CGraphTarget *, CDataSet1D *> *targetMap;
00794         std::list<CGraphTarget *> *targets;     
00795 
00796         CGraphTarget *currentTarget;
00797         CAdaptiveTargetGraph *adaptiveTargetGraph;
00798 public:
00799         CGraphAdaptiveTargetDynamicProgramming(CAdaptiveTargetGraph *graph);
00800 
00801         ~CGraphAdaptiveTargetDynamicProgramming();
00802 
00803         virtual CGraphTransition *getMaxTransition(int index, double &maxValue, CKDRectangle *range = NULL); 
00804 
00805         virtual void addTarget(CGraphTarget *target);
00806         virtual CGraphTarget *getTargetForState(CStateCollection *state);
00807         
00808         virtual void setCurrentTarget(CGraphTarget *target);
00809 
00810         int getNumTargets();
00811         CGraphTarget *getTarget(int index);
00812 
00813         virtual void resetLearnData();
00814 };
00815 
00816 
00817 
00818 #endif
00819