Commit 2a05ee34 authored by Sascha Dewes's avatar Sascha Dewes
Browse files

only use critic_initializer for ddpg and td3

parent e9a961af
Pipeline #403964 failed with stage
in 53 seconds
...@@ -93,6 +93,7 @@ if __name__ == "__main__": ...@@ -93,6 +93,7 @@ if __name__ == "__main__":
<#else> <#else>
initializer = mx.init.Normal() initializer = mx.init.Normal()
</#if> </#if>
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
<#if (config.configuration.criticInitializer)??> <#if (config.configuration.criticInitializer)??>
critic_initializer_params = { critic_initializer_params = {
<#list config.criticInitializerParams?keys as param> <#list config.criticInitializerParams?keys as param>
...@@ -103,11 +104,12 @@ if __name__ == "__main__": ...@@ -103,11 +104,12 @@ if __name__ == "__main__":
<#else> <#else>
critic_initializer = mx.init.Normal() critic_initializer = mx.init.Normal()
</#if> </#if>
</#if>
<#if config.rlAlgorithm == "dqn"> <#if config.rlAlgorithm == "dqn">
qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}() qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
qnet_creator.setWeightInitializer(initializer) qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
<#else> <#elseif config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}() actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
actor_creator.setWeightInitializer(initializer) actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
......
...@@ -53,7 +53,6 @@ if __name__ == "__main__": ...@@ -53,7 +53,6 @@ if __name__ == "__main__":
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal() initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1() qnet_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1()
qnet_creator.setWeightInitializer(initializer) qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
......
...@@ -47,7 +47,6 @@ if __name__ == "__main__": ...@@ -47,7 +47,6 @@ if __name__ == "__main__":
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal() initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2() qnet_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2()
qnet_creator.setWeightInitializer(initializer) qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
......
...@@ -54,7 +54,6 @@ if __name__ == "__main__": ...@@ -54,7 +54,6 @@ if __name__ == "__main__":
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal() initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3() qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3()
qnet_creator.setWeightInitializer(initializer) qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment