Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
C
CNNTrainLang
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
1
Issues
1
List
Boards
Labels
Service Desk
Milestones
Iterations
Merge Requests
0
Merge Requests
0
Requirements
Requirements
List
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Test Cases
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issue
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNTrainLang
Commits
d28d94d3
Commit
d28d94d3
authored
Apr 05, 2020
by
Julian Dierkes
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
introduced tests for inter Architecture GAN CoCos
parent
bbc114e1
Pipeline
#264734
failed with stages
Changes
8
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
439 additions
and
21 deletions
+439
-21
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
...onticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
+5
-1
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANDiscriminatorQNetworkDependency.java
...train/_cocos/CheckGANDiscriminatorQNetworkDependency.java
+1
-19
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorDiscriminatorDependency.java
...rain/_cocos/CheckGANGeneratorDiscriminatorDependency.java
+30
-0
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorHasOneOutput.java
...nticar/cnntrain/_cocos/CheckGANGeneratorHasOneOutput.java
+29
-0
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorQNetworkDependency.java
.../cnntrain/_cocos/CheckGANGeneratorQNetworkDependency.java
+31
-0
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANQNetworkhasOneInput.java
...monticar/cnntrain/_cocos/CheckGANQNetworkhasOneInput.java
+29
-0
src/test/java/de/monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java
...monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java
+148
-1
src/test/java/de/monticore/lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java
.../lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java
+166
-0
No files found.
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
View file @
d28d94d3
...
@@ -59,7 +59,11 @@ public class CNNTrainCocos {
...
@@ -59,7 +59,11 @@ public class CNNTrainCocos {
public
static
void
checkGANCocos
(
final
ConfigurationSymbol
configurationSymbol
)
{
public
static
void
checkGANCocos
(
final
ConfigurationSymbol
configurationSymbol
)
{
CNNTrainConfigurationSymbolChecker
checker
=
new
CNNTrainConfigurationSymbolChecker
()
CNNTrainConfigurationSymbolChecker
checker
=
new
CNNTrainConfigurationSymbolChecker
()
.
addCoCo
(
new
CheckGANNetworkPorts
());
.
addCoCo
(
new
CheckGANDiscriminatorQNetworkDependency
())
.
addCoCo
(
new
CheckGANGeneratorDiscriminatorDependency
())
.
addCoCo
(
new
CheckGANGeneratorHasOneOutput
())
.
addCoCo
(
new
CheckGANGeneratorQNetworkDependency
())
.
addCoCo
(
new
CheckGANQNetworkhasOneInput
());
checker
.
checkAll
(
configurationSymbol
);
checker
.
checkAll
(
configurationSymbol
);
}
}
}
}
\ No newline at end of file
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGAN
NetworkPorts
.java
→
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGAN
DiscriminatorQNetworkDependency
.java
View file @
d28d94d3
...
@@ -13,25 +13,16 @@ import de.se_rwth.commons.logging.Log;
...
@@ -13,25 +13,16 @@ import de.se_rwth.commons.logging.Log;
import
java.util.Optional
;
import
java.util.Optional
;
public
class
CheckGAN
NetworkPorts
implements
CNNTrainConfigurationSymbolCoCo
{
public
class
CheckGAN
DiscriminatorQNetworkDependency
implements
CNNTrainConfigurationSymbolCoCo
{
public
void
CheckGANNetworkPorts
()
{
}
public
void
CheckGANNetworkPorts
()
{
}
@Override
@Override
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
NNArchitectureSymbol
gen
=
configurationSymbol
.
getTrainedArchitecture
().
get
();
NNArchitectureSymbol
dis
=
configurationSymbol
.
getDiscriminatorNetwork
().
get
();
NNArchitectureSymbol
dis
=
configurationSymbol
.
getDiscriminatorNetwork
().
get
();
Optional
<
NNArchitectureSymbol
>
qnet
=
configurationSymbol
.
getQNetwork
();
Optional
<
NNArchitectureSymbol
>
qnet
=
configurationSymbol
.
getQNetwork
();
if
(
gen
.
getOutputs
().
size
()
!=
1
)
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Generator network has more then one output, "
+
"but is supposed to only have one"
);
if
(
qnet
.
isPresent
()
&&
qnet
.
get
().
getInputs
().
size
()
!=
1
)
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Q-Network has more then one input, "
+
"but is supposed to only have one"
);
if
(
qnet
.
isPresent
()
&&
dis
.
getOutputs
().
size
()
!=
2
)
if
(
qnet
.
isPresent
()
&&
dis
.
getOutputs
().
size
()
!=
2
)
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Discriminator needs exactly 2 output "
+
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Discriminator needs exactly 2 output "
+
"ports when q-network is given"
);
"ports when q-network is given"
);
...
@@ -48,14 +39,5 @@ public class CheckGANNetworkPorts implements CNNTrainConfigurationSymbolCoCo {
...
@@ -48,14 +39,5 @@ public class CheckGANNetworkPorts implements CNNTrainConfigurationSymbolCoCo {
if
(
qnet
.
isPresent
()
&&
!
qnet
.
get
().
getInputs
().
get
(
0
).
equals
(
"features"
))
if
(
qnet
.
isPresent
()
&&
!
qnet
.
get
().
getInputs
().
get
(
0
).
equals
(
"features"
))
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Input to q-network needs to be named features"
);
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Input to q-network needs to be named features"
);
if
(!
gen
.
getOutputs
().
get
(
0
).
equals
(
dis
.
getInputs
().
get
(
0
)))
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" The generator networks output name does not "
+
"fit the first discriminators input name"
);
if
(
qnet
.
isPresent
())
if
(
gen
.
getInputs
().
contains
(
qnet
.
get
().
getOutputs
()))
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Generator input does not contain all "
+
"latent-codes outputted by q-network"
);
}
}
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorDiscriminatorDependency.java
0 → 100644
View file @
d28d94d3
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.Optional
;
public
class
CheckGANGeneratorDiscriminatorDependency
implements
CNNTrainConfigurationSymbolCoCo
{
public
void
CheckGANNetworkPorts
()
{
}
@Override
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
NNArchitectureSymbol
gen
=
configurationSymbol
.
getTrainedArchitecture
().
get
();
NNArchitectureSymbol
dis
=
configurationSymbol
.
getDiscriminatorNetwork
().
get
();
if
(!
gen
.
getOutputs
().
get
(
0
).
equals
(
dis
.
getInputs
().
get
(
0
)))
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" The generator networks output name does not "
+
"fit the first discriminators input name"
);
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorHasOneOutput.java
0 → 100644
View file @
d28d94d3
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.Optional
;
public
class
CheckGANGeneratorHasOneOutput
implements
CNNTrainConfigurationSymbolCoCo
{
public
void
CheckGANNetworkPorts
()
{
}
@Override
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
NNArchitectureSymbol
gen
=
configurationSymbol
.
getTrainedArchitecture
().
get
();
if
(
gen
.
getOutputs
().
size
()
!=
1
)
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Generator network has more then one output, "
+
"but is supposed to only have one"
);
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorQNetworkDependency.java
0 → 100644
View file @
d28d94d3
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.Optional
;
public
class
CheckGANGeneratorQNetworkDependency
implements
CNNTrainConfigurationSymbolCoCo
{
public
void
CheckGANNetworkPorts
()
{
}
@Override
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
NNArchitectureSymbol
gen
=
configurationSymbol
.
getTrainedArchitecture
().
get
();
Optional
<
NNArchitectureSymbol
>
qnet
=
configurationSymbol
.
getQNetwork
();
if
(
qnet
.
isPresent
())
if
(!
gen
.
getInputs
().
containsAll
(
qnet
.
get
().
getOutputs
()))
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Generator input does not contain all "
+
"latent-codes outputted by q-network"
);
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANQNetworkhasOneInput.java
0 → 100644
View file @
d28d94d3
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.Optional
;
public
class
CheckGANQNetworkhasOneInput
implements
CNNTrainConfigurationSymbolCoCo
{
public
void
CheckGANNetworkPorts
()
{
}
@Override
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
Optional
<
NNArchitectureSymbol
>
qnet
=
configurationSymbol
.
getQNetwork
();
if
(
qnet
.
isPresent
()
&&
qnet
.
get
().
getInputs
().
size
()
!=
1
)
Log
.
error
(
"0"
+
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
+
" Q-Network has more then one input, "
+
"but is supposed to only have one"
);
}
}
src/test/java/de/monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java
View file @
d28d94d3
...
@@ -43,6 +43,31 @@ public class InterCocoTest extends AbstractCoCoTest {
...
@@ -43,6 +43,31 @@ public class InterCocoTest extends AbstractCoCoTest {
checkValidCriticArchitecture
(
configurationSymbol
);
checkValidCriticArchitecture
(
configurationSymbol
);
}
}
@Test
public
void
testValidDefaultGAN
()
{
// given
final
NNArchitectureSymbol
validGenerator
=
NNBuilder
.
getValidGenerator
();
final
NNArchitectureSymbol
validDiscriminator
=
NNBuilder
.
getValidDiscriminator
();
ConfigurationSymbol
configurationSymbol
=
getDefaultGANConfigurationSymbolFrom
(
"valid_tests"
,
"DefaultGANConfig"
,
validGenerator
,
validDiscriminator
);
// when
checkValidGANArchitecture
(
configurationSymbol
);
}
@Test
public
void
testValidInfoGAN
()
{
// given
final
NNArchitectureSymbol
validGenerator
=
NNBuilder
.
getValidInfoGANGenerator
();
final
NNArchitectureSymbol
validDiscriminator
=
NNBuilder
.
getValidDiscriminatorWithQNet
();
final
NNArchitectureSymbol
validQNetwork
=
NNBuilder
.
getValidQNetwork
();
ConfigurationSymbol
configurationSymbol
=
getInfoGANConfigurationSymbolFrom
(
"valid_tests"
,
"InfoGANConfig"
,
validGenerator
,
validDiscriminator
,
validQNetwork
);
// when
checkValidGANArchitecture
(
configurationSymbol
);
}
@Test
@Test
public
void
testInvalidTrainingArchitectureWithTwoInputs
()
{
public
void
testInvalidTrainingArchitectureWithTwoInputs
()
{
// given
// given
...
@@ -223,6 +248,79 @@ public class InterCocoTest extends AbstractCoCoTest {
...
@@ -223,6 +248,79 @@ public class InterCocoTest extends AbstractCoCoTest {
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
}
}
@Test
public
void
testInvalidDiscriminatorQNetworkDependency
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckGANDiscriminatorQNetworkDependency
();
NNArchitectureSymbol
gen
=
NNBuilder
.
getValidGenerator
();
NNArchitectureSymbol
dis
=
NNBuilder
.
getValidDiscriminator
();
NNArchitectureSymbol
qnet
=
NNBuilder
.
getValidQNetwork
();
ConfigurationSymbol
configurationSymbol
=
getInfoGANConfigurationSymbolFrom
(
"valid_tests"
,
"InfoGANConfig"
,
gen
,
dis
,
qnet
);
// when
checkInvalidGANArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidGeneratorQNetworkDependency
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckGANGeneratorQNetworkDependency
();
NNArchitectureSymbol
gen
=
NNBuilder
.
getValidGenerator
();
NNArchitectureSymbol
dis
=
NNBuilder
.
getValidDiscriminatorWithQNet
();
NNArchitectureSymbol
qnet
=
NNBuilder
.
getValidQNetwork
();
ConfigurationSymbol
configurationSymbol
=
getInfoGANConfigurationSymbolFrom
(
"valid_tests"
,
"InfoGANConfig"
,
gen
,
dis
,
qnet
);
// when
checkInvalidGANArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidGeneratorHasMultipleOutputs
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckGANGeneratorHasOneOutput
();
NNArchitectureSymbol
gen
=
NNBuilder
.
getInvalidGeneratorMultipleOutputs
();
NNArchitectureSymbol
dis
=
NNBuilder
.
getValidDiscriminator
();
ConfigurationSymbol
configurationSymbol
=
getDefaultGANConfigurationSymbolFrom
(
"valid_tests"
,
"DefaultGANConfig"
,
gen
,
dis
);
// when
checkInvalidGANArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidGeneratorDiscriminatorDependency
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckGANGeneratorDiscriminatorDependency
();
NNArchitectureSymbol
gen
=
NNBuilder
.
getValidGenerator
();
NNArchitectureSymbol
dis
=
NNBuilder
.
getValidDiscriminatorDifferentInput
();
ConfigurationSymbol
configurationSymbol
=
getDefaultGANConfigurationSymbolFrom
(
"valid_tests"
,
"DefaultGANConfig"
,
gen
,
dis
);
// when
checkInvalidGANArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidQNetworkMultipleInputs
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckGANQNetworkhasOneInput
();
NNArchitectureSymbol
gen
=
NNBuilder
.
getValidGenerator
();
NNArchitectureSymbol
dis
=
NNBuilder
.
getValidDiscriminatorWithQNet
();
NNArchitectureSymbol
qnet
=
NNBuilder
.
getInvalidQNetworkMultipleInputs
();
ConfigurationSymbol
configurationSymbol
=
getInfoGANConfigurationSymbolFrom
(
"valid_tests"
,
"InfoGANConfig"
,
gen
,
dis
,
qnet
);
// when
checkInvalidGANArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
GAN_ARCHITECTURE_ERROR
));
}
private
ConfigurationSymbol
getConfigurationSymbolFrom
(
final
String
modelPath
,
final
String
model
,
private
ConfigurationSymbol
getConfigurationSymbolFrom
(
final
String
modelPath
,
final
String
model
,
final
NNArchitectureSymbol
actorArchitecture
,
final
NNArchitectureSymbol
criticArchitecture
)
{
final
NNArchitectureSymbol
actorArchitecture
,
final
NNArchitectureSymbol
criticArchitecture
)
{
final
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolByPath
(
modelPath
,
model
);
final
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolByPath
(
modelPath
,
model
);
...
@@ -231,6 +329,25 @@ public class InterCocoTest extends AbstractCoCoTest {
...
@@ -231,6 +329,25 @@ public class InterCocoTest extends AbstractCoCoTest {
return
configurationSymbol
;
return
configurationSymbol
;
}
}
private
ConfigurationSymbol
getDefaultGANConfigurationSymbolFrom
(
final
String
modelPath
,
final
String
model
,
final
NNArchitectureSymbol
genArchitecture
,
final
NNArchitectureSymbol
disArchitecture
)
{
final
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolByPath
(
modelPath
,
model
);
configurationSymbol
.
setTrainedArchitecture
(
genArchitecture
);
configurationSymbol
.
setDiscriminatorNetwork
(
disArchitecture
);
return
configurationSymbol
;
}
private
ConfigurationSymbol
getInfoGANConfigurationSymbolFrom
(
final
String
modelPath
,
final
String
model
,
final
NNArchitectureSymbol
genArchitecture
,
final
NNArchitectureSymbol
disArchitecture
,
final
NNArchitectureSymbol
qnetArchitecture
)
{
final
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolByPath
(
modelPath
,
model
);
configurationSymbol
.
setTrainedArchitecture
(
genArchitecture
);
configurationSymbol
.
setDiscriminatorNetwork
(
disArchitecture
);
configurationSymbol
.
setQNetwork
(
qnetArchitecture
);
return
configurationSymbol
;
}
private
ConfigurationSymbol
getConfigurationSymbolByPath
(
final
String
modelPath
,
final
String
model
)
{
private
ConfigurationSymbol
getConfigurationSymbolByPath
(
final
String
modelPath
,
final
String
model
)
{
return
getCompilationUnitSymbol
(
modelPath
,
model
).
getConfiguration
();
return
getCompilationUnitSymbol
(
modelPath
,
model
).
getConfiguration
();
}
}
...
@@ -238,6 +355,7 @@ public class InterCocoTest extends AbstractCoCoTest {
...
@@ -238,6 +355,7 @@ public class InterCocoTest extends AbstractCoCoTest {
private
enum
CheckOption
{
private
enum
CheckOption
{
TRAINED_ARCHITECTURE_COCOS
,
TRAINED_ARCHITECTURE_COCOS
,
CRITIC_ARCHITECTURE_COCOS
,
CRITIC_ARCHITECTURE_COCOS
,
GAN_ARCHITECTURE_COCOS
,
}
}
private
void
checkInvalidArchitecture
(
private
void
checkInvalidArchitecture
(
...
@@ -249,10 +367,13 @@ public class InterCocoTest extends AbstractCoCoTest {
...
@@ -249,10 +367,13 @@ public class InterCocoTest extends AbstractCoCoTest {
if
(
checkOption
.
equals
(
CheckOption
.
TRAINED_ARCHITECTURE_COCOS
))
{
if
(
checkOption
.
equals
(
CheckOption
.
TRAINED_ARCHITECTURE_COCOS
))
{
CNNTrainCocos
.
checkTrainedArchitectureCoCos
(
configurationSymbol
);
CNNTrainCocos
.
checkTrainedArchitectureCoCos
(
configurationSymbol
);
}
else
{
}
else
if
(
checkOption
.
equals
(
CheckOption
.
CRITIC_ARCHITECTURE_COCOS
))
{
CNNTrainCocos
.
checkCriticCocos
(
configurationSymbol
);
CNNTrainCocos
.
checkCriticCocos
(
configurationSymbol
);
}
else
if
(
checkOption
.
equals
(
CheckOption
.
GAN_ARCHITECTURE_COCOS
))
{
CNNTrainCocos
.
checkGANCocos
(
configurationSymbol
);
}
}
expectedErrors
.
checkExpectedPresent
(
Log
.
getFindings
(),
"Got no findings when checking all "
expectedErrors
.
checkExpectedPresent
(
Log
.
getFindings
(),
"Got no findings when checking all "
+
"cocos. Did you forget to add the new coco to MontiArcCocos?"
);
+
"cocos. Did you forget to add the new coco to MontiArcCocos?"
);
Log
.
getFindings
().
clear
();
Log
.
getFindings
().
clear
();
...
@@ -262,6 +383,19 @@ public class InterCocoTest extends AbstractCoCoTest {
...
@@ -262,6 +383,19 @@ public class InterCocoTest extends AbstractCoCoTest {
+
"the given coco. Did you pass an empty coco checker?"
);
+
"the given coco. Did you pass an empty coco checker?"
);
}
}
private
void
checkInvalidArchitectureOnlyCoCo
(
final
ConfigurationSymbol
configurationSymbol
,
final
CNNTrainConfigurationSymbolCoCo
cocoUUT
,
final
ExpectedErrorInfo
expectedErrors
)
{
Log
.
getFindings
().
clear
();
Log
.
getFindings
().
clear
();
CNNTrainConfigurationSymbolChecker
checker
=
new
CNNTrainConfigurationSymbolChecker
().
addCoCo
(
cocoUUT
);
checker
.
checkAll
(
configurationSymbol
);
expectedErrors
.
checkOnlyExpectedPresent
(
Log
.
getFindings
(),
"Got no findings when checking only "
+
"the given coco. Did you pass an empty coco checker?"
);
}
private
void
checkInvalidTrainedArchitecture
(
private
void
checkInvalidTrainedArchitecture
(
final
ConfigurationSymbol
configurationSymbol
,
final
ConfigurationSymbol
configurationSymbol
,
final
CNNTrainConfigurationSymbolCoCo
cocoUUT
,
final
CNNTrainConfigurationSymbolCoCo
cocoUUT
,
...
@@ -287,4 +421,17 @@ public class InterCocoTest extends AbstractCoCoTest {
...
@@ -287,4 +421,17 @@ public class InterCocoTest extends AbstractCoCoTest {
CNNTrainCocos
.
checkCriticCocos
(
configurationSymbol
);
CNNTrainCocos
.
checkCriticCocos
(
configurationSymbol
);
new
ExpectedErrorInfo
().
checkOnlyExpectedPresent
(
Log
.
getFindings
());
new
ExpectedErrorInfo
().
checkOnlyExpectedPresent
(
Log
.
getFindings
());
}
}
private
void
checkValidGANArchitecture
(
final
ConfigurationSymbol
configurationSymbol
)
{
Log
.
getFindings
().
clear
();
CNNTrainCocos
.
checkGANCocos
(
configurationSymbol
);
new
ExpectedErrorInfo
().
checkOnlyExpectedPresent
(
Log
.
getFindings
());
}
private
void
checkInvalidGANArchitecture
(
final
ConfigurationSymbol
configurationSymbol
,
final
CNNTrainConfigurationSymbolCoCo
cocoUUT
,
ExpectedErrorInfo
expectedErrors
)
{
checkInvalidArchitectureOnlyCoCo
(
configurationSymbol
,
cocoUUT
,
expectedErrors
);
}
}
}
src/test/java/de/monticore/lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java
View file @
d28d94d3
...
@@ -13,6 +13,8 @@ import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
...
@@ -13,6 +13,8 @@ import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import
de.monticore.lang.monticar.cnntrain.annotations.Range
;
import
de.monticore.lang.monticar.cnntrain.annotations.Range
;
import
java.awt.*
;
import
java.awt.*
;
import
java.util.ArrayList
;
import
java.util.HashMap
;
import
java.util.List
;
import
java.util.List
;
import
java.util.Map
;
import
java.util.Map
;
...
@@ -228,4 +230,168 @@ public class NNArchitecturerBuilder {
...
@@ -228,4 +230,168 @@ public class NNArchitecturerBuilder {
dimensions
,
getValidCriticTypes
(),
getValidCriticRanges
());
dimensions
,
getValidCriticTypes
(),
getValidCriticRanges
());
}
}
public
NNArchitectureSymbol
getValidGenerator
()
{
ArrayList
input
=
Lists
.
newArrayList
(
"noise"
);
ArrayList
output
=
Lists
.
newArrayList
(
"data"
);
HashMap
dims
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
"noise"
,
Lists
.
newArrayList
(
100
))
.
put
(
"data"
,
Lists
.
newArrayList
(
3
,
28
,
28
))
.
build
());
HashMap
types
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
String
>
builder
()
.
put
(
"noise"
,
"Q"
)
.
put
(
"data"
,
"Q"
)
.
build
());
HashMap
ranges
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
Range
>
builder
()
.
put
(
"noise"
,
Range
.
withInfinityLimits
()
)
.
put
(
"data"
,
Range
.
withLimits
(-
1
,
1
))
.
build
());
return
getNNArchitectureSymbolFrom
(
"GeneratorValid"
,
input
,
output
,
dims
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getValidInfoGANGenerator
()
{
ArrayList
input
=
Lists
.
newArrayList
(
"noise"
,
"c1"
);
ArrayList
output
=
Lists
.
newArrayList
(
"data"
);
HashMap
dims
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
"noise"
,
Lists
.
newArrayList
(
100
))
.
put
(
"data"
,
Lists
.
newArrayList
(
3
,
28
,
28
))
.
put
(
"c1"
,
Lists
.
newArrayList
(
10
))
.
build
());
HashMap
types
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
String
>
builder
()
.
put
(
"noise"
,
"Q"
)
.
put
(
"data"
,
"Q"
)
.
put
(
"c1"
,
"Q"
)
.
build
());
HashMap
ranges
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
Range
>
builder
()
.
put
(
"noise"
,
Range
.
withInfinityLimits
()
)
.
put
(
"data"
,
Range
.
withLimits
(-
1
,
1
))
.
put
(
"c1"
,
Range
.
withLimits
(
0
,
1
))
.
build
());
return
getNNArchitectureSymbolFrom
(
"GeneratorValid"
,
input
,
output
,
dims
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getInvalidGeneratorMultipleOutputs
()
{
ArrayList
input
=
Lists
.
newArrayList
(
"noise"
);
ArrayList
output
=
Lists
.
newArrayList
(
"data1"
,
"data2"
);
HashMap
dims
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
"noise"
,
Lists
.
newArrayList
(
100
))
.
put
(
"data1"
,
Lists
.
newArrayList
(
3
,
28
,
28
))
.
put
(
"data2"
,
Lists
.
newArrayList
(
10
))
.
build
());
HashMap
types
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
String
>
builder
()
.
put
(
"noise"
,
"Q"
)
.
put
(
"data1"
,
"Q"
)
.
put
(
"data2"
,
"Q"
)
.
build
());
HashMap
ranges
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
Range
>
builder
()
.
put
(
"noise"
,
Range
.
withInfinityLimits
()
)
.
put
(
"data1"
,
Range
.
withLimits
(-
1
,
1
))
.
put
(
"data2"
,
Range
.
withLimits
(
0
,
1
))
.
build
());
return
getNNArchitectureSymbolFrom
(
"GeneratorInvalidGeneratorMultipleOutputs"
,
input
,
output
,
dims
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getValidDiscriminator
()
{
ArrayList
input
=
Lists
.
newArrayList
(
"data"
);
ArrayList
output
=
Lists
.
newArrayList
(
"dis"
);
HashMap
dims
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
"data"
,
Lists
.
newArrayList
(
3
,
28
,
28
))
.
put
(
"dis"
,
Lists
.
newArrayList
(
1
))
.
build
());
HashMap
types
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
String
>
builder
()
.
put
(
"data"
,
"Q"
)
.
put
(
"dis"
,
"Q"
)
.
build
());
HashMap
ranges
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
Range
>
builder
()
.
put
(
"data"
,
Range
.
withInfinityLimits
()
)
.
put
(
"dis"
,
Range
.
withLimits
(
0
,
1
))
.
build
());
return
getNNArchitectureSymbolFrom
(
"DiscriminatorValid"
,
input
,
output
,
dims
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getValidDiscriminatorWithQNet
()
{
ArrayList
input
=
Lists
.
newArrayList
(
"data"
);
ArrayList
output
=
Lists
.
newArrayList
(
"dis"
,
"features"
);
HashMap
dims
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
"data"
,
Lists
.
newArrayList
(
3
,
28
,
28
))
.
put
(
"dis"
,
Lists
.
newArrayList
(
1
))
.
put
(
"features"
,
Lists
.
newArrayList
(
1024
))
.
build
());
HashMap
types
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
String
>
builder
()
.
put
(
"data"
,
"Q"
)
.
put
(
"dis"
,
"Q"
)
.
put
(
"features"
,
"Q"
)
.
build
());
HashMap
ranges
=
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
Range
>
builder
()
.
put
(
"data"
,
Range
.
withInfinityLimits
()
)
.
put
(
"dis"
,
Range
.
withLimits
(
0
,
1
))
.
put
(
"features"
,
Range
.
withInfinityLimits
())
.
build
());
return
getNNArchitectureSymbolFrom
(
"DiscriminatorValidQNet"
,
input
,
output
,
dims
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getValidDiscriminatorDifferentInput
()
{
ArrayList
input
=
Lists
.
newArrayList
(
"data2"
);
ArrayList
output
=
Lists
.
newArrayList
(
"dis"
);
HashMap
dims
=
Maps
<