diff --git a/amf-shapes/shared/src/main/scala/amf/shapes/client/scala/model/domain/ShapeHelpers.scala b/amf-shapes/shared/src/main/scala/amf/shapes/client/scala/model/domain/ShapeHelpers.scala index 9990df240c..04baf46714 100644 --- a/amf-shapes/shared/src/main/scala/amf/shapes/client/scala/model/domain/ShapeHelpers.scala +++ b/amf-shapes/shared/src/main/scala/amf/shapes/client/scala/model/domain/ShapeHelpers.scala @@ -2,7 +2,7 @@ package amf.shapes.client.scala.model.domain import amf.core.client.scala.errorhandling.AMFErrorHandler import amf.core.client.scala.model.domain.{Linkable, RecursiveShape, Shape} -import amf.core.client.scala.traversal.ModelTraversalRegistry +import amf.core.client.scala.traversal.{ModelTraversalRegistry, ShapeTraversalRegistry} import amf.core.internal.annotations.ExplicitField import amf.core.internal.validation.CoreValidations.RecursiveShapeSpecification import amf.shapes.internal.annotations.ParsedFromTypeExpression @@ -32,7 +32,7 @@ trait ShapeHelpers { this: Shape => def cloneShape(recursionErrorHandler: Option[AMFErrorHandler], withRecursionBase: Option[String] = None, - traversal: ModelTraversalRegistry = ModelTraversalRegistry(), + traversal: ShapeTraversalRegistry = ShapeTraversalRegistry(), cloneExamples: Boolean = false): this.type = { if (traversal.isInCurrentPath(this.id)) { buildFixPoint(withRecursionBase, this.name.value(), this, recursionErrorHandler).asInstanceOf[this.type] diff --git a/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/recursion/RecursionErrorRegister.scala b/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/recursion/RecursionErrorRegister.scala index e047d27816..54263c42db 100644 --- a/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/recursion/RecursionErrorRegister.scala +++ b/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/recursion/RecursionErrorRegister.scala @@ -2,7 +2,7 @@ package amf.shapes.internal.domain.resolution.recursion import amf.core.client.scala.errorhandling.AMFErrorHandler import amf.core.client.scala.model.domain.{RecursiveShape, Shape} -import amf.core.client.scala.traversal.ModelTraversalRegistry +import amf.core.client.scala.traversal.{ModelTraversalRegistry, ShapeTraversalRegistry} import amf.core.internal.validation.CoreValidations.RecursiveShapeSpecification import scala.collection.mutable.ListBuffer @@ -10,53 +10,59 @@ import scala.collection.mutable.ListBuffer class RecursionErrorRegister(errorHandler: AMFErrorHandler) { private val errorRegister = ListBuffer[String]() - private def buildRecursion(base: Option[String], s: Shape): RecursiveShape = { + def buildRecursion(base: Option[String], s: Shape): RecursiveShape = { val fixPointId = base.getOrElse(s.id) val r = RecursiveShape(s).withFixPoint(fixPointId) r } - def recursionAndError(root: Shape, - base: Option[String], - s: Shape, - traversal: ModelTraversalRegistry, - criteria: RegisterCriteria = DefaultRegisterCriteria()): RecursiveShape = { - val recursion = buildRecursion(base, s) - recursionError(root, recursion, traversal: ModelTraversalRegistry, Some(root.id), criteria) + def allowedInTraversal(traversal: ShapeTraversalRegistry, + r: RecursiveShape, + checkId: Option[String] = None): Boolean = { + val recursiveShapeIsAllowListed = traversal.isAllowListed(r.id) + val fixpointIsAllowListed = r.fixpoint.option().exists(traversal.isAllowListed) + /*** + * TODO (Refactor needed) + * When calling ShapeExpander `checkId` some times gets set to the root shape ID from where the traversal started. + * Why do we need to opiotnally check if this root id is allow listed? Doesn't it suffice with checking the + * recursive shape ID or its fixpoint? + */ + val checkIdIsAllowListed = checkId.exists(traversal.isAllowListed) + recursiveShapeIsAllowListed || fixpointIsAllowListed || checkIdIsAllowListed } - def recursionError(original: Shape, - r: RecursiveShape, - traversal: ModelTraversalRegistry, - checkId: Option[String] = None, - criteria: RegisterCriteria = DefaultRegisterCriteria()): RecursiveShape = { + def checkRecursionError(root: Shape, + r: RecursiveShape, + traversal: ShapeTraversalRegistry, + checkId: Option[String] = None, + criteria: ThrowRecursionValidationCriteria = DefaultCriteria()): RecursiveShape = { val hasNotRegisteredItYet = !errorRegister.contains(r.id) - if (criteria.decide(r) && !traversal.avoidError(r, checkId) && hasNotRegisteredItYet) { + if (criteria.shouldThrowFor(r) && !allowedInTraversal(traversal, r, checkId) && hasNotRegisteredItYet) { errorHandler.violation( RecursiveShapeSpecification, - original.id, + root.id, None, "Error recursive shape", - original.position(), - original.location() + root.position(), + root.location() ) errorRegister += r.id - } else if (traversal.avoidError(r, checkId)) r.withSupportsRecursion(true) + } else if (allowedInTraversal(traversal, r, checkId)) r.withSupportsRecursion(true) r } } -trait RegisterCriteria { - def decide(r: RecursiveShape): Boolean +trait ThrowRecursionValidationCriteria { + def shouldThrowFor(r: RecursiveShape): Boolean } -case class DefaultRegisterCriteria() extends RegisterCriteria { - override def decide(r: RecursiveShape): Boolean = !r.supportsRecursion.option().getOrElse(false) +case class DefaultCriteria() extends ThrowRecursionValidationCriteria { + override def shouldThrowFor(r: RecursiveShape): Boolean = !r.supportsRecursion.option().getOrElse(false) } -case class LinkableRegisterCriteria(root: Shape, linkable: Shape) extends RegisterCriteria { - override def decide(r: RecursiveShape): Boolean = linkable.linkTarget match { +case class LinkableCriteria(root: Shape, linkable: Shape) extends ThrowRecursionValidationCriteria { + override def shouldThrowFor(r: RecursiveShape): Boolean = linkable.linkTarget match { case Some(element) => element.id.equals(root.id) case None => false } diff --git a/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/shape_normalization/ShapeExpander.scala b/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/shape_normalization/ShapeExpander.scala index decde1ade2..83d98c29cb 100644 --- a/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/shape_normalization/ShapeExpander.scala +++ b/amf-shapes/shared/src/main/scala/amf/shapes/internal/domain/resolution/shape_normalization/ShapeExpander.scala @@ -2,27 +2,15 @@ package amf.shapes.internal.domain.resolution.shape_normalization import amf.core.client.scala.model.domain._ import amf.core.client.scala.model.domain.extensions.PropertyShape -import amf.core.client.scala.traversal.ModelTraversalRegistry +import amf.core.client.scala.traversal.ShapeTraversalRegistry import amf.core.internal.annotations.ExplicitField import amf.core.internal.metamodel.domain.ShapeModel import amf.core.internal.metamodel.domain.extensions.PropertyShapeModel import amf.core.internal.parser.domain.Annotations import amf.core.internal.validation.CoreValidations.TransformationValidation +import amf.shapes.client.scala.model.domain._ import amf.shapes.internal.domain.metamodel._ -import amf.shapes.internal.domain.resolution.recursion.{LinkableRegisterCriteria, RecursionErrorRegister} -import amf.shapes.client.scala.model.domain.UnresolvedShape -import amf.shapes.client.scala.model.domain.{ - AnyShape, - ArrayShape, - FileShape, - MatrixShape, - NilShape, - NodeShape, - ScalarShape, - TupleShape, - UnionShape, - UnresolvedShape -} +import amf.shapes.internal.domain.resolution.recursion.{LinkableCriteria, RecursionErrorRegister} private[resolution] object ShapeExpander { def apply(s: Shape, context: NormalizationContext, recursionRegister: RecursionErrorRegister): Shape = @@ -35,8 +23,8 @@ sealed case class ShapeExpander(root: Shape, recursionRegister: RecursionErrorRe def normalize(): Shape = normalize(root) - protected val traversal: ModelTraversalRegistry = - ModelTraversalRegistry().withAllowedCyclesInstances(Seq(classOf[UnresolvedShape])) + protected val traversal: ShapeTraversalRegistry = + ShapeTraversalRegistry().withAllowedCyclesInstances(Seq(classOf[UnresolvedShape])) protected def ensureHasId(shape: Shape): Unit = { if (Option(shape.id).isEmpty) { @@ -54,14 +42,25 @@ sealed case class ShapeExpander(root: Shape, recursionRegister: RecursionErrorRe override def normalizeAction(shape: Shape): Shape = { shape match { case l: Linkable if l.isLink => - recursionRegister.recursionAndError(root, - Some(root.id), - shape, - traversal, - LinkableRegisterCriteria(root, shape)) - - case _ if traversal.shouldFailIfRecursive(root, shape) && !shape.isInstanceOf[RecursiveShape] => - recursionRegister.recursionAndError(root, None, shape, traversal) + /*** + * TODO: (Refactor needed) + * Why do we create a recursive shape when we find a linkable? Shouldn't this be subject only to traversals? + * The motivation is not explicit in the code. There is for sure some corner case where this case is needed. + * After finding the cocrete case please extract this to a function and make explicit the conditions where + * this is needed, otherwise delete this code. + */ + val recursiveShape = recursionRegister.buildRecursion(Some(root.id), shape) + recursionRegister.checkRecursionError(root, + recursiveShape, + traversal, + Some(root.id), + LinkableCriteria(root, shape)) + recursiveShape + + case _ if traversal.foundRecursion(root, shape) && !shape.isInstanceOf[RecursiveShape] => + val recursiveShape = recursionRegister.buildRecursion(None, shape) + recursionRegister.checkRecursionError(root, recursiveShape, traversal, Some(root.id)) + recursiveShape case _ if traversal.wasVisited(shape.id) => shape @@ -79,7 +78,7 @@ sealed case class ShapeExpander(root: Shape, recursionRegister: RecursionErrorRe case fileShape: FileShape => expandAny(fileShape) case nil: NilShape => nil case node: NodeShape => expandNode(node) - case recursive: RecursiveShape => recursionRegister.recursionError(recursive, recursive, traversal) + case recursive: RecursiveShape => recursionRegister.checkRecursionError(recursive, recursive, traversal) case any: AnyShape => expandAny(any) } }) @@ -92,7 +91,7 @@ sealed case class ShapeExpander(root: Shape, recursionRegister: RecursionErrorRe // in this case i use the father shape id and position, because the inheritance could be a recursive shape already val newInherits = shape.inherits.map { case r: RecursiveShape if r.fixpoint.option().exists(_.equals(shape.id)) => - recursionRegister.recursionError(shape, r, traversal) // direct recursion + recursionRegister.checkRecursionError(shape, r, traversal) // direct recursion case r: RecursiveShape => r case parent => @@ -152,7 +151,7 @@ sealed case class ShapeExpander(root: Shape, recursionRegister: RecursionErrorRe if (mandatory) array.inherits.collect({ case arr: ArrayShape if arr.items.isInstanceOf[RecursiveShape] => arr }).foreach { f => val r = f.items.asInstanceOf[RecursiveShape] - recursionRegister.recursionError(array, r, traversal, Some(array.id)) + recursionRegister.checkRecursionError(array, r, traversal, Some(array.id)) } if (Option(oldItems).isDefined) { val newItems = if (mandatory) { @@ -250,9 +249,9 @@ sealed case class ShapeExpander(root: Shape, recursionRegister: RecursionErrorRe private def traverseOptionalShapeFacet(shape: Shape, from: Shape) = shape match { case _ if shape.inherits.nonEmpty => - traversal.runWithIgnoredIds(() => normalize(shape), shape.inherits.map(_.id).toSet + root.id) + traversal.allow(shape.inherits.map(_.id).toSet + root.id)(() => normalize(shape)) case _: RecursiveShape => shape - case _ => traversal.recursionAllowed(() => normalize(shape), from.id) + case _ => traversal.allow(traversal.currentPath + from.id)(() => normalize(shape)) } protected def expandUnion(union: UnionShape): Shape = { @@ -260,7 +259,7 @@ sealed case class ShapeExpander(root: Shape, recursionRegister: RecursionErrorRe val oldAnyOf = union.fields.getValue(UnionShapeModel.AnyOf) if (Option(oldAnyOf).isDefined) { val newAnyOf = union.anyOf.map { u => - val unionMember = traversal.recursionAllowed(() => recursiveNormalization(u), u.id) + val unionMember = traversal.allow(traversal.currentPath + u.id)(() => recursiveNormalization(u)) unionMember } union.setArrayWithoutId(UnionShapeModel.AnyOf, newAnyOf, oldAnyOf.annotations)