Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

cqfunction.h

Go to the documentation of this file.
00001 // Copyright (C) 2003
00002 // Gerhard Neumann (gneumann@gmx.net)
00003 // Stephan Neumann (sneumann@gmx.net) 
00004 //                
00005 // This file is part of RL Toolbox.
00006 // http://www.igi.tugraz.at/ril_toolbox
00007 //
00008 // All rights reserved.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions
00012 // are met:
00013 // 1. Redistributions of source code must retain the above copyright
00014 //    notice, this list of conditions and the following disclaimer.
00015 // 2. Redistributions in binary form must reproduce the above copyright
00016 //    notice, this list of conditions and the following disclaimer in the
00017 //    documentation and/or other materials provided with the distribution.
00018 // 3. The name of the author may not be used to endorse or promote products
00019 //    derived from this software without specific prior written permission.
00020 // 
00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00024 // IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00025 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00026 // NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00027 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00028 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00030 // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031 
00032 #ifndef CRIABSTRACTQFUNCTION_H
00033 #define CRIABSTRACTQFUNCTION_H
00034 
00035 #include <stdio.h>
00036 
00037 #include "clearndataobject.h"
00038 #include "cbaseobjects.h"
00039 #include "cmyexception.h"
00040 #include "cgradientfunction.h"
00041 
00042 class CAbstractFeatureStochasticModel;
00043 class CAbstractQETraces;
00044 class CGradientQETraces;
00045 
00046 class CAbstractVFunction;
00047 class CFeatureVFunction;
00048 class CFeatureRewardFunction;
00049 class CActionStatistics;
00050 class CFeature;
00051 
00052 #define GRADIENTQFUNCTION 1
00053 #define CONTINUOUSACTIONQFUNCTION 2
00054 
00055 
00057 
00070 class CAbstractQFunction : public CActionObject, virtual public CLearnDataObject
00071 {
00072 protected:
00073         int type;
00074 public:
00075         bool mayDiverge;
00076 
00077         int getType();
00078         bool isType(int type);
00079         void addType(int Type);
00080 
00081 
00083         CAbstractQFunction(CActionSet *actions);
00084         virtual ~CAbstractQFunction();
00085 
00086         virtual void saveData(FILE *file);
00087         virtual void loadData(FILE *file);
00088         virtual void printValues (){};
00089 
00090         virtual void resetData() {};
00091 
00093 
00094         void getActionValues(CStateCollection *state, CActionSet *actions, double *actionValues, CActionDataSet *data = NULL);
00095 
00097 
00099         virtual CAction* getMax(CStateCollection *state, CActionSet *availableActions, CActionDataSet *data = NULL);
00101 
00102         virtual double getMaxValue(CStateCollection *state, CActionSet *availableActions);
00104     virtual void getStatistics(CStateCollection *state, CAction *action, CActionSet *actions, CActionStatistics* statistics);
00105         
00107         virtual void updateValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00109         virtual void setValue(CStateCollection *state, CAction *action, double qValue, CActionData *data = NULL); 
00111         virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL) = 0;
00112 
00113         virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00114 
00115 protected:
00116 };
00117 
00118 
00119 class CQFunctionSum : public CAbstractQFunction
00120 {
00121 protected:
00122         std::map<CAbstractQFunction *, double> *qFunctions;
00123 public:
00124         CQFunctionSum(CActionSet *actions);
00125         virtual ~CQFunctionSum();
00126 
00127 
00129         virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00130 
00131         virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00132 
00133         double getQFunctionFactor(CAbstractQFunction *qFunction);
00134         void setQFunctionFactor(CAbstractQFunction *qFunction, double factor);
00135 
00136         void addQFunction(CAbstractQFunction *qFunction, double factor);
00137         void removeQFunction(CAbstractQFunction *qFunction);
00138 
00139 
00140         void normFactors(double factor);
00141 
00142 };
00143 
00145 
00148 class CDivergentQFunctionException : public CMyException
00149 {
00150 protected:
00151         virtual string getInnerErrorMsg();
00152 public:
00153         string qFunctionName;
00154         CAbstractQFunction *qFunction;
00155         CState *state;
00156         double value;
00157 
00158         CDivergentQFunctionException(string qFunctionName, CAbstractQFunction *qFunction, CState *state, double value);
00159         virtual ~CDivergentQFunctionException(){};
00160 };
00161 
00162 class CGradientQFunction : public CAbstractQFunction, virtual public CGradientUpdateFunction
00163 {
00164 protected:
00165         CFeatureList *localGradientQFunctionFeatures;
00166 
00167 public:
00168         CGradientQFunction(CActionSet *actions);
00169         virtual ~CGradientQFunction();
00170 
00171         virtual int getWeightsOffset(CAction *) {return 0;};
00172 
00173         virtual void getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradient) = 0;
00174 
00176         virtual void updateValue(CStateCollection *state, CAction *action, double td, CActionData *data = NULL);
00177         
00178         virtual void resetData() {CAbstractQFunction::resetData();};
00179         virtual void loadData(FILE *stream) {CGradientUpdateFunction::loadData(stream);};
00180         virtual void saveData(FILE *stream) {CGradientUpdateFunction::saveData(stream);};
00181 
00182         virtual CAbstractQETraces *getStandardETraces();
00183 
00184         virtual void copy(CLearnDataObject *qFunction) {CGradientUpdateFunction::copy(qFunction);};
00185 };
00186 
00187 /*
00188 class CGradientDelayedUpdateQFunction : public CGradientQFunction, public CGradientDelayedUpdateFunction
00189 {
00190 protected:
00191        virtual void updateWeights(CFeatureList *dParams) {CGradientDelayedUpdateFunction::updateWeights(dParams);};
00192 
00193        CGradientQFunction *qFunction;
00194 public:
00196        CGradientDelayedUpdateQFunction(CGradientQFunction *qFunction);
00197        virtual ~CGradientDelayedUpdateQFunction() {};
00198 
00199        virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00200        virtual void getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradientFeatures);
00201 
00202        virtual void resetData() {CGradientDelayedUpdateFunction::resetData();};
00203 
00205        virtual int getNumWeights(){return CGradientDelayedUpdateFunction::getNumWeights();};
00206 
00207        virtual void getWeights(double *parameters) {CGradientDelayedUpdateFunction::getWeights(parameters);};
00208        virtual void setWeights(double *parameters) {CGradientDelayedUpdateFunction::setWeights(parameters);};
00209 
00210        virtual void loadData(FILE *stream) {CGradientQFunction::loadData(stream);};
00211        virtual void saveData(FILE *stream) {CGradientQFunction::saveData(stream);};
00212 
00213 };
00214 */
00216 
00227 class CQFunction : public CGradientQFunction
00228 {
00229 protected:
00231 
00233         std::map<CAction *, CAbstractVFunction *> *vFunctions;
00234 
00235         virtual int getWeightsOffset(CAction *action);
00236    
00237         virtual void updateWeights(CFeatureList *features);
00238 
00239 public:
00241 
00244         CQFunction(CActionSet *actions);
00245         virtual ~CQFunction();
00246 
00248 
00250         virtual void updateValue(CStateCollection *state, CAction *action, double td, CActionData *data = NULL);
00252 
00254         virtual void setValue(CStateCollection *state, CAction *action, double qValue, CActionData *data = NULL); 
00256 
00258         virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00259 
00261 
00264         virtual void updateValue(CState *state, CAction *action, double td, CActionData *data = NULL);
00266 
00269         virtual void setValue(CState *state, CAction *action, double qValue, CActionData *data = NULL); 
00271 
00274         virtual double getValue(CState *state, CAction *action, CActionData *data = NULL);
00275 
00277 
00280         virtual void saveData(FILE *file);
00282 
00284         virtual void loadData(FILE *file);
00286         virtual void printValues();
00287 
00289         CAbstractVFunction *getVFunction(CAction *action);
00291         CAbstractVFunction *getVFunction(int index);
00293 
00296         void setVFunction(CAction *action, CAbstractVFunction *vfunction, bool bDeleteOld = true);
00298 
00301         void setVFunction(int index, CAbstractVFunction *vfunction, bool bDeleteOld = true);
00303         int getNumVFunctions();
00304 
00305         virtual CAbstractQETraces *getStandardETraces();
00306 
00307         //virtual CStateProperties *getGradientCalculator(CAction *action);
00308         virtual void getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradient);
00309 
00310         virtual int getNumWeights();
00311 
00312         virtual void getWeights(double *weights);
00313         virtual void setWeights(double *weights);
00314 
00315         virtual void resetData();
00316         virtual void copy(CLearnDataObject *qFunction);
00317 };
00318 
00320 
00328 class CQFunctionFromStochasticModel :  public CAbstractQFunction, public CStateObject
00329 {
00330 protected:
00331 
00333         CFeatureVFunction *vfunction;
00335         CAbstractFeatureStochasticModel *model;
00337         CStateProperties *discretizer;
00339         CFeatureRewardFunction *rewardfunction;
00340 
00342         CState *discState;
00343 
00344 public:
00346         CQFunctionFromStochasticModel(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model, CFeatureRewardFunction *rewardfunction);
00347 
00348         virtual ~CQFunctionFromStochasticModel();
00349 
00350 // Writes the Action-Values in the actionValues Array.
00351 //      void getActionValues(CStateCollection *state, double *actionValues, CActionSet *actions);
00352 
00354         virtual void updateValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00356         virtual void setValue(CStateCollection *, CAction *, double , CActionData * = NULL) {}; 
00357 
00359 
00360         virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00361 
00363 
00368         virtual double getValue(CState *featState, CAction *action, CActionData *data = NULL);
00370 
00373         virtual double getValue(int feature, CAction *action, CActionData *data = NULL);
00374 
00375         virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00376 };
00377 
00378 
00379 
00381 
00394 class CFeatureQFunction : public CQFunction
00395 {
00396 protected:
00398         CStateProperties *discretizer;
00400         unsigned int features;
00401 
00402         std::list<CFeatureVFunction *> *featureVFunctions;
00403 
00405         virtual void init();
00406 
00408 
00412         void initVFunctions(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model,  CFeatureRewardFunction *rewardFunction, double gamma);
00413 
00414 public:
00416         CFeatureQFunction(CActionSet *actions, CStateProperties *discretizer);
00418 
00422         CFeatureQFunction(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model,  CFeatureRewardFunction *rewardFunction,double gamma);
00423         
00424         virtual ~CFeatureQFunction();
00425         
00427 
00430         void updateValue(CFeature *state, CAction *action, double td, CActionData *data = NULL);
00432 
00435         void setValue(int state, CAction *action, double qValue, CActionData *data = NULL); 
00437 
00440         double getValue(int feature, CAction *action, CActionData *data = NULL);
00441 
00442         void setFeatureCalculator(CStateModifier *discretizer);
00443         CStateProperties *getFeatureCalculator();
00444 
00445 
00446         int getNumFeatures();
00447 
00449 
00450         void saveFeatureActionValueTable(FILE *stream);
00452 
00453         void saveFeatureActionTable(FILE *stream);
00454 };
00455 
00456 class CComposedQFunction : public  CGradientQFunction
00457 {
00458 protected:
00459         std::list<CAbstractQFunction *> *qFunctions;
00460 
00461         virtual int getWeightsOffset(CAction *action);
00462         virtual void updateWeights(CFeatureList *features);
00463 
00464 public:
00465         CComposedQFunction();
00466         virtual ~CComposedQFunction();
00467 
00468         virtual void saveData(FILE *file);
00469         virtual void loadData(FILE *file);
00470         virtual void printValues();
00471 
00472         virtual void getStatistics(CStateCollection *state, CAction *action, CActionSet *actions, CActionStatistics* statistics);
00473 
00475         virtual void updateValue(CStateCollection *state, CAction *action, double td, CActionData *data = NULL);
00477         virtual void setValue(CStateCollection *state, CAction *action, double qValue, CActionData *data = NULL); 
00479         virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00480 
00481         void addQFunction(CAbstractQFunction *qFunction);
00482 
00483         std::list<CAbstractQFunction *> *getQFunctions();
00484         int getNumQFunctions();
00485 
00486         virtual CAbstractQETraces *getStandardETraces();
00487 
00488         //virtual CStateProperties *getGradientCalculator(CAction *action);
00489 
00490         virtual void getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradient);
00491 
00492 
00493         virtual int getNumWeights();
00494         virtual void getWeights(double *weights);
00495         virtual void setWeights(double *weights);
00496 
00497         virtual void resetData();
00498 };
00499 
00500 /*
00501 class CQTable : public CFeatureQFunction
00502 {
00503        CAbstractStateDiscretizer *discretizer;
00504        virtual void init(int states);
00505 
00506 public:
00507        CQTable(CActionSet *actions, CAbstractStateDiscretizer *state);
00508        
00509        ~CQTable();
00510        
00511        void setDiscretizer(CAbstractStateDiscretizer *discretizer);
00512        CAbstractStateDiscretizer *getDiscretizer();
00513        
00514        int getNumStates();
00515 };*/
00516 #endif
00517 
00518 
00519