cue/ast/astutil: implement Apply

Change-Id: Iddf1aafae8760281c25c78f0beafc7f6883e6c86
Reviewed-on: https://cue-review.googlesource.com/c/cue/+/3189
Reviewed-by: Marcel van Lohuizen <mpvl@golang.org>
diff --git a/cue/ast/astutil/apply.go b/cue/ast/astutil/apply.go
new file mode 100644
index 0000000..4e747f5
--- /dev/null
+++ b/cue/ast/astutil/apply.go
@@ -0,0 +1,406 @@
+// Copyright 2018 The CUE Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package astutil
+
+import (
+	"fmt"
+	"reflect"
+
+	"cuelang.org/go/cue/ast"
+)
+
+// A Cursor describes a node encountered during Apply.
+// Information about the node and its parent is available
+// from the Node, Parent, and Index methods.
+//
+// The methods Replace, Delete, InsertBefore, and InsertAfter
+// can be used to change the AST without disrupting Apply.
+// Delete, InsertBefore, and InsertAfter are only defined for modifying
+// a StructLit and will panic in any other context.
+type Cursor interface {
+	// Node returns the current Node.
+	Node() ast.Node
+
+	// Parent returns the parent of the current Node.
+	Parent() Cursor
+
+	// Index reports the index >= 0 of the current Node in the slice of Nodes
+	// that contains it, or a value < 0 if the current Node is not part of a
+	// list.
+	Index() int
+
+	// Replace replaces the current Node with n.
+	// The replacement node is not walked by Apply.
+	Replace(n ast.Node)
+
+	// Delete deletes the current Node from its containing struct.
+	// If the current Node is not part of a struct, Delete panics.
+	Delete()
+
+	// InsertAfter inserts n after the current Node in its containing struct.
+	// If the current Node is not part of a struct, InsertAfter panics.
+	// Apply does not walk n.
+	InsertAfter(n ast.Node)
+
+	// InsertBefore inserts n before the current Node in its containing struct.
+	// If the current Node is not part of a struct, InsertBefore panics.
+	// Apply will not walk n.
+	InsertBefore(n ast.Node)
+}
+
+type cursor struct {
+	parent   Cursor
+	node     ast.Node
+	typ      interface{} // the type of the node
+	index    int         // position of any of the sub types.
+	replaced bool
+}
+
+func newCursor(parent Cursor, n ast.Node, typ interface{}) *cursor {
+	return &cursor{
+		parent: parent,
+		typ:    typ,
+		node:   n,
+		index:  -1,
+	}
+}
+
+func (c *cursor) Parent() Cursor { return c.parent }
+func (c *cursor) Index() int     { return c.index }
+func (c *cursor) Node() ast.Node { return c.node }
+
+func (c *cursor) Replace(n ast.Node) {
+	// panic if the value cannot convert to the original type.
+	reflect.ValueOf(n).Convert(reflect.TypeOf(c.typ).Elem())
+	c.node = n
+	c.replaced = true
+}
+
+func (c *cursor) InsertAfter(n ast.Node)  { panic("unsupported") }
+func (c *cursor) InsertBefore(n ast.Node) { panic("unsupported") }
+func (c *cursor) Delete()                 { panic("unsupported") }
+
+// Apply traverses a syntax tree recursively, starting with root,
+// and calling pre and post for each node as described below.
+// Apply returns the syntax tree, possibly modified.
+//
+// If pre is not nil, it is called for each node before the node's
+// children are traversed (pre-order). If pre returns false, no
+// children are traversed, and post is not called for that node.
+//
+// If post is not nil, and a prior call of pre didn't return false,
+// post is called for each node after its children are traversed
+// (post-order). If post returns false, traversal is terminated and
+// Apply returns immediately.
+//
+// Only fields that refer to AST nodes are considered children;
+// i.e., token.Pos, Scopes, Objects, and fields of basic types
+// (strings, etc.) are ignored.
+//
+// Children are traversed in the order in which they appear in the
+// respective node's struct definition.
+//
+func Apply(node ast.Node, before, after func(Cursor) bool) ast.Node {
+	apply(&applier{before: before, after: after}, nil, &node)
+	return node
+}
+
+// A applyVisitor's before method is invoked for each node encountered by Walk.
+// If the result applyVisitor w is true, Walk visits each of the children
+// of node with the applyVisitor w, followed by a call of w.After.
+type applyVisitor interface {
+	Before(Cursor) applyVisitor
+	After(Cursor) bool
+}
+
+// Helper functions for common node lists. They may be empty.
+
+func applyExprList(v applyVisitor, parent Cursor, ptr interface{}, list []ast.Expr) {
+	c := newCursor(parent, nil, nil)
+	for i, x := range list {
+		c.index = i
+		c.node = x
+		c.typ = &list[i]
+		applyCursor(v, c)
+		if x != c.node {
+			list[i] = c.node.(ast.Expr)
+		}
+	}
+}
+
+type declsCursor struct {
+	*cursor
+	decls, after []ast.Decl
+	delete       bool
+}
+
+func (c *declsCursor) InsertAfter(n ast.Node) {
+	c.after = append(c.after, n.(ast.Decl))
+}
+
+func (c *declsCursor) InsertBefore(n ast.Node) {
+	c.decls = append(c.decls, n.(ast.Decl))
+}
+
+func (c *declsCursor) Delete() { c.delete = true }
+
+func applyDeclList(v applyVisitor, parent Cursor, list []ast.Decl) []ast.Decl {
+	c := &declsCursor{
+		cursor: newCursor(parent, nil, nil),
+		decls:  make([]ast.Decl, 0, len(list)),
+	}
+	for i, x := range list {
+		c.node = x
+		c.typ = &list[i]
+		applyCursor(v, c)
+		if !c.delete {
+			c.decls = append(c.decls, c.node.(ast.Decl))
+		}
+		c.delete = false
+		c.decls = append(c.decls, c.after...)
+		c.after = c.after[:0]
+	}
+	return c.decls
+}
+
+func apply(v applyVisitor, parent Cursor, nodePtr interface{}) {
+	res := reflect.Indirect(reflect.ValueOf(nodePtr))
+	n := res.Interface()
+	node := n.(ast.Node)
+	c := newCursor(parent, node, nodePtr)
+	applyCursor(v, c)
+	if node != c.node {
+		res.Set(reflect.ValueOf(c.node))
+	}
+}
+
+// applyCursor traverses an AST in depth-first order: It starts by calling
+// v.Visit(node); node must not be nil. If the visitor w returned by
+// v.Visit(node) is not nil, apply is invoked recursively with visitor
+// w for each of the non-nil children of node, followed by a call of
+// w.Visit(nil).
+//
+func applyCursor(v applyVisitor, c Cursor) {
+	if v = v.Before(c); v == nil {
+		return
+	}
+
+	node := c.Node()
+
+	// TODO: record the comment groups and interleave with the values like for
+	// parsing and printing?
+	comments := node.Comments()
+	for _, cm := range comments {
+		apply(v, c, &cm)
+	}
+
+	// apply children
+	// (the order of the cases matches the order
+	// of the corresponding node types in go)
+	switch n := node.(type) {
+	// Comments and fields
+	case *ast.Comment:
+		// nothing to do
+
+	case *ast.CommentGroup:
+		for _, cg := range n.List {
+			apply(v, c, &cg)
+		}
+
+	case *ast.Attribute:
+		// nothing to do
+
+	case *ast.Field:
+		apply(v, c, &n.Label)
+		if n.Value != nil {
+			apply(v, c, &n.Value)
+		}
+		for _, a := range n.Attrs {
+			apply(v, c, &a)
+		}
+
+	case *ast.StructLit:
+		n.Elts = applyDeclList(v, c, n.Elts)
+
+	// Expressions
+	case *ast.BottomLit, *ast.BadExpr, *ast.Ident, *ast.BasicLit:
+		// nothing to do
+
+	case *ast.TemplateLabel:
+		apply(v, c, &n.Ident)
+
+	case *ast.Interpolation:
+		applyExprList(v, c, &n, n.Elts)
+
+	case *ast.ListLit:
+		applyExprList(v, c, &n, n.Elts)
+
+	case *ast.Ellipsis:
+		if n.Type != nil {
+			apply(v, c, &n.Type)
+		}
+
+	case *ast.ParenExpr:
+		apply(v, c, &n.X)
+
+	case *ast.SelectorExpr:
+		apply(v, c, &n.X)
+		apply(v, c, &n.Sel)
+
+	case *ast.IndexExpr:
+		apply(v, c, &n.X)
+		apply(v, c, &n.Index)
+
+	case *ast.SliceExpr:
+		apply(v, c, &n.X)
+		if n.Low != nil {
+			apply(v, c, &n.Low)
+		}
+		if n.High != nil {
+			apply(v, c, &n.High)
+		}
+
+	case *ast.CallExpr:
+		apply(v, c, &n.Fun)
+		applyExprList(v, c, &n, n.Args)
+
+	case *ast.UnaryExpr:
+		apply(v, c, &n.X)
+
+	case *ast.BinaryExpr:
+		apply(v, c, &n.X)
+		apply(v, c, &n.Y)
+
+	// Declarations
+	case *ast.ImportSpec:
+		if n.Name != nil {
+			apply(v, c, &n.Name)
+		}
+		apply(v, c, &n.Path)
+
+	case *ast.BadDecl:
+		// nothing to do
+
+	case *ast.ImportDecl:
+		for _, s := range n.Specs {
+			apply(v, c, &s)
+		}
+
+	case *ast.EmbedDecl:
+		apply(v, c, &n.Expr)
+
+	case *ast.Alias:
+		apply(v, c, &n.Ident)
+		apply(v, c, &n.Expr)
+
+	case *ast.Comprehension:
+		clauses := n.Clauses
+		for i := range n.Clauses {
+			apply(v, c, &clauses[i])
+		}
+		apply(v, c, &n.Value)
+
+	// Files and packages
+	case *ast.File:
+		n.Decls = applyDeclList(v, c, n.Decls)
+
+	case *ast.Package:
+		apply(v, c, &n.Name)
+
+	case *ast.ListComprehension:
+		apply(v, c, &n.Expr)
+		clauses := n.Clauses
+		for i := range clauses {
+			apply(v, c, &clauses[i])
+		}
+
+	case *ast.ForClause:
+		if n.Key != nil {
+			apply(v, c, &n.Key)
+		}
+		apply(v, c, &n.Value)
+		apply(v, c, &n.Source)
+
+	case *ast.IfClause:
+		apply(v, c, &n.Condition)
+
+	default:
+		panic(fmt.Sprintf("Walk: unexpected node type %T", n))
+	}
+
+	v.After(c)
+}
+
+type applier struct {
+	before func(Cursor) bool
+	after  func(Cursor) bool
+
+	commentStack []commentFrame
+	current      commentFrame
+}
+
+type commentFrame struct {
+	cg  []*ast.CommentGroup
+	pos int8
+}
+
+func (f *applier) Before(c Cursor) applyVisitor {
+	node := c.Node()
+	if f.before == nil || (f.before(c) && node == c.Node()) {
+		f.commentStack = append(f.commentStack, f.current)
+		f.current = commentFrame{cg: node.Comments()}
+		f.visitComments(c, f.current.pos)
+		return f
+	}
+	return nil
+}
+
+func (f *applier) After(c Cursor) bool {
+	f.visitComments(c, 127)
+	p := len(f.commentStack) - 1
+	f.current = f.commentStack[p]
+	f.commentStack = f.commentStack[:p]
+	f.current.pos++
+	if f.after != nil {
+		f.after(c)
+	}
+	return true
+}
+
+func (f *applier) visitComments(p Cursor, pos int8) {
+	c := &f.current
+	for i := 0; i < len(c.cg); i++ {
+		cg := c.cg[i]
+		if cg.Position == pos {
+			continue
+		}
+		cursor := newCursor(p, cg, cg)
+		if f.before == nil || (f.before(cursor) && !cursor.replaced) {
+			for j, c := range cg.List {
+				cursor := newCursor(p, c, &c)
+				if f.before == nil || (f.before(cursor) && !cursor.replaced) {
+					if f.after != nil {
+						f.after(cursor)
+					}
+				}
+				cg.List[j] = cursor.node.(*ast.Comment)
+			}
+			if f.after != nil {
+				f.after(cursor)
+			}
+		}
+		c.cg[i] = cursor.node.(*ast.CommentGroup)
+	}
+}
diff --git a/cue/ast/astutil/apply_test.go b/cue/ast/astutil/apply_test.go
new file mode 100644
index 0000000..4a2eebf
--- /dev/null
+++ b/cue/ast/astutil/apply_test.go
@@ -0,0 +1,222 @@
+// Copyright 2019 CUE Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package astutil_test
+
+import (
+	"strings"
+	"testing"
+
+	"cuelang.org/go/cue/ast"
+	"cuelang.org/go/cue/ast/astutil"
+	"cuelang.org/go/cue/format"
+	"cuelang.org/go/cue/parser"
+	"cuelang.org/go/cue/token"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestApply(t *testing.T) {
+	testCases := []struct {
+		name   string
+		in     string
+		out    string
+		before func(astutil.Cursor) bool
+		after  func(astutil.Cursor) bool
+	}{{
+		// This should pass
+	}, {
+		name: "insert before",
+		in: `
+		// foo is a
+		foo: {
+			a: 3
+		}
+		`,
+		out: `
+iam: new
+// foo is a
+foo: {
+	iam: new
+	a:   3
+}
+`,
+		before: func(c astutil.Cursor) bool {
+			switch c.Node().(type) {
+			case *ast.Field:
+				c.InsertBefore(&ast.Field{
+					Label: ast.NewIdent("iam"),
+					Value: ast.NewIdent("new"),
+				})
+			}
+			return true
+		},
+	}, {
+		name: "insert after",
+		in: `
+			foo: {
+				a: 3 @test()
+			}
+			`,
+		out: `
+foo: {
+	a:   3 @test()
+	iam: new
+}
+iam: new
+`,
+		before: func(c astutil.Cursor) bool {
+			switch c.Node().(type) {
+			case *ast.Field:
+				c.InsertAfter(&ast.Field{
+					Label: ast.NewIdent("iam"),
+					Value: ast.NewIdent("new"),
+				})
+			}
+			return true
+		},
+	}, {
+		name: "templates",
+		in: `
+				foo: {
+					a <b> c: 3
+				}
+				`,
+		out: `
+foo: {
+	a <b>: {
+		c:   3
+		iam: new
+	}
+}
+	`,
+		before: func(c astutil.Cursor) bool {
+			switch x := c.Node().(type) {
+			case *ast.Field:
+				if _, ok := x.Value.(*ast.StructLit); !ok {
+					c.InsertAfter(&ast.Field{
+						Label: ast.NewIdent("iam"),
+						Value: ast.NewIdent("new"),
+					})
+				}
+			}
+			return true
+		},
+	}, {
+		name: "replace",
+		in: `
+		a: "string"
+		b: 3
+		c: [ 1, 2, 8, 4 ]
+		d: "\(foo) is \(0)"
+		`,
+		out: `
+a: s
+b: 4
+c: [4, 4, 4, 4]
+d: "\(foo) is \(4)"
+`,
+		before: func(c astutil.Cursor) bool {
+			switch x := c.Node().(type) {
+			case *ast.BasicLit:
+				switch x.Kind {
+				case token.STRING:
+					if c.Index() < 0 {
+						c.Replace(ast.NewIdent("s"))
+					}
+				case token.INT:
+					c.Replace(&ast.BasicLit{Kind: token.INT, Value: "4"})
+				}
+			}
+			return true
+		},
+	}, {
+		name: "delete",
+		in: `
+		z: 0
+		a: "foo"
+		b: 3
+		b: "bar"
+		c: 2
+		`,
+		out: `
+a: "foo"
+b: "bar"
+	`,
+		before: func(c astutil.Cursor) bool {
+			f, ok := c.Node().(*ast.Field)
+			if !ok {
+				return true
+			}
+			switch x := f.Value.(type) {
+			case *ast.BasicLit:
+				switch x.Kind {
+				case token.INT:
+					c.Delete()
+				}
+			}
+			return true
+		},
+	}, {
+		name: "comments",
+		in: `
+		// test
+		a: "string"
+		`,
+		out: `
+// 1, 2, 3
+a: "string"
+	`,
+		before: func(c astutil.Cursor) bool {
+			switch c.Node().(type) {
+			case *ast.Comment:
+				c.Replace(&ast.Comment{Text: "// 1, 2, 3"})
+			}
+			return true
+		},
+	}, {
+		name: "comments after",
+		in: `
+	// test
+	a: "string"
+			`,
+		out: `
+// 1, 2, 3
+a: "string"
+		`,
+		after: func(c astutil.Cursor) bool {
+			switch c.Node().(type) {
+			case *ast.Comment:
+				c.Replace(&ast.Comment{Text: "// 1, 2, 3"})
+			}
+			return true
+		},
+	}}
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			f, err := parser.ParseFile(tc.name, tc.in, parser.ParseComments)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			n := astutil.Apply(f, tc.before, tc.after)
+
+			b, err := format.Node(n)
+			require.NoError(t, err)
+			got := strings.TrimSpace(string(b))
+			want := strings.TrimSpace(tc.out)
+			assert.Equal(t, want, got)
+		})
+	}
+}
diff --git a/cue/ast/walk.go b/cue/ast/walk.go
index 6e66754..1dfd1a0 100644
--- a/cue/ast/walk.go
+++ b/cue/ast/walk.go
@@ -30,7 +30,7 @@
 }
 
 // A visitor's before method is invoked for each node encountered by Walk.
-// If the result visitor w is not nil, Walk visits each of the children
+// If the result visitor w is true, Walk visits each of the children
 // of node with the visitor w, followed by a call of w.After.
 type visitor interface {
 	Before(node Node) (w visitor)