cue: support custom validators

Any function in the standard lib that returns
a boolean and has at least one arg.

Change-Id: I80f5432863b4aeb1941006ff847a26f75d59f41c
Reviewed-on: https://cue-review.googlesource.com/c/cue/+/2125
Reviewed-by: Marcel van Lohuizen <mpvl@google.com>
diff --git a/cue/binop.go b/cue/binop.go
index 11cb05a..663bd68 100644
--- a/cue/binop.go
+++ b/cue/binop.go
@@ -473,6 +473,75 @@
 	return ctx.mkIncompatible(src, op, x, other)
 }
 
+func (x *customValidator) binOp(ctx *context, src source, op op, other evaluated) evaluated {
+	newSrc := binSrc(src.Pos(), op, x, other)
+	switch op {
+	case opUnify:
+		k, _ := matchBinOpKind(opUnify, x.kind(), other.kind())
+		if k == bottomKind {
+			break
+		}
+		switch y := other.(type) {
+		case *basicType:
+			k := unifyType(x.kind(), y.kind())
+			if k == x.kind() {
+				return x
+			}
+			return &unification{newSrc, []evaluated{x, y}}
+
+		case *bound:
+			return &unification{newSrc, []evaluated{x, y}}
+
+		case *numLit:
+			if err := x.check(ctx, y); err != nil {
+				return err
+			}
+			// Narrow down number type.
+			if y.k != k {
+				n := *y
+				n.k = k
+				return &n
+			}
+			return other
+
+		case *nullLit, *boolLit, *durationLit, *list, *structLit, *stringLit, *bytesLit:
+			// All remaining concrete types. This includes non-comparable types
+			// for comparison to null.
+			if err := x.check(ctx, y); err != nil {
+				return err
+			}
+			return y
+		}
+	}
+	return ctx.mkIncompatible(src, op, x, other)
+}
+
+func (x *customValidator) check(ctx *context, v evaluated) evaluated {
+	args := make([]evaluated, 1+len(x.args))
+	args[0] = v
+	for i, v := range x.args {
+		args[1+i] = v.(evaluated)
+	}
+	res := x.call.call(ctx, x, args...)
+	if isBottom(res) {
+		return res.(evaluated)
+	}
+	if b, ok := res.(*boolLit); !ok {
+		// should never reach here
+		return ctx.mkErr(x, "invalid custom validator")
+	} else if !b.b {
+		var buf bytes.Buffer
+		buf.WriteString(x.call.Name)
+		buf.WriteString("(")
+		for _, a := range x.args {
+			buf.WriteString(debugStr(ctx, a))
+		}
+		buf.WriteString(")")
+		return ctx.mkErr(x, "value %v not in %v", debugStr(ctx, v), buf.String())
+	}
+	return nil
+}
+
 func evalLambda(ctx *context, a value) (l *lambdaExpr, err evaluated) {
 	if a == nil {
 		return nil, nil
@@ -783,7 +852,7 @@
 
 func (x *numLit) binOp(ctx *context, src source, op op, other evaluated) evaluated {
 	switch y := other.(type) {
-	case *basicType, *bound: // for better error reporting
+	case *basicType, *bound, *customValidator: // for better error reporting
 		if op == opUnify {
 			return y.binOp(ctx, src, op, x)
 		}
diff --git a/cue/builtin.go b/cue/builtin.go
index 13708a4..8751ecb 100644
--- a/cue/builtin.go
+++ b/cue/builtin.go
@@ -51,6 +51,7 @@
 type builtin struct {
 	baseValue
 	Name   string
+	pkg    label
 	Params []kind
 	Result kind
 	Func   func(c *callCtxt)
@@ -63,9 +64,12 @@
 	cue    string
 }
 
-func mustCompileBuiltins(ctx *context, p *builtinPkg, name string) *structLit {
+func mustCompileBuiltins(ctx *context, p *builtinPkg, pkgName string) *structLit {
 	obj := &structLit{}
+	pkgLabel := ctx.label(pkgName, false)
 	for _, b := range p.native {
+		b.pkg = pkgLabel
+
 		f := ctx.label(b.Name, false) // never starts with _
 		// n := &node{baseValue: newBase(imp.Path)}
 		var v evaluated = b
@@ -78,10 +82,9 @@
 
 	// Parse builtin CUE
 	if p.cue != "" {
-		expr, err := parser.ParseExpr(name, p.cue)
+		expr, err := parser.ParseExpr(pkgName, p.cue)
 		if err != nil {
-			fmt.Println(p.cue)
-			panic(err)
+			panic(fmt.Errorf("could not parse %v: %v", p.cue, err))
 		}
 		pkg := evalExpr(ctx.index, obj, expr).(*structLit)
 		for _, a := range pkg.arcs {
@@ -188,6 +191,10 @@
 	if x.Func == nil {
 		return ctx.mkErr(x, "Builtin %q is not a function", x.Name)
 	}
+	if len(x.Params)-1 == len(args) && x.Result == boolKind {
+		// We have a custom builtin
+		return &customValidator{src.base(), args, x}
+	}
 	if len(x.Params) != len(args) {
 		return ctx.mkErr(src, x, "number of arguments does not match (%d vs %d)",
 			len(x.Params), len(args))
diff --git a/cue/debug.go b/cue/debug.go
index 9863631..148941c 100644
--- a/cue/debug.go
+++ b/cue/debug.go
@@ -195,6 +195,16 @@
 			}
 		}
 		write(")")
+	case *customValidator:
+		p.debugStr(x.call)
+		write(" (")
+		for i, a := range x.args {
+			p.debugStr(a)
+			if i < len(x.args)-1 {
+				write(",")
+			}
+		}
+		write(")")
 	case *unaryExpr:
 		write(x.op)
 		p.debugStr(x.x)
diff --git a/cue/eval.go b/cue/eval.go
index 68056ab..aa91722 100644
--- a/cue/eval.go
+++ b/cue/eval.go
@@ -169,6 +169,14 @@
 	return e.err(err)
 }
 
+func (x *customValidator) evalPartial(ctx *context) (result evaluated) {
+	if ctx.trace {
+		defer uni(indent(ctx, "custom", x))
+		defer func() { ctx.debugPrint("result:", result) }()
+	}
+	return x
+}
+
 func (x *bound) evalPartial(ctx *context) (result evaluated) {
 	if ctx.trace {
 		defer uni(indent(ctx, "bound", x))
diff --git a/cue/export.go b/cue/export.go
index 6adf44e..40c58d0 100644
--- a/cue/export.go
+++ b/cue/export.go
@@ -17,6 +17,7 @@
 import (
 	"fmt"
 	"math/rand"
+	"sort"
 	"strconv"
 	"strings"
 	"unicode"
@@ -30,15 +31,70 @@
 	return !m.raw
 }
 
-func export(ctx *context, v value, m options) ast.Expr {
-	e := exporter{ctx, m, nil}
-	return e.expr(v)
+func export(ctx *context, v value, m options) ast.Node {
+	e := exporter{ctx, m, nil, map[label]bool{}, map[string]importInfo{}}
+	top, ok := v.evalPartial(ctx).(*structLit)
+	if ok {
+		top = top.expandFields(ctx)
+		for _, a := range top.arcs {
+			e.top[a.feature] = true
+		}
+	}
+
+	value := e.expr(v)
+	if len(e.imports) == 0 {
+		return value
+	}
+	imports := make([]string, 0, len(e.imports))
+	for k := range e.imports {
+		imports = append(imports, k)
+	}
+	sort.Strings(imports)
+
+	importDecl := &ast.ImportDecl{}
+	file := &ast.File{Decls: []ast.Decl{importDecl}}
+
+	for _, k := range imports {
+		info := e.imports[k]
+		ident := (*ast.Ident)(nil)
+		if info.name != "" {
+			ident = ast.NewIdent(info.name)
+		}
+		if info.alias != "" {
+			file.Decls = append(file.Decls, &ast.Alias{
+				Ident: ast.NewIdent(info.alias),
+				Expr:  ast.NewIdent(info.short),
+			})
+		}
+		importDecl.Specs = append(importDecl.Specs, &ast.ImportSpec{
+			Name: ident,
+			Path: &ast.BasicLit{Kind: token.STRING, Value: quote(k, '"')},
+		})
+	}
+
+	// TODO: should we unwrap structs?
+	if obj, ok := value.(*ast.StructLit); ok {
+		file.Decls = append(file.Decls, obj.Elts...)
+	} else {
+		file.Decls = append(file.Decls, &ast.EmitDecl{Expr: value})
+	}
+
+	// resolve the file.
+	return file
 }
 
 type exporter struct {
-	ctx   *context
-	mode  options
-	stack []remap
+	ctx     *context
+	mode    options
+	stack   []remap
+	top     map[label]bool        // label to alias or ""
+	imports map[string]importInfo // pkg path to info
+}
+
+type importInfo struct {
+	name  string
+	short string
+	alias string
 }
 
 type remap struct {
@@ -111,7 +167,7 @@
 	if doEval(p.mode) {
 		x := p.ctx.manifest(v)
 		if isIncomplete(x) {
-			p = &exporter{p.ctx, options{raw: true}, p.stack}
+			p = &exporter{p.ctx, options{raw: true}, p.stack, p.top, p.imports}
 			return p.expr(v)
 		}
 		v = x
@@ -123,7 +179,42 @@
 	// TODO: also add position information.
 	switch x := v.(type) {
 	case *builtin:
-		return &ast.Ident{Name: x.Name}
+		name := ast.NewIdent(x.Name)
+		if x.pkg == 0 {
+			return name
+		}
+		pkg := p.ctx.labelStr(x.pkg)
+		info, ok := p.imports[pkg]
+		short := info.short
+		if !ok {
+			info.short = ""
+			short = pkg
+			if i := strings.LastIndexByte(pkg, '.'); i >= 0 {
+				short = pkg[i+1:]
+			}
+			for {
+				if _, ok := p.top[p.ctx.label(short, true)]; !ok {
+					break
+				}
+				short += "x"
+				info.name = short
+			}
+			info.short = short
+			p.top[p.ctx.label(short, true)] = true
+			p.imports[pkg] = info
+		}
+		f := p.ctx.label(short, true)
+		for _, e := range p.stack {
+			if e.from == f {
+				if info.alias == "" {
+					info.alias = p.unique(short)
+					p.imports[pkg] = info
+				}
+				short = info.alias
+				break
+			}
+		}
+		return &ast.SelectorExpr{X: ast.NewIdent(short), Sel: name}
 
 	case *nodeRef:
 		return nil
@@ -139,6 +230,7 @@
 			// TODO: should not happen: report error
 			return ident
 		}
+		// TODO: nodes may have changed. Use different algorithm.
 		conflict := false
 		for i := len(p.stack) - 1; i >= 0; i-- {
 			e := &p.stack[i]
@@ -183,6 +275,13 @@
 		}
 		return call
 
+	case *customValidator:
+		call := &ast.CallExpr{Fun: p.expr(x.call)}
+		for _, a := range x.args {
+			call.Args = append(call.Args, p.expr(a))
+		}
+		return call
+
 	case *unaryExpr:
 		return &ast.UnaryExpr{Op: opMap[x.op], X: p.expr(x.x)}
 
@@ -225,6 +324,7 @@
 	case *structLit:
 		obj := &ast.StructLit{}
 		if doEval(p.mode) {
+			x = x.expandFields(p.ctx)
 			for _, a := range x.arcs {
 				p.stack = append(p.stack, remap{
 					key:  x,
@@ -233,7 +333,6 @@
 					syn:  obj,
 				})
 			}
-			x = x.expandFields(p.ctx)
 		}
 		if x.emit != nil {
 			obj.Elts = append(obj.Elts, &ast.EmitDecl{Expr: p.expr(x.emit)})
@@ -271,7 +370,7 @@
 			if !doEval(p.mode) {
 				f.Value = p.expr(a.v)
 			} else if v := p.ctx.manifest(x.at(p.ctx, i)); isIncomplete(v) && !p.mode.concrete {
-				p := &exporter{p.ctx, options{raw: true}, p.stack}
+				p := &exporter{p.ctx, options{raw: true}, p.stack, p.top, p.imports}
 				f.Value = p.expr(a.v)
 			} else {
 				f.Value = p.expr(v)
diff --git a/cue/export_test.go b/cue/export_test.go
index c4289e6..6835f60 100644
--- a/cue/export_test.go
+++ b/cue/export_test.go
@@ -288,6 +288,59 @@
 	}
 }
 
+func TestExportFile(t *testing.T) {
+	testCases := []struct {
+		eval    bool // evaluate the full export
+		in, out string
+	}{{
+		in: `
+		import "strings"
+
+		a: strings.ContainsAny("c")
+		`,
+		out: unindent(`
+		import "strings"
+
+		a: strings.ContainsAny("c")`),
+	}, {
+		in: `
+		import "strings"
+
+		stringsx = strings
+
+		a: {
+			strings: stringsx.ContainsAny("c")
+		}
+		`,
+		out: unindent(`
+		import "strings"
+
+		STRINGS = strings
+		a strings: STRINGS.ContainsAny("c")`),
+	}}
+	for _, tc := range testCases {
+		t.Run("", func(t *testing.T) {
+			var r Runtime
+			inst, err := r.Parse("test", tc.in)
+			if err != nil {
+				t.Fatal(err)
+			}
+			v := inst.Value()
+			ctx := r.index().newContext()
+
+			buf := &bytes.Buffer{}
+			opts := options{raw: false}
+			err = format.Node(buf, export(ctx, v.eval(ctx), opts))
+			if err != nil {
+				log.Fatal(err)
+			}
+			if got := strings.TrimSpace(buf.String()); got != tc.out {
+				t.Errorf("\ngot:\n%v\nwant:\n%v", got, tc.out)
+			}
+		})
+	}
+}
+
 func unindent(s string) string {
 	lines := strings.Split(s, "\n")[1:]
 	ws := lines[0][:len(lines[0])-len(strings.TrimLeft(lines[0], " \t"))]
diff --git a/cue/resolve_test.go b/cue/resolve_test.go
index 75ea63c..d54784b 100644
--- a/cue/resolve_test.go
+++ b/cue/resolve_test.go
@@ -19,7 +19,6 @@
 	"testing"
 
 	"cuelang.org/go/cue/errors"
-	"cuelang.org/go/cue/parser"
 )
 
 var traceOn = flag.Bool("debug", false, "enable tracing")
@@ -42,14 +41,19 @@
 func compileInstance(t *testing.T, body string) (*context, *Instance, errors.List) {
 	t.Helper()
 
-	x := newIndex().NewInstance(nil)
-	f, err := parser.ParseFile("test", body)
-	ctx := x.newContext()
+	var r Runtime
+	x, err := r.Parse("test", body)
+	ctx := r.index().newContext()
 
 	switch errs := err.(type) {
 	case nil:
-		x.insertFile(f)
+		var r Runtime
+		inst, _ := r.Parse("test", body)
+		return r.index().newContext(), inst, nil
 	case errors.List:
+		x := newIndex().NewInstance(nil)
+		ctx := x.newContext()
+
 		return ctx, x, errs
 	default:
 		t.Fatal(err)
@@ -735,11 +739,26 @@
 			`e8: _|_(incompatible bounds >11 and <=11), ` +
 			`e9: _|_((>"a" & <1):unsupported op &((string)*, (number)*))}`,
 	}, {
+		desc: "custom validators",
+		in: `
+		import "strings"
+
+		a: strings.ContainsAny("ab")
+		a: "after"
+
+		b: strings.ContainsAny("c")
+		b: "dog"
+		`,
+		out: `<0>{` +
+			`a: "after", ` +
+			`b: _|_(builtin:ContainsAny ("c"):value "dog" not in ContainsAny("c"))` +
+			`}`,
+	}, {
 		desc: "null coalescing",
 		in: `
-				a: null
-				b: a.x | "b"
-				c: a["x"] | "c"
+			a: null
+			b: a.x | "b"
+			c: a["x"] | "c"
 			`,
 		out: `<1>{a: null, b: "b", c: "c"}`,
 	}, {
diff --git a/cue/rewrite.go b/cue/rewrite.go
index 44dc982..50c2522 100644
--- a/cue/rewrite.go
+++ b/cue/rewrite.go
@@ -83,6 +83,20 @@
 func (x *numLit) rewrite(ctx *context, fn rewriteFunc) value      { return x }
 func (x *durationLit) rewrite(ctx *context, fn rewriteFunc) value { return x }
 
+func (x *customValidator) rewrite(ctx *context, fn rewriteFunc) value {
+	args := make([]evaluated, len(x.args))
+	changed := false
+	for i, a := range x.args {
+		v := rewrite(ctx, a, fn)
+		args[i] = v.(evaluated)
+		changed = changed || v != a
+	}
+	if !changed {
+		return x
+	}
+	return &customValidator{baseValue: x.baseValue, args: args, call: x.call}
+}
+
 func (x *bound) rewrite(ctx *context, fn rewriteFunc) value {
 	v := rewrite(ctx, x.value, fn)
 	if v == x.value {
diff --git a/cue/subsume.go b/cue/subsume.go
index 860d630..892742c 100644
--- a/cue/subsume.go
+++ b/cue/subsume.go
@@ -405,6 +405,23 @@
 }
 
 // structural equivalence
+func (x *customValidator) subsumesImpl(ctx *context, v value, mode subsumeMode) bool {
+	y, ok := v.(*customValidator)
+	if !ok {
+		return isBottom(v)
+	}
+	if x.call != y.call {
+		return false
+	}
+	for i, v := range x.args {
+		if !subsumes(ctx, v, y.args[i], mode) {
+			return false
+		}
+	}
+	return true
+}
+
+// structural equivalence
 func (x *callExpr) subsumesImpl(ctx *context, v value, mode subsumeMode) bool {
 	if c, ok := v.(*callExpr); ok {
 		if len(x.args) != len(c.args) {
diff --git a/cue/types.go b/cue/types.go
index e7a87c4..a4a5c85 100644
--- a/cue/types.go
+++ b/cue/types.go
@@ -660,7 +660,7 @@
 
 // Syntax converts the possibly partially evaluated value into syntax. This
 // can use used to print the value with package format.
-func (v Value) Syntax(opts ...Option) ast.Expr {
+func (v Value) Syntax(opts ...Option) ast.Node {
 	if v.path == nil || v.path.cache == nil {
 		return nil
 	}
diff --git a/cue/value.go b/cue/value.go
index dfa5a10..f862ca1 100644
--- a/cue/value.go
+++ b/cue/value.go
@@ -967,12 +967,37 @@
 
 func (x *callExpr) kind() kind {
 	// TODO: could this be narrowed down?
-	if l, ok := x.x.(*lambdaExpr); ok {
-		return l.returnKind() | nonGround
+	switch c := x.x.(type) {
+	case *lambdaExpr:
+		return c.returnKind() | nonGround
+	case *builtin:
+		switch len(x.args) {
+		case len(c.Params):
+			return c.Result
+		case len(c.Params) - 1:
+			if len(c.Params) == 0 || c.Result&boolKind == 0 {
+				return bottomKind
+			}
+			return c.Params[0]
+		}
 	}
 	return topKind | referenceKind
 }
 
+type customValidator struct {
+	baseValue
+
+	args []evaluated // any but the first value
+	call *builtin    // function must return a bool
+}
+
+func (x *customValidator) kind() kind {
+	if len(x.call.Params) == 0 {
+		return bottomKind
+	}
+	return x.call.Params[0] | nonGround
+}
+
 type params struct {
 	arcs []arc
 }