Commit 10de417f authored by nilsfreyer's avatar nilsfreyer

changed directory dependencies to componentName for weightsharing

parent fc56ba27
Pipeline #101587 failed with stages
in 28 seconds
...@@ -127,6 +127,10 @@ public class CNNArchTemplateController { ...@@ -127,6 +127,10 @@ public class CNNArchTemplateController {
return list; return list;
} }
public String getComponentName(){
return getArchitecture().getComponentName();
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){ public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>(); Map<String, Object> ftlContext = new HashMap<>();
......
...@@ -19,8 +19,8 @@ class ${tc.fileNameWithoutEnding}: ...@@ -19,8 +19,8 @@ class ${tc.fileNameWithoutEnding}:
module = None module = None
_data_dir_ = "${tc.dataPath}/" _data_dir_ = "${tc.dataPath}/"
_model_dir_ = "model/${tc.fullArchitectureName}/" _model_dir_ = "model/${tc.componentName}/"
_model_prefix_ = "${tc.architectureName}" _model_prefix_ = "model"
_input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}] _input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}]
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>] _input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>]
_output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}] _output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}]
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
class ${tc.fileNameWithoutEnding}{ class ${tc.fileNameWithoutEnding}{
public: public:
const std::string json_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-symbol.json"; const std::string json_file = "model/${tc.componentName}/model_newest-symbol.json";
const std::string param_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-0000.params"; const std::string param_file = "model/${tc.componentName}/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"}; //const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {${tc.join(tc.architectureInputs, ",", "\"", "\"")}}; const std::vector<std::string> input_keys = {${tc.join(tc.architectureInputs, ",", "\"", "\"")}};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.architecture.inputs as input>{1,${tc.join(input.definition.type.dimensions, ",")}}<#if input?has_next>,</#if></#list>}; const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.architecture.inputs as input>{1,${tc.join(input.definition.type.dimensions, ",")}}<#if input?has_next>,</#if></#list>};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment