From 348251c1cfd7e69764ede8cc21a8d2531ae3a2f2 Mon Sep 17 00:00:00 2001 From: Leon Mika Date: Sun, 29 Jan 2023 09:12:39 +1100 Subject: [PATCH] Finished the mapping between dynamo attribute values and tamarin values --- internal/common/maputils/map.go | 23 +++ internal/common/sliceutils/map.go | 9 ++ .../services/scriptmanager/resultsetproxy.go | 54 ++++--- .../scriptmanager/resultsetproxy_test.go | 96 +++++++++++++ .../services/scriptmanager/typemapping.go | 135 ++++++++++++++++++ 5 files changed, 294 insertions(+), 23 deletions(-) create mode 100644 internal/common/maputils/map.go create mode 100644 internal/dynamo-browse/services/scriptmanager/typemapping.go diff --git a/internal/common/maputils/map.go b/internal/common/maputils/map.go new file mode 100644 index 0000000..bffa7a1 --- /dev/null +++ b/internal/common/maputils/map.go @@ -0,0 +1,23 @@ +package maputils + +func Values[K comparable, T any](ts map[K]T) []T { + values := make([]T, 0, len(ts)) + for _, v := range ts { + values = append(values, v) + } + return values +} + +func MapValuesWithError[K comparable, T, U any](ts map[K]T, fn func(t T) (U, error)) (map[K]U, error) { + us := make(map[K]U) + + for k, t := range ts { + var err error + us[k], err = fn(t) + if err != nil { + return nil, err + } + } + + return us, nil +} diff --git a/internal/common/sliceutils/map.go b/internal/common/sliceutils/map.go index 0d233b3..43b69e7 100644 --- a/internal/common/sliceutils/map.go +++ b/internal/common/sliceutils/map.go @@ -1,5 +1,14 @@ package sliceutils +func All[T any](ts []T, predicate func(t T) bool) bool { + for _, t := range ts { + if !predicate(t) { + return false + } + } + return true +} + func Map[T, U any](ts []T, fn func(t T) U) []U { us := make([]U, len(ts)) for i, t := range ts { diff --git a/internal/dynamo-browse/services/scriptmanager/resultsetproxy.go b/internal/dynamo-browse/services/scriptmanager/resultsetproxy.go index a8a49a4..d00f7e4 100644 --- a/internal/dynamo-browse/services/scriptmanager/resultsetproxy.go +++ b/internal/dynamo-browse/services/scriptmanager/resultsetproxy.go @@ -2,13 +2,11 @@ package scriptmanager import ( "context" - "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/cloudcmds/tamarin/arg" "github.com/cloudcmds/tamarin/object" "github.com/lmika/audax/internal/dynamo-browse/models" "github.com/lmika/audax/internal/dynamo-browse/models/queryexpr" "github.com/pkg/errors" - "strconv" ) type resultSetProxy struct { @@ -172,19 +170,32 @@ func (i *itemProxy) value(ctx context.Context, args ...object.Object) object.Obj return object.NewError(errors.Errorf("arg error: path expression evaluate error: %v", err)) } - // TODO - switch v := av.(type) { - case *types.AttributeValueMemberS: - return object.NewString(v.Value) - case *types.AttributeValueMemberN: - // TODO: better - f, err := strconv.ParseFloat(v.Value, 64) - if err != nil { - return object.NewError(errors.Errorf("value error: invalid N value: %v", v.Value)) - } - return object.NewFloat(f) + tVal, err := attributeValueToTamarin(av) + if err != nil { + return object.NewError(err) } - return object.NewError(errors.New("TODO")) + return tVal + + // TODO + //switch v := av.(type) { + //case *types.AttributeValueMemberS: + // return object.NewString(v.Value) + //case *types.AttributeValueMemberN: + // // TODO: better + // f, err := strconv.ParseFloat(v.Value, 64) + // if err != nil { + // return object.NewError(errors.Errorf("value error: invalid N value: %v", v.Value)) + // } + // return object.NewFloat(f) + //case *types.AttributeValueMemberBOOL: + // if v.Value { + // return object.True + // } + // return object.False + //case *types.AttributeValueMemberNULL: + // return object.Nil + //} + //return object.NewError(errors.New("TODO")) } func (i *itemProxy) setValue(ctx context.Context, args ...object.Object) object.Object { @@ -202,15 +213,12 @@ func (i *itemProxy) setValue(ctx context.Context, args ...object.Object) object. return object.Errorf("arg error: invalid path expression: %v", err) } - // TODO - newValue := args[1] - switch v := newValue.(type) { - case *object.String: - if err := path.SetEvalItem(i.item, &types.AttributeValueMemberS{Value: v.Value()}); err != nil { - return object.NewError(err) - } - default: - return object.Errorf("type error: unsupported value type (got %v)", newValue.Type()) + newValue, err := tamarinValueToAttributeValue(args[1]) + if err != nil { + return object.NewError(err) + } + if err := path.SetEvalItem(i.item, newValue); err != nil { + return object.NewError(err) } i.resultSetProxy.resultSet.SetDirty(i.itemIndex, true) diff --git a/internal/dynamo-browse/services/scriptmanager/resultsetproxy_test.go b/internal/dynamo-browse/services/scriptmanager/resultsetproxy_test.go index 63f6502..2c04967 100644 --- a/internal/dynamo-browse/services/scriptmanager/resultsetproxy_test.go +++ b/internal/dynamo-browse/services/scriptmanager/resultsetproxy_test.go @@ -54,6 +54,58 @@ func TestResultSetProxy(t *testing.T) { }) } +func TestResultSetProxy_GetAttr(t *testing.T) { + t.Run("should return the value of items within a result set", func(t *testing.T) { + rs := &models.ResultSet{} + rs.SetItems([]models.Item{ + { + "pk": &types.AttributeValueMemberS{Value: "abc"}, + "sk": &types.AttributeValueMemberN{Value: "123"}, + "bool": &types.AttributeValueMemberBOOL{Value: true}, + "null": &types.AttributeValueMemberNULL{Value: true}, + "list": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "apple"}, + &types.AttributeValueMemberS{Value: "banana"}, + &types.AttributeValueMemberS{Value: "cherry"}, + }}, + "map": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "this": &types.AttributeValueMemberS{Value: "that"}, + "another": &types.AttributeValueMemberS{Value: "thing"}, + }}, + "strSet": &types.AttributeValueMemberSS{Value: []string{"apple", "banana", "cherry"}}, + "numSet": &types.AttributeValueMemberNS{Value: []string{"123", "45.67", "8.911", "-321"}}, + }, + }) + + mockedSessionService := mocks.NewSessionService(t) + mockedSessionService.EXPECT().Query(mock.Anything, "some expr", scriptmanager.QueryOptions{}).Return(rs, nil) + + testFS := testScriptFile(t, "test.tm", ` + res := session.query("some expr").unwrap() + + assert(res[0].attr("pk") == "abc", "str attr") + assert(res[0].attr("sk") == 123, "num attr") + assert(res[0].attr("bool") == true, "bool attr") + assert(res[0].attr("null") == nil, "null attr") + assert(res[0].attr("list") == ["apple","banana","cherry"], "list attr") + assert(res[0].attr("map") == {"this":"that", "another":"thing"}, "map attr") + assert(res[0].attr("strSet") == {"apple","banana","cherry"}, "string set") + assert(res[0].attr("numSet") == {123, 45.67, 8.911, -321}, "number set") + `) + + srv := scriptmanager.New(scriptmanager.WithFS(testFS)) + srv.SetIFaces(scriptmanager.Ifaces{ + Session: mockedSessionService, + }) + + ctx := context.Background() + err := <-srv.RunAdHocScript(ctx, "test.tm") + assert.NoError(t, err) + + mockedSessionService.AssertExpectations(t) + }) +} + func TestResultSetProxy_SetAttr(t *testing.T) { t.Run("should set the value of the item within a result set", func(t *testing.T) { rs := &models.ResultSet{} @@ -66,6 +118,39 @@ func TestResultSetProxy_SetAttr(t *testing.T) { mockedSessionService.EXPECT().Query(mock.Anything, "some expr", scriptmanager.QueryOptions{}).Return(rs, nil) mockedSessionService.EXPECT().SetResultSet(mock.Anything, mock.MatchedBy(func(rs *models.ResultSet) bool { assert.Equal(t, "bla-di-bla", rs.Items()[0]["pk"].(*types.AttributeValueMemberS).Value) + assert.Equal(t, "123", rs.Items()[0]["num"].(*types.AttributeValueMemberN).Value) + assert.Equal(t, "123.45", rs.Items()[0]["numFloat"].(*types.AttributeValueMemberN).Value) + assert.Equal(t, true, rs.Items()[0]["bool"].(*types.AttributeValueMemberBOOL).Value) + assert.Equal(t, true, rs.Items()[0]["nil"].(*types.AttributeValueMemberNULL).Value) + + list := rs.Items()[0]["lists"].(*types.AttributeValueMemberL).Value + assert.Equal(t, "abc", list[0].(*types.AttributeValueMemberS).Value) + assert.Equal(t, "123", list[1].(*types.AttributeValueMemberN).Value) + assert.Equal(t, true, list[2].(*types.AttributeValueMemberBOOL).Value) + + nestedLists := rs.Items()[0]["nestedLists"].(*types.AttributeValueMemberL).Value + assert.Equal(t, "1", nestedLists[0].(*types.AttributeValueMemberL).Value[0].(*types.AttributeValueMemberN).Value) + assert.Equal(t, "2", nestedLists[0].(*types.AttributeValueMemberL).Value[1].(*types.AttributeValueMemberN).Value) + assert.Equal(t, "3", nestedLists[1].(*types.AttributeValueMemberL).Value[0].(*types.AttributeValueMemberN).Value) + assert.Equal(t, "4", nestedLists[1].(*types.AttributeValueMemberL).Value[1].(*types.AttributeValueMemberN).Value) + + mapValue := rs.Items()[0]["map"].(*types.AttributeValueMemberM).Value + assert.Equal(t, "world", mapValue["hello"].(*types.AttributeValueMemberS).Value) + assert.Equal(t, "213", mapValue["nums"].(*types.AttributeValueMemberN).Value) + + numSet := rs.Items()[0]["numSet"].(*types.AttributeValueMemberNS).Value + assert.Len(t, numSet, 4) + assert.Contains(t, numSet, "1") + assert.Contains(t, numSet, "2") + assert.Contains(t, numSet, "3") + assert.Contains(t, numSet, "4.5") + + strSet := rs.Items()[0]["strSet"].(*types.AttributeValueMemberSS).Value + assert.Len(t, strSet, 3) + assert.Contains(t, strSet, "a") + assert.Contains(t, strSet, "b") + assert.Contains(t, strSet, "c") + assert.True(t, rs.IsDirty(0)) return true })) @@ -74,7 +159,18 @@ func TestResultSetProxy_SetAttr(t *testing.T) { testFS := testScriptFile(t, "test.tm", ` res := session.query("some expr").unwrap() + res[0].set_attr("pk", "bla-di-bla") + res[0].set_attr("num", 123) + res[0].set_attr("numFloat", 123.45) + res[0].set_attr("bool", true) + res[0].set_attr("nil", nil) + res[0].set_attr("lists", ['abc', 123, true]) + res[0].set_attr("nestedLists", [[1,2], [3,4]]) + res[0].set_attr("map", {"hello": "world", "nums": 213}) + res[0].set_attr("numSet", {1,2,3,4.5}) + res[0].set_attr("strSet", {"a","b","c"}) + session.set_result_set(res) `) diff --git a/internal/dynamo-browse/services/scriptmanager/typemapping.go b/internal/dynamo-browse/services/scriptmanager/typemapping.go new file mode 100644 index 0000000..6cf1583 --- /dev/null +++ b/internal/dynamo-browse/services/scriptmanager/typemapping.go @@ -0,0 +1,135 @@ +package scriptmanager + +import ( + "fmt" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/cloudcmds/tamarin/object" + "github.com/lmika/audax/internal/common/maputils" + "github.com/lmika/audax/internal/common/sliceutils" + "github.com/pkg/errors" + "regexp" + "strconv" +) + +func tamarinValueToAttributeValue(val object.Object) (types.AttributeValue, error) { + switch v := val.(type) { + case *object.String: + return &types.AttributeValueMemberS{Value: v.Value()}, nil + case *object.Int: + return &types.AttributeValueMemberN{Value: strconv.FormatInt(v.Value(), 10)}, nil + case *object.Float: + return &types.AttributeValueMemberN{Value: strconv.FormatFloat(v.Value(), 'f', -1, 64)}, nil + case *object.Bool: + return &types.AttributeValueMemberBOOL{Value: v.Value()}, nil + case *object.NilType: + return &types.AttributeValueMemberNULL{Value: true}, nil + case *object.List: + attrValue, err := sliceutils.MapWithError(v.Value(), tamarinValueToAttributeValue) + if err != nil { + return nil, err + } + return &types.AttributeValueMemberL{Value: attrValue}, nil + case *object.Map: + attrValue, err := maputils.MapValuesWithError(v.Value(), tamarinValueToAttributeValue) + if err != nil { + return nil, err + } + return &types.AttributeValueMemberM{Value: attrValue}, nil + case *object.Set: + values := maputils.Values(v.Value()) + canBeNumSet := sliceutils.All(values, func(t object.Object) bool { + _, isInt := t.(*object.Int) + _, isFloat := t.(*object.Float) + return isInt || isFloat + }) + + if canBeNumSet { + return &types.AttributeValueMemberNS{ + Value: sliceutils.Map(values, func(t object.Object) string { + switch v := t.(type) { + case *object.Int: + return strconv.FormatInt(v.Value(), 10) + case *object.Float: + return strconv.FormatFloat(v.Value(), 'f', -1, 64) + } + panic(fmt.Sprintf("unhandled object type: %v", t.Type())) + }), + }, nil + } + return &types.AttributeValueMemberSS{ + Value: sliceutils.Map(values, func(t object.Object) string { + v, _ := object.AsString(t) + return v + }), + }, nil + } + return nil, errors.Errorf("type error: unsupported value type (got %v)", val.Type()) +} + +func attributeValueToTamarin(val types.AttributeValue) (object.Object, error) { + if val == nil { + return object.Nil, nil + } + + switch v := val.(type) { + case *types.AttributeValueMemberS: + return object.NewString(v.Value), nil + case *types.AttributeValueMemberN: + f, err := convertNumAttributeToTamarinValue(v.Value) + if err != nil { + return nil, errors.Errorf("value error: invalid N value: %v", v.Value) + } + return f, nil + case *types.AttributeValueMemberBOOL: + if v.Value { + return object.True, nil + } + return object.False, nil + case *types.AttributeValueMemberNULL: + return object.Nil, nil + case *types.AttributeValueMemberL: + list, err := sliceutils.MapWithError(v.Value, attributeValueToTamarin) + if err != nil { + return nil, err + } + return object.NewList(list), nil + case *types.AttributeValueMemberM: + objMap, err := maputils.MapValuesWithError(v.Value, attributeValueToTamarin) + if err != nil { + return nil, err + } + return object.NewMap(objMap), nil + case *types.AttributeValueMemberSS: + return object.NewSet(sliceutils.Map(v.Value, func(s string) object.Object { + return object.NewString(s) + })), nil + case *types.AttributeValueMemberNS: + nums, err := sliceutils.MapWithError(v.Value, func(s string) (object.Object, error) { + return convertNumAttributeToTamarinValue(s) + }) + if err != nil { + return nil, err + } + return object.NewSet(nums), nil + } + return nil, errors.Errorf("value error: cannot convert type %T to tamarin object", val) +} + +var intNumberPattern = regexp.MustCompile(`^[-]?[0-9]+$`) + +// XXX - this is pretty crappy in that it does not support large values +func convertNumAttributeToTamarinValue(n string) (object.Object, error) { + if intNumberPattern.MatchString(n) { + parsedInt, err := strconv.ParseInt(n, 10, 64) + if err != nil { + return nil, err + } + return object.NewInt(parsedInt), nil + } + + f, err := strconv.ParseFloat(n, 64) + if err != nil { + return nil, err + } + return object.NewFloat(f), nil +}