Commit e9c67a96 authored by Nicola Gatto's avatar Nicola Gatto

Set critic component name and adapt tests

parent cc9a61ce
Pipeline #162446 failed with stages
......@@ -554,6 +554,7 @@ public class EMADLGenerator {
+ fullCriticName + " does not have a CNN implementation but is required to have one");
System.exit(-1);
}
critic.get().setComponentName(fullCriticName);
configuration.setCriticNetwork(new ArchitectureAdapter(fullCriticName, critic.get()));
}
......
......@@ -267,8 +267,8 @@ public class GenerationTest extends AbstractSymtabTest {
"HelperA.h",
"start_training.sh",
"reinforcement_learning/__init__.py",
"reinforcement_learning/CNNCreator_MountaincarCritic.py",
"reinforcement_learning/CNNNet_MountaincarCritic.py",
"reinforcement_learning/CNNCreator_mountaincar_agent_mountaincarCritic.py",
"reinforcement_learning/CNNNet_mountaincar_agent_mountaincarCritic.py",
"reinforcement_learning/strategy.py",
"reinforcement_learning/agent.py",
"reinforcement_learning/environment.py",
......
......@@ -4,7 +4,7 @@ component MountaincarCritic {
ports
in Q^{2} state,
in Q^{2} action,
out Q(-1:1)^{1} qvalues;
out Q(-oo:oo)^{1} qvalues;
implementation CNN {
(
......
from reinforcement_learning.agent import DdpgAgent
from reinforcement_learning.util import AgentSignalHandler
from reinforcement_learning.cnnarch_logger import ArchLogger
from reinforcement_learning.CNNCreator_MountaincarCritic import CNNCreator_MountaincarCritic
from reinforcement_learning.CNNCreator_mountaincar_agent_mountaincarCritic import CNNCreator_mountaincar_agent_mountaincarCritic
import reinforcement_learning.environment
import CNNCreator_mountaincar_master_actor
......@@ -49,7 +49,7 @@ if __name__ == "__main__":
context = mx.cpu()
actor_creator = CNNCreator_mountaincar_master_actor.CNNCreator_mountaincar_master_actor()
actor_creator.construct(context)
critic_creator = CNNCreator_MountaincarCritic()
critic_creator = CNNCreator_mountaincar_agent_mountaincarCritic()
critic_creator.construct(context)
agent_params = {
......
import mxnet as mx
import logging
import os
from CNNNet_MountaincarCritic import Net_0
from CNNNet_mountaincar_agent_mountaincarCritic import Net_0
class CNNCreator_MountaincarCritic:
_model_dir_ = "model/MountaincarCritic/"
class CNNCreator_mountaincar_agent_mountaincarCritic:
_model_dir_ = "model/mountaincar.agent.MountaincarCritic/"
_model_prefix_ = "model"
def __init__(self):
......@@ -50,7 +50,7 @@ class CNNCreator_MountaincarCritic:
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[0].hybridize()
self.networks[0](mx.nd.zeros((1, 2,), ctx=context), mx.nd.zeros((1, 1,), ctx=context))
self.networks[0](mx.nd.zeros((1, 2,), ctx=context), mx.nd.zeros((1, 2,), ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment