Commit 2b062f05 authored by Sascha Dewes's avatar Sascha Dewes
Browse files

bug fixes

parent 629d11dd
Pipeline #409058 failed with stage
in 5 minutes and 46 seconds
......@@ -170,6 +170,9 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (100,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_infoGAN_infoGANConnector_predictor:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (62,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_infoGAN_infoGANQNetwork:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (512,4,4,)
......
......@@ -46,7 +46,9 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v0')
context = mx.cpu()
initializer = mx.init.Normal()
qnet_creator = CNNCreator_cartpole_master_dqn.CNNCreator_cartpole_master_dqn()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context)
agent_params = {
......
......@@ -47,9 +47,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('MountainCarContinuous-v0')
context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_mountaincar_master_actor.CNNCreator_mountaincar_master_actor()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context)
critic_creator = CNNCreator_mountaincar_agent_mountaincarCritic()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context)
agent_params = {
......
......@@ -170,6 +170,9 @@ class CNNCreator_mountaincar_agent_mountaincarCritic:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (2,)
......
......@@ -53,7 +53,9 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu()
initializer = mx.init.Normal()
qnet_creator = CNNCreator_torcs_agent_torcsAgent_dqn.CNNCreator_torcs_agent_torcsAgent_dqn()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context)
agent_params = {
......
......@@ -54,9 +54,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.gpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_torcs_agent_torcsAgent_actor.CNNCreator_torcs_agent_torcsAgent_actor()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context)
critic_creator = CNNCreator_torcs_agent_network_torcsCritic()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context)
agent_params = {
......
......@@ -170,6 +170,9 @@ class CNNCreator_torcs_agent_network_torcsCritic:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (29,)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment