cue/ast/astutil: support adding imports
Also implements manual resolving of nodes. Use cases:
- bypass shadowing issues when modifying AST (works)
- allow unshadowing when formatting (not implemented)
Change-Id: I774678c86161cda46d07421e81c839d39da85b26
Reviewed-on: https://cue-review.googlesource.com/c/cue/+/3225
Reviewed-by: Marcel van Lohuizen <mpvl@golang.org>
diff --git a/cue/ast/astutil/apply.go b/cue/ast/astutil/apply.go
index 22deb16..42a70f4 100644
--- a/cue/ast/astutil/apply.go
+++ b/cue/ast/astutil/apply.go
@@ -16,9 +16,13 @@
import (
"fmt"
+ "path"
"reflect"
+ "strconv"
+ "strings"
"cuelang.org/go/cue/ast"
+ "cuelang.org/go/cue/token"
)
// A Cursor describes a node encountered during Apply.
@@ -41,6 +45,11 @@
// list.
Index() int
+ // Import reports an opaque identifier that refers to the given package. It
+ // may only be called if the input to apply was an ast.File. If the import
+ // does not exist, it will be added.
+ Import(path string) *ast.Ident
+
// Replace replaces the current Node with n.
// The replacement node is not walked by Apply. Comments of the old node
// are copied to the new node if it has not yet an comments associated
@@ -60,9 +69,19 @@
// If the current Node is not part of a struct, InsertBefore panics.
// Apply will not walk n.
InsertBefore(n ast.Node)
+
+ self() *cursor
+}
+
+type info struct {
+ f *ast.File
+ current *declsCursor
+
+ importPatch []*ast.Ident
}
type cursor struct {
+ file *info
parent Cursor
node ast.Node
typ interface{} // the type of the node
@@ -79,10 +98,78 @@
}
}
+func fileInfo(c Cursor) (info *info) {
+ for ; c != nil; c = c.Parent() {
+ if i := c.self().file; i != nil {
+ return i
+ }
+ }
+ return nil
+}
+
+func (c *cursor) self() *cursor { return c }
func (c *cursor) Parent() Cursor { return c.parent }
func (c *cursor) Index() int { return c.index }
func (c *cursor) Node() ast.Node { return c.node }
+func (c *cursor) Import(importPath string) *ast.Ident {
+ info := fileInfo(c)
+ if info == nil {
+ return nil
+ }
+
+ quoted := strconv.Quote(importPath)
+
+ var imports *ast.ImportDecl
+ var spec *ast.ImportSpec
+ decls := info.current.decls
+ i := 0
+outer:
+ for ; i < len(decls); i++ {
+ d := decls[i]
+ switch t := d.(type) {
+ default:
+ break outer
+
+ case *ast.Package:
+ case *ast.CommentGroup:
+ case *ast.ImportDecl:
+ imports = t
+ for _, s := range t.Specs {
+ if s.Path.Value == quoted {
+ spec = s
+ break
+ }
+ }
+ }
+ }
+
+ if spec == nil {
+ // Import not found, add one.
+ if imports == nil {
+ imports = &ast.ImportDecl{}
+ a := append(append(decls[:i], imports), decls[i:]...)
+ decls = a
+ info.current.decls = decls
+ }
+ path := &ast.BasicLit{Kind: token.STRING, Value: quoted}
+ spec = &ast.ImportSpec{Path: path}
+ imports.Specs = append(imports.Specs, spec)
+ ast.SetRelPos(imports.Specs[0], token.NoRelPos)
+ }
+
+ ident := &ast.Ident{Node: spec} // Name is set later.
+ info.importPatch = append(info.importPatch, ident)
+
+ name := path.Base(importPath)
+ if p := strings.LastIndexByte(name, ':'); p > 0 {
+ name = name[p+1:]
+ }
+ ident.Name = name
+
+ return ident
+}
+
func (c *cursor) Replace(n ast.Node) {
// panic if the value cannot convert to the original type.
reflect.ValueOf(n).Convert(reflect.TypeOf(c.typ).Elem())
@@ -166,6 +253,9 @@
cursor: newCursor(parent, nil, nil),
decls: make([]ast.Decl, 0, len(list)),
}
+ if file, ok := parent.Node().(*ast.File); ok {
+ c.cursor.file = &info{f: file, current: c}
+ }
for i, x := range list {
c.node = x
c.typ = &list[i]
@@ -177,6 +267,24 @@
c.decls = append(c.decls, c.after...)
c.after = c.after[:0]
}
+
+ // TODO: ultimately, programmatically linked nodes have to be resolved
+ // at the end.
+ // if info := c.cursor.file; info != nil {
+ // done := map[*ast.ImportSpec]bool{}
+ // for _, ident := range info.importPatch {
+ // spec := ident.Node.(*ast.ImportSpec)
+ // if done[spec] {
+ // continue
+ // }
+ // done[spec] = true
+
+ // path, _ := strconv.Unquote(spec.Path)
+
+ // ident.Name =
+ // }
+ // }
+
return c.decls
}
diff --git a/cue/ast/astutil/apply_test.go b/cue/ast/astutil/apply_test.go
index 7ed8f08..899e225 100644
--- a/cue/ast/astutil/apply_test.go
+++ b/cue/ast/astutil/apply_test.go
@@ -204,6 +204,49 @@
}
return true
},
+ }, {
+ name: "imports",
+ in: `
+a: "string"
+ `,
+ out: `
+import "list"
+
+a: list
+ `,
+ after: func(c astutil.Cursor) bool {
+ switch c.Node().(type) {
+ case *ast.BasicLit:
+ c.Replace(c.Import("list"))
+ }
+ return true
+ },
+ }, {
+ name: "imports",
+ in: `package foo
+
+import "math"
+
+a: 3
+ `,
+ out: `package foo
+
+import (
+ "math"
+ "list"
+)
+
+a: list
+ `,
+ after: func(c astutil.Cursor) bool {
+ switch x := c.Node().(type) {
+ case *ast.BasicLit:
+ if x.Kind == token.INT {
+ c.Replace(c.Import("list"))
+ }
+ }
+ return true
+ },
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
diff --git a/cue/ast/astutil/resolve.go b/cue/ast/astutil/resolve.go
index 3fc661b..2ebfe2f 100644
--- a/cue/ast/astutil/resolve.go
+++ b/cue/ast/astutil/resolve.go
@@ -29,11 +29,13 @@
type ErrFunc func(pos token.Pos, msg string, args ...interface{})
// Resolve resolves all identifiers in a file. Unresolved identifiers are
-// recorded in Unresolved.
+// recorded in Unresolved. It will not overwrite already resolved values.
func Resolve(f *ast.File, errFn ErrFunc) {
walk(&scope{errFn: errFn}, f)
}
+// Resolve resolves all identifiers in an expression.
+// It will not overwrite already resolved values.
func ResolveExpr(e ast.Expr, errFn ErrFunc) {
f := &ast.File{}
walk(&scope{file: f, errFn: errFn}, e)
@@ -93,6 +95,20 @@
s.index[name] = n
}
+func (s *scope) resolveScope(name string, node ast.Node) (scope ast.Node, ok bool) {
+ last := s
+ for s != nil {
+ if n, ok := s.index[name]; ok && node == n {
+ if last.node == n {
+ return nil, true
+ }
+ return s.node, true
+ }
+ s, last = s.outer, s
+ }
+ return nil, false
+}
+
func (s *scope) lookup(name string) (obj, node ast.Node) {
last := s
for s != nil {
@@ -169,8 +185,21 @@
break
}
if obj, node := s.lookup(name); node != nil {
- x.Node = node
- x.Scope = obj
+ switch {
+ case x.Node == nil:
+ x.Node = node
+ x.Scope = obj
+
+ case x.Node == node:
+ x.Scope = obj
+
+ default: // x.Node != node
+ scope, ok := s.resolveScope(name, x.Node)
+ if !ok {
+ s.file.Unresolved = append(s.file.Unresolved, x)
+ }
+ x.Scope = scope
+ }
} else {
s.file.Unresolved = append(s.file.Unresolved, x)
}