From 934252e1bbbc58e676ddb9821b13a0801578fe0d Mon Sep 17 00:00:00 2001 From: Leon Mika Date: Mon, 13 Jan 2025 21:37:54 +1100 Subject: [PATCH] Added strs:split and fixed closures Closures now properly close over the context in which it was created --- ucl/builtins.go | 40 +++++++++--------- ucl/builtins/strs.go | 34 ++++++++++++++++ ucl/builtins/strs_test.go | 31 ++++++++++++++ ucl/env.go | 2 +- ucl/eval.go | 8 ++-- ucl/objs.go | 21 ++++++---- ucl/testbuiltins_test.go | 86 +++++++++++++++++++++++++++++++++++---- 7 files changed, 181 insertions(+), 41 deletions(-) diff --git a/ucl/builtins.go b/ucl/builtins.go index 791f9d5..381abac 100644 --- a/ucl/builtins.go +++ b/ucl/builtins.go @@ -52,7 +52,7 @@ func addBuiltin(ctx context.Context, args invocationArgs) (Object, error) { switch t := a.(type) { case intObject: n += int(t) - case strObject: + case StringObject: v, err := strconv.Atoi(string(t)) if err != nil { return nil, fmt.Errorf("arg %v of 'add' not convertable to an int", i) @@ -77,7 +77,7 @@ func subBuiltin(ctx context.Context, args invocationArgs) (Object, error) { switch t := a.(type) { case intObject: p = int(t) - case strObject: + case StringObject: v, err := strconv.Atoi(string(t)) if err != nil { return nil, fmt.Errorf("arg %v of 'sub' not convertable to an int", i) @@ -106,7 +106,7 @@ func mupBuiltin(ctx context.Context, args invocationArgs) (Object, error) { switch t := a.(type) { case intObject: n *= int(t) - case strObject: + case StringObject: v, err := strconv.Atoi(string(t)) if err != nil { return nil, fmt.Errorf("arg %v of 'mup' not convertable to an int", i) @@ -131,7 +131,7 @@ func divBuiltin(ctx context.Context, args invocationArgs) (Object, error) { switch t := a.(type) { case intObject: p = int(t) - case strObject: + case StringObject: v, err := strconv.Atoi(string(t)) if err != nil { return nil, fmt.Errorf("arg %v of 'div' not convertable to an int", i) @@ -161,7 +161,7 @@ func modBuiltin(ctx context.Context, args invocationArgs) (Object, error) { switch t := a.(type) { case intObject: p = int(t) - case strObject: + case StringObject: v, err := strconv.Atoi(string(t)) if err != nil { return nil, fmt.Errorf("arg %v of 'mod' not convertable to an int", i) @@ -204,7 +204,7 @@ func toUpperBuiltin(ctx context.Context, args invocationArgs) (Object, error) { if err != nil { return nil, err } - return strObject(strings.ToUpper(sarg)), nil + return StringObject(strings.ToUpper(sarg)), nil } func eqBuiltin(ctx context.Context, args invocationArgs) (Object, error) { @@ -319,8 +319,8 @@ func objectsEqual(l, r Object) bool { } switch lv := l.(type) { - case strObject: - if rv, ok := r.(strObject); ok { + case StringObject: + if rv, ok := r.(StringObject); ok { return lv == rv } case intObject: @@ -380,8 +380,8 @@ func objectsEqual(l, r Object) bool { func objectsLessThan(l, r Object) (bool, error) { switch lv := l.(type) { - case strObject: - if rv, ok := r.(strObject); ok { + case StringObject: + if rv, ok := r.(StringObject); ok { return lv < rv, nil } case intObject: @@ -398,10 +398,10 @@ func strBuiltin(ctx context.Context, args invocationArgs) (Object, error) { } if args.args[0] == nil { - return strObject(""), nil + return StringObject(""), nil } - return strObject(args.args[0].String()), nil + return StringObject(args.args[0].String()), nil } func intBuiltin(ctx context.Context, args invocationArgs) (Object, error) { @@ -416,7 +416,7 @@ func intBuiltin(ctx context.Context, args invocationArgs) (Object, error) { switch v := args.args[0].(type) { case intObject: return v, nil - case strObject: + case StringObject: i, err := strconv.Atoi(string(v)) if err != nil { return nil, errors.New("cannot convert to int") @@ -442,7 +442,7 @@ func concatBuiltin(ctx context.Context, args invocationArgs) (Object, error) { sb.WriteString(a.String()) } - return strObject(sb.String()), nil + return StringObject(sb.String()), nil } func callBuiltin(ctx context.Context, args invocationArgs) (Object, error) { @@ -464,7 +464,7 @@ func lenBuiltin(ctx context.Context, args invocationArgs) (Object, error) { } switch v := args.args[0].(type) { - case strObject: + case StringObject: return intObject(len(string(v))), nil case Listable: return intObject(v.Len()), nil @@ -487,7 +487,7 @@ func indexLookup(ctx context.Context, obj, elem Object) (Object, error) { } return nil, nil case hashable: - strIdx, ok := elem.(strObject) + strIdx, ok := elem.(StringObject) if !ok { return nil, errors.New("expected string for hashable") } @@ -523,7 +523,7 @@ func keysBuiltin(ctx context.Context, args invocationArgs) (Object, error) { case hashable: keys := make(listObject, 0, v.Len()) if err := v.Each(func(k string, _ Object) error { - keys = append(keys, strObject(k)) + keys = append(keys, StringObject(k)) return nil }); err != nil { return nil, err @@ -588,7 +588,7 @@ func filterBuiltin(ctx context.Context, args invocationArgs) (Object, error) { case hashable: newHash := hashObject{} if err := t.Each(func(k string, v Object) error { - if m, err := inv.invoke(ctx, args.fork([]Object{strObject(k), v})); err != nil { + if m, err := inv.invoke(ctx, args.fork([]Object{StringObject(k), v})); err != nil { return err } else if m.Truthy() { newHash[k] = v @@ -649,7 +649,7 @@ func reduceBuiltin(ctx context.Context, args invocationArgs) (Object, error) { case hashable: // TODO: should raise error? if err := t.Each(func(k string, v Object) error { - newAccum, err := block.invoke(ctx, args.fork([]Object{strObject(k), v, accum})) + newAccum, err := block.invoke(ctx, args.fork([]Object{StringObject(k), v, accum})) if err != nil { return err } @@ -885,7 +885,7 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (Object, error) { } case hashable: err := t.Each(func(k string, v Object) error { - last, err = args.evalBlock(ctx, blockIdx, []Object{strObject(k), v}, true) + last, err = args.evalBlock(ctx, blockIdx, []Object{StringObject(k), v}, true) return err }) if errors.As(err, &breakErr) { diff --git a/ucl/builtins/strs.go b/ucl/builtins/strs.go index 4b813ca..03cd9e9 100644 --- a/ucl/builtins/strs.go +++ b/ucl/builtins/strs.go @@ -15,6 +15,7 @@ func Strs() ucl.Module { "trim": trim, "join": join, "has-suffix": hasSuffix, + "split": split, }, } } @@ -83,3 +84,36 @@ func hasSuffix(ctx context.Context, args ucl.CallArgs) (any, error) { return strings.HasSuffix(s, suffix), nil } + +func split(ctx context.Context, args ucl.CallArgs) (any, error) { + var ( + s string + joinStr string + ) + if err := args.Bind(&s, &joinStr); err != nil { + return nil, err + } + + return stringSlice(strings.Split(s, joinStr)), nil +} + +type stringSlice []string + +func (s stringSlice) String() string { + return strings.Join(s, ",") +} + +func (s stringSlice) Truthy() bool { + return len(s) > 0 +} + +func (s stringSlice) Len() int { + return len(s) +} + +func (s stringSlice) Index(i int) ucl.Object { + if i < 0 || i >= len(s) { + return nil + } + return ucl.StringObject(s[i]) +} diff --git a/ucl/builtins/strs_test.go b/ucl/builtins/strs_test.go index 9aef92d..a5d166e 100644 --- a/ucl/builtins/strs_test.go +++ b/ucl/builtins/strs_test.go @@ -163,3 +163,34 @@ func TestStrs_Join(t *testing.T) { }) } } + +func TestStrs_Split(t *testing.T) { + tests := []struct { + desc string + eval string + want any + wantErr bool + }{ + {desc: "split 1", eval: `strs:split "a,b,c" "," | cat`, want: "a,b,c"}, + {desc: "split 2", eval: `strs:split "a;b;c" ";" | cat`, want: "a,b,c"}, + {desc: "split 3", eval: `strs:split "a|b;c" "|" | cat`, want: "a,b;c"}, + {desc: "split 4", eval: `strs:split "abcde" "" | cat`, want: "a,b,c,d,e"}, + {desc: "split 5", eval: `strs:split "abcde" "," | cat`, want: "abcde"}, + {desc: "split 6", eval: `strs:split "" "," | len`, want: 1}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + inst := ucl.New( + ucl.WithModule(builtins.Strs()), + ) + res, err := inst.Eval(context.Background(), tt.eval) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, res) + } + }) + } +} diff --git a/ucl/env.go b/ucl/env.go index 533d984..11b79a3 100644 --- a/ucl/env.go +++ b/ucl/env.go @@ -59,7 +59,7 @@ func (ec *evalCtx) setOrDefineVar(name string, val Object) { } func (ec *evalCtx) getVar(name string) (Object, bool) { - if ec.vars == nil { + if ec == nil { return nil, false } diff --git a/ucl/eval.go b/ucl/eval.go index 90bde5e..2d89ccf 100644 --- a/ucl/eval.go +++ b/ucl/eval.go @@ -159,7 +159,7 @@ func (e evaluator) evalDot(ctx context.Context, ec *evalCtx, n astDot) (Object, for _, dot := range n.DotSuffix { var idx Object if dot.KeyIdent != nil { - idx = strObject(dot.KeyIdent.String()) + idx = StringObject(dot.KeyIdent.String()) } else { idx, err = e.evalPipeline(ctx, ec, dot.Pipeline) if err != nil { @@ -180,7 +180,7 @@ func (e evaluator) evalArg(ctx context.Context, ec *evalCtx, n astCmdArg) (Objec case n.Literal != nil: return e.evalLiteral(ctx, ec, n.Literal) case n.Ident != nil: - return strObject(n.Ident.String()), nil + return StringObject(n.Ident.String()), nil case n.Var != nil: if v, ok := ec.getVar(*n.Var); ok { return v, nil @@ -195,7 +195,7 @@ func (e evaluator) evalArg(ctx context.Context, ec *evalCtx, n astCmdArg) (Objec case n.ListOrHash != nil: return e.evalListOrHash(ctx, ec, n.ListOrHash) case n.Block != nil: - return blockObject{block: n.Block}, nil + return blockObject{block: n.Block, closedEC: ec}, nil } return nil, errors.New("unhandled arg type") } @@ -250,7 +250,7 @@ func (e evaluator) evalLiteral(ctx context.Context, ec *evalCtx, n *astLiteral) if err != nil { return nil, err } - return strObject(uq), nil + return StringObject(uq), nil case n.Int != nil: return intObject(*n.Int), nil } diff --git a/ucl/objs.go b/ucl/objs.go index 50bb7a1..faed869 100644 --- a/ucl/objs.go +++ b/ucl/objs.go @@ -91,13 +91,13 @@ func (s hashObject) Each(fn func(k string, v Object) error) error { return nil } -type strObject string +type StringObject string -func (s strObject) String() string { +func (s StringObject) String() string { return string(s) } -func (s strObject) Truthy() bool { +func (s StringObject) Truthy() bool { return string(s) != "" } @@ -130,7 +130,7 @@ func toGoValue(obj Object) (interface{}, bool) { return v.v, true case nil: return nil, true - case strObject: + case StringObject: return string(v), true case intObject: return int(v), true @@ -169,12 +169,14 @@ func toGoValue(obj Object) (interface{}, bool) { func fromGoValue(v any) (Object, error) { switch t := v.(type) { + case Object: + return t, nil case OpaqueObject: return t, nil case nil: return nil, nil case string: - return strObject(t), nil + return StringObject(t), nil case int: return intObject(t), nil case bool: @@ -286,7 +288,7 @@ func (ma macroArgs) evalBlock(ctx context.Context, n int, args []Object, pushSco } return ma.eval.evalBlock(ctx, ec, v.block) - case strObject: + case StringObject: iv := ma.ec.lookupInvokable(string(v)) if iv == nil { return nil, errors.New("'" + string(v) + "' is not invokable") @@ -360,7 +362,7 @@ func (ia invocationArgs) invokableArg(i int) (invokable, error) { switch v := ia.args[i].(type) { case invokable: return v, nil - case strObject: + case StringObject: iv := ia.ec.lookupInvokable(string(v)) if iv == nil { return nil, errors.New("'" + string(v) + "' is not invokable") @@ -413,7 +415,8 @@ func (i invokableFunc) invoke(ctx context.Context, args invocationArgs) (Object, } type blockObject struct { - block *astBlock + block *astBlock + closedEC *evalCtx } func (bo blockObject) String() string { @@ -425,7 +428,7 @@ func (bo blockObject) Truthy() bool { } func (bo blockObject) invoke(ctx context.Context, args invocationArgs) (Object, error) { - ec := args.ec.fork() + ec := bo.closedEC.fork() for i, n := range bo.block.Names { if i < len(args.args) { ec.setOrDefineVar(n, args.args[i]) diff --git a/ucl/testbuiltins_test.go b/ucl/testbuiltins_test.go index edaa5e8..c3b08f8 100644 --- a/ucl/testbuiltins_test.go +++ b/ucl/testbuiltins_test.go @@ -19,12 +19,12 @@ func WithTestBuiltin() InstOption { })) i.rootEC.addCmd("toUpper", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) { - return strObject(strings.ToUpper(args.args[0].String())), nil + return StringObject(strings.ToUpper(args.args[0].String())), nil })) i.rootEC.addCmd("sjoin", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) { if len(args.args) == 0 { - return strObject(""), nil + return StringObject(""), nil } var line strings.Builder @@ -34,7 +34,7 @@ func WithTestBuiltin() InstOption { } } - return strObject(line.String()), nil + return StringObject(line.String()), nil })) i.rootEC.addCmd("list", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) { @@ -53,7 +53,7 @@ func WithTestBuiltin() InstOption { lst, ok := args.args[0].(Listable) if !ok { - return strObject(""), nil + return StringObject(""), nil } l := lst.Len() @@ -63,11 +63,11 @@ func WithTestBuiltin() InstOption { } sb.WriteString(lst.Index(x).String()) } - return strObject(sb.String()), nil + return StringObject(sb.String()), nil })) - i.rootEC.setOrDefineVar("a", strObject("alpha")) - i.rootEC.setOrDefineVar("bee", strObject("buzz")) + i.rootEC.setOrDefineVar("a", StringObject("alpha")) + i.rootEC.setOrDefineVar("bee", StringObject("buzz")) } } @@ -704,6 +704,78 @@ func TestBuiltins_Return(t *testing.T) { four4 `, want: "xxxx\n"}, + {desc: "check closure 1", expr: ` + proc do-thing { |p| + call $p + } + + proc test-thing { + foreach [1 2 3] { |x| + do-thing { + echo $x + } + } + } + + test-thing + `, want: "1\n2\n3\n(nil)\n"}, + {desc: "check closure 2", expr: ` + proc do-thing { |p| + call $p + } + + proc test-thing { + foreach [1 2 3] { |x| + do-thing (proc { + echo $x + }) + } + } + + test-thing + `, want: "1\n2\n3\n(nil)\n"}, + {desc: "check closure 3", expr: ` + proc do-thing { |p| + call $p + } + + proc test-thing { + foreach [1 2 3] { |x| + set myClosure (proc { echo $x }) + do-thing $myClosure + } + } + + test-thing + `, want: "1\n2\n3\n(nil)\n"}, + {desc: "check closure 4", expr: ` + proc do-thing { |p| + call $p + } + + proc test-thing { + [1 2 3] | map { |x| + proc { echo $x } + } + } + + foreach (test-thing) { |y| call $y } + `, want: "1\n2\n3\n(nil)\n"}, + {desc: "check closure 5", expr: ` + proc do-thing { |p| + call $p + } + + proc test-thing { + [1 2 3] | map { |x| + set myProc (proc { echo $x }) + proc { do-thing $myProc } + } + } + + set hello "xx" + foreach (test-thing) { |y| call $y ; echo $hello } + `, want: "1\nxx\n2\nxx\n3\nxx\n(nil)\n"}, } for _, tt := range tests {