forked from nikitabobko/AeroSpace
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
3/3 Implement nice API on top of AeroShellParser generated code
- Loading branch information
1 parent
098eba8
commit 8ff619f
Showing
3 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
import AeroShellParserGenerated | ||
import Antlr4 | ||
import Common | ||
|
||
/// Use the following technique for quick grammar testing: | ||
/// source .deps/python-venv/bin/activate.fish | ||
/// echo "foo bar" | antlr4-parse ./grammar/AeroShellLexer.g4 ./grammar/AeroShellParser.g4 root -gui | ||
extension String { | ||
func parseShell() -> Result<RawShell, String> { | ||
let stream = ANTLRInputStream(self) | ||
let lexer = AeroShellLexer(stream) | ||
let errorsCollector = ErrorListenerCollector() | ||
lexer.addErrorListener(errorsCollector) | ||
let tokenStream = CommonTokenStream(lexer) | ||
let parser: AeroShellParser | ||
switch Result(catching: { try AeroShellParser(tokenStream) }) { | ||
case .success(let x): parser = x | ||
case .failure(let msg): | ||
return .failure(msg.localizedDescription) | ||
} | ||
parser.addErrorListener(errorsCollector) | ||
let root: AeroShellParser.RootContext | ||
switch Result(catching: { try parser.root() }) { | ||
case .success(let x): root = x | ||
case .failure(let msg): | ||
return .failure(msg.localizedDescription) | ||
} | ||
if !errorsCollector.errors.isEmpty { | ||
return .failure(errorsCollector.errors.joinErrors()) | ||
} | ||
return root.program().map { $0.toTyped() } ?? .success(.empty) | ||
} | ||
} | ||
|
||
class ErrorListenerCollector: BaseErrorListener { | ||
var errors: [String] = [] | ||
override func syntaxError<T>( | ||
_ recognizer: Recognizer<T>, | ||
_ offendingSymbol: AnyObject?, | ||
_ line: Int, | ||
_ charPositionInLine: Int, | ||
_ msg: String, | ||
_ e: AnyObject? | ||
) { | ||
errors.append("Syntax error at \(line):\(charPositionInLine) \(msg)") | ||
} | ||
} | ||
|
||
extension AeroShellParser.ProgramContext { | ||
func toTyped() -> Result<RawShell, String> { | ||
if let x = self as? AeroShellParser.NotContext { | ||
return x.program().toTyped("not node: nil child") | ||
} | ||
if let x = self as? AeroShellParser.PipeContext { | ||
return binaryNode(Shell.pipe, x.program(0), x.program(1)) | ||
} | ||
if let x = self as? AeroShellParser.AndContext { | ||
return binaryNode(Shell.and, x.program(0), x.program(1)) | ||
} | ||
if let x = self as? AeroShellParser.OrContext { | ||
return binaryNode(Shell.or, x.program(0), x.program(1)) | ||
} | ||
if let x = self as? AeroShellParser.SeqContext { | ||
let seq = x.program() | ||
return switch seq.count { | ||
case 0: .failure("seq node: 0 children") | ||
case 1: seq.first!.toTyped() | ||
default: seq.mapAllOrFailures { $0.toTyped() }.mapError { $0.joinErrors() }.map(Shell.seq) | ||
} | ||
} | ||
if let x = self as? AeroShellParser.ParensContext { | ||
return x.program().toTyped("parens node: nil childe") | ||
} | ||
if let x = self as? AeroShellParser.ArgsContext { | ||
return x.arg().mapAllOrFailures { $0.toTyped() }.mapError { $0.joinErrors() }.map(Shell.args) | ||
} | ||
error("Unknown node type: \(self)") | ||
} | ||
} | ||
|
||
|
||
extension AeroShellParser.ArgContext { | ||
func toTyped() -> Result<ShellString<String>, String> { | ||
if let x = self as? AeroShellParser.WordContext { | ||
return .success(.text(x.getText())) | ||
} | ||
if let x = self as? AeroShellParser.DQuotedStringContext { | ||
let seq = x.dStringFragment() | ||
return switch seq.count { | ||
case 1: seq.first!.toTyped() | ||
default: | ||
seq.mapAllOrFailures { $0.toTyped() }.mapError { $0.joinErrors() }.map(ShellString.concatOptimized) | ||
} | ||
} | ||
if let x = self as? AeroShellParser.SQuotedStringContext { | ||
return .success(.text(String(x.getText().dropFirst(1).dropLast(1)))) | ||
} | ||
if let x = self as? AeroShellParser.SubstitutionContext { | ||
return x.program().toTyped("substitution node: nil child").map(ShellString.interpolation) | ||
} | ||
error("Unknown node type: \(self)") | ||
} | ||
} | ||
|
||
extension AeroShellParser.DStringFragmentContext { | ||
func toTyped() -> Result<ShellString<String>, String> { | ||
if let x = ESCAPE_SEQUENCE() { | ||
return switch x.getText() { | ||
case "\\n": .success(.text("\n")) | ||
case "\\t": .success(.text("\t")) | ||
case "\\$": .success(.text("$")) | ||
case "\\\"": .success(.text("\"")) | ||
case "\\\\": .success(.text("\\")) | ||
default: .failure("Unknown ESCAPE_SEQUENCE '\(x.getText())'") | ||
} | ||
} | ||
if let x = TEXT() { | ||
return .success(.text(x.getText())) | ||
} | ||
if let x = program() { | ||
return x.toTyped().map(ShellString.interpolation) | ||
} | ||
error("Unknown node type: \(self)") | ||
} | ||
} | ||
|
||
private func binaryNode( | ||
_ op: (RawShell, RawShell) -> RawShell, | ||
_ a: AeroShellParser.ProgramContext?, | ||
_ b: AeroShellParser.ProgramContext? | ||
) -> Result<RawShell, String> { | ||
a.toTyped("binary node: nil child 0").combine { b.toTyped("binary node: nil child 1") }.map(op) | ||
} | ||
|
||
extension Result { | ||
func combine<T>(_ other: () -> Result<T, Failure>) -> Result<(Success, T), Failure> { | ||
flatMap { a in | ||
other().flatMap { b in | ||
.success((a, b)) | ||
} | ||
} | ||
} | ||
} | ||
|
||
extension Result where Success == AeroShellParser.ProgramContext, Failure == String { | ||
func toTyped() -> Result<RawShell, String> { flatMap { $0.toTyped() } } | ||
} | ||
|
||
private extension Optional where Wrapped == AeroShellParser.ProgramContext { | ||
func toTyped(_ msg: String) -> Result<RawShell, String> { orFailure(msg).toTyped() } | ||
} | ||
|
||
class CmdMutableState { | ||
var stdin: String | ||
var env: [String: String] | ||
|
||
init(stdin: String, pwd: String) { | ||
self.stdin = stdin | ||
self.env = config.execConfig.envVariables | ||
self.env["PWD"] = pwd | ||
} | ||
} | ||
|
||
struct CmdOut { | ||
let stdout: [String] | ||
let exitCode: Int | ||
|
||
static func succ(_ stdout: [String]) -> CmdOut { CmdOut(stdout: stdout, exitCode: 0) } | ||
static func fail(_ stdout: [String]) -> CmdOut { CmdOut(stdout: stdout, exitCode: 1) } | ||
} | ||
|
||
// protocol AeroShell { | ||
// func run(_ state: CmdMutableState) -> CmdOut | ||
// } | ||
// extension [String] : AeroShell { | ||
// func run(_ state: CmdMutableState) -> CmdOut { .succ(self) } | ||
// } | ||
|
||
extension Shell: Equatable where T: Equatable {} | ||
typealias AeroShell = Shell<any Command> | ||
typealias RawShell = Shell<String> | ||
indirect enum Shell<T> { | ||
case args([ShellString<T>]) | ||
case empty | ||
|
||
// Listed in precedence order | ||
case not(Shell<T>) | ||
case pipe(Shell<T>, Shell<T>) | ||
case and(Shell<T>, Shell<T>) | ||
case or(Shell<T>, Shell<T>) | ||
case seq([Shell<T>]) | ||
} | ||
|
||
extension ShellString: Equatable where T: Equatable {} | ||
enum ShellString<T> { | ||
case text(String) | ||
case interpolation(Shell<T>) | ||
case concat([ShellString<T>]) | ||
|
||
static func concatOptimized(_ fragments: [ShellString<T>]) -> ShellString<T> { | ||
var result: [ShellString<T>] = [] | ||
var current: String = "" | ||
_concatOptimized(fragments, &result, ¤t) | ||
if !current.isEmpty { | ||
result.append(.text(current)) | ||
} | ||
return result.singleOrNil() ?? .concat(result) | ||
} | ||
|
||
private static func _concatOptimized( | ||
_ fragments: [ShellString<T>], | ||
_ result: inout [ShellString<T>], | ||
_ current: inout String | ||
) { | ||
for fragment in fragments { | ||
switch fragment { | ||
case .text(let text): current += text | ||
case .concat(let newFragments): _concatOptimized(newFragments, &result, ¤t) | ||
case .interpolation: | ||
if !current.isEmpty { | ||
result.append(.text(current)) | ||
current = "" | ||
} | ||
result.append(fragment) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import XCTest | ||
import Common | ||
|
||
// Because XCTAssertEqual default messages are unreadable! | ||
func assertFailure<T, F>(_ r: Result<T, F>, file: String = #file, line: Int = #line) { | ||
switch r { | ||
case .success: failExpectedActual("Result.failure", r, file: file, line: line) | ||
case .failure: break | ||
} | ||
} | ||
|
||
func assertEquals<T>( _ actual: T, _ expected: T, file: String = #file, line: Int = #line) where T: Equatable { | ||
if actual != expected { | ||
failExpectedActual(expected, actual, file: file, line: line) | ||
} | ||
} | ||
|
||
private func failExpectedActual( _ expected: Any, _ actual: Any, file: String = #file, line: Int = #line) { | ||
XCTFail( | ||
""" | ||
Assertion failed at \(file):\(line) | ||
Expected: | ||
\(expected) | ||
Actual: | ||
\(actual) | ||
""" | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import XCTest | ||
import Common | ||
@testable import AppBundle | ||
|
||
final class AeroShellTest: XCTestCase { | ||
func testParse() { | ||
let a = cmd("a") | ||
let b = cmd("b") | ||
let c = cmd("c") | ||
let d = cmd("d") | ||
let e = cmd("e") | ||
let f = cmd("f") | ||
let backslash = "\\" | ||
let space = " " | ||
|
||
assertEquals("\"foo \(backslash)\" bar \(backslash)\(backslash)\(backslash)\(backslash)\" bar".parseShell().getOrThrow(), cmd("foo \" bar \(backslash)\(backslash)", "bar")) | ||
assertEquals(" ".parseShell().getOrThrow(), .empty) | ||
assertEquals("a | b && c | d".parseShell().getOrThrow(), .and(.pipe(a, b), .pipe(c, d))) | ||
assertEquals("foo && bar || a && baz".parseShell().getOrThrow(), .or(.and(cmd("foo"), cmd("bar")), .and(cmd("a"), cmd("baz")))) | ||
assertEquals("foo a b; bar duh\n baz bro".parseShell().getOrThrow(), .seqV(cmd("foo", "a", "b"), cmd("bar", "duh"), cmd("baz", "bro"))) | ||
assertEquals("(a || b) && (c || d)".parseShell().getOrThrow(), .and(.or(a, b), .or(c, d))) | ||
assertEquals(""" | ||
a # comment 1 | ||
b && c # comment 2 | ||
d; # comment 3 | ||
""".parseShell().getOrThrow(), .seqV(a, .and(b, c), d)) | ||
assertEquals(""" | ||
a && b # comment 1 | ||
# comment 2 | ||
|| c && d | ||
""".parseShell().getOrThrow(), .or(.and(a, b), .and(c, d))) | ||
assertEquals(""" | ||
a \(backslash)\(space) | ||
b c \(backslash) # comment 2 | ||
d && e \(backslash) | ||
&& f | ||
""".parseShell().getOrThrow(), .and(.and(cmd("a", "b", "c", "d"), e), f)) | ||
assertEquals(""" | ||
echo "hi $(foo bar)" | ||
""".parseShell().getOrThrow(), | ||
.args([.text("echo"), .concatV(.text("hi "), .interpolation(cmd("foo", "bar")))]) | ||
) | ||
|
||
assertFailure("echo \"\"\"\"".parseShell()) | ||
assertFailure("echo \"foo \(backslash)\"".parseShell()) | ||
assertFailure("|| foo".parseShell()) | ||
assertFailure("a && (b || c) foo".parseShell()) | ||
} | ||
} | ||
|
||
// extension Shell: ExpressibleByUnicodeScalarLiteral where T == String { // Please Swift | ||
// public init(unicodeScalarLiteral value: UnicodeScalarLiteralType) { error("Unused") } | ||
// } | ||
// extension Shell: ExpressibleByExtendedGraphemeClusterLiteral where T == String { // Please Swift | ||
// public init(extendedGraphemeClusterLiteral value: ExtendedGraphemeClusterLiteralType) { error("Unused") } | ||
// } | ||
// extension Shell: ExpressibleByStringLiteral where T == String { | ||
// public typealias StringLiteralType = String | ||
// public init(stringLiteral: String) { | ||
// self = .args([.text(stringLiteral)]) | ||
// } | ||
// } | ||
|
||
// extension Shell: ExpressibleByArrayLiteral where T == String { | ||
// public typealias ArrayLiteralElement = String | ||
// public init(arrayLiteral elements: ArrayLiteralElement...) { | ||
// self = .args(elements.map(ShellString.text)) | ||
// } | ||
// } | ||
|
||
func cmd(_ args: String...) -> Shell<String> { .args(args.map(ShellString.text)) } | ||
extension Shell { | ||
static func seqV(_ seq: Shell<T>...) -> Shell<T> { .seq(seq) } | ||
} | ||
|
||
extension ShellString { | ||
static func concatV(_ fragments: ShellString<T>...) -> ShellString<T> { .concat(fragments) } | ||
} |