Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In Expr bindings. #10

Merged
merged 5 commits into from
Jun 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions cpp/src/gandiva/jni/jni_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,45 @@ NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
return TreeExprBuilder::MakeOr(children);
}

NodePtr ProtoTypeToInNode(const types::InNode& node) {
NodePtr field = ProtoTypeToFieldNode(node.field());

if (node.has_intvalues()) {
std::unordered_set<int32_t> int_values;
for (int i = 0; i < node.intvalues().intvalues_size(); i++) {
int_values.insert(node.intvalues().intvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionInt32(field, int_values);
}

if (node.has_longvalues()) {
std::unordered_set<int64_t> long_values;
for (int i = 0; i < node.longvalues().longvalues_size(); i++) {
long_values.insert(node.longvalues().longvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionInt64(field, long_values);
}

if (node.has_stringvalues()) {
std::unordered_set<std::string> stringvalues;
for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
stringvalues.insert(node.stringvalues().stringvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionString(field, stringvalues);
}

if (node.has_binaryvalues()) {
std::unordered_set<std::string> stringvalues;
for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) {
stringvalues.insert(node.binaryvalues().binaryvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues);
}
// not supported yet.
std::cerr << "Unknown constant type for in expression.\n";
return nullptr;
}

NodePtr ProtoTypeToNullNode(const types::NullNode& node) {
DataTypePtr data_type = ProtoTypeToDataType(node.type());
if (data_type == nullptr) {
Expand Down Expand Up @@ -344,6 +383,10 @@ NodePtr ProtoTypeToNode(const types::TreeNode& node) {
return ProtoTypeToOrNode(node.ornode());
}

if (node.has_innode()) {
return ProtoTypeToInNode(node.innode());
}

if (node.has_nullnode()) {
return ProtoTypeToNullNode(node.nullnode());
}
Expand Down
27 changes: 27 additions & 0 deletions cpp/src/gandiva/proto/Types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ message TreeNode {
optional StringNode stringNode = 17;
optional BinaryNode binaryNode = 18;
optional DecimalNode decimalNode = 19;

// in expr
optional InNode inNode = 21;
}

message ExpressionRoot {
Expand Down Expand Up @@ -205,3 +208,27 @@ message FunctionSignature {
optional ExtGandivaType returnType = 2;
repeated ExtGandivaType paramTypes = 3;
}

message InNode {
optional FieldNode field = 1;
optional IntConstants intValues = 2;
optional LongConstants longValues = 3;
optional StringConstants stringValues = 4;
optional BinaryConstants binaryValues = 5;
}

message IntConstants {
repeated IntNode intValues = 1;
}

message LongConstants {
repeated LongNode longValues = 1;
}

message StringConstants {
repeated StringNode stringValues = 1;
}

message BinaryConstants {
repeated BinaryNode binaryValues = 1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.gandiva.expression;

import java.nio.charset.Charset;
import java.util.Set;

import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.gandiva.ipc.GandivaTypes;
import org.apache.arrow.vector.types.pojo.Field;

import com.google.protobuf.ByteString;

/**
* In Node representation in java.
*/
public class InNode implements TreeNode {
private static final Charset charset = Charset.forName("UTF-8");

private final Set<Integer> intValues;
private final Set<Long> longValues;
private final Set<String> stringValues;
private final Set<byte[]> binaryValues;
private final Field field;

private InNode(Set<Integer> values, Set<Long> longValues, Set<String> stringValues, Set<byte[]>
binaryValues, Field field) {
this.intValues = values;
this.longValues = longValues;
this.stringValues = stringValues;
this.binaryValues = binaryValues;
this.field = field;
}

public static InNode makeIntInExpr(Field field, Set<Integer> intValues) {
return new InNode(intValues, null, null, null ,field);
}

public static InNode makeLongInExpr(Field field, Set<Long> longValues) {
return new InNode(null, longValues, null, null ,field);
}

public static InNode makeStringInExpr(Field field, Set<String> stringValues) {
return new InNode(null, null, stringValues, null ,field);
}

public static InNode makeBinaryInExpr(Field field, Set<byte[]> binaryValues) {
return new InNode(null, null, null, binaryValues ,field);
}

@Override
public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
GandivaTypes.InNode.Builder inNode = GandivaTypes.InNode.newBuilder();

GandivaTypes.FieldNode.Builder fieldNode = GandivaTypes.FieldNode.newBuilder();
fieldNode.setField(ArrowTypeHelper.arrowFieldToProtobuf(field));
inNode.setField(fieldNode);

if (intValues != null) {
GandivaTypes.IntConstants.Builder intConstants = GandivaTypes.IntConstants.newBuilder();
intValues.stream().forEach(val -> intConstants.addIntValues(GandivaTypes.IntNode.newBuilder()
.setValue(val).build()));
inNode.setIntValues(intConstants.build());
} else if (longValues != null) {
GandivaTypes.LongConstants.Builder longConstants = GandivaTypes.LongConstants.newBuilder();
longValues.stream().forEach(val -> longConstants.addLongValues(GandivaTypes.LongNode.newBuilder()
.setValue(val).build()));
inNode.setLongValues(longConstants.build());
} else if (stringValues != null) {
GandivaTypes.StringConstants.Builder stringConstants = GandivaTypes.StringConstants
.newBuilder();
stringValues.stream().forEach(val -> stringConstants.addStringValues(GandivaTypes.StringNode
.newBuilder().setValue(ByteString.copyFrom(val.getBytes(charset))).build()));
inNode.setStringValues(stringConstants.build());
} else if (binaryValues != null) {
GandivaTypes.BinaryConstants.Builder binaryConstants = GandivaTypes.BinaryConstants
.newBuilder();
binaryValues.stream().forEach(val -> binaryConstants.addBinaryValues(GandivaTypes.BinaryNode
.newBuilder().setValue(ByteString.copyFrom(val)).build()));
inNode.setBinaryValues(binaryConstants.build());
}
GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
builder.setInNode(inNode.build());
return builder.build();

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
Expand Down Expand Up @@ -185,4 +186,24 @@ public static Condition makeCondition(String function,
TreeNode root = makeFunction(function, children, new ArrowType.Bool());
return makeCondition(root);
}

public static TreeNode makeInExpressionInt32(Field resultField,
Set<Integer> intValues) {
return InNode.makeIntInExpr(resultField, intValues);
}

public static TreeNode makeInExpressionBigInt(Field resultField,
Set<Long> longValues) {
return InNode.makeLongInExpr(resultField, longValues);
}

public static TreeNode makeInExpressionString(Field resultField,
Set<String> stringValues) {
return InNode.makeStringInExpr(resultField, stringValues);
}

public static TreeNode makeInExpressionBinary(Field resultField,
Set<byte[]> binaryValues) {
return InNode.makeBinaryInExpr(resultField, binaryValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.junit.Test;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import io.netty.buffer.ArrowBuf;

Expand Down Expand Up @@ -1047,6 +1048,96 @@ public void testEquals() throws GandivaException, Exception {
eval.close();
}

@Test
public void testInExpr() throws GandivaException, Exception {
Field c1 = Field.nullable("c1", int32);

TreeNode inExpr =
TreeBuilder.makeInExpressionInt32(c1, Sets.newHashSet(1,2,3,4,5,15,16));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));

int numRows = 16;
byte[] validity = new byte[]{(byte) 255, 0};
int[] c1Values = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

ArrowBuf c1Validity = buf(validity);
ArrowBuf c1Data = intBuf(c1Values);
ArrowBuf c2Validity = buf(validity);

ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
new ArrowRecordBatch(
numRows,
Lists.newArrayList(fieldNode, fieldNode),
Lists.newArrayList(c1Validity, c1Data, c2Validity));

BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
bitVector.allocateNew(numRows);

List<ValueVector> output = new ArrayList<ValueVector>();
output.add(bitVector);
eval.evaluate(batch, output);

for (int i = 0; i < 5; i++) {
assertTrue(bitVector.getObject(i).booleanValue());
}
for (int i = 5; i < 16; i++) {
assertFalse(bitVector.getObject(i).booleanValue());
}

releaseRecordBatch(batch);
releaseValueVectors(output);
eval.close();
}

@Test
public void testInExprStrings() throws GandivaException, Exception {
Field c1 = Field.nullable("c1", new ArrowType.Utf8());

TreeNode inExpr =
TreeBuilder.makeInExpressionString(c1, Sets.newHashSet("one", "two", "three", "four"));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));

int numRows = 16;
byte[] validity = new byte[]{(byte) 255, 0};
String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
"eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
"sixteen"};

ArrowBuf c1Validity = buf(validity);
List<ArrowBuf> dataBufsX = stringBufs(c1Values);
ArrowBuf c2Validity = buf(validity);

ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
new ArrowRecordBatch(
numRows,
Lists.newArrayList(fieldNode, fieldNode),
Lists.newArrayList(c1Validity, dataBufsX.get(0),dataBufsX.get(1), c2Validity));

BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
bitVector.allocateNew(numRows);

List<ValueVector> output = new ArrayList<ValueVector>();
output.add(bitVector);
eval.evaluate(batch, output);

for (int i = 0; i < 4; i++) {
assertTrue(bitVector.getObject(i).booleanValue());
}
for (int i = 5; i < 16; i++) {
assertFalse(bitVector.getObject(i).booleanValue());
}

releaseRecordBatch(batch);
releaseValueVectors(output);
eval.close();
}

@Test
public void testSmallOutputVectors() throws GandivaException, Exception {
Field a = Field.nullable("a", int32);
Expand Down