Skip to content

Commit

Permalink
Merge pull request #28936 from IMS94/create-fn-code-action
Browse files Browse the repository at this point in the history
Overload NonTerminalNode.findNode() to include start offset and improve related create function code action usage
  • Loading branch information
nadeeshaan authored Mar 3, 2021
2 parents 9e966aa + 180add4 commit 1fc25b3
Show file tree
Hide file tree
Showing 49 changed files with 784 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import io.ballerina.compiler.syntax.tree.RequiredParameterNode;
import io.ballerina.compiler.syntax.tree.RestParameterNode;
import io.ballerina.compiler.syntax.tree.SimpleNameReferenceNode;
import io.ballerina.compiler.syntax.tree.SpecificFieldNode;
import io.ballerina.compiler.syntax.tree.TableConstructorExpressionNode;
import io.ballerina.compiler.syntax.tree.TemplateExpressionNode;
import io.ballerina.compiler.syntax.tree.Token;
Expand Down Expand Up @@ -462,6 +463,11 @@ public Optional<Location> transform(PositionalArgumentNode positionalArgumentNod
return Optional.of(positionalArgumentNode.location());
}

@Override
public Optional<Location> transform(SpecificFieldNode specificFieldNode) {
return Optional.of(specificFieldNode.location());
}

@Override
protected Optional<Location> transformSyntaxNode(Node node) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,24 @@ public Token findToken(int position, boolean insideMinutiae) {

/**
* Find the inner most node encapsulating the a text range.
* Note: When evaluating the position of a node to check the range this will include the start offset while
* Note: When evaluating the position of a node to check the range this will not include the start offset while
* excluding the end offset
*
* @param textRange to evaluate and find the node
* @return {@link NonTerminalNode} which is the inner most non terminal node, encapsulating the given position
*/
public NonTerminalNode findNode(TextRange textRange) {
public NonTerminalNode findNode(TextRange textRange) {
return findNode(textRange, false);
}

/**
* Find the inner most node encapsulating the a text range.
*
* @param textRange text range to evaluate
* @param includeStartOffset whether to include start offset when checking text range
* @return Innermost {@link NonTerminalNode} encapsulation given text range
*/
public NonTerminalNode findNode(TextRange textRange, boolean includeStartOffset) {
TextRange textRangeWithMinutiae = textRangeWithMinutiae();
if (!(this instanceof ModulePartNode)
&& (!textRangeWithMinutiae.contains(textRange.startOffset())
Expand All @@ -143,7 +154,7 @@ public NonTerminalNode findNode(TextRange textRange) {
Optional<Node> temp = Optional.of(this);
while (temp.isPresent() && SyntaxUtils.isNonTerminalNode(temp.get())) {
foundNode = (NonTerminalNode) temp.get();
temp = ((NonTerminalNode) temp.get()).findChildNode(textRange);
temp = ((NonTerminalNode) temp.get()).findChildNode(textRange, includeStartOffset);
}

return foundNode;
Expand Down Expand Up @@ -246,21 +257,31 @@ private Node findChildNode(int position) {
* Find a child node enclosing the given text range.
* If there is no child node which can wrap the given range, this method will return empty
*
* @param textRange text range to evaluate
* @param textRange text range to evaluate
* @param includeStartOffset whether to include start offset when checking textRange
* @return {@link Optional} node found, which is enclosing the given range
*/
private Optional<Node> findChildNode(TextRange textRange) {
private Optional<Node> findChildNode(TextRange textRange, boolean includeStartOffset) {
int offset = textRangeWithMinutiae().startOffset();
for (int bucket = 0; bucket < internalNode.bucketCount(); bucket++) {
STNode internalChildNode = internalNode.childInBucket(bucket);
if (!isSTNodePresent(internalChildNode)) {
continue;
}
int offsetWithMinutiae = offset + internalChildNode.widthWithMinutiae();
if (textRange.startOffset() > offset && textRange.endOffset() <= offsetWithMinutiae) {
// Populate the external node.
return Optional.ofNullable(this.childInBucket(bucket));

if (includeStartOffset) {
if (textRange.startOffset() >= offset && textRange.endOffset() <= offsetWithMinutiae) {
// Populate the external node.
return Optional.ofNullable(this.childInBucket(bucket));
}
} else {
if (textRange.startOffset() > offset && textRange.endOffset() <= offsetWithMinutiae) {
// Populate the external node.
return Optional.ofNullable(this.childInBucket(bucket));
}
}

offset += internalChildNode.widthWithMinutiae();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
*/
package org.ballerinalang.langserver.codeaction.providers;

import io.ballerina.compiler.syntax.tree.AssignmentStatementNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.FunctionCallExpressionNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NonTerminalNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.VariableDeclarationNode;
import io.ballerina.tools.diagnostics.Diagnostic;
import org.ballerinalang.annotation.JavaSPIService;
import org.ballerinalang.langserver.codeaction.CodeActionUtil;
Expand All @@ -33,7 +30,7 @@
import org.eclipse.lsp4j.CodeAction;
import org.eclipse.lsp4j.CodeActionKind;
import org.eclipse.lsp4j.Command;
import org.eclipse.lsp4j.Position;
import org.eclipse.lsp4j.Range;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -62,26 +59,25 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
return Collections.emptyList();
}

if (positionDetails.matchedNode() == null) {
return Collections.emptyList();
}

Optional<FunctionCallExpressionNode> callExpr =
checkAndGetFunctionCallExpressionNode(positionDetails.matchedNode());
if (callExpr.isEmpty()) {
return Collections.emptyList();
}

String diagnosticMessage = diagnostic.message();
Position position = CommonUtil.toRange(diagnostic.location().lineRange()).getStart();
Range range = CommonUtil.toRange(diagnostic.location().lineRange());
String uri = context.fileUri();
CommandArgument posArg = CommandArgument.from(CommandConstants.ARG_KEY_NODE_POS, position);
CommandArgument posArg = CommandArgument.from(CommandConstants.ARG_KEY_NODE_RANGE, range);
CommandArgument uriArg = CommandArgument.from(CommandConstants.ARG_KEY_DOC_URI, uri);

List<Object> args = Arrays.asList(posArg, uriArg);
Matcher matcher = CommandConstants.UNDEFINED_FUNCTION_PATTERN.matcher(diagnosticMessage);
String functionName = (matcher.find() && matcher.groupCount() > 0) ? matcher.group(1) + "(...)" : "";
Node cursorNode = positionDetails.matchedNode();

if (cursorNode == null) {
return Collections.emptyList();
}

Optional<FunctionCallExpressionNode> callExpr = getFunctionCallExpressionNodeAtCursor(cursorNode);

if (callExpr.isEmpty()) {
return Collections.emptyList();
}

boolean isWithinFile = callExpr.get().functionName().kind() == SyntaxKind.SIMPLE_NAME_REFERENCE;
if (isWithinFile) {
Expand All @@ -96,39 +92,13 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
return Collections.emptyList();
}

/**
* Tries to get the function call expression at the cursor.
*
* @param cursorNode Node at the cursor
* @return Optional function call expression at the cursor
*/
public static Optional<FunctionCallExpressionNode> getFunctionCallExpressionNodeAtCursor(Node cursorNode) {
Optional<FunctionCallExpressionNode> fnCallExprNode = checkAndGetFunctionCallExpressionNode(cursorNode);
if (fnCallExprNode.isEmpty()) {
if (cursorNode.kind() == SyntaxKind.LOCAL_VAR_DECL) {
VariableDeclarationNode varNode = (VariableDeclarationNode) cursorNode;
Optional<ExpressionNode> initializer = varNode.initializer();
if (initializer.isPresent()) {
fnCallExprNode = checkAndGetFunctionCallExpressionNode(initializer.get());
}
} else if (cursorNode.kind() == SyntaxKind.ASSIGNMENT_STATEMENT) {
AssignmentStatementNode assignmentNode = (AssignmentStatementNode) cursorNode;
fnCallExprNode = checkAndGetFunctionCallExpressionNode(assignmentNode.expression());
} else if (cursorNode.kind() == SyntaxKind.SIMPLE_NAME_REFERENCE) {
fnCallExprNode = checkAndGetFunctionCallExpressionNode(cursorNode.parent());
}
}

return fnCallExprNode;
}

/**
* Get the function call expression node if the provided node is a function call.
*
* @param node Node to be checked if it's a function call
* @return Optional function call expression node
*/
public static Optional<FunctionCallExpressionNode> checkAndGetFunctionCallExpressionNode(Node node) {
public static Optional<FunctionCallExpressionNode> checkAndGetFunctionCallExpressionNode(NonTerminalNode node) {
FunctionCallExpressionNode functionCallExpressionNode = null;
if (node.kind() == SyntaxKind.FUNCTION_CALL) {
functionCallExpressionNode = (FunctionCallExpressionNode) node;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
}

// Skip, non-local var declarations
if (positionDetails.matchedNode().kind() != SyntaxKind.LOCAL_VAR_DECL) {
VariableDeclarationNode localVarNode = getVariableDeclarationNode(positionDetails.matchedNode());
if (localVarNode == null) {
return Collections.emptyList();
}

Expand All @@ -77,7 +78,6 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
}

// Skip, variable declarations with non-initializers
VariableDeclarationNode localVarNode = (VariableDeclarationNode) positionDetails.matchedNode();
Optional<ExpressionNode> initializer = localVarNode.initializer();
if (initializer.isEmpty()) {
return Collections.emptyList();
Expand Down Expand Up @@ -123,6 +123,14 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
return actions;
}

private VariableDeclarationNode getVariableDeclarationNode(NonTerminalNode node) {
while (node != null && node.kind() != SyntaxKind.LOCAL_VAR_DECL) {
node = node.parent();
}

return node != null ? (VariableDeclarationNode) node : null;
}

private Optional<Range> getParameterTypeRange(NonTerminalNode parameterNode) {
if (parameterNode == null) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode;
import io.ballerina.compiler.syntax.tree.ModulePartNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NonTerminalNode;
import io.ballerina.compiler.syntax.tree.ReturnStatementNode;
import io.ballerina.compiler.syntax.tree.ReturnTypeDescriptorNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.SyntaxTree;
Expand Down Expand Up @@ -64,9 +66,11 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
return Collections.emptyList();
}

if (positionDetails.matchedNode().kind() != SyntaxKind.RETURN_STATEMENT) {
ReturnStatementNode returnStatementNode = getReturnStatement(positionDetails.matchedNode());
if (returnStatementNode == null) {
return Collections.emptyList();
}

Matcher matcher = CommandConstants.INCOMPATIBLE_TYPE_PATTERN.matcher(diagnostic.message());
if (matcher.find() && matcher.groupCount() > 1) {
String foundType = matcher.group(2);
Expand Down Expand Up @@ -106,6 +110,14 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
return Collections.emptyList();
}

private ReturnStatementNode getReturnStatement(NonTerminalNode node) {
while (node != null && node.kind() != SyntaxKind.RETURN_STATEMENT) {
node = node.parent();
}

return node != null ? (ReturnStatementNode) node : null;
}

private FunctionDefinitionNode getFunctionNode(DiagBasedPositionDetails positionDetails) {
Node parent = positionDetails.matchedNode();
while (parent.kind() != SyntaxKind.FUNCTION_DEFINITION) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.ModuleVariableDeclarationNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NonTerminalNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.VariableDeclarationNode;
import io.ballerina.projects.Document;
Expand All @@ -43,6 +44,7 @@
import org.eclipse.lsp4j.TextEdit;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
Expand All @@ -54,6 +56,7 @@
*/
@JavaSPIService("org.ballerinalang.langserver.commons.codeaction.spi.LSCodeActionProvider")
public class TypeCastCodeAction extends AbstractCodeActionProvider {

/**
* {@inheritDoc}
*/
Expand All @@ -64,10 +67,8 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
if (!(diagnostic.message().contains(CommandConstants.INCOMPATIBLE_TYPES))) {
return Collections.emptyList();
}
Node matchedNode = positionDetails.matchedNode();
if (matchedNode.kind() != SyntaxKind.LOCAL_VAR_DECL &&
matchedNode.kind() != SyntaxKind.MODULE_VAR_DECL &&
matchedNode.kind() != SyntaxKind.ASSIGNMENT_STATEMENT) {
Node matchedNode = getMatchedNode(positionDetails.matchedNode());
if (matchedNode == null) {
return Collections.emptyList();
}

Expand Down Expand Up @@ -103,6 +104,16 @@ public List<CodeAction> getDiagBasedCodeActions(Diagnostic diagnostic,
return Collections.singletonList(createQuickFixCodeAction(commandTitle, edits, context.fileUri()));
}

private NonTerminalNode getMatchedNode(NonTerminalNode node) {
List<SyntaxKind> syntaxKinds = Arrays.asList(SyntaxKind.LOCAL_VAR_DECL,
SyntaxKind.MODULE_VAR_DECL, SyntaxKind.ASSIGNMENT_STATEMENT);
while (node != null && !syntaxKinds.contains(node.kind())) {
node = node.parent();
}

return node;
}

private Optional<ExpressionNode> getExpression(Node node) {
if (node.kind() == SyntaxKind.LOCAL_VAR_DECL) {
return ((VariableDeclarationNode) node).initializer();
Expand Down Expand Up @@ -138,7 +149,7 @@ protected Optional<VariableSymbol> getVariableSymbol(CodeActionContext context,
SemanticModel semanticModel = context.currentSemanticModel().orElseThrow();
Document srcFile = context.currentDocument().orElseThrow();
Optional<Symbol> symbol = semanticModel.symbol(srcFile,
assignmentStmtNode.varRef().lineRange().startLine());
assignmentStmtNode.varRef().lineRange().startLine());
if (symbol.isEmpty() || symbol.get().kind() != SymbolKind.VARIABLE) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.ballerinalang.langserver.commons.command.LSCommandExecutorException;
import org.ballerinalang.langserver.commons.command.spi.LSCommandExecutor;
import org.ballerinalang.langserver.contexts.ContextBuilder;
import org.ballerinalang.langserver.exception.UserErrorException;
import org.eclipse.lsp4j.CodeActionParams;
import org.eclipse.lsp4j.Position;
import org.eclipse.lsp4j.Range;
Expand Down Expand Up @@ -74,35 +75,35 @@ public class CreateFunctionExecutor implements LSCommandExecutor {
@Override
public Object execute(ExecuteCommandContext context) throws LSCommandExecutorException {
String uri = null;
Position position = null;
Range range = null;
for (CommandArgument arg : context.getArguments()) {
switch (arg.key()) {
case CommandConstants.ARG_KEY_DOC_URI:
uri = arg.valueAs(String.class);
break;
case CommandConstants.ARG_KEY_NODE_POS:
position = arg.valueAs(Position.class);
case CommandConstants.ARG_KEY_NODE_RANGE:
range = arg.valueAs(Range.class);
break;
default:
}
}

Optional<Path> filePath = CommonUtil.getPathFromURI(uri);
if (position == null || filePath.isEmpty()) {
throw new LSCommandExecutorException("Invalid parameters received for the create function command!");
if (range == null || filePath.isEmpty()) {
throw new UserErrorException("Invalid parameters received for the create function command!");
}

SyntaxTree syntaxTree = context.workspace().syntaxTree(filePath.get()).orElseThrow();
NonTerminalNode cursorNode = CommonUtil.findNode(new Range(position, position), syntaxTree);
NonTerminalNode cursorNode = CommonUtil.findNode(range, syntaxTree);

if (cursorNode == null) {
return Collections.emptyList();
}

Optional<FunctionCallExpressionNode> fnCallExprNode =
CreateFunctionCodeAction.getFunctionCallExpressionNodeAtCursor(cursorNode);
CreateFunctionCodeAction.checkAndGetFunctionCallExpressionNode(cursorNode);
if (fnCallExprNode.isEmpty()) {
return new LSCommandExecutorException("Couldn't find a matching node");
return new UserErrorException("Couldn't find a matching node");
}

SemanticModel semanticModel = context.workspace().semanticModel(filePath.get()).orElseThrow();
Expand Down Expand Up @@ -162,8 +163,8 @@ public Object execute(ExecuteCommandContext context) throws LSCommandExecutorExc

LanguageClient client = context.getLanguageClient();
List<TextEdit> edits = new ArrayList<>();
Range range = new Range(new Position(endLine, endCol), new Position(endLine, endCol));
edits.add(new TextEdit(range, function));
Range insertRange = new Range(new Position(endLine, endCol), new Position(endLine, endCol));
edits.add(new TextEdit(insertRange, function));
TextDocumentEdit textDocumentEdit = new TextDocumentEdit(new VersionedTextDocumentIdentifier(uri, null), edits);
return CommandUtil.applyWorkspaceEdit(Collections.singletonList(Either.forLeft(textDocumentEdit)), client);
}
Expand Down
Loading

0 comments on commit 1fc25b3

Please sign in to comment.