From d800d7a5757755bbe3a41fa98cbbec0828976cb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rok=20Su=C5=A1nik?= Date: Sun, 17 Mar 2024 20:27:48 +0100 Subject: [PATCH] implement spredsheet group-by --- env/spreadsheet.go | 2 +- evaldo/builtins_spreadsheet.go | 151 +++++++++++++++++++++++++++++++-- tests/structures.rye | 17 ++++ 3 files changed, 161 insertions(+), 9 deletions(-) diff --git a/env/spreadsheet.go b/env/spreadsheet.go index b5fd2558..365534d1 100644 --- a/env/spreadsheet.go +++ b/env/spreadsheet.go @@ -222,7 +222,7 @@ func (s Spreadsheet) GetRowValue(column string, rrow SpreadsheetRow) (any, error } } if index < 0 { - return "", nil + return "", fmt.Errorf("column %s not found", column) } return rrow.Values[index], nil } diff --git a/evaldo/builtins_spreadsheet.go b/evaldo/builtins_spreadsheet.go index cf657037..12b95315 100644 --- a/evaldo/builtins_spreadsheet.go +++ b/evaldo/builtins_spreadsheet.go @@ -713,6 +713,52 @@ var Builtins_spreadsheet = map[string]*env.Builtin{ } }, }, + "group-by": { + Argsn: 3, + Doc: "Groups a spreadsheet by the given column and (optional) aggregations.", + Fn: func(ps *env.ProgramState, arg0 env.Object, arg1 env.Object, arg2 env.Object, arg3 env.Object, arg4 env.Object) (res env.Object) { + switch spr := arg0.(type) { + case env.Spreadsheet: + switch aggBlock := arg2.(type) { + case env.Block: + if len(aggBlock.Series.S)%2 != 0 { + return MakeBuiltinError(ps, "Aggregation block must contain pairs of column name and function for each aggregation.", "group-by") + } + aggregations := make(map[string][]string) + for i := 0; i < len(aggBlock.Series.S); i += 2 { + col := aggBlock.Series.S[i] + fun, ok := aggBlock.Series.S[i+1].(env.Word) + if !ok { + return MakeBuiltinError(ps, "Aggregation function must be a word", "group-by") + } + colStr := "" + switch col := col.(type) { + case env.Tagword: + colStr = ps.Idx.GetWord(col.Index) + case env.String: + colStr = col.Value + default: + return MakeBuiltinError(ps, "Aggregation column must be a word or string", "group-by") + } + funStr := ps.Idx.GetWord(fun.Index) + aggregations[colStr] = append(aggregations[colStr], funStr) + } + switch col := arg1.(type) { + case env.Word: + return GroupBy(ps, spr, ps.Idx.GetWord(col.Index), aggregations) + case env.String: + return GroupBy(ps, spr, col.Value, aggregations) + default: + return MakeArgError(ps, 2, []env.Type{env.WordType, env.StringType}, "group-by") + } + default: + return MakeArgError(ps, 3, []env.Type{env.BlockType}, "group-by") + } + default: + return MakeArgError(ps, 1, []env.Type{env.SpreadsheetType}, "group-by") + } + }, + }, } func GenerateColumn(ps *env.ProgramState, s env.Spreadsheet, name env.Word, extractCols env.Block, code env.Block) env.Object { @@ -759,7 +805,7 @@ func GenerateColumnRegexReplace(ps *env.ProgramState, s *env.Spreadsheet, name e // get value from current row val, err := s.GetRowValue(ps.Idx.GetWord(fromColName.Index), row) if err != nil { - return MakeError(ps, "Couldn't retrieve value at row "+strconv.Itoa(ix)) + return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", ix, err)) } var newVal any @@ -1041,10 +1087,10 @@ func LeftJoin(ps *env.ProgramState, s1 env.Spreadsheet, s2 env.Spreadsheet, col1 } } nspr := env.NewSpreadsheet(combinedCols) - for _, row1 := range s1.GetRows() { + for i, row1 := range s1.GetRows() { val1, err := s1.GetRowValue(col1, row1) if err != nil { - return MakeError(ps, "Couldn't retrieve value at row") + return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", i, err)) } newRow := make([]any, len(combinedCols)) @@ -1057,13 +1103,13 @@ func LeftJoin(ps *env.ProgramState, s1 env.Spreadsheet, s2 env.Spreadsheet, col1 s2RowId = rowIds[0] } } else { - for i, row2 := range s2.GetRows() { + for j, row2 := range s2.GetRows() { val2, err := s2.GetRowValue(col2, row2) if err != nil { - return MakeError(ps, "Couldn't retrieve value at row") + return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", j, err)) } if val1.(env.Object).Equal(val2.(env.Object)) { - s2RowId = i + s2RowId = j break } } @@ -1077,11 +1123,100 @@ func LeftJoin(ps *env.ProgramState, s1 env.Spreadsheet, s2 env.Spreadsheet, col1 newRow[i+len(s1.Cols)] = v } } else { - for i := range s2.Cols { - newRow[i+len(s1.Cols)] = env.Void{} + for k := range s2.Cols { + newRow[k+len(s1.Cols)] = env.Void{} } } nspr.AddRow(*env.NewSpreadsheetRow(newRow, nspr)) } return *nspr } + +func GroupBy(ps *env.ProgramState, s env.Spreadsheet, col string, aggregations map[string][]string) env.Object { + if !slices.Contains(s.Cols, col) { + return MakeBuiltinError(ps, "Column not found.", "group-by") + } + + aggregatesByGroup := make(map[string]map[string]float64) + countByGroup := make(map[string]int) + for i, row := range s.Rows { + groupingVal, err := s.GetRowValue(col, row) + if err != nil { + return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", i, err)) + } + groupValStr, ok := groupingVal.(env.String) + if !ok { + return MakeBuiltinError(ps, "Grouping column value must be a string", "group-by") + } + + if _, ok := aggregatesByGroup[groupValStr.Value]; !ok { + aggregatesByGroup[groupValStr.Value] = make(map[string]float64) + } + groupAggregates := aggregatesByGroup[groupValStr.Value] + + for aggCol, funs := range aggregations { + for _, fun := range funs { + colAgg := aggCol + "_" + fun + if fun == "count" { + if aggCol != col { + return MakeBuiltinError(ps, "Count aggregation can only be applied on the grouping column", "group-by") + } + groupAggregates[colAgg]++ + continue + } + valObj, err := s.GetRowValue(aggCol, row) + if err != nil { + return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", i, err)) + } + var val float64 + switch valObj := env.ToRyeValue(valObj).(type) { + case env.Integer: + val = float64(valObj.Value) + case env.Decimal: + val = valObj.Value + default: + return MakeBuiltinError(ps, "Aggregation column value must be a number", "group-by") + } + switch fun { + case "sum": + groupAggregates[colAgg] += val + case "avg": + groupAggregates[colAgg] += val + countByGroup[groupValStr.Value]++ + case "min": + if min, ok := groupAggregates[colAgg]; !ok || val < min { + groupAggregates[colAgg] = val + } + case "max": + if max, ok := groupAggregates[colAgg]; !ok || val > max { + groupAggregates[colAgg] = val + } + default: + return MakeBuiltinError(ps, fmt.Sprintf("Unknown aggregation function: %s", fun), "group-by") + } + } + } + } + newCols := []string{col} + for aggCol, funs := range aggregations { + for _, fun := range funs { + newCols = append(newCols, aggCol+"_"+fun) + } + } + newS := env.NewSpreadsheet(newCols) + for groupVal, groupAggregates := range aggregatesByGroup { + newRow := make([]any, len(newCols)) + newRow[0] = *env.NewString(groupVal) + for i, col := range newCols[1:] { + if strings.HasSuffix(col, "_count") { + newRow[i+1] = *env.NewInteger(int64(groupAggregates[col])) + } else if strings.HasSuffix(col, "_avg") { + newRow[i+1] = *env.NewDecimal(groupAggregates[col] / float64(countByGroup[groupVal])) + } else { + newRow[i+1] = *env.NewDecimal(groupAggregates[col]) + } + } + newS.AddRow(*env.NewSpreadsheetRow(newRow, newS)) + } + return *newS +} diff --git a/tests/structures.rye b/tests/structures.rye index fab6707a..419181a5 100644 --- a/tests/structures.rye +++ b/tests/structures.rye @@ -653,6 +653,23 @@ section "Spreadsheet related functions" names .inner-join houses 'id 'id } spreadsheet { "id" "name" "id_2" "house" } { 1 "Paul" 1 "Atreides" 3 "Vladimir" 3 "Harkonnen" } } + + group "group by" + mold\nowrap ?group-by + { { block } } + { + equal { spreadsheet { "name" "val" } { "a" 1 "b" 2 } |group-by 'name { } |sort-col! 'name + } spreadsheet { "name" } { "a" "b" } + + equal { spreadsheet { "name" "val" } { "a" 1 "b" 6 "a" 5 "b" 10 "a" 7 } + |group-by 'name { 'name count 'val sum 'val min 'val max 'val avg } + |sort-col! 'name + } spreadsheet { "name" "name_count" "val_sum" "val_min" "val_max" "val_avg" } + { + "a" 3 13.0 1.0 7.0 4.333333333333333 + "b" 2 16.0 6.0 10.0 8.0 + } + } }