-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from Adarsh-jaiss/main
Added Snowflake support for the library
- Loading branch information
Showing
9 changed files
with
557 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
package snowflake | ||
|
||
import ( | ||
"database/sql" | ||
"encoding/json" | ||
"fmt" | ||
"os" | ||
|
||
sf "github.com/snowflakedb/gosnowflake" | ||
"github.com/thesaas-company/xray/config" | ||
"github.com/thesaas-company/xray/types" | ||
) | ||
|
||
type Snowflake struct { | ||
Client *sql.DB | ||
Config *config.Config | ||
} | ||
|
||
var DB_PASSWORD string = "root" | ||
|
||
const ( | ||
SNOWFLAKE_TABLES_LIST_QUERY = "SHOW TERSE TABLES" | ||
SNOWFLAKE_SCHEMA_QUERY = ` | ||
SELECT | ||
COLUMN_NAME, | ||
DATA_TYPE, | ||
IS_NULLABLE, | ||
COLUMN_DEFAULT, | ||
IS_UPDATABLE, | ||
IS_IDENTITY, | ||
IS_GENERATED, | ||
IS_UNIQUE, | ||
IS_SYSTEM_COLUMN, | ||
IS_HIDDEN, | ||
IS_READ_ONLY, | ||
IS_COMPUTED, | ||
IS_SPARSE, | ||
IS_COLUMN_SET, | ||
IS_SELF_REFERENCING, | ||
SCOPE_NAME, | ||
SCOPE_SCHEMA, | ||
ORDINAL_POSITION | ||
FROM INFORMATION_SCHEMA.COLUMNS | ||
WHERE TABLE_NAME = ?) | ||
` | ||
) | ||
|
||
// The NewSnowflake function is responsible for creating a new Snowflake object with an initialized database client and configuration. | ||
func NewSnowflake(dbClient *sql.DB) (types.ISQL, error) { | ||
return &Snowflake{ | ||
Client: dbClient, | ||
Config: &config.Config{}, | ||
}, nil | ||
} | ||
|
||
// The NewSnowflakeWithConfig function is responsible for creating a new Snowflake object with an initialized database client and configuration. | ||
func NewSnowflakeWithConfig(config *config.Config) (types.ISQL, error) { | ||
if os.Getenv(DB_PASSWORD) == "" || len(os.Getenv(DB_PASSWORD)) == 0 { | ||
return nil, fmt.Errorf("please set %s env variable for the database", DB_PASSWORD) | ||
} | ||
DB_PASSWORD = os.Getenv(DB_PASSWORD) | ||
|
||
|
||
dsn, err := sf.DSN(&sf.Config{ | ||
Account: config.Account, | ||
User: config.Username, | ||
Password: DB_PASSWORD, | ||
Database: config.DatabaseName, | ||
Warehouse: config.Warehouse, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("error creating snowflake DSN: %v", err) | ||
} | ||
|
||
dbType := types.Snowflake | ||
db, err := sql.Open(dbType.String(), dsn) | ||
if err != nil { | ||
return nil, fmt.Errorf("error opening connection to snowflake database: %v", err) | ||
} | ||
|
||
return &Snowflake{ | ||
Client: db, | ||
Config: config, | ||
}, nil | ||
|
||
} | ||
|
||
// The Schema function returns the schema of a table in Snowflake. | ||
func (s *Snowflake) Schema(table string) (types.Table, error) { | ||
var res types.Table | ||
|
||
rows, err := s.Client.Query(SNOWFLAKE_SCHEMA_QUERY, table) | ||
if err != nil { | ||
return res, fmt.Errorf("error executing sql statement: %v", err) | ||
} | ||
defer rows.Close() | ||
|
||
var columns []types.Column | ||
for rows.Next() { | ||
var column types.Column | ||
if err := rows.Scan( | ||
&column.Name, | ||
&column.Type, | ||
&column.IsNullable, | ||
&column.ColumnDefault, | ||
&column.IsUpdatable, | ||
&column.IsIdentity, | ||
&column.IsGenerated, | ||
&column.IsUnique, | ||
&column.IsSystemColumn, | ||
&column.IsHidden, | ||
&column.IsReadOnly, | ||
&column.IsComputed, | ||
&column.IsSparse, | ||
&column.IsColumnSet, | ||
&column.IsSelfReferencing, | ||
&column.ScopeName, | ||
&column.ScopeSchema, | ||
&column.OrdinalPosition); err != nil { | ||
return res, fmt.Errorf("error scanning rows: %v", err) | ||
} | ||
column.Description = "" // default description | ||
column.Metatags = []string{} // default metatags as an empty string slice | ||
column.Metatags = append(column.Metatags, column.Name) | ||
column.Visibility = true // default visibility | ||
columns = append(columns, column) | ||
} | ||
|
||
// checking for erros from iterating over the rows | ||
if err := rows.Err(); err != nil { | ||
return res, fmt.Errorf("error iterating over rows: %v", err) | ||
} | ||
|
||
return types.Table{ | ||
Name: table, | ||
Columns: columns, | ||
ColumnCount: int64(len(columns)), | ||
Description: "", | ||
Metatags: []string{}, | ||
}, nil | ||
} | ||
|
||
// Every table in Snowflake lives "inside" a schema. Every schema lives "inside" a database. It's a hierarchical system. | ||
// The Tables function returns a list of tables in a Snowflake database. | ||
func (s *Snowflake) Tables(DatabaseName string) ([]string, error) { | ||
query := fmt.Sprintf("USE WAREHOUSE %s", s.Config.Warehouse) | ||
_, err := s.Client.Query(query) | ||
if err != nil { | ||
return nil, fmt.Errorf("error executing sql statement: %v", err) | ||
} | ||
|
||
rows,err := s.Client.Query(SNOWFLAKE_TABLES_LIST_QUERY) | ||
if err != nil { | ||
return nil, fmt.Errorf("error executing sql statement and querying tables list: %v", err) | ||
} | ||
defer rows.Close() | ||
|
||
var tables []string | ||
for rows.Next() { | ||
var table string | ||
if err := rows.Scan(&table); err != nil { | ||
return nil, fmt.Errorf("error scanning database: %v", err) | ||
} | ||
tables = append(tables, table) | ||
} | ||
|
||
// checking for errors in iterating over rows | ||
if err := rows.Err(); err != nil { | ||
return nil, fmt.Errorf("error iterating over rows:%v", err) | ||
} | ||
|
||
return tables, nil | ||
} | ||
|
||
// The Execute function executes a query on a Snowflake database and returns the result as a JSON byte slice. | ||
func (s *Snowflake) Execute(query string) ([]byte, error) { | ||
rows, err := s.Client.Query(query) | ||
if err != nil { | ||
return nil, fmt.Errorf("error executing sql statement: %v", err) | ||
} | ||
defer rows.Close() | ||
|
||
columns, err := rows.Columns() | ||
if err != nil { | ||
return nil, fmt.Errorf("error getting columns: %v", err) | ||
} | ||
|
||
// Scan the result into a slice of slices | ||
var results [][]interface{} | ||
for rows.Next() { | ||
values := make([]interface{}, len(columns)) | ||
pointers := make([]interface{}, len(columns)) | ||
for i := range values { | ||
pointers[i] = &values[i] | ||
} | ||
|
||
if err := rows.Scan(pointers...); err != nil { | ||
return nil, fmt.Errorf("error scanning row: %v", err) | ||
} | ||
|
||
results = append(results, values) | ||
} | ||
|
||
// Check for errors from iterating over rows | ||
if err := rows.Err(); err != nil { | ||
return nil, fmt.Errorf("error iterating rows: %v", err) | ||
} | ||
|
||
// Convert the result to JSON | ||
queryResult := types.QueryResult{ | ||
Columns: columns, | ||
Rows: results, | ||
} | ||
|
||
jsonData, err := json.Marshal(queryResult) | ||
if err != nil { | ||
return nil, fmt.Errorf("error marshaling json: %v", err) | ||
} | ||
|
||
return jsonData, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package snowflake | ||
|
||
import ( | ||
"database/sql" | ||
"encoding/json" | ||
"fmt" | ||
"reflect" | ||
"regexp" | ||
"testing" | ||
|
||
"github.com/DATA-DOG/go-sqlmock" | ||
"github.com/thesaas-company/xray/types" | ||
) | ||
|
||
func MockDB() (*sql.DB, sqlmock.Sqlmock) { | ||
db, mock, err := sqlmock.New() | ||
if err != nil { | ||
panic("An error occured while creating a new mock database connection") | ||
} | ||
|
||
return db, mock | ||
} | ||
|
||
func TestSchema(t *testing.T) { | ||
db, mock := MockDB() | ||
defer db.Close() | ||
|
||
table_name := "user" | ||
|
||
columns := []string{"name", "type", "IsNullable", "DefaultValue", "IsUpdatable", "IsIdenity", "IsGenerated", "IsUnique", "IsSystemColumn", "IsHidden", "IsReadOnly", "IsComputed", "IsSparse", "IsColumnSet", "IsSelfReplacing", "ScopeName", "ScopeSchema", "OrdinalPosition"} | ||
mockRows := sqlmock.NewRows(columns).AddRow("id", "int", true, 1, true, false, true, true, false, false, true, false, true, false, true, "scope1", "schema1", 1) | ||
|
||
mock.ExpectQuery(regexp.QuoteMeta(SNOWFLAKE_SCHEMA_QUERY)).WithArgs(table_name).WillReturnRows(mockRows) | ||
|
||
s, err := NewSnowflake(db) | ||
if err != nil { | ||
t.Errorf("error initialising snowflake: %s", err) | ||
} | ||
|
||
res, err := s.Schema(table_name) | ||
if err != nil { | ||
t.Errorf("error executing query : %v", err) | ||
} | ||
|
||
fmt.Printf("Table schema %+v\n", res) | ||
|
||
if err := mock.ExpectationsWereMet(); err != nil { | ||
t.Errorf("there was unfulfilled expectations: %s", err) | ||
} | ||
|
||
} | ||
|
||
func TestExecute(t *testing.T) { | ||
db, mock := MockDB() | ||
defer db.Close() | ||
|
||
query := `SELECT id, name FROM "user"` | ||
mockRows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Rohan") | ||
|
||
mock.ExpectQuery(regexp.QuoteMeta(query)).WillReturnRows(mockRows) | ||
|
||
p, err := NewSnowflake(db) | ||
if err != nil { | ||
t.Errorf("error executing query: %s", err) | ||
} | ||
res, err := p.Execute(query) | ||
if err != nil { | ||
t.Errorf("error executing the query: %s", err) | ||
} | ||
|
||
var result types.QueryResult | ||
if err := json.Unmarshal(res, &result); err != nil { | ||
t.Errorf("error unmarshalling the result: %s", err) | ||
} | ||
|
||
fmt.Printf("Query result: %+v\n", result) | ||
if err := mock.ExpectationsWereMet(); err != nil { | ||
t.Errorf("there were unfulfilled expectations: %s", err) | ||
} | ||
} | ||
|
||
func TestTables(t *testing.T) { | ||
db, mock := MockDB() | ||
defer db.Close() | ||
|
||
tableList := []string{"user", "product", "order"} | ||
Warehouse := "datasherlock" | ||
mock.ExpectQuery("USE WAREHOUSE ").WithArgs(Warehouse).WillReturnRows(sqlmock.NewRows([]string{"result"}).AddRow("")) | ||
|
||
rows := sqlmock.NewRows([]string{"table_name"}). | ||
AddRow(tableList[0]). | ||
AddRow(tableList[1]). | ||
AddRow(tableList[2]) | ||
mock.ExpectQuery(regexp.QuoteMeta(SNOWFLAKE_TABLES_LIST_QUERY)).WillReturnRows(rows) | ||
|
||
s, err := NewSnowflake(db) | ||
if err != nil { | ||
t.Fatalf("error initializing snowflake: %s", err) | ||
} | ||
|
||
query := fmt.Sprintf("USE WAREHOUSE %s", Warehouse) | ||
_, err = s.Tables(query) | ||
if err != nil { | ||
return | ||
} | ||
|
||
tables, err := s.Tables("test") // Database name isn't used in the query, so you can pass any value here | ||
if err != nil { | ||
t.Errorf("error retrieving table names: %s", err) | ||
} | ||
|
||
expected := tableList // Using the same list as returned by the mock | ||
if !reflect.DeepEqual(tables, expected) { | ||
t.Errorf("expected: %v, got: %v", expected, tables) | ||
} | ||
|
||
if err := mock.ExpectationsWereMet(); err != nil { | ||
t.Errorf("there were unfulfilled expectations: %s", err) | ||
} | ||
} |
Oops, something went wrong.