diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 69eeab426ed01..209f369f009be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -529,7 +529,6 @@ trait ShuffleSpec { * clustering expressions. * * This will only be called when: - * - [[canCreatePartitioning]] returns true. * - [[isCompatibleWith]] returns false on the side where the `clustering` is from. */ def createPartitioning(clustering: Seq[Expression]): Partitioning = @@ -542,7 +541,7 @@ case object SinglePartitionShuffleSpec extends ShuffleSpec { other.numPartitions == 1 } - override def canCreatePartitioning: Boolean = true + override def canCreatePartitioning: Boolean = false override def createPartitioning(clustering: Seq[Expression]): Partitioning = SinglePartition diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index 7e11d4f68392f..51e7688732265 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -367,7 +367,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { assert(HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), distribution) .canCreatePartitioning) } - assert(SinglePartitionShuffleSpec.canCreatePartitioning) + assert(!SinglePartitionShuffleSpec.canCreatePartitioning) withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false") { assert(ShuffleSpecCollection(Seq( HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 67a58da89625e..581fa1475b8a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -73,9 +73,13 @@ case class EnsureRequirements( case _ => false }.map(_._2) + // Special case: if all sides of the join are single partition + val allSinglePartition = + childrenIndexes.forall(children(_).outputPartitioning == SinglePartition) + // If there are more than one children, we'll need to check partitioning & distribution of them // and see if extra shuffles are necessary. - if (childrenIndexes.length > 1) { + if (childrenIndexes.length > 1 && !allSinglePartition) { val specs = childrenIndexes.map(i => { val requiredDist = requiredChildDistributions(i) assert(requiredDist.isInstanceOf[ClusteredDistribution], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c3c8959d6e1ca..000bd8c84f64a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -262,7 +262,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val numExchanges = collect(plan) { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } { @@ -278,7 +278,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val numExchanges = collect(plan) { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 7237cc5f0fa51..d692ba5b17073 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -433,8 +433,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: Nil, exprC :: exprD :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, SinglePartition, _, _), _), - SortExec(_, _, ShuffleExchangeExec(SinglePartition, _, _), _), _) => + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(left.numPartitions == 5) + assert(right.numPartitions == 5) case other => fail(other.toString) } @@ -690,6 +692,45 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + test("SPARK-40703: shuffle for SinglePartitionShuffleSpec") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> 20.toString) { + // We should re-shuffle the side with single partition when the other side is + // `HashPartitioning` with shuffle node, and respect the minimum parallelism. + var plan1: SparkPlan = ShuffleExchangeExec( + outputPartitioning = HashPartitioning(exprA :: Nil, 10), + DummySparkPlan()) + var plan2 = DummySparkPlan(outputPartitioning = SinglePartition) + var smjExec = SortMergeJoinExec(exprA :: Nil, exprC :: Nil, Inner, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprA)) + assert(rightKeys === Seq(exprC)) + assert(left.numPartitions == 20) + assert(right.numPartitions == 20) + case other => fail(other.toString) + } + + // We should also re-shuffle the side with only a single partition even the other side does + // not have `ShuffleExchange`, but just `HashPartitioning`. However in this case the minimum + // shuffle parallelism will be ignored since we don't want to introduce extra shuffle. + plan1 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprA :: Nil, 10)) + plan2 = DummySparkPlan(outputPartitioning = SinglePartition) + smjExec = SortMergeJoinExec(exprA :: Nil, exprC :: Nil, Inner, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprA)) + assert(rightKeys === Seq(exprC)) + assert(right.numPartitions == 10) + case other => fail(other.toString) + } + } + } + test("Check with KeyGroupedPartitioning") { // simplest case: identity transforms var plan1 = DummySparkPlan(