Commit 133c0326 authored by Sebastian Nickels's avatar Sebastian Nickels

Fixed some tests and removed MultipleInput/MultipleOutput test since the...

Fixed some tests and removed MultipleInput/MultipleOutput test since the functionality is already covered in MultipleStreams test
parent 741f22c7
Pipeline #155898 failed with stages
in 17 seconds
...@@ -235,7 +235,6 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -235,7 +235,6 @@ public class GenerationTest extends AbstractSymtabTest {
} }
} }
@Ignore
@Test @Test
public void gluonDdpgTest() { public void gluonDdpgTest() {
Log.getFindings().clear(); Log.getFindings().clear();
......
...@@ -42,30 +42,6 @@ public class IntegrationGluonTest extends IntegrationTest { ...@@ -42,30 +42,6 @@ public class IntegrationGluonTest extends IntegrationTest {
super("GLUON", "39253EC049D4A4E5FA0536AD34874B9D#1DBAEE1B1BD83FB7CB5F70AE91B29638#C4C23549E737A759721D6694C75D9771#5AF0CE68E408E8C1F000E49D72AC214A"); super("GLUON", "39253EC049D4A4E5FA0536AD34874B9D#1DBAEE1B1BD83FB7CB5F70AE91B29638#C4C23549E737A759721D6694C75D9771#5AF0CE68E408E8C1F000E49D72AC214A");
} }
@Test
public void testMultipleInputs() {
Log.getFindings().clear();
deleteHashFile(multipleInputsHashFile);
String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleInputs", "-b", "GLUON"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testMultipleOutputs() {
Log.getFindings().clear();
deleteHashFile(multipleOutputsHashFile);
String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "GLUON"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
@Test @Test
public void testMultipleStreams() { public void testMultipleStreams() {
Log.getFindings().clear(); Log.getFindings().clear();
......
...@@ -94,7 +94,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: ...@@ -94,7 +94,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
with autograd.record(): with autograd.record():
predictions_output = self._networks[0](image_data) predictions_output = self._networks[0](image_data)
loss = loss_functions['predictions'](predictions_output, predictions_label) loss = \
loss_functions['predictions'](predictions_output, predictions_label)
loss.backward() loss.backward()
...@@ -125,7 +126,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: ...@@ -125,7 +126,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
batch.label[0].as_in_context(mx_context) batch.label[0].as_in_context(mx_context)
] ]
predictions_output = self._networks[0](image_data) if True: # Fix indentation
predictions_output = self._networks[0](image_data)
predictions = [ predictions = [
mx.nd.argmax(predictions_output, axis=1) mx.nd.argmax(predictions_output, axis=1)
...@@ -143,7 +145,9 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: ...@@ -143,7 +145,9 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
batch.label[0].as_in_context(mx_context) batch.label[0].as_in_context(mx_context)
] ]
predictions_output = self._networks[0](image_data) if True: # Fix indentation
predictions_output = self._networks[0](image_data)
predictions = [ predictions = [
mx.nd.argmax(predictions_output, axis=1) mx.nd.argmax(predictions_output, axis=1)
] ]
......
...@@ -96,7 +96,7 @@ if __name__ == "__main__": ...@@ -96,7 +96,7 @@ if __name__ == "__main__":
resume_agent_params = { resume_agent_params = {
'session_dir': resume_directory, 'session_dir': resume_directory,
'environment': env, 'environment': env,
'net': qnet_creator.net, 'net': qnet_creator.networks[0],
} }
agent = DqnAgent.resume_from_session(**resume_agent_params) agent = DqnAgent.resume_from_session(**resume_agent_params)
else: else:
...@@ -108,4 +108,4 @@ if __name__ == "__main__": ...@@ -108,4 +108,4 @@ if __name__ == "__main__":
train_successful = agent.train() train_successful = agent.train()
if train_successful: if train_successful:
agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_newest', epoch=0) agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment