Commit 8effddcb authored by Sebastian Nickels's avatar Sebastian Nickels

Implemented multiple inputs

parent 9d15e90e
......@@ -20,8 +20,22 @@
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Log;
import org.junit.Test;
import static junit.framework.TestCase.assertTrue;
public class IntegrationGluonTest extends IntegrationTest {
public IntegrationGluonTest() {
super("GLUON", "39253EC049D4A4E5FA0536AD34874B9D#1DBAEE1B1BD83FB7CB5F70AE91B29638#C4C23549E737A759721D6694C75D9771#5AF0CE68E408E8C1F000E49D72AC214A");
}
@Test
public void testMultipleInputs() {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleInputs", "-b", "GLUON"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
}
configuration MultipleInputs{
num_epoch:10
batch_size:5
context:cpu
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
component MultipleInputs{
ports in Z(0:255)^{3, 32, 32} data[2],
out Q(0:1)^{10} softmax;
implementation CNN {
(
data[0] ->
Convolution(kernel=(5,5), channels=8, padding="valid") ->
Convolution(kernel=(5,5), channels=8, padding="valid") ->
FullyConnected(units=128) ->
Dropout()
|
data[1] ->
Convolution(kernel=(5,5), channels=8, padding="valid") ->
Convolution(kernel=(5,5), channels=8, padding="valid") ->
FullyConnected(units=128) ->
Dropout()
) ->
Concatenate() ->
FullyConnected(units=10) ->
Softmax() ->
softmax;
}
}
\ No newline at end of file
......@@ -10,3 +10,4 @@ VGG16 data/VGG16
ResNeXt50 data/ResNeXt50
instanceTestCifar.CifarNetwork src/test/resources/training_data
mnist.LeNetNetwork data/mnist.LeNetNetwork
MultipleInputs src/test/resources/training_data/MultipleInputs
\ No newline at end of file
......@@ -13,9 +13,10 @@ class CNNPredictor_mnist_mnistClassifier_net{
public:
const std::string json_file = "model/mnist.LeNetNetwork/model_newest-symbol.json";
const std::string param_file = "model/mnist.LeNetNetwork/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"image"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,1,28,28}};
const std::vector<std::string> input_keys = {
"data"
};
const std::vector<std::vector<mx_uint>> input_shapes = {{1, 1, 28, 28}};
const bool use_gpu = false;
PredictorHandle handle;
......@@ -67,15 +68,17 @@ public:
const mx_uint num_input_nodes = input_keys.size();
const char* input_key[1] = { "data" };
const char** input_keys_ptr = input_key;
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
input_shape_indptr[0] = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
input_shape_indptr[i+1] = input_shapes[i].size();
shape_data_size += input_shapes[i].size();
input_shape_indptr[i+1] = shape_data_size;
}
mx_uint input_shape_data[shape_data_size];
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment