cue: support custom validators
Any function in the standard lib that returns
a boolean and has at least one arg.
Change-Id: I80f5432863b4aeb1941006ff847a26f75d59f41c
Reviewed-on: https://cue-review.googlesource.com/c/cue/+/2125
Reviewed-by: Marcel van Lohuizen <mpvl@google.com>
diff --git a/cue/binop.go b/cue/binop.go
index 11cb05a..663bd68 100644
--- a/cue/binop.go
+++ b/cue/binop.go
@@ -473,6 +473,75 @@
return ctx.mkIncompatible(src, op, x, other)
}
+func (x *customValidator) binOp(ctx *context, src source, op op, other evaluated) evaluated {
+ newSrc := binSrc(src.Pos(), op, x, other)
+ switch op {
+ case opUnify:
+ k, _ := matchBinOpKind(opUnify, x.kind(), other.kind())
+ if k == bottomKind {
+ break
+ }
+ switch y := other.(type) {
+ case *basicType:
+ k := unifyType(x.kind(), y.kind())
+ if k == x.kind() {
+ return x
+ }
+ return &unification{newSrc, []evaluated{x, y}}
+
+ case *bound:
+ return &unification{newSrc, []evaluated{x, y}}
+
+ case *numLit:
+ if err := x.check(ctx, y); err != nil {
+ return err
+ }
+ // Narrow down number type.
+ if y.k != k {
+ n := *y
+ n.k = k
+ return &n
+ }
+ return other
+
+ case *nullLit, *boolLit, *durationLit, *list, *structLit, *stringLit, *bytesLit:
+ // All remaining concrete types. This includes non-comparable types
+ // for comparison to null.
+ if err := x.check(ctx, y); err != nil {
+ return err
+ }
+ return y
+ }
+ }
+ return ctx.mkIncompatible(src, op, x, other)
+}
+
+func (x *customValidator) check(ctx *context, v evaluated) evaluated {
+ args := make([]evaluated, 1+len(x.args))
+ args[0] = v
+ for i, v := range x.args {
+ args[1+i] = v.(evaluated)
+ }
+ res := x.call.call(ctx, x, args...)
+ if isBottom(res) {
+ return res.(evaluated)
+ }
+ if b, ok := res.(*boolLit); !ok {
+ // should never reach here
+ return ctx.mkErr(x, "invalid custom validator")
+ } else if !b.b {
+ var buf bytes.Buffer
+ buf.WriteString(x.call.Name)
+ buf.WriteString("(")
+ for _, a := range x.args {
+ buf.WriteString(debugStr(ctx, a))
+ }
+ buf.WriteString(")")
+ return ctx.mkErr(x, "value %v not in %v", debugStr(ctx, v), buf.String())
+ }
+ return nil
+}
+
func evalLambda(ctx *context, a value) (l *lambdaExpr, err evaluated) {
if a == nil {
return nil, nil
@@ -783,7 +852,7 @@
func (x *numLit) binOp(ctx *context, src source, op op, other evaluated) evaluated {
switch y := other.(type) {
- case *basicType, *bound: // for better error reporting
+ case *basicType, *bound, *customValidator: // for better error reporting
if op == opUnify {
return y.binOp(ctx, src, op, x)
}
diff --git a/cue/builtin.go b/cue/builtin.go
index 13708a4..8751ecb 100644
--- a/cue/builtin.go
+++ b/cue/builtin.go
@@ -51,6 +51,7 @@
type builtin struct {
baseValue
Name string
+ pkg label
Params []kind
Result kind
Func func(c *callCtxt)
@@ -63,9 +64,12 @@
cue string
}
-func mustCompileBuiltins(ctx *context, p *builtinPkg, name string) *structLit {
+func mustCompileBuiltins(ctx *context, p *builtinPkg, pkgName string) *structLit {
obj := &structLit{}
+ pkgLabel := ctx.label(pkgName, false)
for _, b := range p.native {
+ b.pkg = pkgLabel
+
f := ctx.label(b.Name, false) // never starts with _
// n := &node{baseValue: newBase(imp.Path)}
var v evaluated = b
@@ -78,10 +82,9 @@
// Parse builtin CUE
if p.cue != "" {
- expr, err := parser.ParseExpr(name, p.cue)
+ expr, err := parser.ParseExpr(pkgName, p.cue)
if err != nil {
- fmt.Println(p.cue)
- panic(err)
+ panic(fmt.Errorf("could not parse %v: %v", p.cue, err))
}
pkg := evalExpr(ctx.index, obj, expr).(*structLit)
for _, a := range pkg.arcs {
@@ -188,6 +191,10 @@
if x.Func == nil {
return ctx.mkErr(x, "Builtin %q is not a function", x.Name)
}
+ if len(x.Params)-1 == len(args) && x.Result == boolKind {
+ // We have a custom builtin
+ return &customValidator{src.base(), args, x}
+ }
if len(x.Params) != len(args) {
return ctx.mkErr(src, x, "number of arguments does not match (%d vs %d)",
len(x.Params), len(args))
diff --git a/cue/debug.go b/cue/debug.go
index 9863631..148941c 100644
--- a/cue/debug.go
+++ b/cue/debug.go
@@ -195,6 +195,16 @@
}
}
write(")")
+ case *customValidator:
+ p.debugStr(x.call)
+ write(" (")
+ for i, a := range x.args {
+ p.debugStr(a)
+ if i < len(x.args)-1 {
+ write(",")
+ }
+ }
+ write(")")
case *unaryExpr:
write(x.op)
p.debugStr(x.x)
diff --git a/cue/eval.go b/cue/eval.go
index 68056ab..aa91722 100644
--- a/cue/eval.go
+++ b/cue/eval.go
@@ -169,6 +169,14 @@
return e.err(err)
}
+func (x *customValidator) evalPartial(ctx *context) (result evaluated) {
+ if ctx.trace {
+ defer uni(indent(ctx, "custom", x))
+ defer func() { ctx.debugPrint("result:", result) }()
+ }
+ return x
+}
+
func (x *bound) evalPartial(ctx *context) (result evaluated) {
if ctx.trace {
defer uni(indent(ctx, "bound", x))
diff --git a/cue/export.go b/cue/export.go
index 6adf44e..40c58d0 100644
--- a/cue/export.go
+++ b/cue/export.go
@@ -17,6 +17,7 @@
import (
"fmt"
"math/rand"
+ "sort"
"strconv"
"strings"
"unicode"
@@ -30,15 +31,70 @@
return !m.raw
}
-func export(ctx *context, v value, m options) ast.Expr {
- e := exporter{ctx, m, nil}
- return e.expr(v)
+func export(ctx *context, v value, m options) ast.Node {
+ e := exporter{ctx, m, nil, map[label]bool{}, map[string]importInfo{}}
+ top, ok := v.evalPartial(ctx).(*structLit)
+ if ok {
+ top = top.expandFields(ctx)
+ for _, a := range top.arcs {
+ e.top[a.feature] = true
+ }
+ }
+
+ value := e.expr(v)
+ if len(e.imports) == 0 {
+ return value
+ }
+ imports := make([]string, 0, len(e.imports))
+ for k := range e.imports {
+ imports = append(imports, k)
+ }
+ sort.Strings(imports)
+
+ importDecl := &ast.ImportDecl{}
+ file := &ast.File{Decls: []ast.Decl{importDecl}}
+
+ for _, k := range imports {
+ info := e.imports[k]
+ ident := (*ast.Ident)(nil)
+ if info.name != "" {
+ ident = ast.NewIdent(info.name)
+ }
+ if info.alias != "" {
+ file.Decls = append(file.Decls, &ast.Alias{
+ Ident: ast.NewIdent(info.alias),
+ Expr: ast.NewIdent(info.short),
+ })
+ }
+ importDecl.Specs = append(importDecl.Specs, &ast.ImportSpec{
+ Name: ident,
+ Path: &ast.BasicLit{Kind: token.STRING, Value: quote(k, '"')},
+ })
+ }
+
+ // TODO: should we unwrap structs?
+ if obj, ok := value.(*ast.StructLit); ok {
+ file.Decls = append(file.Decls, obj.Elts...)
+ } else {
+ file.Decls = append(file.Decls, &ast.EmitDecl{Expr: value})
+ }
+
+ // resolve the file.
+ return file
}
type exporter struct {
- ctx *context
- mode options
- stack []remap
+ ctx *context
+ mode options
+ stack []remap
+ top map[label]bool // label to alias or ""
+ imports map[string]importInfo // pkg path to info
+}
+
+type importInfo struct {
+ name string
+ short string
+ alias string
}
type remap struct {
@@ -111,7 +167,7 @@
if doEval(p.mode) {
x := p.ctx.manifest(v)
if isIncomplete(x) {
- p = &exporter{p.ctx, options{raw: true}, p.stack}
+ p = &exporter{p.ctx, options{raw: true}, p.stack, p.top, p.imports}
return p.expr(v)
}
v = x
@@ -123,7 +179,42 @@
// TODO: also add position information.
switch x := v.(type) {
case *builtin:
- return &ast.Ident{Name: x.Name}
+ name := ast.NewIdent(x.Name)
+ if x.pkg == 0 {
+ return name
+ }
+ pkg := p.ctx.labelStr(x.pkg)
+ info, ok := p.imports[pkg]
+ short := info.short
+ if !ok {
+ info.short = ""
+ short = pkg
+ if i := strings.LastIndexByte(pkg, '.'); i >= 0 {
+ short = pkg[i+1:]
+ }
+ for {
+ if _, ok := p.top[p.ctx.label(short, true)]; !ok {
+ break
+ }
+ short += "x"
+ info.name = short
+ }
+ info.short = short
+ p.top[p.ctx.label(short, true)] = true
+ p.imports[pkg] = info
+ }
+ f := p.ctx.label(short, true)
+ for _, e := range p.stack {
+ if e.from == f {
+ if info.alias == "" {
+ info.alias = p.unique(short)
+ p.imports[pkg] = info
+ }
+ short = info.alias
+ break
+ }
+ }
+ return &ast.SelectorExpr{X: ast.NewIdent(short), Sel: name}
case *nodeRef:
return nil
@@ -139,6 +230,7 @@
// TODO: should not happen: report error
return ident
}
+ // TODO: nodes may have changed. Use different algorithm.
conflict := false
for i := len(p.stack) - 1; i >= 0; i-- {
e := &p.stack[i]
@@ -183,6 +275,13 @@
}
return call
+ case *customValidator:
+ call := &ast.CallExpr{Fun: p.expr(x.call)}
+ for _, a := range x.args {
+ call.Args = append(call.Args, p.expr(a))
+ }
+ return call
+
case *unaryExpr:
return &ast.UnaryExpr{Op: opMap[x.op], X: p.expr(x.x)}
@@ -225,6 +324,7 @@
case *structLit:
obj := &ast.StructLit{}
if doEval(p.mode) {
+ x = x.expandFields(p.ctx)
for _, a := range x.arcs {
p.stack = append(p.stack, remap{
key: x,
@@ -233,7 +333,6 @@
syn: obj,
})
}
- x = x.expandFields(p.ctx)
}
if x.emit != nil {
obj.Elts = append(obj.Elts, &ast.EmitDecl{Expr: p.expr(x.emit)})
@@ -271,7 +370,7 @@
if !doEval(p.mode) {
f.Value = p.expr(a.v)
} else if v := p.ctx.manifest(x.at(p.ctx, i)); isIncomplete(v) && !p.mode.concrete {
- p := &exporter{p.ctx, options{raw: true}, p.stack}
+ p := &exporter{p.ctx, options{raw: true}, p.stack, p.top, p.imports}
f.Value = p.expr(a.v)
} else {
f.Value = p.expr(v)
diff --git a/cue/export_test.go b/cue/export_test.go
index c4289e6..6835f60 100644
--- a/cue/export_test.go
+++ b/cue/export_test.go
@@ -288,6 +288,59 @@
}
}
+func TestExportFile(t *testing.T) {
+ testCases := []struct {
+ eval bool // evaluate the full export
+ in, out string
+ }{{
+ in: `
+ import "strings"
+
+ a: strings.ContainsAny("c")
+ `,
+ out: unindent(`
+ import "strings"
+
+ a: strings.ContainsAny("c")`),
+ }, {
+ in: `
+ import "strings"
+
+ stringsx = strings
+
+ a: {
+ strings: stringsx.ContainsAny("c")
+ }
+ `,
+ out: unindent(`
+ import "strings"
+
+ STRINGS = strings
+ a strings: STRINGS.ContainsAny("c")`),
+ }}
+ for _, tc := range testCases {
+ t.Run("", func(t *testing.T) {
+ var r Runtime
+ inst, err := r.Parse("test", tc.in)
+ if err != nil {
+ t.Fatal(err)
+ }
+ v := inst.Value()
+ ctx := r.index().newContext()
+
+ buf := &bytes.Buffer{}
+ opts := options{raw: false}
+ err = format.Node(buf, export(ctx, v.eval(ctx), opts))
+ if err != nil {
+ log.Fatal(err)
+ }
+ if got := strings.TrimSpace(buf.String()); got != tc.out {
+ t.Errorf("\ngot:\n%v\nwant:\n%v", got, tc.out)
+ }
+ })
+ }
+}
+
func unindent(s string) string {
lines := strings.Split(s, "\n")[1:]
ws := lines[0][:len(lines[0])-len(strings.TrimLeft(lines[0], " \t"))]
diff --git a/cue/resolve_test.go b/cue/resolve_test.go
index 75ea63c..d54784b 100644
--- a/cue/resolve_test.go
+++ b/cue/resolve_test.go
@@ -19,7 +19,6 @@
"testing"
"cuelang.org/go/cue/errors"
- "cuelang.org/go/cue/parser"
)
var traceOn = flag.Bool("debug", false, "enable tracing")
@@ -42,14 +41,19 @@
func compileInstance(t *testing.T, body string) (*context, *Instance, errors.List) {
t.Helper()
- x := newIndex().NewInstance(nil)
- f, err := parser.ParseFile("test", body)
- ctx := x.newContext()
+ var r Runtime
+ x, err := r.Parse("test", body)
+ ctx := r.index().newContext()
switch errs := err.(type) {
case nil:
- x.insertFile(f)
+ var r Runtime
+ inst, _ := r.Parse("test", body)
+ return r.index().newContext(), inst, nil
case errors.List:
+ x := newIndex().NewInstance(nil)
+ ctx := x.newContext()
+
return ctx, x, errs
default:
t.Fatal(err)
@@ -735,11 +739,26 @@
`e8: _|_(incompatible bounds >11 and <=11), ` +
`e9: _|_((>"a" & <1):unsupported op &((string)*, (number)*))}`,
}, {
+ desc: "custom validators",
+ in: `
+ import "strings"
+
+ a: strings.ContainsAny("ab")
+ a: "after"
+
+ b: strings.ContainsAny("c")
+ b: "dog"
+ `,
+ out: `<0>{` +
+ `a: "after", ` +
+ `b: _|_(builtin:ContainsAny ("c"):value "dog" not in ContainsAny("c"))` +
+ `}`,
+ }, {
desc: "null coalescing",
in: `
- a: null
- b: a.x | "b"
- c: a["x"] | "c"
+ a: null
+ b: a.x | "b"
+ c: a["x"] | "c"
`,
out: `<1>{a: null, b: "b", c: "c"}`,
}, {
diff --git a/cue/rewrite.go b/cue/rewrite.go
index 44dc982..50c2522 100644
--- a/cue/rewrite.go
+++ b/cue/rewrite.go
@@ -83,6 +83,20 @@
func (x *numLit) rewrite(ctx *context, fn rewriteFunc) value { return x }
func (x *durationLit) rewrite(ctx *context, fn rewriteFunc) value { return x }
+func (x *customValidator) rewrite(ctx *context, fn rewriteFunc) value {
+ args := make([]evaluated, len(x.args))
+ changed := false
+ for i, a := range x.args {
+ v := rewrite(ctx, a, fn)
+ args[i] = v.(evaluated)
+ changed = changed || v != a
+ }
+ if !changed {
+ return x
+ }
+ return &customValidator{baseValue: x.baseValue, args: args, call: x.call}
+}
+
func (x *bound) rewrite(ctx *context, fn rewriteFunc) value {
v := rewrite(ctx, x.value, fn)
if v == x.value {
diff --git a/cue/subsume.go b/cue/subsume.go
index 860d630..892742c 100644
--- a/cue/subsume.go
+++ b/cue/subsume.go
@@ -405,6 +405,23 @@
}
// structural equivalence
+func (x *customValidator) subsumesImpl(ctx *context, v value, mode subsumeMode) bool {
+ y, ok := v.(*customValidator)
+ if !ok {
+ return isBottom(v)
+ }
+ if x.call != y.call {
+ return false
+ }
+ for i, v := range x.args {
+ if !subsumes(ctx, v, y.args[i], mode) {
+ return false
+ }
+ }
+ return true
+}
+
+// structural equivalence
func (x *callExpr) subsumesImpl(ctx *context, v value, mode subsumeMode) bool {
if c, ok := v.(*callExpr); ok {
if len(x.args) != len(c.args) {
diff --git a/cue/types.go b/cue/types.go
index e7a87c4..a4a5c85 100644
--- a/cue/types.go
+++ b/cue/types.go
@@ -660,7 +660,7 @@
// Syntax converts the possibly partially evaluated value into syntax. This
// can use used to print the value with package format.
-func (v Value) Syntax(opts ...Option) ast.Expr {
+func (v Value) Syntax(opts ...Option) ast.Node {
if v.path == nil || v.path.cache == nil {
return nil
}
diff --git a/cue/value.go b/cue/value.go
index dfa5a10..f862ca1 100644
--- a/cue/value.go
+++ b/cue/value.go
@@ -967,12 +967,37 @@
func (x *callExpr) kind() kind {
// TODO: could this be narrowed down?
- if l, ok := x.x.(*lambdaExpr); ok {
- return l.returnKind() | nonGround
+ switch c := x.x.(type) {
+ case *lambdaExpr:
+ return c.returnKind() | nonGround
+ case *builtin:
+ switch len(x.args) {
+ case len(c.Params):
+ return c.Result
+ case len(c.Params) - 1:
+ if len(c.Params) == 0 || c.Result&boolKind == 0 {
+ return bottomKind
+ }
+ return c.Params[0]
+ }
}
return topKind | referenceKind
}
+type customValidator struct {
+ baseValue
+
+ args []evaluated // any but the first value
+ call *builtin // function must return a bool
+}
+
+func (x *customValidator) kind() kind {
+ if len(x.call.Params) == 0 {
+ return bottomKind
+ }
+ return x.call.Params[0] | nonGround
+}
+
type params struct {
arcs []arc
}