Skip to content

Commit

Permalink
Implemented recursive schema checker
Browse files Browse the repository at this point in the history
Enables nested structures to be compared regardless of order.
  • Loading branch information
Stephen Kestle committed May 14, 2021
1 parent 7513117 commit 791606f
Showing 1 changed file with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,42 @@
package com.github.mrpowers.spark.daria.sql

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StructField, StructType}

import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}

case class InvalidDataFrameSchemaException(smth: String) extends Exception(smth)

private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructType) {
private def diff(required: Seq[StructField], schema: StructType): Seq[StructField] = {
required.filterNot(isPresentIn(schema))
}

private def isPresentIn(schema: StructType)(reqField: StructField): Boolean = {
Try(schema(reqField.name)) match {
case Success(namedField) =>
val basicMatch =
namedField.name == reqField.name &&
namedField.nullable == reqField.nullable &&
namedField.metadata == reqField.metadata

val contentMatch = reqField.dataType match {
case reqSchema: StructType =>
namedField.dataType match {
case fieldSchema: StructType =>
diff(reqSchema, fieldSchema).isEmpty
case _ => false
}
case _ => reqField == namedField
}

basicMatch && contentMatch
case Failure(_) => false
}
}

val missingStructFields = requiredSchema.diff(df.schema)
val missingStructFields: Seq[StructField] = diff(requiredSchema, df.schema)

def missingStructFieldsMessage(): String = {
s"The [${missingStructFields.mkString(", ")}] StructFields are not included in the DataFrame with the following StructFields [${df.schema.toString()}]"
Expand Down

0 comments on commit 791606f

Please sign in to comment.