00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef __CCARTPOLE_H
00033 #define __CCARTPOLE_H
00034
00035 #include "cqtconfig.h"
00036
00037 #include "ctransitionfunction.h"
00038 #include "crewardfunction.h"
00039
00040
00041 #ifdef RL_TOOLBOX_USE_QT
00042 #include "cqtmodelvisualizer.h"
00043 #endif
00044
00045 class CCartPoleModel : public CLinearActionContinuousTimeTransitionFunction
00046 {
00047 protected:
00048 virtual void doSimulationStep(CState *state, double timestep, CAction *action, CActionData *data);
00049
00050
00051 public:
00052 double uMax;
00053 double lengthTrack;
00054 double g;
00055 double massCart;
00056 double massPole;
00057 double lengthPole;
00058 double mu_c;
00059 double mu_p;
00060
00061 bool endLeaveTrack;
00062 bool endOverRotate;
00063
00064 double failedReward;
00065
00066 CCartPoleModel(double dt, double uMax = 10, double lengthTrack = 4.8, double lengthPole = 0.5, double massCart = 1.0, double massPole = 0.5, double mu_c = 1.0, double mu_p = 0.1, double g = 9.8, bool endLeaveTrack = true,bool endOverRotate = true);
00067 ~CCartPoleModel();
00068
00069 virtual Matrix *getB(CState *state);
00070 virtual ColumnVector *getA(CState *state);
00071
00072 virtual bool isFailedState(CState *state);
00073
00074 virtual void getResetState(CState *state);
00075 };
00076
00077 class CCartPoleRewardFunction : public CStateReward
00078 {
00079 protected:
00080 CCartPoleModel *cartpoleModel;
00081 public:
00082 bool useHeighPeak;
00083 bool punishOverRotate;
00084 CCartPoleRewardFunction(CCartPoleModel *model);
00085 ~CCartPoleRewardFunction() {};
00086
00087 virtual double getStateReward(CState *state);
00088 virtual void getInputDerivation(CState *modelState, ColumnVector *targetState);
00089
00090 };
00091
00092 class CCartPoleHeightRewardFunction : public CStateReward
00093 {
00094 protected:
00095 CCartPoleModel *cartpoleModel;
00096 public:
00097 CCartPoleHeightRewardFunction(CCartPoleModel *model);
00098
00099 virtual double getStateReward(CState *state);
00100 virtual void getInputDerivation(CState *modelState, ColumnVector *targetState);
00101
00102 };
00103
00104 #ifdef RL_TOOLBOX_USE_QT
00105
00106 class CQTCartPoleVisualizer : public CQTModelVisualizer
00107 {
00108 protected:
00109 double phi;
00110 double dphi;
00111 double x;
00112 double dx;
00113
00114 CCartPoleModel *cartModel;
00115
00116 virtual void doDrawState( QPainter *painter);
00117
00118 public:
00119
00120 CQTCartPoleVisualizer( CCartPoleModel *model, QWidget *parent = NULL, const char *name = NULL);
00121
00122 virtual void newDrawState(CStateCollection *state);
00123 };
00124
00125 #endif
00126
00127 #endif
00128