internal/core/eval: recursive disjunctions

Change-Id: I4d3c3951ef11f2e74ee9d31d61c7b4489d637e13
Reviewed-on: https://cue-review.googlesource.com/c/cue/+/8043
Reviewed-by: Marcel van Lohuizen <mpvl@golang.org>
diff --git a/internal/core/eval/disjunct.go b/internal/core/eval/disjunct.go
index 0fb57b1..b02d425 100644
--- a/internal/core/eval/disjunct.go
+++ b/internal/core/eval/disjunct.go
@@ -128,7 +128,7 @@
 		envDisjunct{env, a, x.NumDefaults, cloneID})
 }
 
-func (n *nodeContext) updateResult(state adt.VertexStatus) (isFinal bool) {
+func (n *nodeContext) updateResult(state adt.VertexStatus) {
 	n.postDisjunct(state)
 
 	if n.hasErr() {
@@ -146,7 +146,7 @@
 		if err != nil {
 			n.disjunctErrs = append(n.disjunctErrs, err)
 		}
-		return n.isFinal
+		return
 	}
 
 	n.touched = true
@@ -159,7 +159,7 @@
 
 	for _, v := range d.Values {
 		if adt.Equal(n.ctx, v, &result) {
-			return isFinal
+			return
 		}
 	}
 
@@ -179,34 +179,28 @@
 		d.NumDefaults = j
 	}
 
-	// return n.isFinal
-
 	switch {
 	case !n.nodeShared.hasResult():
 
 	case n.nodeShared.isDefault() && n.defaultMode != isDefault:
-		return n.isFinal
+		return
 
 	case !n.nodeShared.isDefault() && n.defaultMode == isDefault:
 
 	default:
-		return n.isFinal // n.defaultMode == isDefault
+		return // n.defaultMode == isDefault
 	}
 
 	n.nodeShared.setResult(n.node)
 
-	return n.isFinal
+	return
 }
 
-func (n *nodeContext) tryDisjuncts(state adt.VertexStatus) (finished bool) {
-	if !n.insertDisjuncts() || !n.updateResult(state) {
-		if !n.isFinal {
-			return false // More iterations to do.
-		}
-	}
+func (n *nodeContext) processDisjuncts(state adt.VertexStatus) {
+	n.processDisjunct(state, 0, len(n.disjunctions))
 
 	if n.nodeShared.hasResult() {
-		return true // found something
+		return // found something
 	}
 
 	if len(n.disjunctions) > 0 {
@@ -227,103 +221,123 @@
 		}
 		n.node.SetValue(n.ctx, adt.Finalized, b)
 	}
-	return true
 }
 
-// TODO: add proper conjuncts for the ones used by the disjunctions to replace
-// the original source.
-//
-func (n *nodeContext) insertDisjuncts() (inserted bool) {
-	p := 0
-	inserted = true
-
-	n.subDisjunctions = n.subDisjunctions[:0]
-
-	// fmt.Println("----", debug.NodeString(n.ctx, n.node, nil))
-	for _, d := range n.disjunctions {
-		n.subDisjunctions = append(n.subDisjunctions, d)
-
-		sub := len(n.disjunctions)
-		defMode, ok := n.insertSingleDisjunct(p, d, false)
-		p++
-		if !ok {
-			inserted = false
-			break
-		}
-
-		subMode := maybeDefault
-		for ; sub < len(n.disjunctions); sub++ {
-			d := n.disjunctions[sub]
-
-			// TODO: HACK ALERT: we ignore the default tags of the subexpression
-			// if we already have a scalar value and can no longer change the
-			// outcome.
-			// This is not conform the spec, but mimics the old implementation.
-			// It also results in nicer default semantics. Changing this will
-			// break existing CUE code in awkward ways.
-			// We probably should address this when we figure out how to change
-			// the spec to accommodate for this. For instance, we could say
-			// that if a disjunction only contributes a single disjunct to an
-			// end result, default information is ignored. Not the greatest
-			// definition, though.
-			// Another alternative might be to have a special builtin that
-			// mimics the good behavior.
-			// Note that the same result can be obtained in CUE by adding
-			// 0 to a referenced number (forces the default to be discarded).
-			wasScalar := n.scalar != nil // Hack line 1
-
-			n.subDisjunctions = append(n.subDisjunctions, d)
-			mode, ok := n.insertSingleDisjunct(p, d, true)
-			p++
-			if !ok {
-				inserted = false
-				break
-			}
-
-			if !wasScalar { // Hack line 2.
-				subMode = combineDefault(subMode, mode)
-			}
-		}
-		defMode = combineDefault(defMode, subMode)
-
-		n.defaultMode = combineDefault(n.defaultMode, defMode)
-	}
-
-	// Find last disjunction at which there is no overflow.
-	for ; p > 0 && n.stack[p-1]+1 >= len(n.subDisjunctions[p-1].values); p-- {
-	}
-	if p > 0 {
-		// Increment a valid position and set all subsequent entries to 0.
-		n.stack[p-1]++
-		n.stack = n.stack[:p]
-	}
-	return inserted
-}
-
-func (n *nodeContext) insertSingleDisjunct(p int, d envDisjunct, isSub bool) (mode defaultMode, ok bool) {
-	if p >= len(n.stack) {
-		n.stack = append(n.stack, 0)
-	}
-
-	k := n.stack[p]
-	v := d.values[k]
-	n.isFinal = n.isFinal && k == len(d.values)-1
-	c := adt.MakeConjunct(d.env, v.expr, d.cloneID)
-	n.addExprConjunct(c)
-
-	for n.expandOne() {
-	}
-
+// TODO: move state to nodeShared.
+func (n *nodeContext) processDisjunct(state adt.VertexStatus, k, sub int) {
+	isSub := false
+	var d envDisjunct
 	switch {
-	case d.numDefaults == 0:
-		mode = maybeDefault
-	case v.isDefault:
-		mode = isDefault
+	case sub < len(n.disjunctions):
+		d = n.disjunctions[sub]
+		sub++
+		isSub = true
+
+	case k < len(n.disjunctions):
+		d = n.disjunctions[k]
+		k++
+
 	default:
-		mode = notDefault
+		n.updateResult(state)
+		return
 	}
 
-	return mode, !n.hasErr()
+	// save current state of node and nodeContext
+	nSaved := snapshotVertex(n.node)
+	saved := *n
+
+	for i, v := range d.values {
+		n.eval.stats.DisjunctCount++
+
+		if i > 0 {
+			*n = saved
+			*(n.node) = nSaved
+			// restore state
+		}
+
+		// TODO: HACK ALERT: we ignore the default tags of the subexpression
+		// if we already have a scalar value and can no longer change the
+		// outcome.
+		// This is not conform the spec, but mimics the old implementation.
+		// It also results in nicer default semantics. Changing this will
+		// break existing CUE code in awkward ways.
+		// We probably should address this when we figure out how to change
+		// the spec to accommodate for this. For instance, we could say
+		// that if a disjunction only contributes a single disjunct to an
+		// end result, default information is ignored. Not the greatest
+		// definition, though.
+		// Another alternative might be to have a special builtin that
+		// mimics the good behavior.
+		// Note that the same result can be obtained in CUE by adding
+		// 0 to a referenced number (forces the default to be discarded).
+		wasScalar := n.scalar != nil // Hack line 1
+
+		c := adt.MakeConjunct(d.env, v.expr, d.cloneID)
+		n.addExprConjunct(c)
+
+		for n.expandOne() {
+		}
+
+		if n.hasErr() {
+			continue
+		}
+
+		var mode defaultMode
+		switch {
+		case d.numDefaults == 0:
+			mode = maybeDefault
+		case v.isDefault:
+			mode = isDefault
+		default:
+			mode = notDefault
+		}
+
+		if isSub {
+			if !wasScalar { // Hack line 2.
+				n.subMode = combineDefault(n.subMode, mode)
+			}
+		} else if sub == len(n.disjunctions) {
+			n.defaultMode = combineDefault(n.defaultMode, n.subMode)
+			n.defaultMode = combineDefault(n.defaultMode, mode)
+			n.subMode = maybeDefault
+		}
+
+		n.processDisjunct(state, k, sub)
+	}
+}
+
+// Clone makes a shallow copy of a Vertex. The purpose is to create different
+// disjuncts from the same Vertex under computation. This allows the conjuncts
+// of an arc to be reset to a previous position and the reuse of earlier
+// computations.
+//
+// Notes: only Arcs need to be cloned recursively. Structs is assumed to not yet
+// be computed at the time that a Clone is needed and must be nil. Conjuncts no
+// longer needed and can become nil. All other fields can be copied shallowly.
+//
+// USE TO SAVE NODE BRANCH FOR DISJUNCTION, BUT BEFORE POSTDIJSUNCT.
+func snapshotVertex(v *adt.Vertex) adt.Vertex {
+	c := *v
+
+	if len(v.Arcs) > 0 {
+		c.Arcs = make([]*adt.Vertex, len(v.Arcs))
+		for i, arc := range v.Arcs {
+			// For child arcs, only Conjuncts are set and Arcs and
+			// Structs will be nil.
+			a := *arc
+			c.Arcs[i] = &a
+
+			a.Conjuncts = make([]adt.Conjunct, len(arc.Conjuncts))
+			copy(a.Conjuncts, arc.Conjuncts)
+		}
+	}
+
+	if len(v.Structs) > 0 {
+		c.Structs = make([]*adt.StructInfo, len(v.Structs))
+		copy(c.Structs, v.Structs)
+	}
+
+	return c
 }
 
 // Default rules from spec:
diff --git a/internal/core/eval/eval.go b/internal/core/eval/eval.go
index 495fa22..74ce4d5 100644
--- a/internal/core/eval/eval.go
+++ b/internal/core/eval/eval.go
@@ -336,7 +336,6 @@
 			v.Closed = true
 		}
 	}
-	saved := *v
 
 	if !v.Label.IsInt() && v.Parent != nil && !ignore {
 		// Visit arcs recursively to validate and compute error.
@@ -353,74 +352,66 @@
 	defer c.PopArc(c.PushArc(v))
 
 	e.stats.UnifyCount++
-	for i := 0; ; i++ {
-		e.stats.DisjunctCount++
 
-		// Clear any remaining error.
-		if err := c.Err(); err != nil {
-			panic("uncaught error")
-		}
-
-		// Set the cache to a cycle error to ensure a cyclic reference will result
-		// in an error if applicable. A cyclic error may be ignored for
-		// non-expression references. The cycle error may also be removed as soon
-		// as there is evidence what a correct value must be, but before all
-		// validation has taken place.
-		*v = saved
-		v.BaseValue = cycle
-
-		v.UpdateStatus(adt.Evaluating)
-
-		// If the result is a struct, it needs to be closed if:
-		//   1) this node introduces a definition
-		//   2) this node is a child of a node that introduces a definition,
-		//      recursively.
-		//   3) this node embeds a closed struct.
-		n := e.newNodeContext(shared)
-
-		for _, x := range v.Conjuncts {
-			// TODO: needed for reentrancy. Investigate usefulness for cycle
-			// detection.
-			n.addExprConjunct(x)
-		}
-
-		if i == 0 {
-			// Use maybeSetCache for cycle breaking
-			for n.maybeSetCache(); n.expandOne(); n.maybeSetCache() {
-			}
-			if v.Status() > adt.Evaluating && state <= adt.Partial {
-				// We have found a partial result. There may still be errors
-				// down the line which may result from further evaluating this
-				// field, but that will be caught when evaluating this field
-				// for real.
-				shared.setResult(v)
-				e.freeNodeContext(n)
-				return shared
-			}
-			if !n.done() && len(n.disjunctions) > 0 && isEvaluating(v) {
-				// We disallow entering computations of disjunctions with
-				// incomplete data.
-				b := c.NewErrf("incomplete cause disjunction")
-				b.Code = adt.IncompleteError
-				v.SetValue(n.ctx, adt.Finalized, b)
-				shared.setResult(v)
-				e.freeNodeContext(n)
-				return shared
-			}
-		}
-
-		// Handle disjunctions. If there are no disjunctions, this call is
-		// equivalent to calling n.postDisjunct.
-		if n.tryDisjuncts(state) {
-			if v.BaseValue == nil {
-				v.BaseValue = n.getValidators()
-			}
-
-			e.freeNodeContext(n)
-			break
-		}
+	// Clear any remaining error.
+	if err := c.Err(); err != nil {
+		panic("uncaught error")
 	}
 
+	// Set the cache to a cycle error to ensure a cyclic reference will result
+	// in an error if applicable. A cyclic error may be ignored for
+	// non-expression references. The cycle error may also be removed as soon
+	// as there is evidence what a correct value must be, but before all
+	// validation has taken place.
+	v.BaseValue = cycle
+
+	v.UpdateStatus(adt.Evaluating)
+
+	// If the result is a struct, it needs to be closed if:
+	//   1) this node introduces a definition
+	//   2) this node is a child of a node that introduces a definition,
+	//      recursively.
+	//   3) this node embeds a closed struct.
+	n := e.newNodeContext(shared)
+
+	for _, x := range v.Conjuncts {
+		// TODO: needed for reentrancy. Investigate usefulness for cycle
+		// detection.
+		n.addExprConjunct(x)
+	}
+
+	// Use maybeSetCache for cycle breaking
+	for n.maybeSetCache(); n.expandOne(); n.maybeSetCache() {
+	}
+	if v.Status() > adt.Evaluating && state <= adt.Partial {
+		// We have found a partial result. There may still be errors
+		// down the line which may result from further evaluating this
+		// field, but that will be caught when evaluating this field
+		// for real.
+		shared.setResult(v)
+		e.freeNodeContext(n)
+		return shared
+	}
+	if !n.done() && len(n.disjunctions) > 0 && isEvaluating(v) {
+		// We disallow entering computations of disjunctions with
+		// incomplete data.
+		b := c.NewErrf("incomplete cause disjunction")
+		b.Code = adt.IncompleteError
+		v.SetValue(n.ctx, adt.Finalized, b)
+		shared.setResult(v)
+		e.freeNodeContext(n)
+		return shared
+	}
+
+	n.processDisjuncts(state)
+
+	// Handle disjunctions. If there are no disjunctions, this call is
+	// equivalent to calling n.postDisjunct.
+	if v.BaseValue == nil {
+		v.BaseValue = n.getValidators()
+	}
+
+	e.freeNodeContext(n)
 	return shared
 }
 
@@ -639,7 +630,6 @@
 
 	result_ adt.Vertex
 	isDone  bool
-	stack   []int
 }
 
 func (e *Evaluator) newSharedNode(ctx *adt.OpContext, node *adt.Vertex) *nodeShared {
@@ -652,7 +642,6 @@
 			ctx:  ctx,
 			node: node,
 
-			stack:        n.stack[:0],
 			disjunct:     adt.Disjunction{Values: n.disjunct.Values[:0]},
 			disjunctErrs: n.disjunctErrs[:0],
 		}
@@ -758,10 +747,9 @@
 	hasNonCycle bool // has conjunct without structural cycle
 
 	// Disjunction handling
-	disjunctions    []envDisjunct
-	subDisjunctions []envDisjunct
-	defaultMode     defaultMode
-	isFinal         bool
+	disjunctions []envDisjunct
+	defaultMode  defaultMode
+	subMode      defaultMode
 }
 
 func (e *Evaluator) newNodeContext(shared *nodeShared) *nodeContext {
@@ -770,20 +758,17 @@
 		e.freeListNode = n.nextFree
 
 		*n = nodeContext{
-			kind:       adt.TopKind,
-			nodeShared: shared,
-			isFinal:    true,
-
-			arcMap:          n.arcMap[:0],
-			checks:          n.checks[:0],
-			dynamicFields:   n.dynamicFields[:0],
-			ifClauses:       n.ifClauses[:0],
-			forClauses:      n.forClauses[:0],
-			lists:           n.lists[:0],
-			vLists:          n.vLists[:0],
-			exprs:           n.exprs[:0],
-			disjunctions:    n.disjunctions[:0],
-			subDisjunctions: n.subDisjunctions[:0],
+			kind:          adt.TopKind,
+			nodeShared:    shared,
+			arcMap:        n.arcMap[:0],
+			checks:        n.checks[:0],
+			dynamicFields: n.dynamicFields[:0],
+			ifClauses:     n.ifClauses[:0],
+			forClauses:    n.forClauses[:0],
+			lists:         n.lists[:0],
+			vLists:        n.vLists[:0],
+			exprs:         n.exprs[:0],
+			disjunctions:  n.disjunctions[:0],
 		}
 
 		return n
@@ -793,9 +778,6 @@
 	return &nodeContext{
 		kind:       adt.TopKind,
 		nodeShared: shared,
-
-		// These get cleared upon proof to the contrary.
-		isFinal: true,
 	}
 }