Skip to content

Commit

Permalink
[NOID] Fixes #1596: Change key/secret to optional in apoc.nlp calls f…
Browse files Browse the repository at this point in the history
…or AWS (#4062)
  • Loading branch information
vga91 committed Jul 3, 2024
1 parent 1cdba92 commit 0493c03
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 36 deletions.
13 changes: 0 additions & 13 deletions full/src/main/kotlin/apoc/nlp/aws/AWSProcedures.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import apoc.nlp.NLPHelperFunctions
import apoc.nlp.NLPHelperFunctions.getNodeProperty
import apoc.nlp.NLPHelperFunctions.keyPhraseRelationshipType
import apoc.nlp.NLPHelperFunctions.partition
import apoc.nlp.NLPHelperFunctions.verifyKey
import apoc.nlp.NLPHelperFunctions.verifyNodeProperty
import apoc.nlp.NLPHelperFunctions.verifySource
import apoc.result.NodeWithMapResult
Expand Down Expand Up @@ -59,8 +58,6 @@ class AWSProcedures {
verifySource(source)
val nodeProperty = getNodeProperty(config)
verifyNodeProperty(source, nodeProperty)
verifyKey(config, "key")
verifyKey(config, "secret")

val client: AWSClient = awsClient(config)

Expand All @@ -78,8 +75,6 @@ class AWSProcedures {
verifySource(source)
val nodeProperty = getNodeProperty(config)
verifyNodeProperty(source, nodeProperty)
verifyKey(config, "key")
verifyKey(config, "secret")

val client = awsClient(config)
val relationshipType = NLPHelperFunctions.entityRelationshipType(config)
Expand All @@ -103,8 +98,6 @@ class AWSProcedures {
verifySource(source)
val nodeProperty = getNodeProperty(config)
verifyNodeProperty(source, nodeProperty)
verifyKey(config, "key")
verifyKey(config, "secret")

val client: AWSClient = awsClient(config)

Expand All @@ -124,8 +117,6 @@ class AWSProcedures {
verifySource(source)
val nodeProperty = getNodeProperty(config)
verifyNodeProperty(source, nodeProperty)
verifyKey(config, "key")
verifyKey(config, "secret")

val client = awsClient(config)
val relationshipType = keyPhraseRelationshipType(config)
Expand All @@ -149,8 +140,6 @@ class AWSProcedures {
verifySource(source)
val nodeProperty = getNodeProperty(config)
verifyNodeProperty(source, nodeProperty)
verifyKey(config, "key")
verifyKey(config, "secret")

val client: AWSClient = awsClient(config)

Expand All @@ -170,8 +159,6 @@ class AWSProcedures {
verifySource(source)
val nodeProperty = getNodeProperty(config)
verifyNodeProperty(source, nodeProperty)
verifyKey(config, "key")
verifyKey(config, "secret")

val client = awsClient(config)
val storeGraph: Boolean = config.getOrDefault("write", false) as Boolean
Expand Down
71 changes: 50 additions & 21 deletions full/src/main/kotlin/apoc/nlp/aws/RealAWSClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,75 @@ package apoc.nlp.aws

import apoc.result.MapResult
import apoc.util.JsonUtil
import com.amazonaws.auth.AWSStaticCredentialsProvider
import com.amazonaws.auth.BasicAWSCredentials
import com.amazonaws.auth.*
import com.amazonaws.services.comprehend.AmazonComprehendClientBuilder
import com.amazonaws.services.comprehend.model.BatchDetectEntitiesRequest
import com.amazonaws.services.comprehend.model.BatchDetectEntitiesResult
import com.amazonaws.services.comprehend.model.BatchDetectKeyPhrasesRequest
import com.amazonaws.services.comprehend.model.BatchDetectKeyPhrasesResult
import com.amazonaws.services.comprehend.model.BatchDetectSentimentRequest
import com.amazonaws.services.comprehend.model.BatchDetectSentimentResult
import com.amazonaws.services.comprehend.model.*
import org.neo4j.graphdb.Node
import org.neo4j.logging.Log

class RealAWSClient(config: Map<String, Any>, private val log: Log) : AWSClient {
private val apiKey = config["key"].toString()
private val apiSecret = config["secret"].toString()
companion object {
val missingCredentialError = """
Error during AWS credentials retrieving.
Make sure the key ID and the Secret Key are defined via `key` and `secret` parameters
or via one of these ways: https://docs.aws.amazon.com/AWSJavaSDK/latest/javadoc/com/amazonaws/auth/DefaultAWSCredentialsProviderChain.html:
"""
}
private val apiKey = config["key"]?.toString()
private val apiSecret = config["secret"]?.toString()
private val apiSessionToken = config["token"].toString()
private val region = config.getOrDefault("region", "us-east-1").toString()
private val language = config.getOrDefault("language", "en").toString()
private val nodeProperty = config.getOrDefault("nodeProperty", "text").toString()

private val awsClient = AmazonComprehendClientBuilder.standard()
.withCredentials(AWSStaticCredentialsProvider(BasicAWSCredentials(apiKey, apiSecret)))
.withCredentials(awsStaticCredentialsProvider())
.withRegion(region)
.build()

override fun entities(data: List<Node>, batchId: Int): BatchDetectEntitiesResult? {
val convertedData = convertInput(data)
val batch = BatchDetectEntitiesRequest().withTextList(convertedData).withLanguageCode(language)
return awsClient.batchDetectEntities(batch)
private fun awsStaticCredentialsProvider(): AWSCredentialsProvider {
return if (!apiKey.isNullOrEmpty() && !apiSecret.isNullOrEmpty()) {
AWSStaticCredentialsProvider(getAwsBasicCredentials())
} else {
DefaultAWSCredentialsProviderChain()
}
}

private fun getAwsBasicCredentials() = if (apiSessionToken.isEmpty()) {
BasicAWSCredentials(apiKey, apiSecret)
} else {
BasicSessionCredentials(apiKey, apiSecret, apiSessionToken)
}


override fun entities(data: List<Node>, batchId: Int): BatchDetectEntitiesResult? {
try {
val convertedData = convertInput(data)
val batch = BatchDetectEntitiesRequest().withTextList(convertedData).withLanguageCode(language)
return awsClient.batchDetectEntities(batch)
} catch (e: Exception) {
throw RuntimeException(missingCredentialError + e)
}
}

override fun keyPhrases(data: List<Node>, batchId: Int): BatchDetectKeyPhrasesResult? {
val convertedData = convertInput(data)
val batch = BatchDetectKeyPhrasesRequest().withTextList(convertedData).withLanguageCode(language)
return awsClient.batchDetectKeyPhrases(batch)
try {
val convertedData = convertInput(data)
val batch = BatchDetectKeyPhrasesRequest().withTextList(convertedData).withLanguageCode(language)
return awsClient.batchDetectKeyPhrases(batch)
} catch (e: Exception) {
throw RuntimeException(missingCredentialError + e)
}
}

override fun sentiment(data: List<Node>, batchId: Int): BatchDetectSentimentResult? {
val convertedData = convertInput(data)
val batch = BatchDetectSentimentRequest().withTextList(convertedData).withLanguageCode(language)
return awsClient.batchDetectSentiment(batch)
try {
val convertedData = convertInput(data)
val batch = BatchDetectSentimentRequest().withTextList(convertedData).withLanguageCode(language)
return awsClient.batchDetectSentiment(batch)
} catch (e: Exception) {
throw RuntimeException(missingCredentialError + e)
}
}

fun sentiment(data: List<Node>, config: Map<String, Any?>): List<MapResult> {
Expand Down
171 changes: 171 additions & 0 deletions full/src/test/kotlin/apoc/nlp/aws/AWSProceduresAPIWithEnvVarsTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package apoc.nlp.aws

import apoc.util.TestUtil
import com.amazonaws.SDKGlobalConfiguration.ACCESS_KEY_ENV_VAR
import com.amazonaws.SDKGlobalConfiguration.SECRET_KEY_ENV_VAR
import org.junit.Assert.assertTrue
import org.junit.Assume.assumeNotNull
import org.junit.BeforeClass
import org.junit.ClassRule
import org.junit.Test
import org.neo4j.graphdb.Result
import org.neo4j.test.rule.ImpermanentDbmsRule


/**
* To execute tests, set these environment variables:
* AWS_ACCESS_KEY_ID=<apiKey>;AWS_SECRET_KEY=<secretKey>
*/
class AWSProceduresAPIWithEnvVarsTest {
companion object {
private val apiKey: String? = System.getenv(ACCESS_KEY_ENV_VAR)
private val apiSecret: String? = System.getenv(SECRET_KEY_ENV_VAR)

@ClassRule
@JvmField
val neo4j = ImpermanentDbmsRule()

@BeforeClass
@JvmStatic
fun beforeClass() {
neo4j.executeTransactionally("""
CREATE (:Article {
uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/",
body: "These days I’m rarely more than a few feet away from my Nintendo Switch and I play board games, card games and role playing games with friends at least once or twice a week. I’ve even organised lunch-time Mario Kart 8 tournaments between the Neo4j European offices!"
});""")

neo4j.executeTransactionally("""
CREATE (:Article {
uri: "https://en.wikipedia.org/wiki/Nintendo_Switch",
body: "The Nintendo Switch is a video game console developed by Nintendo, released worldwide in most regions on March 3, 2017. It is a hybrid console that can be used as a home console and portable device. The Nintendo Switch was unveiled on October 20, 2016. Nintendo offers a Joy-Con Wheel, a small steering wheel-like unit that a Joy-Con can slot into, allowing it to be used for racing games such as Mario Kart 8."
});
""")

assumeNotNull(apiKey, apiSecret)
TestUtil.registerProcedure(neo4j, AWSProcedures::class.java)
}
}

@Test
fun `should extract entities in stream mode`() {
neo4j.executeTransactionally("""
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
CALL apoc.nlp.aws.entities.stream(a, {
nodeProperty: "body"
})
YIELD value
UNWIND value.entities AS result
RETURN result;
""", mapOf()) {
assertStreamWithScoreResult(it)
}
}

@Test
fun `should extract entities in graph mode`() {
neo4j.executeTransactionally("""
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
CALL apoc.nlp.aws.entities.graph(a, {
nodeProperty: "body",
writeRelationshipType: "ENTITY"
})
YIELD graph AS g
RETURN g;
""", mapOf()) {
assertGraphResult(it)
}
}

@Test
fun `should extract key phrases in stream mode`() {
neo4j.executeTransactionally("""
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
CALL apoc.nlp.aws.keyPhrases.stream(a, {
nodeProperty: "body"
})
YIELD value
UNWIND value.keyPhrases AS result
RETURN result
""", mapOf()) {
assertStreamWithScoreResult(it)
}
}

@Test
fun `should extract key phrases in graph mode`() {
neo4j.executeTransactionally("""
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
CALL apoc.nlp.aws.keyPhrases.graph(a, {
nodeProperty: "body",
writeRelationshipType: "KEY_PHRASE",
write: true
})
YIELD graph AS g
RETURN g;
""", mapOf()) {
assertGraphResult(it)
}
}

@Test
fun `should extract sentiment in stream mode`() {
neo4j.executeTransactionally("""
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
CALL apoc.nlp.aws.sentiment.stream(a, {
nodeProperty: "body"
})
YIELD value
RETURN value AS result;
""", mapOf()) {
assertSentimentScoreResult(it)
}
}

@Test
fun `should extract sentiment in graph mode`() {
neo4j.executeTransactionally("""
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
CALL apoc.nlp.aws.sentiment.graph(a, {
nodeProperty: "body",
write: true
})
YIELD graph AS g
UNWIND g.nodes AS node
RETURN node {.uri, .sentiment, .sentimentScore} AS result;
""", mapOf()) {
assertSentimentScoreResult(it)
}
}

private fun assertStreamWithScoreResult(it: Result) {
val asSequence = it.asSequence().toList()
assertTrue(asSequence.isNotEmpty())

asSequence.forEach {
val entity: Map<String, Any> = it["result"] as Map<String, Any>
assertTrue(entity.containsKey("score"))
}
}

private fun assertGraphResult(it: Result) {
val asSequence = it.asSequence().toList()
assertTrue(asSequence.isNotEmpty())

asSequence.forEach {
val entity: Map<String, Any> = it["g"] as Map<String, Any>
assertTrue(entity.containsKey("nodes"))
assertTrue(entity.containsKey("relationships"))
}
}

private fun assertSentimentScoreResult(it: Result) {
val asSequence = it.asSequence().toList()
assertTrue(asSequence.isNotEmpty())

asSequence.forEach {
val entity: Map<String, Any> = it["result"] as Map<String, Any>
assertTrue(entity.containsKey("sentimentScore"))
}
}
}

5 changes: 3 additions & 2 deletions full/src/test/kotlin/apoc/nlp/aws/AWSProceduresErrorsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package apoc.nlp.aws

import apoc.nlp.aws.RealAWSClient.Companion.missingCredentialError
import apoc.util.TestUtil
import org.hamcrest.CoreMatchers.containsString
import org.junit.AfterClass
Expand Down Expand Up @@ -94,7 +95,7 @@ class AWSProceduresErrorsTest {
println(it.resultAsString())
}
}
assertThat(exception.message, containsString("java.lang.IllegalArgumentException: Missing parameter `key`"))
assertThat(exception.message, containsString(missingCredentialError))
}

@Test
Expand All @@ -111,6 +112,6 @@ class AWSProceduresErrorsTest {
println(it.resultAsString())
}
}
assertThat(exception.message, containsString("java.lang.IllegalArgumentException: Missing parameter `secret`"))
assertThat(exception.message, containsString(missingCredentialError))
}
}

0 comments on commit 0493c03

Please sign in to comment.