internal/core/adt: add memory management

Change-Id: I9241dabf3da0581dca6521c538207880fa6d0497
Reviewed-on: https://cue-review.googlesource.com/c/cue/+/8123
Reviewed-by: Marcel van Lohuizen <mpvl@golang.org>
Reviewed-by: CUE cueckoo <cueckoo@gmail.com>
diff --git a/internal/core/adt/default.go b/internal/core/adt/default.go
index 3140845..31e137f 100644
--- a/internal/core/adt/default.go
+++ b/internal/core/adt/default.go
@@ -70,6 +70,7 @@
 			w = d.Values[0]
 		default:
 			x := *v
+			x.state = nil
 			x.BaseValue = &Disjunction{
 				Src:         d.Src,
 				Values:      d.Values[:d.NumDefaults],
@@ -92,6 +93,7 @@
 
 		w := *v
 		w.BaseValue = &m
+		w.state = nil
 		return &w
 	}
 }
diff --git a/internal/core/adt/disjunct.go b/internal/core/adt/disjunct.go
index 77cc0a1..a4c3cc2 100644
--- a/internal/core/adt/disjunct.go
+++ b/internal/core/adt/disjunct.go
@@ -122,11 +122,6 @@
 	for n.expandOne() {
 	}
 
-	errNode := n
-	if parent != nil {
-		errNode = parent
-	}
-
 	// save node to snapShot in nodeContex
 	// save nodeContext.
 
@@ -142,6 +137,10 @@
 		n.postDisjunct(state)
 
 		if n.hasErr() {
+			// TODO: consider finalizing the node thusly:
+			// if recursive {
+			// 	n.node.Finalize(n.ctx)
+			// }
 			x := n.node
 			err, ok := x.BaseValue.(*Bottom)
 			if !ok {
@@ -154,25 +153,27 @@
 				err = x.ChildErrors
 			}
 			if err != nil {
-				errNode.disjunctErrs = append(errNode.disjunctErrs, err)
+				parent.disjunctErrs = append(parent.disjunctErrs, err)
 			}
-			if recursive || len(n.disjunctions) > 0 {
-				n.ctx.Unifier.freeNodeContext(n)
+			if recursive {
+				n.free()
 			}
 			return
 		}
+		if n.node.BaseValue == nil {
+			n.node.BaseValue = n.getValidators()
+		}
+
 		// TODO: clean up this mess:
 		result := *n.node // XXX: n.result = snapshotVertex(n.node)?
 
-		if result.BaseValue == nil {
-			result.BaseValue = n.getValidators()
-		}
-
-		if state < Finalized {
+		if recursive && state < Finalized {
 			*n = m
 		}
 		n.result = result
-		n.disjuncts = append(n.disjuncts, n)
+		if recursive {
+			n.disjuncts = append(n.disjuncts, n)
+		}
 
 	case len(n.disjunctions) > 0:
 		// Process full disjuncts to ensure that erroneous disjuncts are
@@ -201,6 +202,7 @@
 					for _, v := range d.expr.Values {
 						cn := dn.clone()
 						*cn.node = snapshotVertex(dn.snapshot)
+						cn.node.state = cn
 
 						c := MakeConjunct(d.env, v.Val, d.cloneID)
 						cn.addExprConjunct(c)
@@ -215,6 +217,7 @@
 					for i, v := range d.value.Values {
 						cn := dn.clone()
 						*cn.node = snapshotVertex(dn.snapshot)
+						cn.node.state = cn
 
 						cn.addValueConjunct(d.env, v, d.cloneID)
 
@@ -226,14 +229,17 @@
 				}
 			}
 
-			if i > 0 {
-				for _, d := range a {
-					n.ctx.freeNodeContext(d)
+			if len(n.disjuncts) == 0 {
+				n.makeError()
+			}
+
+			if recursive || i > 0 {
+				for _, x := range a {
+					x.free()
 				}
 			}
 
 			if len(n.disjuncts) == 0 {
-				n.makeError()
 				break
 			}
 		}
@@ -253,33 +259,27 @@
 	// Compare to root, but add to this one.
 	// TODO: if only one value is left, set to maybeDefault.
 	switch p := parent; {
-	case p != nil:
+	case p != n:
 		p.disjunctErrs = append(p.disjunctErrs, n.disjunctErrs...)
 		n.disjunctErrs = n.disjunctErrs[:0]
 
-		k := 0
 	outer:
 		for _, d := range n.disjuncts {
 			for _, v := range p.disjuncts {
 				if Equal(n.ctx, &v.result, &d.result) {
-					n.ctx.Unifier.freeNodeContext(n)
 					if d.defaultMode == isDefault {
 						v.defaultMode = isDefault
 					}
+					d.free()
 					continue outer
 				}
 			}
-			n.disjuncts[k] = d
-			k++
 
 			d.defaultMode = combineDefault(m, d.defaultMode)
+			p.disjuncts = append(p.disjuncts, d)
 		}
 
-		p.disjuncts = append(p.disjuncts, n.disjuncts[:k]...)
 		n.disjuncts = n.disjuncts[:0]
-
-	case n.done():
-		n.isDone = true
 	}
 }
 
diff --git a/internal/core/adt/eval.go b/internal/core/adt/eval.go
index 0365249..5bc75d0 100644
--- a/internal/core/adt/eval.go
+++ b/internal/core/adt/eval.go
@@ -61,6 +61,7 @@
 Freed:  {{.Freed}}
 Reused: {{.Reused}}
 Allocs: {{.Allocs}}
+Retain: {{.Retained}}
 
 Unifications: {{.UnifyCount}}
 Disjuncts:    {{.DisjunctCount}}`))
@@ -200,7 +201,8 @@
 		}
 	}
 
-	n := v.state
+	n := v.getNodeContext(c)
+	defer v.freeNode(n)
 
 	switch v.Status() {
 	case Evaluating:
@@ -210,11 +212,6 @@
 		return
 
 	case 0:
-		// from state 0
-		n = e.newNodeContext(c, v)
-
-		v.state = n
-
 		if v.Label.IsDef() {
 			v.Closed = true
 		}
@@ -310,12 +307,22 @@
 			return
 		}
 
-		n.expandDisjuncts(state, nil, maybeDefault, false)
+		n.expandDisjuncts(state, n, maybeDefault, false)
+
+		// If the state has changed, it is because a disjunct has been run. In this case, our node will have completed, and it will
+		// set a value soon.
+		v.state = n // alternatively, set to nil
+
+		for _, d := range n.disjuncts {
+			d.free()
+		}
 
 		switch len(n.disjuncts) {
 		case 0:
 		case 1:
-			*v = n.disjuncts[0].result
+			x := n.disjuncts[0].result
+			x.state = nil
+			*v = x
 
 		default:
 			d := n.createDisjunct()
@@ -573,6 +580,7 @@
 	for i, x := range n.disjuncts {
 		v := new(Vertex)
 		*v = x.result
+		v.state = nil
 		switch x.defaultMode {
 		case isDefault:
 			a[i] = a[p]
@@ -605,6 +613,7 @@
 // checks should only be performed once the full value is known.
 type nodeContext struct {
 	nextFree *nodeContext
+	refCount int
 
 	ctx  *OpContext
 	node *Vertex
@@ -653,7 +662,6 @@
 	hasTop      bool
 	hasCycle    bool // has conjunct with structural cycle
 	hasNonCycle bool // has conjunct without structural cycle
-	isDone      bool
 
 	// Disjunction handling
 	disjunctions []envDisjunct
@@ -666,6 +674,8 @@
 func (n *nodeContext) clone() *nodeContext {
 	d := n.ctx.Unifier.newNodeContext(n.ctx, n.node)
 
+	d.refCount++
+
 	d.ctx = n.ctx
 	d.node = n.node
 
@@ -733,21 +743,59 @@
 }
 
 func (v *Vertex) getNodeContext(c *OpContext) *nodeContext {
-	if v.state != nil {
+	if v.state == nil {
 		if v.status == Finalized {
-			panic("dangling node")
+			return nil
 		}
-		return v.state
+		v.state = c.Unifier.newNodeContext(c, v)
+	} else if v.state.node != v {
+		panic("getNodeContext: nodeContext out of sync")
 	}
-	v.state = c.Unifier.newNodeContext(c, v)
+	v.state.refCount++
 	return v.state
 }
 
+func (v *Vertex) freeNode(n *nodeContext) {
+	if n == nil {
+		return
+	}
+	if n.node != v {
+		panic("freeNode: unpaired free")
+	}
+	if v.state != nil && v.state != n {
+		panic("freeNode: nodeContext out of sync")
+	}
+	if n.refCount--; n.refCount == 0 {
+		if v.status == Finalized {
+			v.freeNodeState()
+		} else {
+			n.ctx.Unifier.stats.Retained++
+		}
+	}
+}
+
+func (v *Vertex) freeNodeState() {
+	if v.state == nil {
+		return
+	}
+	state := v.state
+	v.state = nil
+
+	state.ctx.Unifier.freeNodeContext(state)
+}
+
+func (n *nodeContext) free() {
+	if n.refCount--; n.refCount == 0 {
+		n.ctx.Unifier.freeNodeContext(n)
+	}
+}
+
 func (e *Unifier) freeNodeContext(n *nodeContext) {
-	// TODO: re-enable memory management.
-	// e.stats.Freed++
-	// n.nextFree = e.freeListNode
-	// e.freeListNode = n
+	e.stats.Freed++
+	n.nextFree = e.freeListNode
+	e.freeListNode = n
+	n.node = nil
+	n.refCount = 0
 }
 
 // TODO(perf): return a dedicated ConflictError that can track original