This repository has been archived by the owner on Oct 24, 2024. It is now read-only.
forked from smithoss/gonymizer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
db_client.go
360 lines (314 loc) · 9.27 KB
/
db_client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
package gonymizer
import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/lib/pq"
log "github.com/sirupsen/logrus"
)
// RowCounts is used to keep track of the number of rows for a given schema and table.
type RowCounts struct {
SchemaName *string
TableName *string
Count *int
}
// CheckIfDbExists checks to see if the database exists using the provided db connection.
func CheckIfDbExists(db *sql.DB, dbName string) (exists bool, err error) {
s := "SELECT exists(SELECT datname FROM pg_catalog.pg_database WHERE lower(datname) = lower($1));"
row := db.QueryRow(s, dbName)
_ = row.Scan(&exists)
log.Debugf("Exists: %t", exists)
return exists, err
}
// GetAllProceduresInSchema will return all procedures for the given schemas in SQL form.
func GetAllProceduresInSchema(conf PGConfig, schema string) ([]string, error) {
var (
rows *sql.Rows
procedures []string
)
db, err := OpenDB(conf)
if err != nil {
log.Error(err)
return nil, err
}
defer db.Close()
rows, err = db.Query(`
SELECT pg_get_functiondef(f.oid)
FROM pg_catalog.pg_proc f
INNER JOIN pg_catalog.pg_namespace n ON (f.pronamespace = n.oid)
WHERE n.nspname = $1`, schema)
if err != nil {
log.Error(err)
return nil, err
}
defer rows.Close()
for {
var procedure string
for rows.Next() {
_ = rows.Scan(&procedure)
procedures = append(procedures, procedure)
}
if !rows.NextResultSet() {
break
}
}
return procedures, nil
}
// GetAllSchemaColumns will return a row pointer to a list of table and column names for the given database connection.
func GetAllSchemaColumns(db *sql.DB) (*sql.Rows, error) {
query := `
SELECT table_catalog, table_schema, table_name, column_name, data_type, ordinal_position,
CASE
WHEN is_nullable = 'YES' THEN
TRUE
WHEN is_nullable = 'NO' THEN
FALSE
END AS is_nullable
FROM information_schema.columns
WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
ORDER BY table_schema, table_name, ordinal_position
`
rows, err := db.Query(query)
if err != nil {
log.Error(err)
return nil, err
}
return rows, nil
}
// GetAllTablesInSchema will return a list of database tables for a given database configuration.
func GetAllTablesInSchema(conf PGConfig, schema string) ([]string, error) {
var (
rows *sql.Rows
tableNames []string
)
db, err := OpenDB(conf)
if err != nil {
log.Error(err)
return nil, err
}
defer db.Close()
// Set default to the public schema
if len(schema) < 1 {
schema = "public"
}
rows, err = db.Query(`
SELECT table_name
FROM information_schema.tables
WHERE table_schema = $1`,
schema,
)
if err != nil {
log.Error(err)
return nil, err
}
defer rows.Close()
for {
var tableName string
for rows.Next() {
_ = rows.Scan(&tableName)
tableNames = append(tableNames, tableName)
}
if !rows.NextResultSet() {
break
}
}
return tableNames, nil
}
// GetSchemasInDatabase returns a list of schemas for a given database configuration. If an excludeSchemas list is
// provided GetSchemasInDatabase will leave them out of the returned list of schemas.
func GetSchemasInDatabase(conf PGConfig, excludeSchemas []string) ([]string, error) {
var (
rows *sql.Rows
includedSchemas []string
)
db, err := OpenDB(conf)
if err != nil {
log.Error(err)
return nil, err
}
rows, err = db.Query(`
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ($1)`, pq.Array(excludeSchemas))
if err != nil {
log.Error("Query IN clause: ")
log.Error(err)
return nil, err
}
for {
var schema string
for rows.Next() {
found := false
_ = rows.Scan(&schema)
for _, ecs := range excludeSchemas {
if ecs == schema {
found = true
}
}
if !found {
includedSchemas = append(includedSchemas, schema)
}
}
// Loop until the resulting set is compelete
if !rows.NextResultSet() {
return includedSchemas, db.Close()
}
}
}
// GetSchemaColumnEquals returns a pointer to a list of database rows containing the names of tables and columns for
// the provided schema (using the SQL equals operator).
func GetSchemaColumnEquals(db *sql.DB, schema string) (*sql.Rows, error) {
rows, err := db.Query(`
SELECT table_catalog, table_schema, table_name, column_name, data_type, ordinal_position,
CASE
WHEN is_nullable = 'YES' THEN
TRUE
WHEN is_nullable = 'NO' THEN
FALSE
END AS is_nullable
FROM information_schema.columns
WHERE table_schema = $1
ORDER BY table_schema, table_name, ordinal_position`, schema)
if err != nil {
log.Error(err)
return nil, err
}
return rows, nil
}
// GetSchemaColumnsLike will return a pointer to a list of database rows containing the names of tables and columns for
// the provided schema (using the SQL LIKE operator).
func GetSchemaColumnsLike(db *sql.DB, schemaPrefix string) (*sql.Rows, error) {
var selectedSchema string
// NOTE: Since we are grabbing a schema that matches the schemaPrefix we will assume UNIFORMITY in the DDL across all
// tables in each schema that match the prefix. Following this requirement, we can assume that we only need to grab a
// single schema that matches the prefix and use it as the map for all schemas that match the schemaPrefix.
err := db.QueryRow("SELECT table_schema FROM information_schema.columns WHERE table_schema LIKE $1 LIMIT 1",
schemaPrefix+"%").Scan(&selectedSchema)
switch err {
case sql.ErrNoRows:
fmt.Println("No rows were returned!")
case nil:
break
default:
panic(err)
}
// Now grab all the columns from this schema
rows, err := db.Query(`
SELECT table_catalog, table_schema, table_name, column_name, data_type, ordinal_position,
CASE
WHEN is_nullable = 'YES' THEN
TRUE
WHEN is_nullable = 'NO' THEN
FALSE
END AS is_nullable
FROM information_schema.columns
WHERE table_schema = $1
ORDER BY table_schema, table_name, ordinal_position`, selectedSchema)
if err != nil {
log.Error(err)
return nil, err
}
return rows, nil
}
// GetTableRowCountsInDB collects the number of rows for each table in the given supplied schema prefix and will not
// include any of the tables listed in the excludeTable list. Returns a list of tables the number of rows for each.
func GetTableRowCountsInDB(conf PGConfig, schemaPrefix string, excludeTable []string) (*[]RowCounts, error) {
var (
rows *sql.Rows
dbRowCounts []RowCounts
)
db, err := OpenDB(conf)
if err != nil {
log.Error(err)
return nil, err
}
defer db.Close()
// Get a list of all schemas + tables in the database (excluding excludeTable)
query := `
SELECT schemaname, tablename
FROM pg_catalog.pg_tables
WHERE schemaname NOT LIKE 'pg_%'
AND schemaname != 'information_schema'
`
if len(excludeTable) > 0 {
query += " AND tablename NOT IN ($1)"
query += "\n ORDER BY schemaname, tablename;"
rows, err = db.Query(query, pq.Array(excludeTable))
} else {
query += " ORDER BY schemaname, tablename;"
rows, err = db.Query(query)
}
if err != nil {
return nil, err
} else if rows == nil {
return nil, errors.New("Returned 0 tables in " + conf.DefaultDBName + ".")
}
// Build array string to pass into query (Injection Safe)
// See: https://groups.google.com/forum/#!msg/golang-nuts/vHbg09g7s2I/RKU7XsO25SIJ
for {
for rows.Next() {
var (
schemaName string
tableName string
count int
exclude bool
)
count = 0
exclude = false
_ = rows.Scan(&schemaName, &tableName)
// Search exclude list to see if schema + table are in it. if so skip them
//TODO: Refactor this to use efficient search (key lookups are possible)
for _, e := range excludeTable {
s := strings.Split(e, ".")
if len(schemaPrefix) > 0 && strings.HasPrefix(s[0], schemaPrefix) && s[1] == tableName {
exclude = true
break
} else if schemaName == s[0] && tableName == s[1] {
exclude = true
break
}
}
if !exclude {
dbRowCounts = append(dbRowCounts, RowCounts{SchemaName: &schemaName, TableName: &tableName, Count: &count})
}
}
if !rows.NextResultSet() {
break
}
}
// Luckily Postgres is smart and does not blow away cache for a
// simple Count(*). See -> https://stackoverflow.com/questions/37097736/understanding-postgres-caching
for _, row := range dbRowCounts {
query := fmt.Sprintf("SELECT COUNT(*) FROM %s.%s;", *row.SchemaName, *row.TableName)
if err := db.QueryRow(query).Scan(row.Count); err != nil {
log.Error(err)
}
}
return &dbRowCounts, err
}
// KillDatabaseConnections will kill all connections to the provided database name.
func KillDatabaseConnections(db *sql.DB, dbName string) (err error) {
var success string
query := `
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE pid != pg_backend_pid()
AND datname = $1;`
err = db.QueryRow(query, dbName).Scan(&success)
if err != nil {
log.Error(err)
}
log.Debug("Success: ", success)
return err
}
// RenameDatabase will rename a database using the fromName to the toName.
func RenameDatabase(db *sql.DB, fromName, toName string) (err error) {
_, err = db.Exec(fmt.Sprintf("ALTER DATABASE %s RENAME TO %s", fromName, toName))
if err != nil {
log.Errorf("Unable to rename database '%s' -> '%s'", fromName, toName)
log.Error(err)
return err
}
return err
}