Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

user trip feature #59

Merged
merged 4 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion server/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
transactionRepository "github.com/Taehoya/pocket-mate/internal/pkg/repositories/transaction"
tripRepository "github.com/Taehoya/pocket-mate/internal/pkg/repositories/trip"
userRepository "github.com/Taehoya/pocket-mate/internal/pkg/repositories/user"
userTripRepository "github.com/Taehoya/pocket-mate/internal/pkg/repositories/usertrip"
countryUseCase "github.com/Taehoya/pocket-mate/internal/pkg/usecases/country"
transactionUseCase "github.com/Taehoya/pocket-mate/internal/pkg/usecases/transaction"
tripUsecase "github.com/Taehoya/pocket-mate/internal/pkg/usecases/trip"
Expand Down Expand Up @@ -44,8 +45,9 @@ func main() {
countryUseCase := countryUseCase.NewCountryUseCase(countryRepository)
userRepository := userRepository.NewUserRepository(db)
tripRepository := tripRepository.NewTripRepository(db)
userTripRepository := userTripRepository.NewUserTripRepository(db)
transactionRepository := transactionRepository.NewTransactionRepository(db)
tripUseCase := tripUsecase.NewTripUseCase(tripRepository, countryRepository, transactionRepository)
tripUseCase := tripUsecase.NewTripUseCase(tripRepository, userTripRepository, countryRepository, transactionRepository)
userUsecase := userUsecase.NewUserUseCase(userRepository)
transactionUseCase := transactionUseCase.NewTransactionUseCase(transactionRepository)

Expand Down
8 changes: 7 additions & 1 deletion server/internal/pkg/entities/trip.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Trip struct {
id int
name string
budget float64
leader bool
countryId int
description string
note Note
Expand All @@ -27,11 +28,12 @@ type Trip struct {
updatedAt time.Time
}

func NewTrip(id int, name string, budget float64, countryId int, description string, note Note, startDateTime time.Time, endDateTime time.Time, createdAt time.Time, updatedAt time.Time) *Trip {
func NewTrip(id int, name string, budget float64, leader bool, countryId int, description string, note Note, startDateTime time.Time, endDateTime time.Time, createdAt time.Time, updatedAt time.Time) *Trip {
return &Trip{
id: id,
name: name,
budget: budget,
leader: leader,
countryId: countryId,
description: description,
note: note,
Expand All @@ -55,6 +57,10 @@ func (t *Trip) Budget() float64 {
return t.budget
}

func (t *Trip) Leader() bool {
return t.leader
}

func (t *Trip) CountryID() int {
return t.countryId
}
Expand Down
13 changes: 9 additions & 4 deletions server/internal/pkg/mocks/trip/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@ func NewTripRepositoryMock() *TripRepositoryMock {
return new(TripRepositoryMock)
}

func (m *TripRepositoryMock) SaveTrip(ctx context.Context, name string, userId int, budget float64, countryId int, description string, note entities.Note, startDateTime time.Time, endDateTime time.Time) error {
func (m *TripRepositoryMock) SaveTrip(ctx context.Context, name string, userId int, budget float64, countryId int, description string, note entities.Note, startDateTime time.Time, endDateTime time.Time) (int, error) {
ret := m.Called(ctx, name, userId, budget, countryId, description, note, startDateTime, endDateTime)

var r0 error
var r0 int
if ret.Get(0) != nil {
r0 = ret.Get(0).(error)
r0 = ret.Get(0).(int)
}

return r0
var r1 error
if ret.Get(1) != nil {
r1 = ret.Get(1).(error)
}

return r0, r1
}

func (m *TripRepositoryMock) GetTrip(ctx context.Context, userId int) ([]*entities.Trip, error) {
Expand Down
48 changes: 48 additions & 0 deletions server/internal/pkg/mocks/usertrip/repository.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package mocks

import (
"context"

"github.com/stretchr/testify/mock"
)

type UserTripRepositoryMock struct {
mock.Mock
}

func NewUserTripRepositoryMock() *UserTripRepositoryMock {
return new(UserTripRepositoryMock)
}

func (m *UserTripRepositoryMock) SaveUserTrip(ctx context.Context, userId int, tripId int, leader bool) error {
ret := m.Called(ctx, userId, tripId, leader)

var r0 error
if ret.Get(0) != nil {
r0 = ret.Get(0).(error)
}

return r0
}

func (m *UserTripRepositoryMock) DeleteUserTrip(ctx context.Context, userId int, tripId int) error {
ret := m.Called(ctx, userId, tripId)

var r0 error
if ret.Get(0) != nil {
r0 = ret.Get(0).(error)
}

return r0
}

func (m *UserTripRepositoryMock) UpdateUserTrip(ctx context.Context, userId int, tripId int, leader bool) error {
ret := m.Called(ctx, userId, tripId, leader)

var r0 error
if ret.Get(0) != nil {
r0 = ret.Get(0).(error)
}

return r0
}
70 changes: 56 additions & 14 deletions server/internal/pkg/repositories/trip/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ func NewTripRepository(db *sql.DB) *TripRepository {
}
}

func (r *TripRepository) SaveTrip(ctx context.Context, name string, userId int, budget float64, countryId int, description string, note entities.Note, startDateTime time.Time, endDateTime time.Time) error {
func (r *TripRepository) SaveTrip(ctx context.Context, name string, userId int, budget float64, countryId int, description string, note entities.Note, startDateTime time.Time, endDateTime time.Time) (int, error) {
noteJson, err := json.Marshal(note)

if err != nil {
log.Printf("failed to marshal note: %v\n", err)
return fmt.Errorf("internal Server Error")
return -1, fmt.Errorf("internal Server Error")
}
noteString := string(noteJson)

Expand All @@ -40,31 +40,51 @@ func (r *TripRepository) SaveTrip(ctx context.Context, name string, userId int,
result, err := r.db.ExecContext(ctx, query, name, userId, budget, countryId, description, noteString, startDateTime, endDateTime)
if err != nil {
log.Printf("failed to execute query: %v\n", err)
return fmt.Errorf("internal Server Error")
return -1, fmt.Errorf("internal Server Error")
}

rows, err := result.RowsAffected()
if err != nil {
log.Printf("failed to get affected rows: %\nv", err)
return fmt.Errorf("internal Server Error")
return -1, fmt.Errorf("internal Server Error")
}

if rows != 1 {
log.Printf("expected 1 affected row, got %d\n", rows)
return fmt.Errorf("internal Server Error")
return -1, fmt.Errorf("internal Server Error")
}

return nil
id, err := result.LastInsertId()
if err != nil {
log.Printf("failed to get last inserted id: %v\n", err)
return -1, fmt.Errorf("internal Server Error")
}
return int(id), nil
}

func (r *TripRepository) GetTrip(ctx context.Context, userId int) ([]*entities.Trip, error) {
var trips []*entities.Trip
query := `
SELECT
id, name, budget, country_id, description, note, start_date_time, end_date_time, created_at, updated_at
id,
name,
budget,
leader,
country_id,
description,
note,
start_date_time,
end_date_time,
created_at,
updated_at
FROM
trips
WHERE user_id = ?`
user_trips ut
LEFT JOIN
trips t
ON
ut.trip_id = t.id
WHERE
t.user_id = ?`

rows, err := r.db.QueryContext(ctx, query, userId)
if err != nil {
Expand All @@ -77,6 +97,7 @@ func (r *TripRepository) GetTrip(ctx context.Context, userId int) ([]*entities.T
var id int
var name string
var budget float64
var leader bool
var countryId int
var description string
var noteJson string
Expand All @@ -86,7 +107,7 @@ func (r *TripRepository) GetTrip(ctx context.Context, userId int) ([]*entities.T
var createdAt time.Time
var updatedAt time.Time

if err := rows.Scan(&id, &name, &budget, &countryId, &description, &noteJson, &startDateTime, &endDateTime, &createdAt, &updatedAt); err != nil {
if err := rows.Scan(&id, &name, &budget, &leader, &countryId, &description, &noteJson, &startDateTime, &endDateTime, &createdAt, &updatedAt); err != nil {
log.Printf("failed to scan trip: %v\n", err)
return nil, fmt.Errorf("internal server error")
}
Expand All @@ -96,7 +117,7 @@ func (r *TripRepository) GetTrip(ctx context.Context, userId int) ([]*entities.T
return nil, fmt.Errorf("internal server error")
}

trip := entities.NewTrip(id, name, budget, countryId, description, note, startDateTime, endDateTime, createdAt, updatedAt)
trip := entities.NewTrip(id, name, budget, leader, countryId, description, note, startDateTime, endDateTime, createdAt, updatedAt)
trips = append(trips, trip)
}

Expand Down Expand Up @@ -158,7 +179,27 @@ func (r *TripRepository) UpdateTrip(ctx context.Context, tripId int, name string

func (r *TripRepository) GetTripById(ctx context.Context, tripId int) (*entities.Trip, error) {
var trip *entities.Trip
query := `SELECT id, name, budget, country_id, description, note, start_date_time, end_date_time, created_at, updated_at FROM trips WHERE id = ?;`
query := `
SELECT
id,
name,
budget,
leader,
country_id,
description,
note,
start_date_time,
end_date_time,
created_at,
updated_at
FROM
user_trips ut
LEFT JOIN
trips t
ON
ut.trip_id = t.id
WHERE
t.id = ?;`
rows, err := r.db.QueryContext(ctx, query, tripId)

if err != nil {
Expand All @@ -171,6 +212,7 @@ func (r *TripRepository) GetTripById(ctx context.Context, tripId int) (*entities
var id int
var name string
var budget float64
var leader bool
var countryId int
var description string
var noteJson string
Expand All @@ -180,7 +222,7 @@ func (r *TripRepository) GetTripById(ctx context.Context, tripId int) (*entities
var createdAt time.Time
var updatedAt time.Time

if err := rows.Scan(&id, &name, &budget, &countryId, &description, &noteJson, &startDateTime, &endDateTime, &createdAt, &updatedAt); err != nil {
if err := rows.Scan(&id, &name, &budget, &leader, &countryId, &description, &noteJson, &startDateTime, &endDateTime, &createdAt, &updatedAt); err != nil {
log.Printf("failed to scan trip: %v\n", err)
return nil, fmt.Errorf("internal server error")
}
Expand All @@ -190,7 +232,7 @@ func (r *TripRepository) GetTripById(ctx context.Context, tripId int) (*entities
return nil, fmt.Errorf("internal server error")
}

trip = entities.NewTrip(id, name, budget, countryId, description, note, startDateTime, endDateTime, createdAt, updatedAt)
trip = entities.NewTrip(id, name, budget, leader, countryId, description, note, startDateTime, endDateTime, createdAt, updatedAt)
}

if err := rows.Err(); err != nil {
Expand Down
18 changes: 4 additions & 14 deletions server/internal/pkg/repositories/trip/repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ func TestSaveTrip(t *testing.T) {
startDateTime := time.Now()
endDateTime := time.Now()

err = repository.SaveTrip(ctx, name, userId, budget, countryId, description, note, startDateTime, endDateTime)
id, err := repository.SaveTrip(ctx, name, userId, budget, countryId, description, note, startDateTime, endDateTime)
assert.NoError(t, err)
assert.NotEqual(t, id, -1)
})
}

Expand All @@ -55,20 +56,10 @@ func TestGetTrip(t *testing.T) {

ctx := context.TODO()
userId := 1
note := entities.Note{
NoteType: "test-type",
NoteColor: "test-note-color",
BoundColor: "test-bound-color",
}

expected := []*entities.Trip{
entities.NewTrip(1, "test-name", 1, 1.0000, "test-description", note, time.Now(), time.Now(), time.Now(), time.Now()),
}

trips, err := repository.GetTrip(ctx, userId)
assert.NoError(t, err)
assert.NotNil(t, trips)
assert.Equal(t, trips[0].ID(), expected[0].ID())
})
}

Expand Down Expand Up @@ -109,8 +100,7 @@ func TestUpdateTrip(t *testing.T) {
mysqltest.SetUp(db, "./setup_test.sql")

ctx := context.TODO()
tripId := 1
userId := 1
tripId := 2
name := "updated-name"
budget := 1000.0
countryId := 1
Expand All @@ -124,7 +114,7 @@ func TestUpdateTrip(t *testing.T) {
startDateTime := time.Now()
endDateTime := time.Now()

err = repository.UpdateTrip(ctx, userId, name, budget, countryId, description, note, startDateTime, endDateTime)
err = repository.UpdateTrip(ctx, tripId, name, budget, countryId, description, note, startDateTime, endDateTime)
assert.NoError(t, err)

trip, err := repository.GetTripById(ctx, tripId)
Expand Down
8 changes: 7 additions & 1 deletion server/internal/pkg/repositories/trip/setup_test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@ VALUES
INSERT INTO trips
(id, name, user_id, budget, country_id, description, note, start_date_time, end_date_time, created_at, updated_at, deleted_at)
VALUES
(1, 'test-name', 1, 1.0000, 1, 'test-description', '{"Bound": 0, "NoteColor": "#000000", "BoundColor": "#111111"}', '2023-11-13 15:04:05', '2023-11-13 15:04:05', '2023-11-13 14:05:27', '2023-11-13 14:05:27', NULL);
(1, 'test-name', 1, 1.0000, 1, 'test-description', '{"Bound": 0, "NoteColor": "#000000", "BoundColor": "#111111"}', '2023-11-13 15:04:05', '2023-11-13 15:04:05', '2023-11-13 14:05:27', '2023-11-13 14:05:27', NULL),
(2, 'test-name', 1, 1.0000, 1, 'test-description', '{"Bound": 0, "NoteColor": "#000000", "BoundColor": "#111111"}', '2023-11-13 15:04:05', '2023-11-13 15:04:05', '2023-11-13 14:05:27', '2023-11-13 14:05:27', NULL);

INSERT INTO user_trips
(user_id, trip_id, leader)
VALUES
(1, 2, 1);
3 changes: 2 additions & 1 deletion server/internal/pkg/repositories/trip/teardown_test.sql
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
DELETE FROM user_trips;
DELETE FROM trips;
DELETE FROM users;
DELETE FROM users;
Loading
Loading