Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

ctdlearner.h

Go to the documentation of this file.
00001 // Copyright (C) 2003
00002 // Gerhard Neumann (gneumann@gmx.net)
00003 // Stephan Neumann (sneumann@gmx.net) 
00004 //                
00005 // This file is part of RL Toolbox.
00006 // http://www.igi.tugraz.at/ril_toolbox
00007 //
00008 // All rights reserved.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions
00012 // are met:
00013 // 1. Redistributions of source code must retain the above copyright
00014 //    notice, this list of conditions and the following disclaimer.
00015 // 2. Redistributions in binary form must reproduce the above copyright
00016 //    notice, this list of conditions and the following disclaimer in the
00017 //    documentation and/or other materials provided with the distribution.
00018 // 3. The name of the author may not be used to endorse or promote products
00019 //    derived from this software without specific prior written permission.
00020 // 
00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00024 // IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00025 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00026 // NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00027 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00028 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00030 // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031 
00032 #ifndef CABSTRACTRILEARNER_H
00033 #define CABSTRACTRILEARNER_H
00034 
00035 #include "cerrorlistener.h"
00036 #include "cagentlistener.h"
00037 #include "cbaseobjects.h"
00038 
00039 class CAgentController;
00040 class CDeterministicController;
00041 class CAbstractQFunction;
00042 class CAbstractQETraces;
00043 class CActionDataSet;
00044 class CGradientQFunction;
00045 class CGradientQETraces;
00046 class CResidualFunction;
00047 class CResidualGradientFunction;
00048 class CFeatureList;
00049 class CAbstractBetaCalculator;
00050 class CFeatureQFunction;
00051 
00053 
00085 class CTDLearner : public CSemiMDPRewardListener, public CErrorSender
00086 {
00087   protected:
00088 
00090         bool externETraces;
00091 
00093         CAgentController *estimationPolicy;
00094 
00096         CAction *lastEstimatedAction;
00097 
00098         CAbstractQFunction *qfunction;
00099 
00100         CAbstractQETraces *etraces;
00101 
00102         CActionDataSet *actionDataSet;
00103 
00105 
00112         virtual void learnStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00113 
00115         virtual double getTemporalDifference(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00116 
00118         virtual double getResidual(double oldQ, double reward, int duration, double newQ);
00119 
00121         virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *action);
00122 
00123 public:
00125     CTDLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction, CAbstractQETraces *etraces, CAgentController *estimationPolicy);         
00127 
00130         CTDLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction, CAgentController *estimationPolicy);         
00131                 
00132         virtual ~CTDLearner();
00133                 
00134         virtual void loadValues(char *filename);
00135         virtual void saveValues(char *filename);
00136 
00137         virtual void loadValues(FILE *stream);
00138         virtual void saveValues(FILE *stream);
00139 
00141         virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00143 
00147         virtual void intermediateStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00148 
00150         virtual void newEpisode();
00151 
00153 //      void setGamma(double gamma);
00155         void setAlpha(double alpha);
00157         void setLambda(double lambda);
00158 
00159         CAgentController* getEstimationPolicy();
00160         void setEstimationPolicy(CAgentController * estimationPolicy);
00161 
00162         CAbstractQFunction* getQFunction();
00163 
00164         CAbstractQETraces *getETraces();
00165 };
00166 
00168 
00174 class CQLearner : public CTDLearner
00175 {
00176 public:
00177         CQLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction);
00178         ~CQLearner();
00179 };
00180 
00182 
00191 class CSarsaLearner : public CTDLearner
00192 {
00193 public:
00194         CSarsaLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction, CDeterministicController *agent);
00195         ~CSarsaLearner();
00196 };
00197 
00198 
00199 class CTDGradientLearner : public CTDLearner
00200 {
00201 protected:
00202         CResidualFunction *residual;
00203         CResidualGradientFunction *residualGradient;
00204         CGradientQFunction *gradientQFunction;
00205         CGradientQETraces *gradientQETraces;
00206 
00207         CFeatureList *oldGradient;
00208         CFeatureList *newGradient;
00209         CFeatureList *residualGradientFeatures;
00210 
00211         virtual double getResidual(double oldQ, double reward, int duration, double newQ);
00212         virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *action);
00213 
00214 public:
00215         CTDGradientLearner(CRewardFunction *rewardFunction, CGradientQFunction *qfunction, CAgentController *agent, CResidualFunction *residual, CResidualGradientFunction *residualGradient);
00216 
00217         ~CTDGradientLearner();
00218 };
00219 
00220 class CTDResidualLearner : public CTDGradientLearner
00221 {
00222 protected:
00223         
00224         CGradientQETraces *residualGradientTraces;
00225         CGradientQETraces *directGradientTraces;
00226 
00227         CGradientQETraces *residualETraces;
00228 
00229         CAbstractBetaCalculator *betaCalculator;
00230 
00231         virtual void learnStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00232 
00233 public:
00234         CTDResidualLearner(CRewardFunction *rewardFunction, CGradientQFunction *qfunction, CAgentController *agent, CResidualFunction *residual, CResidualGradientFunction *residualGradient, CAbstractBetaCalculator *betaCalc);
00235 
00236         ~CTDResidualLearner();
00237 
00238         void newEpisode();
00239 
00240         virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *action, double td);
00241 
00242         CGradientQETraces *getResidualETraces() {return residualETraces;};
00243 };
00244 
00245 
00246 
00247 class CQAverageTDErrorLearner : public CErrorListener, public CStateObject
00248 {
00249         protected:
00250                 double updateRate;
00251                         
00252                 CFeatureQFunction *averageErrorFunction;
00253         public:
00254                 CQAverageTDErrorLearner(CFeatureQFunction *averageErrorFunction, double updateRate);
00255                 virtual ~CQAverageTDErrorLearner();
00256                 
00257                 virtual void onParametersChanged();
00258                 
00259                 virtual void receiveError(double error, CStateCollection *state, CAction *action, CActionData *data = NULL);    
00260 };
00261 
00262 class CQAverageTDVarianceLearner : public CQAverageTDErrorLearner
00263 {
00264         public:
00265                 
00266                 CQAverageTDVarianceLearner(CFeatureQFunction *averageErrorFunction, double updateRate);
00267                 virtual ~CQAverageTDVarianceLearner();
00268                 
00269                 virtual void receiveError(double error, CStateCollection *state, CAction *action, CActionData *data = NULL);    
00270 };
00271 
00272 #endif
00273