Skip to content

Commit

Permalink
Use context-based database calls, add golden test for team
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonehusin committed Apr 5, 2024
1 parent d5e4547 commit f897680
Show file tree
Hide file tree
Showing 13 changed files with 362 additions and 108 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"userEnvProbe": "loginInteractiveShell",
"features": {
"ghcr.io/devcontainers/features/terraform:1": {},
"ghcr.io/devcontainers-contrib/features/direnv:1": {}
"ghcr.io/devcontainers-contrib/features/direnv:1": {},
"ghcr.io/guiyomh/features/just:0": {}
}
}
24 changes: 24 additions & 0 deletions Justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
set positional-arguments

# Run dev mode, which includes extra logging
dev *args='': generate
go run -tags dev . "$@"

# Run a regular build
run *args='': generate
go run . "$@"

test: generate
go test -v ./...

update-golden: generate
go test -v ./... -test.update-golden

mod:
go mod tidy

dependencies *args='':
./deps.sh "$@"

generate: dependencies
go generate ./...
29 changes: 16 additions & 13 deletions cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,35 +47,38 @@ var ImportCommand = &cli.Command{
Flags: ConcatFlags([][]cli.Flag{importFlags, flags}),
}

func importAction(ctx *cli.Context) error {
providerName := ctx.String("provider")
provider, err := pager.NewPager(providerName, ctx.String("provider-api-key"), ctx.String("provider-app-id"))
func importAction(cliCtx *cli.Context) error {
providerName := cliCtx.String("provider")
provider, err := pager.NewPager(providerName, cliCtx.String("provider-api-key"), cliCtx.String("provider-app-id"))
if err != nil {
return fmt.Errorf("initializing pager provider: %w", err)
}
fh, err := pager.NewFireHydrant(ctx.String("firehydrant-api-key"), ctx.String("firehydrant-api-endpoint"))
fh, err := pager.NewFireHydrant(cliCtx.String("firehydrant-api-key"), cliCtx.String("firehydrant-api-endpoint"))
if err != nil {
return fmt.Errorf("initializing FireHydrant client: %w", err)
}

if err := importUsers(ctx.Context, provider, fh); err != nil {
ctx := store.WithContext(cliCtx.Context)
defer store.FromContext(ctx).Close()

if err := importUsers(ctx, provider, fh); err != nil {
return fmt.Errorf("importing users: %w", err)
}
console.Infof("Imported users from %s.\n", providerName)

if err := importTeams(ctx.Context, provider, fh); err != nil {
if err := importTeams(ctx, provider, fh); err != nil {
return fmt.Errorf("importing teams: %w", err)
}
console.Infof("Imported teams from %s.\n", providerName)

tfr, err := tfrender.New(
ctx.String("output-dir"),
cliCtx.String("output-dir"),
fmt.Sprintf("%s_to_fh_signals.tf", strings.ToLower(providerName)),
)
if err != nil {
return fmt.Errorf("initializing Terraform render space: %w", err)
}
return tfr.Write(ctx.Context)
return tfr.Write(ctx)
}

func importTeams(ctx context.Context, provider pager.Pager, fh *pager.FireHydrant) error {
Expand Down Expand Up @@ -121,7 +124,7 @@ func importTeams(ctx context.Context, provider pager.Pager, fh *pager.FireHydran
continue
case 1:
console.Successf("[+ CREATE] '%s' will be created in FireHydrant.\n", extTeam.String())
if err := store.Query.InsertExtTeam(ctx, store.InsertExtTeamParams{
if err := store.UseQueries(ctx).InsertExtTeam(ctx, store.InsertExtTeamParams{
ID: extTeam.ID,
Name: extTeam.Name,
Slug: extTeam.Slug,
Expand All @@ -131,7 +134,7 @@ func importTeams(ctx context.Context, provider pager.Pager, fh *pager.FireHydran
}
continue
default:
if err := store.Query.InsertExtTeam(ctx, store.InsertExtTeamParams{
if err := store.UseQueries(ctx).InsertExtTeam(ctx, store.InsertExtTeamParams{
ID: extTeam.ID,
Name: extTeam.Name,
Slug: extTeam.Slug,
Expand All @@ -144,7 +147,7 @@ func importTeams(ctx context.Context, provider pager.Pager, fh *pager.FireHydran
}
}

allTeams, err := store.Query.ListExtTeams(ctx)
allTeams, err := store.UseQueries(ctx).ListExtTeams(ctx)
if err != nil {
return fmt.Errorf("unable to list all teams: %w", err)
}
Expand All @@ -161,7 +164,7 @@ func importTeams(ctx context.Context, provider pager.Pager, fh *pager.FireHydran
}

for _, member := range t.Members {
if err := store.Query.InsertExtMembership(ctx, store.InsertExtMembershipParams{
if err := store.UseQueries(ctx).InsertExtMembership(ctx, store.InsertExtMembershipParams{
TeamID: extTeam.ID,
UserID: member.ID,
}); err != nil {
Expand All @@ -185,7 +188,7 @@ func importUsers(ctx context.Context, provider pager.Pager, fh *pager.FireHydran
}
console.Successf("Found %d users from provider.\n", len(providerUsers))
for _, user := range providerUsers {
if err := store.Query.InsertExtUser(ctx, store.InsertExtUserParams{
if err := store.UseQueries(ctx).InsertExtUser(ctx, store.InsertExtUserParams{
ID: user.ID,
Name: user.Name,
Email: user.Email,
Expand Down
12 changes: 6 additions & 6 deletions pager/firehydrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewFireHydrant(apiKey string, apiURL string) (*FireHydrant, error) {

func (f *FireHydrant) ListTeams(ctx context.Context) ([]*Team, error) {
teams := []*Team{}
stored, err := store.Query.ListFhTeams(ctx)
stored, err := store.UseQueries(ctx).ListFhTeams(ctx)
if err == nil && len(stored) > 0 {
for _, t := range stored {
teams = append(teams, &Team{
Expand Down Expand Up @@ -59,7 +59,7 @@ func (f *FireHydrant) ListTeams(ctx context.Context) ([]*Team, error) {
}

for _, t := range teams {
if err := store.Query.InsertFhTeam(ctx, store.InsertFhTeamParams{
if err := store.UseQueries(ctx).InsertFhTeam(ctx, store.InsertFhTeamParams{
ID: t.ID,
Name: t.Name,
Slug: t.Slug,
Expand All @@ -85,7 +85,7 @@ func (f *FireHydrant) toTeam(team firehydrant.TeamResponse) *Team {
// the provided API key access.
func (f *FireHydrant) ListUsers(ctx context.Context) ([]*User, error) {
users := []*User{}
stored, err := store.Query.ListFhUsers(ctx)
stored, err := store.UseQueries(ctx).ListFhUsers(ctx)
if err == nil && len(stored) > 0 {
for _, u := range stored {
users = append(users, &User{
Expand Down Expand Up @@ -115,7 +115,7 @@ func (f *FireHydrant) ListUsers(ctx context.Context) ([]*User, error) {
}

for _, u := range users {
if err := store.Query.InsertFhUser(ctx, store.InsertFhUserParams{
if err := store.UseQueries(ctx).InsertFhUser(ctx, store.InsertFhUserParams{
ID: u.ID,
Email: u.Email,
Name: u.Name,
Expand Down Expand Up @@ -147,7 +147,7 @@ func (f *FireHydrant) MatchUsers(ctx context.Context, users []*User) ([]*User, e

unmatchedUsers := []*User{}
for _, user := range users {
fhUser, err := store.Query.GetFhUserByEmail(ctx, user.Email)
fhUser, err := store.UseQueries(ctx).GetFhUserByEmail(ctx, user.Email)
if err == nil {
if err := f.PairUsers(ctx, fhUser.ID, user.ID); err != nil {
return nil, fmt.Errorf("pairing users: %w", err)
Expand All @@ -161,7 +161,7 @@ func (f *FireHydrant) MatchUsers(ctx context.Context, users []*User) ([]*User, e
}

func (f *FireHydrant) PairUsers(ctx context.Context, fhUserID string, extUserID string) error {
return store.Query.LinkExtUser(ctx, store.LinkExtUserParams{
return store.UseQueries(ctx).LinkExtUser(ctx, store.LinkExtUserParams{
FhUserID: sql.NullString{Valid: true, String: fhUserID},
ID: extUserID,
})
Expand Down
16 changes: 16 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,19 @@ Afterwards, the tool will generate the mapping appropriately, handling de-duplic
- [ ] Support for importing escalation policies
- [ ] Auto-run `terraform apply` for users who would not manage their organization with Terraform after importing
- [ ] Build + publish Docker image to simplify usage down to `docker run firehydrant/signals-migrator import`

## Developing

A devcontainer setup has been prepared to be used in VS Code. Run `direnv allow` to auto-load `.env` file.

If alternative method is preferred, you will need:

- Go compiler
- Install tools in `./deps.sh`

Also recommended to smooth out development outside of devcontainer:

- [direnv](https://direnv.net/) for autoloading `.env` file, also automatically adds `./bin` to `$PATH`.
- [just](https://just.systems/) for running tasks defined in `Justfile`.

Most commands in `Justfile` can be run out in regular bash too.
32 changes: 18 additions & 14 deletions store/open.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,28 @@ package store
import (
"context"
"database/sql"
"time"
)

func openDB() *Queries {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

func NewStore() *Store {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
panic(err)
}
_, err = db.ExecContext(ctx, `PRAGMA foreign_keys = true;`)
if err != nil {
panic(err)
}
_, err = db.ExecContext(ctx, schema)
if err != nil {
panic(err)
}
return New(db)
return &Store{conn: db}
}

func (s *Store) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return s.conn.ExecContext(ctx, query, args...)
}

func (s *Store) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return s.conn.PrepareContext(ctx, query)
}

func (s *Store) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return s.conn.QueryContext(ctx, query, args...)
}

func (s *Store) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return s.conn.QueryRowContext(ctx, query, args...)
}
60 changes: 22 additions & 38 deletions store/open_dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@ import (
"github.com/fatih/color"
)

type loggedQueries struct {
q *sql.DB
func NewStore() *Store {
f := filepath.Join(os.TempDir(), "signals-migrator.db")
log.Printf("using db file: %s", f)

db, err := sql.Open("sqlite", f)
if err != nil {
panic(err)
}
return &Store{conn: db}
}

func (q *loggedQueries) log(t time.Duration, queryStr string) {
func (s *Store) log(t time.Duration, queryStr string) {
qInfo := strings.SplitN(queryStr, "\n", 2)
name := strings.TrimSpace(qInfo[0])
query := ""
Expand All @@ -33,49 +40,26 @@ func (q *loggedQueries) log(t time.Duration, queryStr string) {
)
}

func (q *loggedQueries) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
func (s *Store) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
t := time.Now()
defer func() { q.log(time.Since(t), query) }()
return q.q.ExecContext(ctx, query, args...)
defer func() { s.log(time.Since(t), query) }()
return s.conn.ExecContext(ctx, query, args...)
}

func (q *loggedQueries) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
func (s *Store) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
t := time.Now()
defer func() { q.log(time.Since(t), query) }()
return q.q.PrepareContext(ctx, query)
defer func() { s.log(time.Since(t), query) }()
return s.conn.PrepareContext(ctx, query)
}

func (q *loggedQueries) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
func (s *Store) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
t := time.Now()
defer func() { q.log(time.Since(t), query) }()
return q.q.QueryContext(ctx, query, args...)
defer func() { s.log(time.Since(t), query) }()
return s.conn.QueryContext(ctx, query, args...)
}

func (q *loggedQueries) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
func (s *Store) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
t := time.Now()
defer func() { q.log(time.Since(t), query) }()
return q.q.QueryRowContext(ctx, query, args...)
}

func openDB() *Queries {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

f := filepath.Join(os.TempDir(), "signals-migrator.db")
log.Printf("using db file %s", f)

db, err := sql.Open("sqlite", f)
if err != nil {
panic(err)
}
dbtx := &loggedQueries{q: db}
_, err = dbtx.ExecContext(ctx, `PRAGMA foreign_keys = true;`)
if err != nil {
panic(err)
}
_, err = dbtx.ExecContext(ctx, schema)
if err != nil {
panic(err)
}
return New(dbtx)
defer func() { s.log(time.Since(t), query) }()
return s.conn.QueryRowContext(ctx, query, args...)
}
Loading

0 comments on commit f897680

Please sign in to comment.