Skip to content
Snippets Groups Projects
Commit c547eb54 authored by Ian Resende da Cunha's avatar Ian Resende da Cunha
Browse files

att

parent 76103d67
No related branches found
No related tags found
No related merge requests found
File added
File added
......@@ -9,7 +9,7 @@ from tf_agents.environments import py_environment
from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts
from xvfbwrapper import Xvfb
#from xvfbwrapper import Xvfb
#This environment calculates reward based on the ratio of samples below the upper bound
class CassandraEnv(py_environment.PyEnvironment):
......@@ -30,8 +30,8 @@ class CassandraEnv(py_environment.PyEnvironment):
self._observation_spec = {'observations': array_spec.BoundedArraySpec(shape=(6,), dtype=np.float32, minimum=[0, 3, -1, 0, 0, -100], maximum=[100000, 9, 1, 100, 100, 100], name='observation'),
'legal_moves': array_spec.ArraySpec(shape=(3,), dtype=np.bool_, name='legal_moves')}
self._vdisplay = Xvfb()
self._vdisplay.start()
#self._vdisplay = Xvfb()
#self._vdisplay.start()
#Connect to matlab and load the simulation
#self._eng=matlab.engine.connect_matlab()
......@@ -72,7 +72,8 @@ class CassandraEnv(py_environment.PyEnvironment):
data = np.array(data)
#To solve a bug which happens some time where the file is not ready
if (data.shape[1] == 2):
print(data.shape)
if (len(data) > 1 and data.shape[1] == 2):
metrics = [np.mean(data[-200:,1]), np.median(data[-200:,1]), np.quantile(data[-200:,1], 0.95), np.quantile(data[-200:,1], 0.99), np.max(data[-200:,1])]
return metrics
......@@ -184,7 +185,7 @@ class CassandraEnv(py_environment.PyEnvironment):
#calculate reward
error = observations_and_legal_moves['observations'][5]
reward = (100 - self._active_nodes * 5) if error <= 0 else 0
reward = (100 - self._active_nodes * 10) if error <= 0 else 0
#Increment time counter
self._next_step += self._step_delay
......@@ -199,4 +200,4 @@ class CassandraEnv(py_environment.PyEnvironment):
def __del__(self):
self._eng.quit()
self._vdisplay.stop()
#self._vdisplay.stop()
......@@ -116,6 +116,19 @@ def observation_and_action_constraint_splitter(obs):
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
train_episode_counter = tf.Variable(0)
#global_step = tf.compat.v1.train.get_or_create_global_step()
start_epsilon = 0.5
n_of_steps = 150
end_epsilon = 0.1
epsilon = tf.compat.v1.train.polynomial_decay(
start_epsilon,
train_episode_counter,
n_of_steps,
end_learning_rate=end_epsilon)
agent = dqn_agent.DqnAgent(
env.time_step_spec(),
......@@ -124,12 +137,15 @@ agent = dqn_agent.DqnAgent(
optimizer=optimizer,
target_update_tau = target_update_tau,
observation_and_action_constraint_splitter = observation_and_action_constraint_splitter,
epsilon_greedy = epsilon,
td_errors_loss_fn = common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
logging.info('Teste:')
logging.info(train_episode_counter.numpy())
logging.info('Agent created!')
#logging.info('Agent 0: %d',agent.optimizer.numpy())
###############################################################################
......@@ -169,18 +185,27 @@ def trainDqnAgent (agent, environment, buffer, batch_size=64, max_episodes=50, m
for i_episode in range(max_episodes):
logging.debug('New episode: %d', i_episode)
logging.info('Episode: %d',agent.train_step_counter.numpy())
# Initialize the environment
environment.reset()
#logging.info('Epsilon 2: %d',agent.epsilon_greedy.numpy())
episode_return = 0.0
# Reset the train step
agent.train_step_counter.assign(0)
for i in range(max_iterations):
logging.debug('step = {0}'.format(i))
train_episode_counter.assign(i)
logging.info(train_episode_counter.numpy())
logging.info('Epsilon:')
logging.info(epsilon)
#make a step in the environment
time_step, action_step, next_time_step = moveStep(environment, agent)
episode_return += next_time_step.reward
......
......@@ -3,17 +3,17 @@ Adaptor = cassandra_simulink_bound_env
SimulinkModel = cassandra_cluster_20_nodes_2020b
StepDelay = 10
SyncDelay = 10
BootDelay = 120
InitialNodes = 5
BootDelay = 80
InitialNodes = 6
MaxNodes = 9
MinNodes = 3
AverageTimeWindow = 10
ResponseTimeUpperBound = 0.2
ResponseTimeUpperBound = 0.3
TerminationMaxThreshold = 20
[Hyperparameters]
MaxIterations = 200
MaxEpisodes = 150
MaxIterations = 150
MaxEpisodes = 600
InitialCollectSteps = 100
CollectStepsPerIteration = 1
ReplayBufferMaxLength = 100000
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment