internal/core/adt: "fix" disjunction resolution

The recent performance improvement eliminated disjuncts
too aggressively, causing some disjuncts to disappear.
Early elimination prevents exponential blowup, but of course
it should not make it incorrect.

Note that for most common uses, like protobuf and K8s
structs with discriminator fields, special optimizations exist
that make this linear. But this has not been implemented yet.

Now this fix exposed another bug. This bug was deliberate to
work around a limitation of default values when representing
oneOf fields. We planned to introduce a "required field" annotation
which would also allow to represent these kind of semantics
much more elegantly.

For now, though, we are stuck in a position that more than one
oneOf field cannot be represented properly with defaults.
We fix it for now in an even more AWFUL way, and instead work
towards a proper solution based on required fields.

Issue #726

Change-Id: Id4bcd0445612e12fba48a744654de4e1852c552e
Reviewed-on: https://cue-review.googlesource.com/c/cue/+/8641
Reviewed-by: Paul Jolly <paul@myitcv.org.uk>
Reviewed-by: Marcel van Lohuizen <mpvl@golang.org>
diff --git a/cue/testdata/disjunctions/elimination.txtar b/cue/testdata/disjunctions/elimination.txtar
index b33f099..60bc35d 100644
--- a/cue/testdata/disjunctions/elimination.txtar
+++ b/cue/testdata/disjunctions/elimination.txtar
@@ -142,6 +142,46 @@
     x: { a: b: struct.MinFields(3) } | null
 }
 
+preserveClosedness: small: p1: {
+    #A: #B & {a: string}
+    #B: {
+      *{} | {a: string}
+      *{} | {b: int}
+    }
+}
+
+preserveClosedness: small: p2: {
+    #A: #B & {a: string}
+    #B: {
+      {a: string} | *{}
+      *{} | {b: int}
+    }
+}
+
+preserveClosedness: medium: p1: {
+    #A: #B & {a: string}
+    #B: {
+      *{} | {a: string} | {b: string}
+      *{} | {c: int} | {d: string}
+    }
+}
+
+preserveClosedness: medium: p2: {
+    #A: #B & {a: string}
+    #B: {
+      {a: string} | *{} | {b: string}
+      *{} | {c: int} | {d: string}
+    }
+}
+
+preserveClosedness: medium: p3: {
+    #A: #B & {a: string}
+    #B: {
+      {a: string} | {b: string} | *{}
+      *{} | {c: int} | {d: string}
+    }
+}
+
 -- out/eval --
 (struct){
   disambiguateClosed: (struct){
@@ -321,6 +361,145 @@
       }
     }
   }
+  preserveClosedness: (struct){
+    small: (struct){
+      p1: (struct){
+        #A: (struct){ |(*(#struct){
+            a: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+            b: (int){ int }
+          }) }
+        #B: (struct){ |(*(#struct){
+          }, (#struct){
+            b: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+            b: (int){ int }
+          }) }
+      }
+      p2: (struct){
+        #A: (struct){ |(*(#struct){
+            a: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+            b: (int){ int }
+          }) }
+        #B: (struct){ |(*(#struct){
+          }, (#struct){
+            a: (string){ string }
+            b: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+          }, (#struct){
+            b: (int){ int }
+          }) }
+      }
+    }
+    medium: (struct){
+      p1: (struct){
+        #A: (struct){ |(*(#struct){
+            a: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+            d: (string){ string }
+          }) }
+        #B: (struct){ |(*(#struct){
+          }, (#struct){
+            c: (int){ int }
+          }, (#struct){
+            d: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+            d: (string){ string }
+          }, (#struct){
+            b: (string){ string }
+          }, (#struct){
+            b: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            b: (string){ string }
+            d: (string){ string }
+          }) }
+      }
+      p2: (struct){
+        #A: (struct){ |(*(#struct){
+            a: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+            d: (string){ string }
+          }) }
+        #B: (struct){ |(*(#struct){
+          }, (#struct){
+            a: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+            d: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+          }, (#struct){
+            c: (int){ int }
+          }, (#struct){
+            d: (string){ string }
+          }, (#struct){
+            b: (string){ string }
+          }, (#struct){
+            b: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            b: (string){ string }
+            d: (string){ string }
+          }) }
+      }
+      p3: (struct){
+        #A: (struct){ |(*(#struct){
+            a: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+            d: (string){ string }
+          }) }
+        #B: (struct){ |(*(#struct){
+          }, (#struct){
+            a: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            a: (string){ string }
+            d: (string){ string }
+          }, (#struct){
+            b: (string){ string }
+          }, (#struct){
+            b: (string){ string }
+            c: (int){ int }
+          }, (#struct){
+            b: (string){ string }
+            d: (string){ string }
+          }, (#struct){
+            a: (string){ string }
+          }, (#struct){
+            c: (int){ int }
+          }, (#struct){
+            d: (string){ string }
+          }) }
+      }
+    }
+  }
 }
 -- out/compile --
 --- in.cue
@@ -785,4 +964,101 @@
       }
     }
   }
+  preserveClosedness: {
+    small: {
+      p1: {
+        #A: (〈0;#B〉 & {
+          a: string
+        })
+        #B: {
+          (*{}|{
+            a: string
+          })
+          (*{}|{
+            b: int
+          })
+        }
+      }
+    }
+  }
+  preserveClosedness: {
+    small: {
+      p2: {
+        #A: (〈0;#B〉 & {
+          a: string
+        })
+        #B: {
+          ({
+            a: string
+          }|*{})
+          (*{}|{
+            b: int
+          })
+        }
+      }
+    }
+  }
+  preserveClosedness: {
+    medium: {
+      p1: {
+        #A: (〈0;#B〉 & {
+          a: string
+        })
+        #B: {
+          (*{}|{
+            a: string
+          }|{
+            b: string
+          })
+          (*{}|{
+            c: int
+          }|{
+            d: string
+          })
+        }
+      }
+    }
+  }
+  preserveClosedness: {
+    medium: {
+      p2: {
+        #A: (〈0;#B〉 & {
+          a: string
+        })
+        #B: {
+          ({
+            a: string
+          }|*{}|{
+            b: string
+          })
+          (*{}|{
+            c: int
+          }|{
+            d: string
+          })
+        }
+      }
+    }
+  }
+  preserveClosedness: {
+    medium: {
+      p3: {
+        #A: (〈0;#B〉 & {
+          a: string
+        })
+        #B: {
+          ({
+            a: string
+          }|{
+            b: string
+          }|*{})
+          (*{}|{
+            c: int
+          }|{
+            d: string
+          })
+        }
+      }
+    }
+  }
 }
diff --git a/cue/types.go b/cue/types.go
index e140bee..b85cfbf 100644
--- a/cue/types.go
+++ b/cue/types.go
@@ -1831,7 +1831,7 @@
 	if v.v == nil || other.v == nil {
 		return false
 	}
-	return adt.Equal(v.ctx().opCtx, v.v, other.v)
+	return adt.Equal(v.ctx().opCtx, v.v, other.v, false)
 }
 
 // Format prints a debug version of a value.
diff --git a/internal/core/adt/disjunct.go b/internal/core/adt/disjunct.go
index e7db3cd..bb31c03 100644
--- a/internal/core/adt/disjunct.go
+++ b/internal/core/adt/disjunct.go
@@ -187,6 +187,34 @@
 
 		n.disjuncts = append(n.disjuncts, n)
 
+		// HACK: this is an AWFUL, AWFUL HACK to work around a limitation of
+		// using defaults for marking oneOfs in protobuffers. Previously this
+		// was worked around by another hack that (deliberately erroneously)
+		// would move the default status of a disjunct of which only one
+		// disjunct remained from "not default" to "maybe default". For
+		// protobuf oneOfs this would mean that only the "intended" struct would
+		// get the default value. It also worked around various other
+		// limitations.
+		//
+		// With the latest performance enhancements this old hack still worked,
+		// but only because it introduced a bug that this hack relied on. Fixing
+		// this bug now causes this hack to no longer work.
+		//
+		// Ultimately, the correct way to address the issue for the protobuf
+		// representation is not to use defaults at all. Instead we should use
+		// the required annotator (which fixes a whole load of other issues):
+		//
+		//    {} | {a!: int} | {b!: int}
+		//
+		// This would force that only one of these can be true independent of
+		// default magic. Aside from fixing this issue, it also moves to a model
+		// that is consistent with the recommendation to not use defaults in
+		// a top-level API specification.
+		//
+		// The hack we use now recognizes the oneOf patterns and then sets the
+		// default for the "smallest" element.
+		protoForm := true
+
 		for i, d := range n.disjunctions {
 			a := n.disjuncts
 			n.disjuncts = n.buffer[:0]
@@ -197,6 +225,23 @@
 				n.ctx.inDisjunct++
 			}
 
+			// HACK: see above
+			defaultCount := 0
+			override := true
+			if d.expr == nil {
+				override = false
+			} else {
+				for _, v := range d.expr.Values {
+					if !d.hasDefaults || v.Default {
+						defaultCount++
+						if s, ok := v.Val.(*StructLit); !ok || len(s.Decls) > 0 {
+							protoForm = false
+							override = false
+						}
+					}
+				}
+			}
+
 			for _, dn := range a {
 				switch {
 				case d.expr != nil:
@@ -208,6 +253,12 @@
 						c := MakeConjunct(d.env, v.Val, d.cloneID)
 						cn.addExprConjunct(c)
 
+						if override {
+							if s, ok := v.Val.(*StructLit); ok && len(s.Decls) == 0 {
+								cn.protoCount++
+							}
+						}
+
 						newMode := mode(d.hasDefaults, v.Default)
 						cn.defaultMode = combineDefault(dn.defaultMode, newMode)
 
@@ -249,6 +300,25 @@
 			}
 		}
 
+		// HACK: see above
+		if protoForm {
+			min := int32(0)
+			minPos := 0
+			for i, d := range n.disjuncts {
+				if d.defaultMode == isDefault {
+					min = 0
+					break
+				}
+				if d.protoCount > min {
+					min = d.protoCount
+					minPos = i
+				}
+			}
+			if min > 0 {
+				n.disjuncts[minPos].defaultMode = isDefault
+			}
+		}
+
 		// HACK alert: this replaces the hack of the previous algorithm with a
 		// slightly less worse hack: instead of dropping the default info when
 		// the value was scalar before, we drop this information when there
@@ -271,7 +341,7 @@
 	outer:
 		for _, d := range n.disjuncts {
 			for _, v := range p.disjuncts {
-				if Equal(n.ctx, &v.result, &d.result) {
+				if Equal(n.ctx, &v.result, &d.result, true) {
 					if d.defaultMode == isDefault {
 						v.defaultMode = isDefault
 					}
diff --git a/internal/core/adt/equality.go b/internal/core/adt/equality.go
index ca7b41f..de57df9 100644
--- a/internal/core/adt/equality.go
+++ b/internal/core/adt/equality.go
@@ -14,17 +14,17 @@
 
 package adt
 
-func Equal(ctx *OpContext, v, w Value) bool {
+func Equal(ctx *OpContext, v, w Value, optional bool) bool {
 	if x, ok := v.(*Vertex); ok {
-		return equalVertex(ctx, x, w)
+		return equalVertex(ctx, x, w, optional)
 	}
 	if y, ok := w.(*Vertex); ok {
-		return equalVertex(ctx, y, v)
+		return equalVertex(ctx, y, v, optional)
 	}
-	return equalTerminal(ctx, v, w)
+	return equalTerminal(ctx, v, w, optional)
 }
 
-func equalVertex(ctx *OpContext, x *Vertex, v Value) bool {
+func equalVertex(ctx *OpContext, x *Vertex, v Value, opt bool) bool {
 	y, ok := v.(*Vertex)
 	if !ok {
 		return false
@@ -47,7 +47,7 @@
 	if x.IsClosed(ctx) != y.IsClosed(ctx) {
 		return false
 	}
-	if !equalOptional(ctx, x, y) {
+	if opt && !equalClosed(ctx, x, y) {
 		return false
 	}
 
@@ -55,7 +55,7 @@
 	for _, a := range x.Arcs {
 		for _, b := range y.Arcs {
 			if a.Label == b.Label {
-				if !Equal(ctx, a, b) {
+				if !Equal(ctx, a, b, opt) {
 					return false
 				}
 				continue loop1
@@ -81,12 +81,10 @@
 		return true // both are struct or list.
 	}
 
-	return equalTerminal(ctx, v, w)
+	return equalTerminal(ctx, v, w, opt)
 }
 
-// equalOptional tests if x and y have the same set of close information.
-// Right now this just checks if it has the same source structs that
-// define optional fields.
+// equalClosed tests if x and y have the same set of close information.
 // TODO: the following refinements are possible:
 // - unify optional fields and equate the optional fields
 // - do the same for pattern constraints, where the pattern constraints
@@ -95,14 +93,14 @@
 //
 // For all these refinements it would be necessary to have well-working
 // structure sharing so as to not repeatedly recompute optional arcs.
-func equalOptional(ctx *OpContext, x, y *Vertex) bool {
+func equalClosed(ctx *OpContext, x, y *Vertex) bool {
 	return verifyStructs(x, y) && verifyStructs(y, x)
 }
 
 func verifyStructs(x, y *Vertex) bool {
 outer:
 	for _, s := range x.Structs {
-		if !s.StructLit.HasOptional() {
+		if s.closeInfo == nil || s.closeInfo.span|DefinitionSpan == 0 {
 			continue
 		}
 		for _, t := range y.Structs {
@@ -115,7 +113,7 @@
 	return true
 }
 
-func equalTerminal(ctx *OpContext, v, w Value) bool {
+func equalTerminal(ctx *OpContext, v, w Value, opt bool) bool {
 	if v == w {
 		return true
 	}
@@ -132,7 +130,7 @@
 
 	case *BoundValue:
 		if y, ok := w.(*BoundValue); ok {
-			return x.Op == y.Op && Equal(ctx, x.Value, y.Value)
+			return x.Op == y.Op && Equal(ctx, x.Value, y.Value, opt)
 		}
 
 	case *BasicType:
@@ -147,7 +145,7 @@
 		}
 		// always ordered the same
 		for i, xe := range x.Values {
-			if !Equal(ctx, xe, y.Values[i]) {
+			if !Equal(ctx, xe, y.Values[i], opt) {
 				return false
 			}
 		}
@@ -161,7 +159,7 @@
 			return false
 		}
 		for i, xe := range x.Values {
-			if !Equal(ctx, xe, y.Values[i]) {
+			if !Equal(ctx, xe, y.Values[i], opt) {
 				return false
 			}
 		}
diff --git a/internal/core/adt/eval.go b/internal/core/adt/eval.go
index 38c0c94..487b405 100644
--- a/internal/core/adt/eval.go
+++ b/internal/core/adt/eval.go
@@ -772,6 +772,7 @@
 	hasTop      bool
 	hasCycle    bool // has conjunct with structural cycle
 	hasNonCycle bool // has conjunct without structural cycle
+	protoCount  int32
 
 	// Disjunction handling
 	disjunctions []envDisjunct
diff --git a/internal/core/adt/simplify.go b/internal/core/adt/simplify.go
index c58ad47..dea3066 100644
--- a/internal/core/adt/simplify.go
+++ b/internal/core/adt/simplify.go
@@ -214,7 +214,7 @@
 				return nil
 			}
 			for i, a := range x.Args {
-				if !Equal(ctx, a, y.Args[i]) {
+				if !Equal(ctx, a, y.Args[i], false) {
 					return nil
 				}
 			}