Skip to content

Commit

Permalink
Merge pull request #6 from Adarsh-jaiss/main
Browse files Browse the repository at this point in the history
Added Snowflake support for the library
  • Loading branch information
tqindia authored Apr 26, 2024
2 parents 3e63b2c + 6083c70 commit 0e03a9e
Show file tree
Hide file tree
Showing 9 changed files with 557 additions and 7 deletions.
13 changes: 13 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/thesaas-company/xray/config"
"github.com/thesaas-company/xray/databases/mysql"
"github.com/thesaas-company/xray/databases/postgres"
"github.com/thesaas-company/xray/databases/snowflake"
"github.com/thesaas-company/xray/logger"
"github.com/thesaas-company/xray/types"
)
Expand All @@ -27,6 +28,12 @@ func NewClientWithConfig(dbConfig *config.Config, dbType types.DbType) (types.IS
return nil, err
}
return logger.NewLogger(sqlClient), nil
case types.Snowflake:
sqlClient, err := snowflake.NewSnowflakeWithConfig(dbConfig)
if err != nil {
return nil, err
}
return logger.NewLogger(sqlClient), nil
default:
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
Expand All @@ -49,6 +56,12 @@ func NewClient(dbClient *sql.DB, dbType types.DbType) (types.ISQL, error) {
return nil, err
}
return logger.NewLogger(sqlClient), nil
case types.Snowflake:
sqlClient, err := snowflake.NewSnowflake(dbClient)
if err != nil {
return nil, err
}
return logger.NewLogger(sqlClient), nil
default:
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
Expand Down
2 changes: 1 addition & 1 deletion databases/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestSchema(t *testing.T) {
// we then create a new instance of our MySQL object and test the function
m, err := NewMySQL(db)
if err != nil {
t.Errorf("error executing query: %s", err)
t.Errorf("error initialising mysql: %s", err)
}
response, err := m.Schema(tableName)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion databases/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestSchema(t *testing.T) {

m, err := NewPostgres(db)
if err != nil {
t.Errorf("error executing query: %s", err)
t.Errorf("error initialising postgres: %s", err)
}
response, err := m.Schema(table_name)
if err != nil {
Expand Down
221 changes: 221 additions & 0 deletions databases/snowflake/snowflake.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package snowflake

import (
"database/sql"
"encoding/json"
"fmt"
"os"

sf "github.com/snowflakedb/gosnowflake"
"github.com/thesaas-company/xray/config"
"github.com/thesaas-company/xray/types"
)

type Snowflake struct {
Client *sql.DB
Config *config.Config
}

var DB_PASSWORD string = "root"

const (
SNOWFLAKE_TABLES_LIST_QUERY = "SHOW TERSE TABLES"
SNOWFLAKE_SCHEMA_QUERY = `
SELECT
COLUMN_NAME,
DATA_TYPE,
IS_NULLABLE,
COLUMN_DEFAULT,
IS_UPDATABLE,
IS_IDENTITY,
IS_GENERATED,
IS_UNIQUE,
IS_SYSTEM_COLUMN,
IS_HIDDEN,
IS_READ_ONLY,
IS_COMPUTED,
IS_SPARSE,
IS_COLUMN_SET,
IS_SELF_REFERENCING,
SCOPE_NAME,
SCOPE_SCHEMA,
ORDINAL_POSITION
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = ?)
`
)

// The NewSnowflake function is responsible for creating a new Snowflake object with an initialized database client and configuration.
func NewSnowflake(dbClient *sql.DB) (types.ISQL, error) {
return &Snowflake{
Client: dbClient,
Config: &config.Config{},
}, nil
}

// The NewSnowflakeWithConfig function is responsible for creating a new Snowflake object with an initialized database client and configuration.
func NewSnowflakeWithConfig(config *config.Config) (types.ISQL, error) {
if os.Getenv(DB_PASSWORD) == "" || len(os.Getenv(DB_PASSWORD)) == 0 {
return nil, fmt.Errorf("please set %s env variable for the database", DB_PASSWORD)
}
DB_PASSWORD = os.Getenv(DB_PASSWORD)


dsn, err := sf.DSN(&sf.Config{
Account: config.Account,
User: config.Username,
Password: DB_PASSWORD,
Database: config.DatabaseName,
Warehouse: config.Warehouse,
})
if err != nil {
return nil, fmt.Errorf("error creating snowflake DSN: %v", err)
}

dbType := types.Snowflake
db, err := sql.Open(dbType.String(), dsn)
if err != nil {
return nil, fmt.Errorf("error opening connection to snowflake database: %v", err)
}

return &Snowflake{
Client: db,
Config: config,
}, nil

}

// The Schema function returns the schema of a table in Snowflake.
func (s *Snowflake) Schema(table string) (types.Table, error) {
var res types.Table

rows, err := s.Client.Query(SNOWFLAKE_SCHEMA_QUERY, table)
if err != nil {
return res, fmt.Errorf("error executing sql statement: %v", err)
}
defer rows.Close()

var columns []types.Column
for rows.Next() {
var column types.Column
if err := rows.Scan(
&column.Name,
&column.Type,
&column.IsNullable,
&column.ColumnDefault,
&column.IsUpdatable,
&column.IsIdentity,
&column.IsGenerated,
&column.IsUnique,
&column.IsSystemColumn,
&column.IsHidden,
&column.IsReadOnly,
&column.IsComputed,
&column.IsSparse,
&column.IsColumnSet,
&column.IsSelfReferencing,
&column.ScopeName,
&column.ScopeSchema,
&column.OrdinalPosition); err != nil {
return res, fmt.Errorf("error scanning rows: %v", err)
}
column.Description = "" // default description
column.Metatags = []string{} // default metatags as an empty string slice
column.Metatags = append(column.Metatags, column.Name)
column.Visibility = true // default visibility
columns = append(columns, column)
}

// checking for erros from iterating over the rows
if err := rows.Err(); err != nil {
return res, fmt.Errorf("error iterating over rows: %v", err)
}

return types.Table{
Name: table,
Columns: columns,
ColumnCount: int64(len(columns)),
Description: "",
Metatags: []string{},
}, nil
}

// Every table in Snowflake lives "inside" a schema. Every schema lives "inside" a database. It's a hierarchical system.
// The Tables function returns a list of tables in a Snowflake database.
func (s *Snowflake) Tables(DatabaseName string) ([]string, error) {
query := fmt.Sprintf("USE WAREHOUSE %s", s.Config.Warehouse)
_, err := s.Client.Query(query)
if err != nil {
return nil, fmt.Errorf("error executing sql statement: %v", err)
}

rows,err := s.Client.Query(SNOWFLAKE_TABLES_LIST_QUERY)
if err != nil {
return nil, fmt.Errorf("error executing sql statement and querying tables list: %v", err)
}
defer rows.Close()

var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, fmt.Errorf("error scanning database: %v", err)
}
tables = append(tables, table)
}

// checking for errors in iterating over rows
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over rows:%v", err)
}

return tables, nil
}

// The Execute function executes a query on a Snowflake database and returns the result as a JSON byte slice.
func (s *Snowflake) Execute(query string) ([]byte, error) {
rows, err := s.Client.Query(query)
if err != nil {
return nil, fmt.Errorf("error executing sql statement: %v", err)
}
defer rows.Close()

columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("error getting columns: %v", err)
}

// Scan the result into a slice of slices
var results [][]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
pointers := make([]interface{}, len(columns))
for i := range values {
pointers[i] = &values[i]
}

if err := rows.Scan(pointers...); err != nil {
return nil, fmt.Errorf("error scanning row: %v", err)
}

results = append(results, values)
}

// Check for errors from iterating over rows
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating rows: %v", err)
}

// Convert the result to JSON
queryResult := types.QueryResult{
Columns: columns,
Rows: results,
}

jsonData, err := json.Marshal(queryResult)
if err != nil {
return nil, fmt.Errorf("error marshaling json: %v", err)
}

return jsonData, nil
}
120 changes: 120 additions & 0 deletions databases/snowflake/snowflake_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package snowflake

import (
"database/sql"
"encoding/json"
"fmt"
"reflect"
"regexp"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/thesaas-company/xray/types"
)

func MockDB() (*sql.DB, sqlmock.Sqlmock) {
db, mock, err := sqlmock.New()
if err != nil {
panic("An error occured while creating a new mock database connection")
}

return db, mock
}

func TestSchema(t *testing.T) {
db, mock := MockDB()
defer db.Close()

table_name := "user"

columns := []string{"name", "type", "IsNullable", "DefaultValue", "IsUpdatable", "IsIdenity", "IsGenerated", "IsUnique", "IsSystemColumn", "IsHidden", "IsReadOnly", "IsComputed", "IsSparse", "IsColumnSet", "IsSelfReplacing", "ScopeName", "ScopeSchema", "OrdinalPosition"}
mockRows := sqlmock.NewRows(columns).AddRow("id", "int", true, 1, true, false, true, true, false, false, true, false, true, false, true, "scope1", "schema1", 1)

mock.ExpectQuery(regexp.QuoteMeta(SNOWFLAKE_SCHEMA_QUERY)).WithArgs(table_name).WillReturnRows(mockRows)

s, err := NewSnowflake(db)
if err != nil {
t.Errorf("error initialising snowflake: %s", err)
}

res, err := s.Schema(table_name)
if err != nil {
t.Errorf("error executing query : %v", err)
}

fmt.Printf("Table schema %+v\n", res)

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there was unfulfilled expectations: %s", err)
}

}

func TestExecute(t *testing.T) {
db, mock := MockDB()
defer db.Close()

query := `SELECT id, name FROM "user"`
mockRows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Rohan")

mock.ExpectQuery(regexp.QuoteMeta(query)).WillReturnRows(mockRows)

p, err := NewSnowflake(db)
if err != nil {
t.Errorf("error executing query: %s", err)
}
res, err := p.Execute(query)
if err != nil {
t.Errorf("error executing the query: %s", err)
}

var result types.QueryResult
if err := json.Unmarshal(res, &result); err != nil {
t.Errorf("error unmarshalling the result: %s", err)
}

fmt.Printf("Query result: %+v\n", result)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestTables(t *testing.T) {
db, mock := MockDB()
defer db.Close()

tableList := []string{"user", "product", "order"}
Warehouse := "datasherlock"
mock.ExpectQuery("USE WAREHOUSE ").WithArgs(Warehouse).WillReturnRows(sqlmock.NewRows([]string{"result"}).AddRow(""))

rows := sqlmock.NewRows([]string{"table_name"}).
AddRow(tableList[0]).
AddRow(tableList[1]).
AddRow(tableList[2])
mock.ExpectQuery(regexp.QuoteMeta(SNOWFLAKE_TABLES_LIST_QUERY)).WillReturnRows(rows)

s, err := NewSnowflake(db)
if err != nil {
t.Fatalf("error initializing snowflake: %s", err)
}

query := fmt.Sprintf("USE WAREHOUSE %s", Warehouse)
_, err = s.Tables(query)
if err != nil {
return
}

tables, err := s.Tables("test") // Database name isn't used in the query, so you can pass any value here
if err != nil {
t.Errorf("error retrieving table names: %s", err)
}

expected := tableList // Using the same list as returned by the mock
if !reflect.DeepEqual(tables, expected) {
t.Errorf("expected: %v, got: %v", expected, tables)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
Loading

0 comments on commit 0e03a9e

Please sign in to comment.