Added strs:split and fixed closures

Closures now properly close over the context in which it was created
This commit is contained in:
Leon Mika 2025-01-13 21:37:54 +11:00
parent d111d84dbf
commit 934252e1bb
7 changed files with 181 additions and 41 deletions

View file

@ -52,7 +52,7 @@ func addBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
switch t := a.(type) { switch t := a.(type) {
case intObject: case intObject:
n += int(t) n += int(t)
case strObject: case StringObject:
v, err := strconv.Atoi(string(t)) v, err := strconv.Atoi(string(t))
if err != nil { if err != nil {
return nil, fmt.Errorf("arg %v of 'add' not convertable to an int", i) return nil, fmt.Errorf("arg %v of 'add' not convertable to an int", i)
@ -77,7 +77,7 @@ func subBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
switch t := a.(type) { switch t := a.(type) {
case intObject: case intObject:
p = int(t) p = int(t)
case strObject: case StringObject:
v, err := strconv.Atoi(string(t)) v, err := strconv.Atoi(string(t))
if err != nil { if err != nil {
return nil, fmt.Errorf("arg %v of 'sub' not convertable to an int", i) return nil, fmt.Errorf("arg %v of 'sub' not convertable to an int", i)
@ -106,7 +106,7 @@ func mupBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
switch t := a.(type) { switch t := a.(type) {
case intObject: case intObject:
n *= int(t) n *= int(t)
case strObject: case StringObject:
v, err := strconv.Atoi(string(t)) v, err := strconv.Atoi(string(t))
if err != nil { if err != nil {
return nil, fmt.Errorf("arg %v of 'mup' not convertable to an int", i) return nil, fmt.Errorf("arg %v of 'mup' not convertable to an int", i)
@ -131,7 +131,7 @@ func divBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
switch t := a.(type) { switch t := a.(type) {
case intObject: case intObject:
p = int(t) p = int(t)
case strObject: case StringObject:
v, err := strconv.Atoi(string(t)) v, err := strconv.Atoi(string(t))
if err != nil { if err != nil {
return nil, fmt.Errorf("arg %v of 'div' not convertable to an int", i) return nil, fmt.Errorf("arg %v of 'div' not convertable to an int", i)
@ -161,7 +161,7 @@ func modBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
switch t := a.(type) { switch t := a.(type) {
case intObject: case intObject:
p = int(t) p = int(t)
case strObject: case StringObject:
v, err := strconv.Atoi(string(t)) v, err := strconv.Atoi(string(t))
if err != nil { if err != nil {
return nil, fmt.Errorf("arg %v of 'mod' not convertable to an int", i) return nil, fmt.Errorf("arg %v of 'mod' not convertable to an int", i)
@ -204,7 +204,7 @@ func toUpperBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return strObject(strings.ToUpper(sarg)), nil return StringObject(strings.ToUpper(sarg)), nil
} }
func eqBuiltin(ctx context.Context, args invocationArgs) (Object, error) { func eqBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
@ -319,8 +319,8 @@ func objectsEqual(l, r Object) bool {
} }
switch lv := l.(type) { switch lv := l.(type) {
case strObject: case StringObject:
if rv, ok := r.(strObject); ok { if rv, ok := r.(StringObject); ok {
return lv == rv return lv == rv
} }
case intObject: case intObject:
@ -380,8 +380,8 @@ func objectsEqual(l, r Object) bool {
func objectsLessThan(l, r Object) (bool, error) { func objectsLessThan(l, r Object) (bool, error) {
switch lv := l.(type) { switch lv := l.(type) {
case strObject: case StringObject:
if rv, ok := r.(strObject); ok { if rv, ok := r.(StringObject); ok {
return lv < rv, nil return lv < rv, nil
} }
case intObject: case intObject:
@ -398,10 +398,10 @@ func strBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
} }
if args.args[0] == nil { if args.args[0] == nil {
return strObject(""), nil return StringObject(""), nil
} }
return strObject(args.args[0].String()), nil return StringObject(args.args[0].String()), nil
} }
func intBuiltin(ctx context.Context, args invocationArgs) (Object, error) { func intBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
@ -416,7 +416,7 @@ func intBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
switch v := args.args[0].(type) { switch v := args.args[0].(type) {
case intObject: case intObject:
return v, nil return v, nil
case strObject: case StringObject:
i, err := strconv.Atoi(string(v)) i, err := strconv.Atoi(string(v))
if err != nil { if err != nil {
return nil, errors.New("cannot convert to int") return nil, errors.New("cannot convert to int")
@ -442,7 +442,7 @@ func concatBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
sb.WriteString(a.String()) sb.WriteString(a.String())
} }
return strObject(sb.String()), nil return StringObject(sb.String()), nil
} }
func callBuiltin(ctx context.Context, args invocationArgs) (Object, error) { func callBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
@ -464,7 +464,7 @@ func lenBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
} }
switch v := args.args[0].(type) { switch v := args.args[0].(type) {
case strObject: case StringObject:
return intObject(len(string(v))), nil return intObject(len(string(v))), nil
case Listable: case Listable:
return intObject(v.Len()), nil return intObject(v.Len()), nil
@ -487,7 +487,7 @@ func indexLookup(ctx context.Context, obj, elem Object) (Object, error) {
} }
return nil, nil return nil, nil
case hashable: case hashable:
strIdx, ok := elem.(strObject) strIdx, ok := elem.(StringObject)
if !ok { if !ok {
return nil, errors.New("expected string for hashable") return nil, errors.New("expected string for hashable")
} }
@ -523,7 +523,7 @@ func keysBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
case hashable: case hashable:
keys := make(listObject, 0, v.Len()) keys := make(listObject, 0, v.Len())
if err := v.Each(func(k string, _ Object) error { if err := v.Each(func(k string, _ Object) error {
keys = append(keys, strObject(k)) keys = append(keys, StringObject(k))
return nil return nil
}); err != nil { }); err != nil {
return nil, err return nil, err
@ -588,7 +588,7 @@ func filterBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
case hashable: case hashable:
newHash := hashObject{} newHash := hashObject{}
if err := t.Each(func(k string, v Object) error { if err := t.Each(func(k string, v Object) error {
if m, err := inv.invoke(ctx, args.fork([]Object{strObject(k), v})); err != nil { if m, err := inv.invoke(ctx, args.fork([]Object{StringObject(k), v})); err != nil {
return err return err
} else if m.Truthy() { } else if m.Truthy() {
newHash[k] = v newHash[k] = v
@ -649,7 +649,7 @@ func reduceBuiltin(ctx context.Context, args invocationArgs) (Object, error) {
case hashable: case hashable:
// TODO: should raise error? // TODO: should raise error?
if err := t.Each(func(k string, v Object) error { if err := t.Each(func(k string, v Object) error {
newAccum, err := block.invoke(ctx, args.fork([]Object{strObject(k), v, accum})) newAccum, err := block.invoke(ctx, args.fork([]Object{StringObject(k), v, accum}))
if err != nil { if err != nil {
return err return err
} }
@ -885,7 +885,7 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (Object, error) {
} }
case hashable: case hashable:
err := t.Each(func(k string, v Object) error { err := t.Each(func(k string, v Object) error {
last, err = args.evalBlock(ctx, blockIdx, []Object{strObject(k), v}, true) last, err = args.evalBlock(ctx, blockIdx, []Object{StringObject(k), v}, true)
return err return err
}) })
if errors.As(err, &breakErr) { if errors.As(err, &breakErr) {

View file

@ -15,6 +15,7 @@ func Strs() ucl.Module {
"trim": trim, "trim": trim,
"join": join, "join": join,
"has-suffix": hasSuffix, "has-suffix": hasSuffix,
"split": split,
}, },
} }
} }
@ -83,3 +84,36 @@ func hasSuffix(ctx context.Context, args ucl.CallArgs) (any, error) {
return strings.HasSuffix(s, suffix), nil return strings.HasSuffix(s, suffix), nil
} }
func split(ctx context.Context, args ucl.CallArgs) (any, error) {
var (
s string
joinStr string
)
if err := args.Bind(&s, &joinStr); err != nil {
return nil, err
}
return stringSlice(strings.Split(s, joinStr)), nil
}
type stringSlice []string
func (s stringSlice) String() string {
return strings.Join(s, ",")
}
func (s stringSlice) Truthy() bool {
return len(s) > 0
}
func (s stringSlice) Len() int {
return len(s)
}
func (s stringSlice) Index(i int) ucl.Object {
if i < 0 || i >= len(s) {
return nil
}
return ucl.StringObject(s[i])
}

View file

@ -163,3 +163,34 @@ func TestStrs_Join(t *testing.T) {
}) })
} }
} }
func TestStrs_Split(t *testing.T) {
tests := []struct {
desc string
eval string
want any
wantErr bool
}{
{desc: "split 1", eval: `strs:split "a,b,c" "," | cat`, want: "a,b,c"},
{desc: "split 2", eval: `strs:split "a;b;c" ";" | cat`, want: "a,b,c"},
{desc: "split 3", eval: `strs:split "a|b;c" "|" | cat`, want: "a,b;c"},
{desc: "split 4", eval: `strs:split "abcde" "" | cat`, want: "a,b,c,d,e"},
{desc: "split 5", eval: `strs:split "abcde" "," | cat`, want: "abcde"},
{desc: "split 6", eval: `strs:split "" "," | len`, want: 1},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
inst := ucl.New(
ucl.WithModule(builtins.Strs()),
)
res, err := inst.Eval(context.Background(), tt.eval)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, res)
}
})
}
}

View file

@ -59,7 +59,7 @@ func (ec *evalCtx) setOrDefineVar(name string, val Object) {
} }
func (ec *evalCtx) getVar(name string) (Object, bool) { func (ec *evalCtx) getVar(name string) (Object, bool) {
if ec.vars == nil { if ec == nil {
return nil, false return nil, false
} }

View file

@ -159,7 +159,7 @@ func (e evaluator) evalDot(ctx context.Context, ec *evalCtx, n astDot) (Object,
for _, dot := range n.DotSuffix { for _, dot := range n.DotSuffix {
var idx Object var idx Object
if dot.KeyIdent != nil { if dot.KeyIdent != nil {
idx = strObject(dot.KeyIdent.String()) idx = StringObject(dot.KeyIdent.String())
} else { } else {
idx, err = e.evalPipeline(ctx, ec, dot.Pipeline) idx, err = e.evalPipeline(ctx, ec, dot.Pipeline)
if err != nil { if err != nil {
@ -180,7 +180,7 @@ func (e evaluator) evalArg(ctx context.Context, ec *evalCtx, n astCmdArg) (Objec
case n.Literal != nil: case n.Literal != nil:
return e.evalLiteral(ctx, ec, n.Literal) return e.evalLiteral(ctx, ec, n.Literal)
case n.Ident != nil: case n.Ident != nil:
return strObject(n.Ident.String()), nil return StringObject(n.Ident.String()), nil
case n.Var != nil: case n.Var != nil:
if v, ok := ec.getVar(*n.Var); ok { if v, ok := ec.getVar(*n.Var); ok {
return v, nil return v, nil
@ -195,7 +195,7 @@ func (e evaluator) evalArg(ctx context.Context, ec *evalCtx, n astCmdArg) (Objec
case n.ListOrHash != nil: case n.ListOrHash != nil:
return e.evalListOrHash(ctx, ec, n.ListOrHash) return e.evalListOrHash(ctx, ec, n.ListOrHash)
case n.Block != nil: case n.Block != nil:
return blockObject{block: n.Block}, nil return blockObject{block: n.Block, closedEC: ec}, nil
} }
return nil, errors.New("unhandled arg type") return nil, errors.New("unhandled arg type")
} }
@ -250,7 +250,7 @@ func (e evaluator) evalLiteral(ctx context.Context, ec *evalCtx, n *astLiteral)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return strObject(uq), nil return StringObject(uq), nil
case n.Int != nil: case n.Int != nil:
return intObject(*n.Int), nil return intObject(*n.Int), nil
} }

View file

@ -91,13 +91,13 @@ func (s hashObject) Each(fn func(k string, v Object) error) error {
return nil return nil
} }
type strObject string type StringObject string
func (s strObject) String() string { func (s StringObject) String() string {
return string(s) return string(s)
} }
func (s strObject) Truthy() bool { func (s StringObject) Truthy() bool {
return string(s) != "" return string(s) != ""
} }
@ -130,7 +130,7 @@ func toGoValue(obj Object) (interface{}, bool) {
return v.v, true return v.v, true
case nil: case nil:
return nil, true return nil, true
case strObject: case StringObject:
return string(v), true return string(v), true
case intObject: case intObject:
return int(v), true return int(v), true
@ -169,12 +169,14 @@ func toGoValue(obj Object) (interface{}, bool) {
func fromGoValue(v any) (Object, error) { func fromGoValue(v any) (Object, error) {
switch t := v.(type) { switch t := v.(type) {
case Object:
return t, nil
case OpaqueObject: case OpaqueObject:
return t, nil return t, nil
case nil: case nil:
return nil, nil return nil, nil
case string: case string:
return strObject(t), nil return StringObject(t), nil
case int: case int:
return intObject(t), nil return intObject(t), nil
case bool: case bool:
@ -286,7 +288,7 @@ func (ma macroArgs) evalBlock(ctx context.Context, n int, args []Object, pushSco
} }
return ma.eval.evalBlock(ctx, ec, v.block) return ma.eval.evalBlock(ctx, ec, v.block)
case strObject: case StringObject:
iv := ma.ec.lookupInvokable(string(v)) iv := ma.ec.lookupInvokable(string(v))
if iv == nil { if iv == nil {
return nil, errors.New("'" + string(v) + "' is not invokable") return nil, errors.New("'" + string(v) + "' is not invokable")
@ -360,7 +362,7 @@ func (ia invocationArgs) invokableArg(i int) (invokable, error) {
switch v := ia.args[i].(type) { switch v := ia.args[i].(type) {
case invokable: case invokable:
return v, nil return v, nil
case strObject: case StringObject:
iv := ia.ec.lookupInvokable(string(v)) iv := ia.ec.lookupInvokable(string(v))
if iv == nil { if iv == nil {
return nil, errors.New("'" + string(v) + "' is not invokable") return nil, errors.New("'" + string(v) + "' is not invokable")
@ -413,7 +415,8 @@ func (i invokableFunc) invoke(ctx context.Context, args invocationArgs) (Object,
} }
type blockObject struct { type blockObject struct {
block *astBlock block *astBlock
closedEC *evalCtx
} }
func (bo blockObject) String() string { func (bo blockObject) String() string {
@ -425,7 +428,7 @@ func (bo blockObject) Truthy() bool {
} }
func (bo blockObject) invoke(ctx context.Context, args invocationArgs) (Object, error) { func (bo blockObject) invoke(ctx context.Context, args invocationArgs) (Object, error) {
ec := args.ec.fork() ec := bo.closedEC.fork()
for i, n := range bo.block.Names { for i, n := range bo.block.Names {
if i < len(args.args) { if i < len(args.args) {
ec.setOrDefineVar(n, args.args[i]) ec.setOrDefineVar(n, args.args[i])

View file

@ -19,12 +19,12 @@ func WithTestBuiltin() InstOption {
})) }))
i.rootEC.addCmd("toUpper", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) { i.rootEC.addCmd("toUpper", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) {
return strObject(strings.ToUpper(args.args[0].String())), nil return StringObject(strings.ToUpper(args.args[0].String())), nil
})) }))
i.rootEC.addCmd("sjoin", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) { i.rootEC.addCmd("sjoin", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) {
if len(args.args) == 0 { if len(args.args) == 0 {
return strObject(""), nil return StringObject(""), nil
} }
var line strings.Builder var line strings.Builder
@ -34,7 +34,7 @@ func WithTestBuiltin() InstOption {
} }
} }
return strObject(line.String()), nil return StringObject(line.String()), nil
})) }))
i.rootEC.addCmd("list", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) { i.rootEC.addCmd("list", invokableFunc(func(ctx context.Context, args invocationArgs) (Object, error) {
@ -53,7 +53,7 @@ func WithTestBuiltin() InstOption {
lst, ok := args.args[0].(Listable) lst, ok := args.args[0].(Listable)
if !ok { if !ok {
return strObject(""), nil return StringObject(""), nil
} }
l := lst.Len() l := lst.Len()
@ -63,11 +63,11 @@ func WithTestBuiltin() InstOption {
} }
sb.WriteString(lst.Index(x).String()) sb.WriteString(lst.Index(x).String())
} }
return strObject(sb.String()), nil return StringObject(sb.String()), nil
})) }))
i.rootEC.setOrDefineVar("a", strObject("alpha")) i.rootEC.setOrDefineVar("a", StringObject("alpha"))
i.rootEC.setOrDefineVar("bee", strObject("buzz")) i.rootEC.setOrDefineVar("bee", StringObject("buzz"))
} }
} }
@ -704,6 +704,78 @@ func TestBuiltins_Return(t *testing.T) {
four4 four4
`, want: "xxxx\n"}, `, want: "xxxx\n"},
{desc: "check closure 1", expr: `
proc do-thing { |p|
call $p
}
proc test-thing {
foreach [1 2 3] { |x|
do-thing {
echo $x
}
}
}
test-thing
`, want: "1\n2\n3\n(nil)\n"},
{desc: "check closure 2", expr: `
proc do-thing { |p|
call $p
}
proc test-thing {
foreach [1 2 3] { |x|
do-thing (proc {
echo $x
})
}
}
test-thing
`, want: "1\n2\n3\n(nil)\n"},
{desc: "check closure 3", expr: `
proc do-thing { |p|
call $p
}
proc test-thing {
foreach [1 2 3] { |x|
set myClosure (proc { echo $x })
do-thing $myClosure
}
}
test-thing
`, want: "1\n2\n3\n(nil)\n"},
{desc: "check closure 4", expr: `
proc do-thing { |p|
call $p
}
proc test-thing {
[1 2 3] | map { |x|
proc { echo $x }
}
}
foreach (test-thing) { |y| call $y }
`, want: "1\n2\n3\n(nil)\n"},
{desc: "check closure 5", expr: `
proc do-thing { |p|
call $p
}
proc test-thing {
[1 2 3] | map { |x|
set myProc (proc { echo $x })
proc { do-thing $myProc }
}
}
set hello "xx"
foreach (test-thing) { |y| call $y ; echo $hello }
`, want: "1\nxx\n2\nxx\n3\nxx\n(nil)\n"},
} }
for _, tt := range tests { for _, tt := range tests {