RLPack
 
Loading...
Searching...
No Matches
SumTree.h
Go to the documentation of this file.
1
2#ifndef RLPACK_BINARIES_MEMORY_SUMTREE_SUMTREE_H_
3#define RLPACK_BINARIES_MEMORY_SUMTREE_SUMTREE_H_
4
5#include <cassert>
6#include <deque>
7#include <iostream>
8#include <optional>
9#include <vector>
10
11#include "../sumtree_node/SumTreeNode.h"
12
28class SumTree {
29public:
30 explicit SumTree(int32_t bufferSize);
33 void create_tree(std::deque<float_t> &priorities,
34 std::optional<std::vector<SumTreeNode *>> &children);
35 void reset(int64_t parallelismSizeThreshold = 4096);
36 int64_t sample(float_t seedValue, int64_t currentSize);
37 void update(int64_t index, float_t value);
38 [[maybe_unused]] float_t get_cumulative_sum();
39 int64_t get_tree_height();
40
41private:
43 std::vector<SumTreeNode *> sumTree_;
45 std::vector<SumTreeNode *> leaves_;
47 int64_t bufferSize_ = 32768;
49 int64_t treeHeight_ = 0;
50
51 void propagate_changes_upwards(SumTreeNode *node, float_t change);
52 SumTreeNode *traverse(SumTreeNode *node, float_t value);
53};
60#endif//RLPACK_BINARIES_MEMORY_SUMTREE_SUMTREE_H_
The class SumTree is a class which represents the Sum-Tree which is used in proportional prioritizati...
Definition: SumTree.h:28
SumTreeNode * traverse(SumTreeNode *node, float_t value)
Definition: SumTree.cpp:185
void update(int64_t index, float_t value)
Definition: SumTree.cpp:134
int64_t get_tree_height()
Definition: SumTree.cpp:158
void create_tree(std::deque< float_t > &priorities, std::optional< std::vector< SumTreeNode * > > &children)
Definition: SumTree.cpp:27
int64_t sample(float_t seedValue, int64_t currentSize)
Definition: SumTree.cpp:115
float_t get_cumulative_sum()
Definition: SumTree.cpp:148
int64_t treeHeight_
Attribute to store the tree height.
Definition: SumTree.h:49
int64_t bufferSize_
Attribute to store the buffer size. If not initialised, when using SumTree::SumTree(),...
Definition: SumTree.h:47
void reset(int64_t parallelismSizeThreshold=4096)
Definition: SumTree.cpp:97
std::vector< SumTreeNode * > sumTree_
A vector to store the pointers to dynamically allocated SumTreeNode nodes.
Definition: SumTree.h:43
void propagate_changes_upwards(SumTreeNode *node, float_t change)
Definition: SumTree.cpp:168
std::vector< SumTreeNode * > leaves_
A vector to store the pointers to dynamically allocated SumTreeNode leaves.
Definition: SumTree.h:45
The class SumTreeNode is a private class which represents a node in Sum-Tree. This is only used when ...
Definition: SumTreeNode.h:22