Skip to content
Snippets Groups Projects
Commit 8c5bdea8 authored by Tobias Döring's avatar Tobias Döring
Browse files

Added plotting of input weights

parent a68bbf2e
No related branches found
No related tags found
1 merge request!2Evo neuro
......@@ -35,6 +35,7 @@ if __name__ == '__main__':
if TEST_WALKER:
rewards = []
population.walker.plot_input_weights()
for i in range(10):
rewards.append(population.walker.get_reward(10000, True))
print("Reward: ", rewards[-1])
......@@ -56,6 +57,8 @@ if __name__ == '__main__':
avg_rewards.append(population.get_walker_stats())
with open(f'./models/{HIDDEN_LAYER}_{VERSION}_{POP_SIZE}_{LEARNING_RATE}_AvgRewards', 'wb') as fp:
pickle.dump(avg_rewards, fp)
if gen == 1000:
population.lr = 0.01
plot_reward(avg_rewards)
except KeyboardInterrupt:
......
......@@ -57,7 +57,7 @@ class Population:
for i in range(self.size):
for k in weights:
weights_change = np.dot(self.mutants[i].weights[k].T, A[i]).T
weights[k] = weights[k] + self.lr/(self.size*self.mutation_factor) * weights_change
weights[k] = weights[k] + self.lr/(self.size*self.lr) * weights_change
self.walker.set_weights(weights)
for mutant in self.mutants:
mutant.set_weights(weights)
......
......@@ -3,6 +3,7 @@ import numpy as np
import pickle
import copy
import os
import matplotlib.pyplot as plt
np.random.seed(42)
......@@ -53,6 +54,40 @@ class Walker:
def set_weights(self, weights):
self.weights = copy.deepcopy(weights)
def plot_input_weights(self):
weights = []
names = [
"hull_angle",
"hull_angularVelocity",
"vel_x",
"vel_y",
"hip_joint_1_angle",
"hip_joint_2_angle",
"knee_joint_1_angle",
"knee_joint_2_angle",
"leg_1_ground_contact_flag",
"hip_joint_2_angle",
"hip_joint_2_speed",
"knee_joint_2_angle",
"knee_joint_2_speed",
"leg_2_ground_contact_flag",
"lidar reading 1",
"lidar reading 2",
"lidar reading 3",
"lidar reading 4",
"lidar reading 5",
"lidar reading 6",
"lidar reading 7",
"lidar reading 8",
"lidar reading 9",
"lidar reading 10"
]
for i in range(24):
weights.append(sum(self.weights['W1'][i]))
plt.bar(names, weights)
plt.xticks(rotation = 45, ha = "right")
plt.show()
def save(self):
if not os.path.isdir('./models'):
os.mkdir('./models')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment