From 9c66b6530cb9da4cb863a86830e3405a687225da Mon Sep 17 00:00:00 2001 From: I539231 Date: Sun, 20 Aug 2023 10:00:27 +0200 Subject: [PATCH 1/2] feat: Start Gorm migration process --- .../jackc/pgservicefile/.travis.yml | 9 + .../pgx/v5/internal/nbconn/bufferqueue.go | 70 + .../jackc/pgx/v5/internal/nbconn/nbconn.go | 520 ++++ .../internal/nbconn/nbconn_fake_non_block.go | 11 + .../internal/nbconn/nbconn_real_non_block.go | 81 + .../vendor/github.com/jinzhu/gorm/README.md | 5 - .../github.com/jinzhu/gorm/association.go | 377 --- .../vendor/github.com/jinzhu/gorm/callback.go | 250 -- .../github.com/jinzhu/gorm/callback_create.go | 197 -- .../github.com/jinzhu/gorm/callback_delete.go | 63 - .../github.com/jinzhu/gorm/callback_query.go | 109 - .../jinzhu/gorm/callback_query_preload.go | 410 ---- .../jinzhu/gorm/callback_row_query.go | 41 - .../github.com/jinzhu/gorm/callback_save.go | 170 -- .../github.com/jinzhu/gorm/callback_update.go | 121 - .../vendor/github.com/jinzhu/gorm/dialect.go | 147 -- .../github.com/jinzhu/gorm/dialect_common.go | 196 -- .../github.com/jinzhu/gorm/dialect_mysql.go | 246 -- .../jinzhu/gorm/dialect_postgres.go | 147 -- .../github.com/jinzhu/gorm/dialect_sqlite3.go | 107 - .../jinzhu/gorm/dialects/mysql/mysql.go | 3 - .../jinzhu/gorm/dialects/postgres/postgres.go | 81 - .../github.com/jinzhu/gorm/docker-compose.yml | 30 - .../vendor/github.com/jinzhu/gorm/errors.go | 72 - .../vendor/github.com/jinzhu/gorm/field.go | 66 - .../github.com/jinzhu/gorm/interface.go | 24 - .../jinzhu/gorm/join_table_handler.go | 211 -- .../vendor/github.com/jinzhu/gorm/logger.go | 141 -- .../vendor/github.com/jinzhu/gorm/main.go | 886 ------- .../vendor/github.com/jinzhu/gorm/model.go | 14 - .../github.com/jinzhu/gorm/model_struct.go | 677 ------ .../vendor/github.com/jinzhu/gorm/naming.go | 124 - .../vendor/github.com/jinzhu/gorm/scope.go | 1425 ----------- .../vendor/github.com/jinzhu/gorm/search.go | 203 -- .../vendor/github.com/jinzhu/gorm/test_all.sh | 5 - .../vendor/github.com/jinzhu/gorm/utils.go | 226 -- .../vendor/github.com/jinzhu/gorm/wercker.yml | 149 -- .../vendor/github.com/jinzhu/now/Guardfile | 3 + .../github.com/jinzhu/{gorm => now}/License | 0 .../vendor/github.com/jinzhu/now/README.md | 137 ++ .../vendor/github.com/jinzhu/now/main.go | 200 ++ .../vendor/github.com/jinzhu/now/now.go | 245 ++ .../vendor/github.com/jinzhu/now/time.go | 9 + .../vendor/github.com/lib/pq/.gitignore | 6 - .../vendor/github.com/lib/pq/LICENSE.md | 8 - .../vendor/github.com/lib/pq/README.md | 36 - .../vendor/github.com/lib/pq/TESTS.md | 33 - .../vendor/github.com/lib/pq/array.go | 895 ------- .../vendor/github.com/lib/pq/buf.go | 91 - .../vendor/github.com/lib/pq/conn.go | 2112 ----------------- .../vendor/github.com/lib/pq/conn_go115.go | 8 - .../vendor/github.com/lib/pq/conn_go18.go | 247 -- .../vendor/github.com/lib/pq/connector.go | 120 - .../vendor/github.com/lib/pq/copy.go | 348 --- .../vendor/github.com/lib/pq/doc.go | 268 --- .../vendor/github.com/lib/pq/encode.go | 632 ----- .../vendor/github.com/lib/pq/error.go | 523 ---- .../vendor/github.com/lib/pq/hstore/hstore.go | 118 - .../vendor/github.com/lib/pq/krb.go | 27 - .../vendor/github.com/lib/pq/notice.go | 72 - .../vendor/github.com/lib/pq/notify.go | 858 ------- .../vendor/github.com/lib/pq/oid/doc.go | 6 - .../vendor/github.com/lib/pq/oid/types.go | 343 --- .../vendor/github.com/lib/pq/rows.go | 93 - .../vendor/github.com/lib/pq/scram/scram.go | 264 --- .../vendor/github.com/lib/pq/ssl.go | 204 -- .../github.com/lib/pq/ssl_permissions.go | 93 - .../vendor/github.com/lib/pq/ssl_windows.go | 10 - .../vendor/github.com/lib/pq/url.go | 76 - .../vendor/github.com/lib/pq/user_other.go | 10 - .../vendor/github.com/lib/pq/user_posix.go | 25 - .../vendor/github.com/lib/pq/user_windows.go | 27 - .../vendor/github.com/lib/pq/uuid.go | 23 - .../gorm => gorm.io/driver/mysql}/.gitignore | 3 + .../vendor/gorm.io/driver/mysql/License | 21 + .../vendor/gorm.io/driver/mysql/README.md | 51 + .../gorm.io/driver/mysql/error_translator.go | 21 + .../vendor/gorm.io/driver/mysql/migrator.go | 408 ++++ .../vendor/gorm.io/driver/mysql/mysql.go | 533 +++++ .../vendor/gorm.io/driver/postgres/.gitignore | 1 + .../vendor/gorm.io/driver/postgres/License | 21 + .../vendor/gorm.io/driver/postgres/README.md | 31 + .../driver/postgres/error_translator.go | 44 + .../gorm.io/driver/postgres/migrator.go | 771 ++++++ .../gorm.io/driver/postgres/postgres.go | 249 ++ .../vendor/gorm.io/gorm/.gitignore | 7 + .../vendor/gorm.io/gorm/.golangci.yml | 20 + .../vendor/gorm.io/gorm/LICENSE | 21 + .../vendor/gorm.io/gorm/README.md | 44 + .../vendor/gorm.io/gorm/association.go | 579 +++++ .../vendor/gorm.io/gorm/callbacks.go | 341 +++ .../gorm.io/gorm/callbacks/associations.go | 453 ++++ .../gorm.io/gorm/callbacks/callbacks.go | 83 + .../gorm.io/gorm/callbacks/callmethod.go | 32 + .../vendor/gorm.io/gorm/callbacks/create.go | 345 +++ .../vendor/gorm.io/gorm/callbacks/delete.go | 185 ++ .../vendor/gorm.io/gorm/callbacks/helper.go | 152 ++ .../gorm.io/gorm/callbacks/interfaces.go | 39 + .../vendor/gorm.io/gorm/callbacks/preload.go | 266 +++ .../vendor/gorm.io/gorm/callbacks/query.go | 316 +++ .../vendor/gorm.io/gorm/callbacks/raw.go | 17 + .../vendor/gorm.io/gorm/callbacks/row.go | 23 + .../gorm.io/gorm/callbacks/transaction.go | 32 + .../vendor/gorm.io/gorm/callbacks/update.go | 304 +++ .../vendor/gorm.io/gorm/chainable_api.go | 469 ++++ .../vendor/gorm.io/gorm/clause/clause.go | 89 + .../vendor/gorm.io/gorm/clause/delete.go | 23 + .../vendor/gorm.io/gorm/clause/expression.go | 385 +++ .../vendor/gorm.io/gorm/clause/from.go | 37 + .../vendor/gorm.io/gorm/clause/group_by.go | 48 + .../vendor/gorm.io/gorm/clause/insert.go | 39 + .../vendor/gorm.io/gorm/clause/joins.go | 47 + .../vendor/gorm.io/gorm/clause/limit.go | 48 + .../vendor/gorm.io/gorm/clause/locking.go | 31 + .../vendor/gorm.io/gorm/clause/on_conflict.go | 59 + .../vendor/gorm.io/gorm/clause/order_by.go | 54 + .../vendor/gorm.io/gorm/clause/returning.go | 34 + .../vendor/gorm.io/gorm/clause/select.go | 59 + .../vendor/gorm.io/gorm/clause/set.go | 60 + .../vendor/gorm.io/gorm/clause/update.go | 38 + .../vendor/gorm.io/gorm/clause/values.go | 45 + .../vendor/gorm.io/gorm/clause/where.go | 190 ++ .../vendor/gorm.io/gorm/clause/with.go | 3 + .../vendor/gorm.io/gorm/errors.go | 52 + .../vendor/gorm.io/gorm/finisher_api.go | 766 ++++++ .../vendor/gorm.io/gorm/gorm.go | 503 ++++ .../vendor/gorm.io/gorm/interfaces.go | 98 + .../vendor/gorm.io/gorm/logger/logger.go | 211 ++ .../vendor/gorm.io/gorm/logger/sql.go | 162 ++ .../vendor/gorm.io/gorm/migrator.go | 109 + .../gorm.io/gorm/migrator/column_type.go | 107 + .../vendor/gorm.io/gorm/migrator/index.go | 43 + .../vendor/gorm.io/gorm/migrator/migrator.go | 965 ++++++++ .../gorm.io/gorm/migrator/table_type.go | 33 + .../vendor/gorm.io/gorm/model.go | 16 + .../vendor/gorm.io/gorm/prepare_stmt.go | 229 ++ .../vendor/gorm.io/gorm/scan.go | 342 +++ .../vendor/gorm.io/gorm/schema/check.go | 35 + .../vendor/gorm.io/gorm/schema/field.go | 988 ++++++++ .../vendor/gorm.io/gorm/schema/index.go | 166 ++ .../vendor/gorm.io/gorm/schema/interfaces.go | 36 + .../vendor/gorm.io/gorm/schema/naming.go | 186 ++ .../vendor/gorm.io/gorm/schema/pool.go | 19 + .../gorm.io/gorm/schema/relationship.go | 699 ++++++ .../vendor/gorm.io/gorm/schema/schema.go | 370 +++ .../vendor/gorm.io/gorm/schema/serializer.go | 170 ++ .../vendor/gorm.io/gorm/schema/utils.go | 208 ++ .../vendor/gorm.io/gorm/soft_delete.go | 170 ++ .../vendor/gorm.io/gorm/statement.go | 728 ++++++ .../vendor/gorm.io/gorm/utils/utils.go | 150 ++ 150 files changed, 16017 insertions(+), 14499 deletions(-) create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/README.md delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/association.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_create.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_delete.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query_preload.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_row_query.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_save.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_update.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_common.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_mysql.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_postgres.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/mysql/mysql.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/docker-compose.yml delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/errors.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/field.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/interface.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/join_table_handler.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/logger.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/main.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model_struct.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/naming.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/scope.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/search.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/test_all.sh delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/utils.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/wercker.yml create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/Guardfile rename src/code.cloudfoundry.org/vendor/github.com/jinzhu/{gorm => now}/License (100%) create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/README.md create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/main.go create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/now.go create mode 100644 src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/time.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/.gitignore delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/LICENSE.md delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/README.md delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/TESTS.md delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/array.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/buf.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go115.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go18.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/connector.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/copy.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/doc.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/encode.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/error.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/hstore/hstore.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/krb.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/notice.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/notify.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/doc.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/types.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/rows.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/scram/scram.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_permissions.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_windows.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/url.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_other.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_posix.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_windows.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/lib/pq/uuid.go rename src/code.cloudfoundry.org/vendor/{github.com/jinzhu/gorm => gorm.io/driver/mysql}/.gitignore (61%) create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/License create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/.gitignore create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/License create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/README.md create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/.gitignore create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/.golangci.yml create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/LICENSE create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/association.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/associations.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callbacks.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callmethod.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/delete.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/helper.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/interfaces.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/raw.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/row.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/transaction.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/clause.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/delete.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/from.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/group_by.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/insert.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/joins.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/on_conflict.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/order_by.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/returning.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/select.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/set.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/update.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/values.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/with.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/errors.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/column_type.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/index.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/table_type.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/model.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/pool.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/soft_delete.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml new file mode 100644 index 000000000..e176228e8 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go new file mode 100644 index 000000000..4bf25481c --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go @@ -0,0 +1,70 @@ +package nbconn + +import ( + "sync" +) + +const minBufferQueueLen = 8 + +type bufferQueue struct { + lock sync.Mutex + queue []*[]byte + r, w int +} + +func (bq *bufferQueue) pushBack(buf *[]byte) { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.w >= len(bq.queue) { + bq.growQueue() + } + bq.queue[bq.w] = buf + bq.w++ +} + +func (bq *bufferQueue) pushFront(buf *[]byte) { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.w >= len(bq.queue) { + bq.growQueue() + } + copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) + bq.queue[bq.r] = buf + bq.w++ +} + +func (bq *bufferQueue) popFront() *[]byte { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.r == bq.w { + return nil + } + + buf := bq.queue[bq.r] + bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. + bq.r++ + + if bq.r == bq.w { + bq.r = 0 + bq.w = 0 + if len(bq.queue) > minBufferQueueLen { + bq.queue = make([]*[]byte, minBufferQueueLen) + } + } + + return buf +} + +func (bq *bufferQueue) growQueue() { + desiredLen := (len(bq.queue) + 1) * 3 / 2 + if desiredLen < minBufferQueueLen { + desiredLen = minBufferQueueLen + } + + newQueue := make([]*[]byte, desiredLen) + copy(newQueue, bq.queue) + bq.queue = newQueue +} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go new file mode 100644 index 000000000..7a38383f0 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go @@ -0,0 +1,520 @@ +// Package nbconn implements a non-blocking net.Conn wrapper. +// +// It is designed to solve three problems. +// +// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all +// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. +// +// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. +// +// The third is to efficiently check if a connection has been closed via a non-blocking read. +package nbconn + +import ( + "crypto/tls" + "errors" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +var errClosed = errors.New("closed") +var ErrWouldBlock = new(wouldBlockError) + +const fakeNonblockingWriteWaitDuration = 100 * time.Millisecond +const minNonblockingReadWaitDuration = time.Microsecond +const maxNonblockingReadWaitDuration = 100 * time.Millisecond + +// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read +// mode. +var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) + +// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to +// ignore all future calls. +var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) + +// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. +type wouldBlockError struct{} + +func (*wouldBlockError) Error() string { + return "would block" +} + +func (*wouldBlockError) Timeout() bool { return true } +func (*wouldBlockError) Temporary() bool { return true } + +// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to +// the underlying connection. +type Conn interface { + net.Conn + + // Flush flushes any buffered writes. + Flush() error + + // BufferReadUntilBlock reads and buffers any successfully read bytes until the read would block. + BufferReadUntilBlock() error +} + +// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. +type NetConn struct { + // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit + // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and + // https://github.com/jackc/pgx/issues/1307. Only access with atomics + closed int64 // 0 = not closed, 1 = closed + + conn net.Conn + rawConn syscall.RawConn + + readQueue bufferQueue + writeQueue bufferQueue + + readFlushLock sync.Mutex + // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the + // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. + nonblockWriteFunc func(fd uintptr) (done bool) + nonblockWriteBuf []byte + nonblockWriteErr error + nonblockWriteN int + + // non-blocking reads with syscall.RawConn are done with a callback function. By using these fields instead of the + // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. + nonblockReadFunc func(fd uintptr) (done bool) + nonblockReadBuf []byte + nonblockReadErr error + nonblockReadN int + + readDeadlineLock sync.Mutex + readDeadline time.Time + readNonblocking bool + fakeNonBlockingShortReadCount int + fakeNonblockingReadWaitDuration time.Duration + + writeDeadlineLock sync.Mutex + writeDeadline time.Time +} + +func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { + nc := &NetConn{ + conn: conn, + fakeNonblockingReadWaitDuration: maxNonblockingReadWaitDuration, + } + + if !fakeNonBlockingIO { + if sc, ok := conn.(syscall.Conn); ok { + if rawConn, err := sc.SyscallConn(); err == nil { + nc.rawConn = rawConn + } + } + } + + return nc +} + +// Read implements io.Reader. +func (c *NetConn) Read(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + + err = c.flush() + if err != nil { + return 0, err + } + + for n < len(b) { + buf := c.readQueue.popFront() + if buf == nil { + break + } + copiedN := copy(b[n:], *buf) + if copiedN < len(*buf) { + *buf = (*buf)[copiedN:] + c.readQueue.pushFront(buf) + } else { + iobufpool.Put(buf) + } + n += copiedN + } + + // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to + // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. + if n > 0 { + return n, nil + } + + var readNonblocking bool + c.readDeadlineLock.Lock() + readNonblocking = c.readNonblocking + c.readDeadlineLock.Unlock() + + var readN int + if readNonblocking { + readN, err = c.nonblockingRead(b[n:]) + } else { + readN, err = c.conn.Read(b[n:]) + } + n += readN + return n, err +} + +// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is +// closed. Call Flush to actually write to the underlying connection. +func (c *NetConn) Write(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + buf := iobufpool.Get(len(b)) + copy(*buf, b) + c.writeQueue.pushBack(buf) + return len(b), nil +} + +func (c *NetConn) Close() (err error) { + swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) + if !swapped { + return errClosed + } + + defer func() { + closeErr := c.conn.Close() + if err == nil { + err = closeErr + } + }() + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + err = c.flush() + if err != nil { + return err + } + + return nil +} + +func (c *NetConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *NetConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). +func (c *NetConn) SetDeadline(t time.Time) error { + err := c.SetReadDeadline(t) + if err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. +func (c *NetConn) SetReadDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + if c.readDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.readDeadline = t + return nil + } + + if t == NonBlockingDeadline { + c.readNonblocking = true + t = time.Time{} + } else { + c.readNonblocking = false + } + + c.readDeadline = t + + return c.conn.SetReadDeadline(t) +} + +func (c *NetConn) SetWriteDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + if c.writeDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.writeDeadline = t + return nil + } + + c.writeDeadline = t + + return c.conn.SetWriteDeadline(t) +} + +func (c *NetConn) Flush() error { + if c.isClosed() { + return errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + return c.flush() +} + +// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. +func (c *NetConn) flush() error { + var stopChan chan struct{} + var errChan chan error + + defer func() { + if stopChan != nil { + select { + case stopChan <- struct{}{}: + case <-errChan: + } + } + }() + + for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { + remainingBuf := *buf + for len(remainingBuf) > 0 { + n, err := c.nonblockingWrite(remainingBuf) + remainingBuf = remainingBuf[n:] + if err != nil { + if !errors.Is(err, ErrWouldBlock) { + *buf = (*buf)[:len(remainingBuf)] + copy(*buf, remainingBuf) + c.writeQueue.pushFront(buf) + return err + } + + // Writing was blocked. Reading might unblock it. + if stopChan == nil { + stopChan, errChan = c.bufferNonblockingRead() + } + + select { + case err := <-errChan: + stopChan = nil + return err + default: + } + + } + } + iobufpool.Put(buf) + } + + return nil +} + +func (c *NetConn) BufferReadUntilBlock() error { + for { + buf := iobufpool.Get(8 * 1024) + n, err := c.nonblockingRead(*buf) + if n > 0 { + *buf = (*buf)[:n] + c.readQueue.pushBack(buf) + } else if n == 0 { + iobufpool.Put(buf) + } + + if err != nil { + if errors.Is(err, ErrWouldBlock) { + return nil + } else { + return err + } + } + } +} + +func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { + stopChan = make(chan struct{}) + errChan = make(chan error, 1) + + go func() { + for { + err := c.BufferReadUntilBlock() + if err != nil { + errChan <- err + return + } + + select { + case <-stopChan: + return + default: + } + } + }() + + return stopChan, errChan +} + +func (c *NetConn) isClosed() bool { + closed := atomic.LoadInt64(&c.closed) + return closed == 1 +} + +func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingWrite(b) + } else { + return c.realNonblockingWrite(b) + } +} + +func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + + deadline := time.Now().Add(fakeNonblockingWriteWaitDuration) + if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { + err = c.conn.SetWriteDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetWriteDeadline(c.writeDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Write(b) +} + +func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingRead(b) + } else { + return c.realNonblockingRead(b) + } +} + +func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + + // The first 5 reads only read 1 byte at a time. This should give us 4 chances to read when we are sure the bytes are + // already in Go or the OS's receive buffer. + if c.fakeNonBlockingShortReadCount < 5 && len(b) > 0 && c.fakeNonblockingReadWaitDuration < minNonblockingReadWaitDuration { + b = b[:1] + } + + startTime := time.Now() + deadline := startTime.Add(c.fakeNonblockingReadWaitDuration) + if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { + err = c.conn.SetReadDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // If the read was successful and the wait duration is not already the minimum + if err == nil && c.fakeNonblockingReadWaitDuration > minNonblockingReadWaitDuration { + endTime := time.Now() + + if n > 0 && c.fakeNonBlockingShortReadCount < 5 { + c.fakeNonBlockingShortReadCount++ + } + + // The wait duration should be 2x the fastest read that has occurred. This should give reasonable assurance that + // a Read deadline will not block a read before it has a chance to read data already in Go or the OS's receive + // buffer. + proposedWait := endTime.Sub(startTime) * 2 + if proposedWait < minNonblockingReadWaitDuration { + proposedWait = minNonblockingReadWaitDuration + } + if proposedWait < c.fakeNonblockingReadWaitDuration { + c.fakeNonblockingReadWaitDuration = proposedWait + } + } + + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetReadDeadline(c.readDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Read(b) +} + +// syscall.Conn is interface + +// TLSClient establishes a TLS connection as a client over conn using config. +// +// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby +// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the +// *TLSConn is returned. +func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { + tc := tls.Client(conn, config) + err := tc.Handshake() + if err != nil { + return nil, err + } + + // Ensure last written part of Handshake is actually sent. + err = conn.Flush() + if err != nil { + return nil, err + } + + return &TLSConn{ + tlsConn: tc, + nbConn: conn, + }, nil +} + +// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a +// tls.Conn. +type TLSConn struct { + tlsConn *tls.Conn + nbConn *NetConn +} + +func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } +func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } +func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } +func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } +func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } +func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } + +func (tc *TLSConn) Close() error { + // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then + // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our + // own 5 second deadline then make all set deadlines no-op. + tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) + tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) + + return tc.tlsConn.Close() +} + +func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } +func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } +func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go new file mode 100644 index 000000000..4915c6219 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go @@ -0,0 +1,11 @@ +//go:build !unix + +package nbconn + +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + return c.fakeNonblockingWrite(b) +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + return c.fakeNonblockingRead(b) +} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go new file mode 100644 index 000000000..e93372f25 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go @@ -0,0 +1,81 @@ +//go:build unix + +package nbconn + +import ( + "errors" + "io" + "syscall" +) + +// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + if c.nonblockWriteFunc == nil { + c.nonblockWriteFunc = func(fd uintptr) (done bool) { + c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) + return true + } + } + c.nonblockWriteBuf = b + c.nonblockWriteN = 0 + c.nonblockWriteErr = nil + + err = c.rawConn.Write(c.nonblockWriteFunc) + n = c.nonblockWriteN + c.nonblockWriteBuf = nil // ensure that no reference to b is kept. + if err == nil && c.nonblockWriteErr != nil { + if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = c.nonblockWriteErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + if c.nonblockReadFunc == nil { + c.nonblockReadFunc = func(fd uintptr) (done bool) { + c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf) + return true + } + } + c.nonblockReadBuf = b + c.nonblockReadN = 0 + c.nonblockReadErr = nil + + err = c.rawConn.Read(c.nonblockReadFunc) + n = c.nonblockReadN + c.nonblockReadBuf = nil // ensure that no reference to b is kept. + if err == nil && c.nonblockReadErr != nil { + if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = c.nonblockReadErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + // syscall read did not return an error and 0 bytes were read means EOF. + if n == 0 { + return 0, io.EOF + } + + return n, nil +} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/README.md b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/README.md deleted file mode 100644 index 85588a791..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# GORM - -GORM V2 moved to https://github.com/go-gorm/gorm - -GORM V1 Doc https://v1.gorm.io/ diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/association.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/association.go deleted file mode 100644 index a73344fe6..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/association.go +++ /dev/null @@ -1,377 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Association Mode contains some helper methods to handle relationship things easily. -type Association struct { - Error error - scope *Scope - column string - field *Field -} - -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} - -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames, associationForeignDBNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for idx, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) - } - } - } else { - // If has one/many relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, field.DBName) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - return association -} - -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break - } - } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break - } - } - } - } - - return association -} - -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { - return association.Replace() -} - -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - switch relationship.Kind { - case "many_to_many": - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - case "has_many", "has_one": - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - case "belongs_to": - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - relationship.PolymorphicValue, - ) - } - - if err := query.Model(fieldValue).Count(&count).Error; err != nil { - association.Error = err - } - return count -} - -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { - var ( - scope = association.scope - field = association.field - relationship = field.Relationship - ) - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -// setErr set error when the error is not nil. And return Association. -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback.go deleted file mode 100644 index 1f0e3c79c..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback.go +++ /dev/null @@ -1,250 +0,0 @@ -package gorm - -import "fmt" - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{logger: nopLogger{}} - -// Callback is a struct that contains all CRUD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - logger logger - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - logger logger - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone(logger logger) *Callback { - return &Callback{ - logger: logger, - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - if cp.kind == "row_query" { - if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) - cp.before = "gorm:row_query" - } - } - - cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("CreatedAt", now) -// scope.SetColumn("UpdatedAt", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind { - if p.remove { - callback = nil - } else { - callback = *p.processor - } - } - } - return -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CRUD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_create.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_create.go deleted file mode 100644 index c4d25f372..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_create.go +++ /dev/null @@ -1,197 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := scope.db.nowFunc() - - if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { - if createdAtField.IsBlank { - createdAtField.Set(now) - } - } - - if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { - if updatedAtField.IsBlank { - updatedAtField.Set(now) - } - } - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal && !field.IsIgnored { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - insertModifier string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - if str, ok := scope.Get("gorm:insert_modifier"); ok { - insertModifier = strings.ToUpper(fmt.Sprint(str)) - if insertModifier == "INTO" { - insertModifier = "" - } - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) - var lastInsertIDReturningSuffix string - if lastInsertIDOutputInterstitial == "" { - lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - } - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v %v%v%v", - addExtraSpaceIfExist(insertModifier), - quotedTableName, - scope.Dialect().DefaultValueStr(), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", - addExtraSpaceIfExist(insertModifier), - scope.QuotedTableName(), - strings.Join(columns, ","), - addExtraSpaceIfExist(lastInsertIDOutputInterstitial), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql: no primaryField - if primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: lastInsertID implemention for majority of dialects - if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - return - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_delete.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_delete.go deleted file mode 100644 index 48b97acbf..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_delete.go +++ /dev/null @@ -1,63 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" -) - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while deleting")) - return - } - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") - - if !scope.Search.Unscoped && hasDeletedAtField { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v=%v%v%v", - scope.QuotedTableName(), - scope.Quote(deletedAtField.DBName), - scope.AddToVars(scope.db.nowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query.go deleted file mode 100644 index 544afd631..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query.go +++ /dev/null @@ -1,109 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - //we are only preloading relations, dont touch base model - if _, skip := scope.InstanceGet("gorm:only_preload"); skip { - return - } - - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } else if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query_preload.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query_preload.go deleted file mode 100644 index a936180ad..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_query_preload.go +++ /dev/null @@ -1,410 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - if ap, ok := scope.Get("gorm:auto_preload"); ok { - // If gorm:auto_preload IS NOT a bool then auto preload. - // Else if it IS a bool, use the value - if apb, ok := ap.(bool); !ok { - autoPreload(scope) - } else if apb { - autoPreload(scope) - } - } - - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func autoPreload(scope *Scope) { - for _, field := range scope.Fields() { - if field.Relationship == nil { - continue - } - - if val, ok := field.TagSettingsGet("PRELOAD"); ok { - if preload, err := strconv.ParseBool(val); err != nil { - scope.Err(errors.New("invalid preload option")) - return - } else if !preload { - continue - } - } - - scope.Search.Preload(field.Name) - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - foreignValuesToResults := make(map[string]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) - foreignValuesToResults[foreignValues] = result - } - for j := 0; j < indirectScopeValue.Len(); j++ { - indirectValue := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) - if result, found := foreignValuesToResults[valueString]; found { - indirectValue.FieldByName(field.Name).Set(result) - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - f := object.FieldByName(field.Name) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f.Set(reflect.Append(f, results...)) - } else { - f.Set(reflect.MakeSlice(f.Type(), 0, 0)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - foreignFieldToObjects := make(map[string][]*reflect.Value) - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) - foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) - } - } - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) - if objects, found := foreignFieldToObjects[valueString]; found { - for _, object := range objects { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) - - if len(preloadDB.search.selects) == 0 { - preloadDB = preloadDB.Select("*") - } - - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - scope.New(elem.Addr().Interface()). - InstanceSet("gorm:skip_query_callback", true). - callCallbacks(scope.db.parent.callbacks.queries) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - - for source, fields := range fieldsSourceMap { - for _, f := range fields { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if f.Len() != 0 { - continue - } - - v := reflect.MakeSlice(f.Type(), 0, 0) - if len(linkHash[source]) > 0 { - v = reflect.Append(f, linkHash[source]...) - } - - f.Set(v) - } - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_row_query.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_row_query.go deleted file mode 100644 index 323b16054..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_row_query.go +++ /dev/null @@ -1,41 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" -) - -// Define callbacks for row query -func init() { - DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) -} - -type RowQueryResult struct { - Row *sql.Row -} - -type RowsQueryResult struct { - Rows *sql.Rows - Error error -} - -// queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { - if result, ok := scope.InstanceGet("row_query_result"); ok { - scope.prepareQuerySQL() - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) - } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) - } - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_save.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_save.go deleted file mode 100644 index 3b4e05895..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_save.go +++ /dev/null @@ -1,170 +0,0 @@ -package gorm - -import ( - "reflect" - "strings" -) - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { - checkTruth := func(value interface{}) bool { - if v, ok := value.(bool); ok && !v { - return false - } - - if v, ok := value.(string); ok { - v = strings.ToLower(v) - return v == "true" - } - - return true - } - - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if r = field.Relationship; r != nil { - autoUpdate, autoCreate, saveReference = true, true, true - - if value, ok := scope.Get("gorm:save_associations"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } - - if value, ok := scope.Get("gorm:association_autoupdate"); ok { - autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { - autoUpdate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_autocreate"); ok { - autoCreate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { - autoCreate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_save_reference"); ok { - saveReference = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { - saveReference = checkTruth(value) - } - } - } - - return -} - -func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - newScope := scope.New(fieldValue) - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if saveReference { - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(newDB.Save(elem).Error) - } - } else if autoUpdate { - scope.Err(newDB.Save(elem).Error) - } - - if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_update.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_update.go deleted file mode 100644 index 699e534b9..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/callback_update.go +++ /dev/null @@ -1,121 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "sort" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while updating")) - return - } - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", scope.db.nowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - // Sort the column names so that the generated SQL is the same every time. - updateMap := updateAttrs.(map[string]interface{}) - var columns []string - for c := range updateMap { - columns = append(columns, c) - } - sort.Strings(columns) - - for _, column := range columns { - value := updateMap[column] - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { - if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect.go deleted file mode 100644 index 749587f44..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db SQLCommon) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - // ModifyColumn modify column's type - ModifyColumn(tableName string, columnName string, typ string) error - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) (string, error) - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` - LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string - // DefaultValueStr - DefaultValueStr() string - - // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference - BuildKeyName(kind, tableName string, fields ...string) string - - // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect - NormalizeIndexAndColumn(indexName, columnName string) (string, string) - - // CurrentDatabase return current database name - CurrentDatabase() string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db SQLCommon) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// GetDialect gets the dialect for the specified dialect name -func GetDialect(name string) (dialect Dialect, ok bool) { - dialect, ok = dialectsMap[name] - return -} - -// ParseFieldStructForDialect get field's sql data type -var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var ( - reflectType = field.Struct.Type - dataType, _ = field.TagSettingsGet("TYPE") - ) - - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - if gormDataType, ok := fieldValue.Interface().(interface { - GormDataType(Dialect) string - }); ok { - dataType = gormDataType.GormDataType(dialect) - } - - // Get scanner's real value - if dataType == "" { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - } - - // Default Size - if num, ok := field.TagSettingsGet("SIZE"); ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - notNull, _ := field.TagSettingsGet("NOT NULL") - unique, _ := field.TagSettingsGet("UNIQUE") - additionalType = notNull + " " + unique - if value, ok := field.TagSettingsGet("DEFAULT"); ok { - additionalType = additionalType + " DEFAULT " + value - } - - if value, ok := field.TagSettingsGet("COMMENT"); ok { - additionalType = additionalType + " COMMENT " + value - } - - return fieldValue, dataType, size, strings.TrimSpace(additionalType) -} - -func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_common.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_common.go deleted file mode 100644 index d549510cc..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_common.go +++ /dev/null @@ -1,196 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") - -// DefaultForeignKeyNamer contains the default foreign key name generator method -type DefaultForeignKeyNamer struct { -} - -type commonDialect struct { - db SQLCommon - DefaultForeignKeyNamer -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db SQLCommon) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - return strings.ToLower(value) != "false" - } - return field.IsPrimaryKey -} - -func (s *commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) - return err -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -// LimitAndOffsetSQL return generated SQL with Limit and Offset -func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - if parsedLimit, err := s.parseInt(limit); err != nil { - return "", err - } else if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := s.parseInt(offset); err != nil { - return "", err - } else if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (commonDialect) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference -func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) - keyName = keyNameRegex.ReplaceAllString(keyName, "_") - return keyName -} - -// NormalizeIndexAndColumn returns argument's index name and column name without doing anything -func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - return indexName, columnName -} - -func (commonDialect) parseInt(value interface{}) (int64, error) { - return strconv.ParseInt(fmt.Sprint(value), 0, 0) -} - -// IsByteArrayOrSlice returns true of the reflected value is an array or slice -func IsByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_mysql.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_mysql.go deleted file mode 100644 index b4467ffa1..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_mysql.go +++ /dev/null @@ -1,246 +0,0 @@ -package gorm - -import ( - "crypto/sha1" - "database/sql" - "fmt" - "reflect" - "regexp" - "strings" - "time" - "unicode/utf8" -) - -var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (s *mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - // MySQL allows only one auto increment column per table, and it must - // be a KEY column. - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { - field.TagSettingsDelete("AUTO_INCREMENT") - } - } - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint AUTO_INCREMENT" - } else { - sqlType = "tinyint" - } - case reflect.Int, reflect.Int16, reflect.Int32: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint unsigned AUTO_INCREMENT" - } else { - sqlType = "tinyint unsigned" - } - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - precision := "" - if p, ok := field.TagSettingsGet("PRECISION"); ok { - precision = fmt.Sprintf("(%s)", p) - } - - if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { - sqlType = fmt.Sprintf("DATETIME%v", precision) - } else { - sqlType = fmt.Sprintf("DATETIME%v NULL", precision) - } - } - default: - if IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - parsedLimit, err := s.parseInt(limit) - if err != nil { - return "", err - } - if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - - if offset != nil { - parsedOffset, err := s.parseInt(offset) - if err != nil { - return "", err - } - if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - } - } - return -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) HasTable(tableName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - var name string - // allow mysql database name with '-' character - if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { - if err == sql.ErrNoRows { - return false - } - panic(err) - } else { - return true - } -} - -func (s mysql) HasIndex(tableName string, indexName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) HasColumn(tableName string, columnName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) - if utf8.RuneCountInString(keyName) <= 64 { - return keyName - } - h := sha1.New() - h.Write([]byte(keyName)) - bs := h.Sum(nil) - - // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) - if len(destRunes) > 24 { - destRunes = destRunes[:24] - } - - return fmt.Sprintf("%s%x", string(destRunes), bs) -} - -// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed -func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - submatch := mysqlIndexRegex.FindStringSubmatch(indexName) - if len(submatch) != 3 { - return indexName, columnName - } - indexName = submatch[1] - columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) - return indexName, columnName -} - -func (mysql) DefaultValueStr() string { - return "VALUES()" -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_postgres.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_postgres.go deleted file mode 100644 index d2df31318..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_postgres.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "encoding/json" - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) - RegisterDialect("cloudsqlpostgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (s *postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint32, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettingsGet("SIZE"); !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "bytea" - - if isUUID(dataValue) { - sqlType = "uuid" - } - - if isJSON(dataValue) { - sqlType = "jsonb" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { - return "" -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} - -func isJSON(value reflect.Value) bool { - _, ok := value.Interface().(json.RawMessage) - return ok -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go deleted file mode 100644 index 5f96c363a..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go +++ /dev/null @@ -1,107 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (s *sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/mysql/mysql.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/mysql/mysql.go deleted file mode 100644 index 9deba48ae..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/mysql/mysql.go +++ /dev/null @@ -1,3 +0,0 @@ -package mysql - -import _ "github.com/go-sql-driver/mysql" diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go deleted file mode 100644 index e6c088b1c..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go +++ /dev/null @@ -1,81 +0,0 @@ -package postgres - -import ( - "database/sql" - "database/sql/driver" - - "encoding/json" - "errors" - "fmt" - - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" -) - -type Hstore map[string]*string - -// Value get value of Hstore -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -// Scan scan value into Hstore -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} - -// Jsonb Postgresql's JSONB data type -type Jsonb struct { - json.RawMessage -} - -// Value get value of Jsonb -func (j Jsonb) Value() (driver.Value, error) { - if len(j.RawMessage) == 0 { - return nil, nil - } - return j.MarshalJSON() -} - -// Scan scan value into Jsonb -func (j *Jsonb) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) - } - - return json.Unmarshal(bytes, j) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/docker-compose.yml b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/docker-compose.yml deleted file mode 100644 index 79bf5fc39..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/docker-compose.yml +++ /dev/null @@ -1,30 +0,0 @@ -version: '3' - -services: - mysql: - image: 'mysql:latest' - ports: - - 9910:3306 - environment: - - MYSQL_DATABASE=gorm - - MYSQL_USER=gorm - - MYSQL_PASSWORD=gorm - - MYSQL_RANDOM_ROOT_PASSWORD="yes" - postgres: - image: 'postgres:latest' - ports: - - 9920:5432 - environment: - - POSTGRES_USER=gorm - - POSTGRES_DB=gorm - - POSTGRES_PASSWORD=gorm - mssql: - image: 'mcmoe/mssqldocker:latest' - ports: - - 9930:1433 - environment: - - ACCEPT_EULA=Y - - SA_PASSWORD=LoremIpsum86 - - MSSQL_DB=gorm - - MSSQL_USER=gorm - - MSSQL_PASSWORD=LoremIpsum86 diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/errors.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/errors.go deleted file mode 100644 index d5ef8d571..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/errors.go +++ /dev/null @@ -1,72 +0,0 @@ -package gorm - -import ( - "errors" - "strings" -) - -var ( - // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL occurs when you attempt a query with invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") -) - -// Errors contains all happened errors -type Errors []error - -// IsRecordNotFoundError returns true if error contains a RecordNotFound error -func IsRecordNotFoundError(err error) bool { - if errs, ok := err.(Errors); ok { - for _, err := range errs { - if err == ErrRecordNotFound { - return true - } - } - } - return err == ErrRecordNotFound -} - -// GetErrors gets all errors that have occurred and returns a slice of errors (Error type) -func (errs Errors) GetErrors() []error { - return errs -} - -// Add adds an error to a given slice of errors -func (errs Errors) Add(newErrors ...error) Errors { - for _, err := range newErrors { - if err == nil { - continue - } - - if errors, ok := err.(Errors); ok { - errs = errs.Add(errors...) - } else { - ok = true - for _, e := range errs { - if err == e { - ok = false - } - } - if ok { - errs = append(errs, err) - } - } - } - return errs -} - -// Error takes a slice of all errors that have occurred and returns it as a formatted string -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/field.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/field.go deleted file mode 100644 index acd06e20d..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/field.go +++ /dev/null @@ -1,66 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - v := reflectValue.Interface() - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = scanner.Scan(v) - } - } else { - err = scanner.Scan(v) - } - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/interface.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/interface.go deleted file mode 100644 index fe6492314..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/interface.go +++ /dev/null @@ -1,24 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" -) - -// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. -type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/join_table_handler.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/join_table_handler.go deleted file mode 100644 index a036d46d2..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/join_table_handler.go +++ /dev/null @@ -1,211 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - s.Source.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - s.Destination.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return DefaultTableNameHandler(db, s.TableName) -} - -func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - for _, joinTableSource := range joinTableSources { - if joinTableSource.ModelType == modelType { - for _, foreignKey := range joinTableSource.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - conditionMap[foreignKey.DBName] = field.Field.Interface() - } - } - break - } - } - } -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - var ( - scope = db.NewScope("") - conditionMap = map[string]interface{}{} - ) - - // Update condition map for source - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) - - // Update condition map for destination - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range conditionMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - conditionMap = map[string]interface{}{} - ) - - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) - - for key, value := range conditionMap { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/logger.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/logger.go deleted file mode 100644 index 88e167dd6..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/logger.go +++ /dev/null @@ -1,141 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "strconv" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`\?`) - numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) -) - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} - -var LogFormatter = func(values ...interface{}) (messages []interface{}) { - if len(values) > 1 { - var ( - sql string - formattedValues []string - level = values[0] - currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - ) - - messages = []interface{}{source, currentTime} - - if len(values) == 2 { - //remove the line break - currentTime = currentTime[1:] - //remove the brackets - source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) - - messages = []interface{}{currentTime, source} - } - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - if t.IsZero() { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) - } - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - switch value.(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: - formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) - default: - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } - } else { - formattedValues = append(formattedValues, "NULL") - } - } - - // differentiate between $n placeholders or else treat like ? - if numericPlaceHolderRegexp.MatchString(values[3].(string)) { - sql = values[3].(string) - for index, value := range formattedValues { - placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) - sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") - } - } else { - formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - } - - messages = append(messages, sql) - messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - } - - return -} - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - logger.Println(LogFormatter(values...)...) -} - -type nopLogger struct{} - -func (nopLogger) Print(values ...interface{}) {} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/main.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/main.go deleted file mode 100644 index 466e80c33..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/main.go +++ /dev/null @@ -1,886 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "sync" - "time" -) - -// DB contains information for current db connection -type DB struct { - sync.RWMutex - Value interface{} - Error error - RowsAffected int64 - - // single db - db SQLCommon - blockGlobalUpdate bool - logMode logModeValue - logger logger - search *search - values sync.Map - - // global db - parent *DB - callbacks *Callback - dialect Dialect - singularTable bool - - // function to be used to override the creating of a new timestamp - nowFuncOverride func() time.Time -} - -type logModeValue int - -const ( - defaultLogMode logModeValue = iota - noLogMode - detailedLogMode -) - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (db *DB, err error) { - if len(args) == 0 { - err = errors.New("invalid database source") - return nil, err - } - var source string - var dbSQL SQLCommon - var ownDbSQL bool - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - ownDbSQL = true - case SQLCommon: - dbSQL = value - ownDbSQL = false - default: - return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) - } - - db = &DB{ - db: dbSQL, - logger: defaultLogger, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), - } - db.parent = db - if err != nil { - return - } - // Send a ping to make sure the database connection is alive. - if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil && ownDbSQL { - d.Close() - } - } - return -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -type closer interface { - Close() error -} - -// Close close current db connection. If database connection is not an io.Closer, returns an error. -func (s *DB) Close() error { - if db, ok := s.parent.db.(closer); ok { - return db.Close() - } - return errors.New("can't close current db") -} - -// DB get `*sql.DB` from current connection -// If the underlying database connection is not a *sql.DB, returns nil -func (s *DB) DB() *sql.DB { - db, ok := s.db.(*sql.DB) - if !ok { - panic("can't support full GORM on currently status, maybe this is a TX instance.") - } - return db -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() SQLCommon { - return s.db -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.dialect -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone(s.logger) - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = detailedLogMode - } else { - s.logMode = noLogMode - } - return s -} - -// SetNowFuncOverride set the function to be used when creating a new timestamp -func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { - s.nowFuncOverride = nowFuncOverride - return s -} - -// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, -// otherwise defaults to the global NowFunc() -func (s *DB) nowFunc() time.Time { - if s.nowFuncOverride != nil { - return s.nowFuncOverride() - } - - return NowFunc() -} - -// BlockGlobalUpdate if true, generates an error on update/delete without where clause. -// This is to prevent eventual error with empty objects updates/deletions -func (s *DB) BlockGlobalUpdate(enable bool) *DB { - s.blockGlobalUpdate = enable - return s -} - -// HasBlockGlobalUpdate return state of block -func (s *DB) HasBlockGlobalUpdate() bool { - return s.blockGlobalUpdate -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - s.parent.Lock() - defer s.parent.Unlock() - s.parent.singularTable = enable -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - scope := &Scope{db: dbClone, Value: value} - if s.search != nil { - scope.Search = s.search.clone() - } else { - scope.Search = &search{} - } - return scope -} - -// QueryExpr returns the query as SqlExpr object -func (s *DB) QueryExpr() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(scope.SQL, scope.SQLVars...) -} - -// SubQuery returns the query as sub query -func (s *DB) SubQuery() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query interface{}, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Take return a record that match given conditions, the order will depend on the database implementation -func (s *DB) Take(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -//Preloads preloads relations, don`t touch out -func (s *DB) Preloads(out interface{}) *DB { - return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - scope = s.NewScope(result) - clone = scope.db - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := s.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db - } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -// WARNING when update with struct, GORM will not update fields that with zero value -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.NewScope(value) - if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db - if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().Table(scope.TableName()).FirstOrCreate(value) - } - return newDB - } - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.NewScope(nil) - generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Transaction start a transaction as a block, -// return error will rollback, otherwise to commit. -func (s *DB) Transaction(fc func(tx *DB) error) (err error) { - - if _, ok := s.db.(*sql.Tx); ok { - return fc(s) - } - - panicked := true - tx := s.Begin() - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - tx.Rollback() - } - }() - - err = fc(tx) - - if err == nil { - err = tx.Commit().Error - } - - panicked = false - return -} - -// Begin begins a transaction -func (s *DB) Begin() *DB { - return s.BeginTx(context.Background(), &sql.TxOptions{}) -} - -// BeginTx begins a transaction with options -func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.BeginTx(ctx, opts) - c.db = interface{}(tx).(SQLCommon) - - c.dialect.SetDB(c.db) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - if err := db.Rollback(); err != nil && err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// RollbackUnlessCommitted rollback a transaction if it has not yet been -// committed. -func (s *DB) RollbackUnlessCommitted() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - err := db.Rollback() - // Ignore the error indicating that the transaction has already - // been committed. - if err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") -func (s *DB) RemoveForeignKey(field string, dest string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeForeignKey(field, dest) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values.Store(name, value) - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values.Load(name) - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == defaultLogMode { - go s.print("error", fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors(s.GetErrors()) - errors = errors.Add(err) - if len(errors) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() []error { - if errs, ok := s.Error.(Errors); ok { - return errs - } else if s.Error != nil { - return []error{s.Error} - } - return []error{} -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := &DB{ - db: s.db, - parent: s.parent, - logger: s.logger, - logMode: s.logMode, - Value: s.Value, - Error: s.Error, - blockGlobalUpdate: s.blockGlobalUpdate, - dialect: newDialect(s.dialect.GetName(), s.db), - nowFuncOverride: s.nowFuncOverride, - } - - s.values.Range(func(k, v interface{}) bool { - db.values.Store(k, v) - return true - }) - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = db - return db -} - -func (s *DB) print(v ...interface{}) { - s.logger.Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == detailedLogMode { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == detailedLogMode { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model.go deleted file mode 100644 index f37ff7eaa..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model.go +++ /dev/null @@ -1,14 +0,0 @@ -package gorm - -import "time" - -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model_struct.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model_struct.go deleted file mode 100644 index 57dbec385..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/model_struct.go +++ /dev/null @@ -1,677 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -// lock for mutating global cached model metadata -var structsLock sync.Mutex - -// global cache of model metadata -var modelStructsMap sync.Map - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - - defaultTableName string - l sync.Mutex -} - -// TableName returns model's table name -func (s *ModelStruct) TableName(db *DB) string { - s.l.Lock() - defer s.l.Unlock() - - if s.defaultTableName == "" && db != nil && s.ModelType != nil { - // Set default table name - if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { - s.defaultTableName = tabler.TableName() - } else { - tableName := ToTableName(s.ModelType.Name()) - db.parent.RLock() - if db == nil || (db.parent != nil && !db.parent.singularTable) { - tableName = inflection.Plural(tableName) - } - db.parent.RUnlock() - s.defaultTableName = tableName - } - } - - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship - - tagSettingsLock sync.RWMutex -} - -// TagSettingsSet Sets a tag in the tag settings map -func (sf *StructField) TagSettingsSet(key, val string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - sf.TagSettings[key] = val -} - -// TagSettingsGet returns a tag from the tag settings -func (sf *StructField) TagSettingsGet(key string) (string, bool) { - sf.tagSettingsLock.RLock() - defer sf.tagSettingsLock.RUnlock() - val, ok := sf.TagSettings[key] - return val, ok -} - -// TagSettingsDelete deletes a tag -func (sf *StructField) TagSettingsDelete(key string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - delete(sf.TagSettings, key) -} - -func (sf *StructField) clone() *StructField { - clone := &StructField{ - DBName: sf.DBName, - Name: sf.Name, - Names: sf.Names, - IsPrimaryKey: sf.IsPrimaryKey, - IsNormal: sf.IsNormal, - IsIgnored: sf.IsIgnored, - IsScanner: sf.IsScanner, - HasDefaultValue: sf.HasDefaultValue, - Tag: sf.Tag, - TagSettings: map[string]string{}, - Struct: sf.Struct, - IsForeignKey: sf.IsForeignKey, - } - - if sf.Relationship != nil { - relationship := *sf.Relationship - clone.Relationship = &relationship - } - - // copy the struct field tagSettings, they should be read-locked while they are copied - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - for key, value := range sf.TagSettings { - clone.TagSettings[key] = value - } - - return clone -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - PolymorphicValue string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - return scope.getModelStruct(scope, make([]*StructField, 0)) -} - -func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - isSingularTable := false - if scope.db != nil && scope.db.parent != nil { - scope.db.parent.RLock() - isSingularTable = scope.db.parent.singularTable - scope.db.parent.RUnlock() - } - - hashKey := struct { - singularTable bool - reflectType reflect.Type - }{isSingularTable, reflectType} - if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { - return value.(*ModelStruct) - } - - modelStruct.ModelType = reflectType - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if _, ok := field.TagSettingsGet("-"); ok { - field.IsIgnored = true - } else { - if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettingsGet(key); !ok { - field.TagSettingsSet(key, value) - } - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { - subField.DBName = prefix + subField.DBName - } - - if subField.IsPrimaryKey { - if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } else { - subField.IsPrimaryKey = false - } - } - - if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { - if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { - newJoinTableHandler := &JoinTableHandler{} - newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) - subField.Relationship.JoinTableHandler = newJoinTableHandler - } - } - - modelStruct.StructFields = append(modelStruct.StructFields, subField) - allFields = append(allFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - foreignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - relationship.Kind = "many_to_many" - - { // Foreign Keys for Source - joinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { - joinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - - // setup join table foreign keys for source - if len(joinTableDBNames) > idx { - // if defined join table's foreign key - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) - } else { - defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) - } - } - } - } - - { // Foreign Keys for Association (Destination) - associationJoinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { - associationJoinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for idx, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // setup join table foreign keys for association - if len(associationJoinTableDBNames) > idx { - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) - } else { - // join table foreign keys for association - joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{rootScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - tagForeignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{rootScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettingsGet("COLUMN"); ok { - field.DBName = value - } else { - field.DBName = ToColumnName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - allFields = append(allFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Store(hashKey, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - if str == "" { - continue - } - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/naming.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/naming.go deleted file mode 100644 index 6b0a4fddb..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/naming.go +++ /dev/null @@ -1,124 +0,0 @@ -package gorm - -import ( - "bytes" - "strings" -) - -// Namer is a function type which is given a string and return a string -type Namer func(string) string - -// NamingStrategy represents naming strategies -type NamingStrategy struct { - DB Namer - Table Namer - Column Namer -} - -// TheNamingStrategy is being initialized with defaultNamingStrategy -var TheNamingStrategy = &NamingStrategy{ - DB: defaultNamer, - Table: defaultNamer, - Column: defaultNamer, -} - -// AddNamingStrategy sets the naming strategy -func AddNamingStrategy(ns *NamingStrategy) { - if ns.DB == nil { - ns.DB = defaultNamer - } - if ns.Table == nil { - ns.Table = defaultNamer - } - if ns.Column == nil { - ns.Column = defaultNamer - } - TheNamingStrategy = ns -} - -// DBName alters the given name by DB -func (ns *NamingStrategy) DBName(name string) string { - return ns.DB(name) -} - -// TableName alters the given name by Table -func (ns *NamingStrategy) TableName(name string) string { - return ns.Table(name) -} - -// ColumnName alters the given name by Column -func (ns *NamingStrategy) ColumnName(name string) string { - return ns.Column(name) -} - -// ToDBName convert string to db name -func ToDBName(name string) string { - return TheNamingStrategy.DBName(name) -} - -// ToTableName convert string to table name -func ToTableName(name string) string { - return TheNamingStrategy.TableName(name) -} - -// ToColumnName convert string to db name -func ToColumnName(name string) string { - return TheNamingStrategy.ColumnName(name) -} - -var smap = newSafeMap() - -func defaultNamer(name string) string { - const ( - lower = false - upper = true - ) - - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber bool - ) - - for i, v := range value[:len(value)-1] { - nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/scope.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/scope.go deleted file mode 100644 index 56c3d6e5e..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/scope.go +++ /dev/null @@ -1,1425 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" - "regexp" - "strings" - "time" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() SQLCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Contains(str, ".") { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToColumnName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - _, skipBindVar := scope.InstanceGet("skip_bindvar") - - if expr, ok := value.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - if skipBindVar { - scope.AddToVars(arg) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - - if skipBindVar { - return "?" - } - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Contains(scope.Search.tableName, " ") { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - joinSQL := scope.joinsSQL() - whereSQL := scope.whereSQL() - if scope.Search.raw { - whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") - } - return joinSQL + whereSQL + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); scope.Err(err) == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var ( - columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` - isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") - countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") -) - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - values = make([]interface{}, len(columns)) - selectFields []*Field - selectedColumnsMap = map[string]int{} - resetFields = map[int]*Field{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - offset := 0 - if idx, ok := selectedColumnsMap[column]; ok { - offset = idx + 1 - selectFields = selectFields[offset:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[index] = field - } - - selectedColumnsMap[column] = offset + fieldIndex - - if field.IsNormal { - break - } - } - } - } - - scope.Err(rows.Scan(values...)) - - for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { - var ( - quotedTableName = scope.QuotedTableName() - quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) - equalSQL = "=" - inSQL = "IN" - ) - - // If building not conditions - if !include { - equalSQL = "<>" - inSQL = "NOT IN" - } - - switch value := clause["query"].(type) { - case sql.NullInt64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - if !include && reflect.ValueOf(value).Len() == 0 { - return - } - str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) - clause["args"] = []interface{}{value} - case string: - if isNumberRegexp.MatchString(value) { - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) - } - - if value != "" { - if !include { - if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) - } - } else { - str = fmt.Sprintf("(%v)", value) - } - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) - } else { - if !include { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) - } - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - newScope := scope.New(value) - - if len(newScope.Fields()) == 0 { - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - scopeQuotedTableName := newScope.QuotedTableName() - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank && field.Relationship == nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - default: - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - - replacements := []string{} - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if as, ok := arg.([][]interface{}); ok { - var tempMarks []string - for _, a := range as { - var arrayMarks []string - for _, v := range a { - arrayMarks = append(arrayMarks, scope.AddToVars(v)) - } - - if len(arrayMarks) > 0 { - tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) - } - } - - if len(tempMarks) > 0 { - replacements = append(replacements, strings.Join(tempMarks, ",")) - } - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = valuer.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) - } - - if err != nil { - scope.Err(err) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for _, s := range str { - if s == '?' && len(replacements) > i { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(s) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - replacements := []string{} - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - replacements = append(replacements, scope.AddToVars(arg)) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for pos, char := range str { - if str[pos] == '?' { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(char) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && hasDeletedAtField { - sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildCondition(clause, false); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - if str, ok := order.(string); ok { - orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - orders = append(orders, exp) - } - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) - scope.Err(err) - return sql -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(scope.CombinedConditionSql()) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - defer func() { - if err := recover(); err != nil { - if db, ok := scope.db.db.(sqlTx); ok { - db.Rollback() - } - panic(err) - } - }() - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values, db: db}).Fields() { - if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false, scope.db), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value, true, scope.db) { - if field, ok := scope.FieldByName(key); ok { - if scope.changeableField(field) { - if _, ok := value.(*SqlExpr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal && !field.IsIgnored { - hasUpdate = true - if err == ErrUnaddressable { - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } else { - results[key] = value - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - - result := &RowQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Row -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - - result := &RowsQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Rows, result.Error -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { - queryStr := strings.ToLower(fmt.Sprint(query)) - if queryStr == column { - return true - } - - if strings.HasSuffix(queryStr, "as "+column) { - return true - } - - if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { - return true - } - - return false -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - if dest.Len() > 0 { - dest.Set(reflect.Zero(dest.Type())) - } - - if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { - scope.Search.Select(column) - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - if len(scope.Search.group) != 0 { - if len(scope.Search.havingConditions) != 0 { - scope.prepareQuerySQL() - scope.Search = &search{} - scope.Search.Select("count(*)") - scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL)) - } else { - scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") - scope.Search.group += " ) AS count_table" - } - } else { - scope.Search.Select("count(*)") - } - } - scope.Search.ignoreOrderQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - tx := scope.db.Set("gorm:association:source", scope.Value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(tx.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - scope.Err(tx.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return " " + tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - // Compatible with old generated key - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var mysql mysql - var query string - if scope.Dialect().GetName() == mysql.GetName() { - query = `ALTER TABLE %s DROP FOREIGN KEY %s;` - } else { - query = `ALTER TABLE %s DROP CONSTRAINT %s;` - } - - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettingsGet("INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - indexes[name] = append(indexes[name], column) - } - } - - if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - uniqueIndexes[name] = append(uniqueIndexes[name], column) - } - } - } - - for name, columns := range indexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - for name, columns := range uniqueIndexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - resultMap := make(map[string][]interface{}) - for _, value := range values { - indirectValue := indirect(reflect.ValueOf(value)) - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - var hasValue = false - for _, column := range columns { - field := object.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - case reflect.Struct: - var result []interface{} - var hasValue = false - for _, column := range columns { - field := indirectValue.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - } - for _, v := range resultMap { - results = append(results, v) - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - resultsMap := map[interface{}]bool{} - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { - resultsMap[elem.Addr()] = true - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() && resultsMap[result.Addr()] != true { - resultsMap[result.Addr()] = true - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} - -func (scope *Scope) hasConditions() bool { - return !scope.PrimaryKeyZero() || - len(scope.Search.whereConditions) > 0 || - len(scope.Search.orConditions) > 0 || - len(scope.Search.notConditions) > 0 -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/search.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/search.go deleted file mode 100644 index 52ae2efcb..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/search.go +++ /dev/null @@ -1,203 +0,0 @@ -package gorm - -import ( - "fmt" -) - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []interface{} - preload []searchPreload - offset interface{} - limit interface{} - group string - tableName string - raw bool - Unscoped bool - ignoreOrderQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := search{ - db: s.db, - whereConditions: make([]map[string]interface{}, len(s.whereConditions)), - orConditions: make([]map[string]interface{}, len(s.orConditions)), - notConditions: make([]map[string]interface{}, len(s.notConditions)), - havingConditions: make([]map[string]interface{}, len(s.havingConditions)), - joinConditions: make([]map[string]interface{}, len(s.joinConditions)), - initAttrs: make([]interface{}, len(s.initAttrs)), - assignAttrs: make([]interface{}, len(s.assignAttrs)), - selects: s.selects, - omits: make([]string, len(s.omits)), - orders: make([]interface{}, len(s.orders)), - preload: make([]searchPreload, len(s.preload)), - offset: s.offset, - limit: s.limit, - group: s.group, - tableName: s.tableName, - raw: s.raw, - Unscoped: s.Unscoped, - ignoreOrderQuery: s.ignoreOrderQuery, - } - for i, value := range s.whereConditions { - clone.whereConditions[i] = value - } - for i, value := range s.orConditions { - clone.orConditions[i] = value - } - for i, value := range s.notConditions { - clone.notConditions[i] = value - } - for i, value := range s.havingConditions { - clone.havingConditions[i] = value - } - for i, value := range s.joinConditions { - clone.joinConditions[i] = value - } - for i, value := range s.initAttrs { - clone.initAttrs[i] = value - } - for i, value := range s.assignAttrs { - clone.assignAttrs[i] = value - } - for i, value := range s.omits { - clone.omits[i] = value - } - for i, value := range s.orders { - clone.orders[i] = value - } - for i, value := range s.preload { - clone.preload[i] = value - } - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value interface{}, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - s.orders = []interface{}{} - } - - if value != nil && value != "" { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit interface{}) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset interface{}) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*SqlExpr); ok { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) - } else { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - } - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/test_all.sh b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/test_all.sh deleted file mode 100644 index 5cfb3321a..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/test_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -dialects=("postgres" "mysql" "mssql" "sqlite") - -for dialect in "${dialects[@]}" ; do - DEBUG=false GORM_DIALECT=${dialect} go test -done diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/utils.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/utils.go deleted file mode 100644 index d2ae9465d..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/utils.go +++ /dev/null @@ -1,226 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -// SQL expression -type SqlExpr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *SqlExpr { - return &SqlExpr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - switch value.Kind() { - case reflect.String: - return value.Len() == 0 - case reflect.Bool: - return !value.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return value.Uint() == 0 - case reflect.Float32, reflect.Float64: - return value.Float() == 0 - case reflect.Interface, reflect.Ptr: - return value.IsNil() - } - - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/wercker.yml b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/wercker.yml deleted file mode 100644 index 1de947b8a..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/wercker.yml +++ /dev/null @@ -1,149 +0,0 @@ -# use the default golang container from Docker Hub -box: golang - -services: - - name: mariadb - id: mariadb:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql - id: mysql:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql57 - id: mysql:5.7 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql56 - id: mysql:5.6 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: postgres - id: postgres:latest - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres96 - id: postgres:9.6 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres95 - id: postgres:9.5 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres94 - id: postgres:9.4 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres93 - id: postgres:9.3 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: mssql - id: mcmoe/mssqldocker:latest - env: - ACCEPT_EULA: Y - SA_PASSWORD: LoremIpsum86 - MSSQL_DB: gorm - MSSQL_USER: gorm - MSSQL_PASSWORD: LoremIpsum86 - -# The steps that will be executed in the build pipeline -build: - # The steps that will be executed on build - steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace - - # Gets the dependencies - - script: - name: go get - code: | - cd $WERCKER_SOURCE_DIR - go version - go get -t -v ./... - - # Build the project - - script: - name: go build - code: | - go build ./... - - # Test the project - - script: - name: test sqlite - code: | - go test -race -v ./... - - - script: - name: test mariadb - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.7 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.6 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test postgres - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres96 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres95 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres94 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres93 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: codecov - code: | - go test -race -coverprofile=coverage.txt -covermode=atomic ./... - bash <(curl -s https://codecov.io/bash) diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/Guardfile b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/Guardfile new file mode 100644 index 000000000..0b860b065 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/Guardfile @@ -0,0 +1,3 @@ +guard 'gotest' do + watch(%r{\.go$}) +end diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/License b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/License similarity index 100% rename from src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/License rename to src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/License diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/README.md b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/README.md new file mode 100644 index 000000000..e81d31de6 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/README.md @@ -0,0 +1,137 @@ +## Now + +Now is a time toolkit for golang + +[![go report card](https://goreportcard.com/badge/github.com/jinzhu/now "go report card")](https://goreportcard.com/report/github.com/jinzhu/now) +[![test status](https://github.com/jinzhu/now/workflows/tests/badge.svg?branch=master "test status")](https://github.com/jinzhu/now/actions) +[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) + +## Install + +``` +go get -u github.com/jinzhu/now +``` + +## Usage + +Calculating time based on current time + +```go +import "github.com/jinzhu/now" + +time.Now() // 2013-11-18 17:51:49.123456789 Mon + +now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon +now.BeginningOfHour() // 2013-11-18 17:00:00 Mon +now.BeginningOfDay() // 2013-11-18 00:00:00 Mon +now.BeginningOfWeek() // 2013-11-17 00:00:00 Sun +now.BeginningOfMonth() // 2013-11-01 00:00:00 Fri +now.BeginningOfQuarter() // 2013-10-01 00:00:00 Tue +now.BeginningOfYear() // 2013-01-01 00:00:00 Tue + +now.EndOfMinute() // 2013-11-18 17:51:59.999999999 Mon +now.EndOfHour() // 2013-11-18 17:59:59.999999999 Mon +now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon +now.EndOfWeek() // 2013-11-23 23:59:59.999999999 Sat +now.EndOfMonth() // 2013-11-30 23:59:59.999999999 Sat +now.EndOfQuarter() // 2013-12-31 23:59:59.999999999 Tue +now.EndOfYear() // 2013-12-31 23:59:59.999999999 Tue + +now.WeekStartDay = time.Monday // Set Monday as first day, default is Sunday +now.EndOfWeek() // 2013-11-24 23:59:59.999999999 Sun +``` + +Calculating time based on another time + +```go +t := time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.Now().Location()) +now.With(t).EndOfMonth() // 2013-02-28 23:59:59.999999999 Thu +``` + +Calculating time based on configuration + +```go +location, err := time.LoadLocation("Asia/Shanghai") + +myConfig := &now.Config{ + WeekStartDay: time.Monday, + TimeLocation: location, + TimeFormats: []string{"2006-01-02 15:04:05"}, +} + +t := time.Date(2013, 11, 18, 17, 51, 49, 123456789, time.Now().Location()) // // 2013-11-18 17:51:49.123456789 Mon +myConfig.With(t).BeginningOfWeek() // 2013-11-18 00:00:00 Mon + +myConfig.Parse("2002-10-12 22:14:01") // 2002-10-12 22:14:01 +myConfig.Parse("2002-10-12 22:14") // returns error 'can't parse string as time: 2002-10-12 22:14' +``` + +### Monday/Sunday + +Don't be bothered with the `WeekStartDay` setting, you can use `Monday`, `Sunday` + +```go +now.Monday() // 2013-11-18 00:00:00 Mon +now.Monday("17:44") // 2013-11-18 17:44:00 Mon +now.Sunday() // 2013-11-24 00:00:00 Sun (Next Sunday) +now.Sunday("18:19:24") // 2013-11-24 18:19:24 Sun (Next Sunday) +now.EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of next Sunday) + +t := time.Date(2013, 11, 24, 17, 51, 49, 123456789, time.Now().Location()) // 2013-11-24 17:51:49.123456789 Sun +now.With(t).Monday() // 2013-11-18 00:00:00 Mon (Last Monday if today is Sunday) +now.With(t).Monday("17:44") // 2013-11-18 17:44:00 Mon (Last Monday if today is Sunday) +now.With(t).Sunday() // 2013-11-24 00:00:00 Sun (Beginning Of Today if today is Sunday) +now.With(t).Sunday("18:19:24") // 2013-11-24 18:19:24 Sun (Beginning Of Today if today is Sunday) +now.With(t).EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of Today if today is Sunday) +``` + +### Parse String to Time + +```go +time.Now() // 2013-11-18 17:51:49.123456789 Mon + +// Parse(string) (time.Time, error) +t, err := now.Parse("2017") // 2017-01-01 00:00:00, nil +t, err := now.Parse("2017-10") // 2017-10-01 00:00:00, nil +t, err := now.Parse("2017-10-13") // 2017-10-13 00:00:00, nil +t, err := now.Parse("1999-12-12 12") // 1999-12-12 12:00:00, nil +t, err := now.Parse("1999-12-12 12:20") // 1999-12-12 12:20:00, nil +t, err := now.Parse("1999-12-12 12:20:21") // 1999-12-12 12:20:21, nil +t, err := now.Parse("10-13") // 2013-10-13 00:00:00, nil +t, err := now.Parse("12:20") // 2013-11-18 12:20:00, nil +t, err := now.Parse("12:20:13") // 2013-11-18 12:20:13, nil +t, err := now.Parse("14") // 2013-11-18 14:00:00, nil +t, err := now.Parse("99:99") // 2013-11-18 12:20:00, Can't parse string as time: 99:99 + +// MustParse must parse string to time or it will panic +now.MustParse("2013-01-13") // 2013-01-13 00:00:00 +now.MustParse("02-17") // 2013-02-17 00:00:00 +now.MustParse("2-17") // 2013-02-17 00:00:00 +now.MustParse("8") // 2013-11-18 08:00:00 +now.MustParse("2002-10-12 22:14") // 2002-10-12 22:14:00 +now.MustParse("99:99") // panic: Can't parse string as time: 99:99 +``` + +Extend `now` to support more formats is quite easy, just update `now.TimeFormats` with other time layouts, e.g: + +```go +now.TimeFormats = append(now.TimeFormats, "02 Jan 2006 15:04") +``` + +Please send me pull requests if you want a format to be supported officially + +## Contributing + +You can help to make the project better, check out [http://gorm.io/contribute.html](http://gorm.io/contribute.html) for things you can do. + +# Author + +**jinzhu** + +* +* +* + +## License + +Released under the [MIT License](http://www.opensource.org/licenses/MIT). diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/main.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/main.go new file mode 100644 index 000000000..8f78bc752 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/main.go @@ -0,0 +1,200 @@ +// Package now is a time toolkit for golang. +// +// More details README here: https://github.com/jinzhu/now +// +// import "github.com/jinzhu/now" +// +// now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon +// now.BeginningOfDay() // 2013-11-18 00:00:00 Mon +// now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon +package now + +import "time" + +// WeekStartDay set week start day, default is sunday +var WeekStartDay = time.Sunday + +// TimeFormats default time formats will be parsed as +var TimeFormats = []string{ + "2006", "2006-1", "2006-1-2", "2006-1-2 15", "2006-1-2 15:4", "2006-1-2 15:4:5", "1-2", + "15:4:5", "15:4", "15", + "15:4:5 Jan 2, 2006 MST", "2006-01-02 15:04:05.999999999 -0700 MST", "2006-01-02T15:04:05Z0700", "2006-01-02T15:04:05Z07", + "2006.1.2", "2006.1.2 15:04:05", "2006.01.02", "2006.01.02 15:04:05", "2006.01.02 15:04:05.999999999", + "1/2/2006", "1/2/2006 15:4:5", "2006/01/02", "20060102", "2006/01/02 15:04:05", + time.ANSIC, time.UnixDate, time.RubyDate, time.RFC822, time.RFC822Z, time.RFC850, + time.RFC1123, time.RFC1123Z, time.RFC3339, time.RFC3339Nano, + time.Kitchen, time.Stamp, time.StampMilli, time.StampMicro, time.StampNano, +} + +// Config configuration for now package +type Config struct { + WeekStartDay time.Weekday + TimeLocation *time.Location + TimeFormats []string +} + +// DefaultConfig default config +var DefaultConfig *Config + +// New initialize Now based on configuration +func (config *Config) With(t time.Time) *Now { + return &Now{Time: t, Config: config} +} + +// Parse parse string to time based on configuration +func (config *Config) Parse(strs ...string) (time.Time, error) { + if config.TimeLocation == nil { + return config.With(time.Now()).Parse(strs...) + } else { + return config.With(time.Now().In(config.TimeLocation)).Parse(strs...) + } +} + +// MustParse must parse string to time or will panic +func (config *Config) MustParse(strs ...string) time.Time { + if config.TimeLocation == nil { + return config.With(time.Now()).MustParse(strs...) + } else { + return config.With(time.Now().In(config.TimeLocation)).MustParse(strs...) + } +} + +// Now now struct +type Now struct { + time.Time + *Config +} + +// With initialize Now with time +func With(t time.Time) *Now { + config := DefaultConfig + if config == nil { + config = &Config{ + WeekStartDay: WeekStartDay, + TimeFormats: TimeFormats, + } + } + + return &Now{Time: t, Config: config} +} + +// New initialize Now with time +func New(t time.Time) *Now { + return With(t) +} + +// BeginningOfMinute beginning of minute +func BeginningOfMinute() time.Time { + return With(time.Now()).BeginningOfMinute() +} + +// BeginningOfHour beginning of hour +func BeginningOfHour() time.Time { + return With(time.Now()).BeginningOfHour() +} + +// BeginningOfDay beginning of day +func BeginningOfDay() time.Time { + return With(time.Now()).BeginningOfDay() +} + +// BeginningOfWeek beginning of week +func BeginningOfWeek() time.Time { + return With(time.Now()).BeginningOfWeek() +} + +// BeginningOfMonth beginning of month +func BeginningOfMonth() time.Time { + return With(time.Now()).BeginningOfMonth() +} + +// BeginningOfQuarter beginning of quarter +func BeginningOfQuarter() time.Time { + return With(time.Now()).BeginningOfQuarter() +} + +// BeginningOfYear beginning of year +func BeginningOfYear() time.Time { + return With(time.Now()).BeginningOfYear() +} + +// EndOfMinute end of minute +func EndOfMinute() time.Time { + return With(time.Now()).EndOfMinute() +} + +// EndOfHour end of hour +func EndOfHour() time.Time { + return With(time.Now()).EndOfHour() +} + +// EndOfDay end of day +func EndOfDay() time.Time { + return With(time.Now()).EndOfDay() +} + +// EndOfWeek end of week +func EndOfWeek() time.Time { + return With(time.Now()).EndOfWeek() +} + +// EndOfMonth end of month +func EndOfMonth() time.Time { + return With(time.Now()).EndOfMonth() +} + +// EndOfQuarter end of quarter +func EndOfQuarter() time.Time { + return With(time.Now()).EndOfQuarter() +} + +// EndOfYear end of year +func EndOfYear() time.Time { + return With(time.Now()).EndOfYear() +} + +// Monday monday + +func Monday(strs ...string) time.Time { + return With(time.Now()).Monday(strs...) +} + +// Sunday sunday +func Sunday(strs ...string) time.Time { + return With(time.Now()).Sunday(strs...) +} + +// EndOfSunday end of sunday +func EndOfSunday() time.Time { + return With(time.Now()).EndOfSunday() +} + +// Quarter returns the yearly quarter +func Quarter() uint { + return With(time.Now()).Quarter() +} + +// Parse parse string to time +func Parse(strs ...string) (time.Time, error) { + return With(time.Now()).Parse(strs...) +} + +// ParseInLocation parse string to time in location +func ParseInLocation(loc *time.Location, strs ...string) (time.Time, error) { + return With(time.Now().In(loc)).Parse(strs...) +} + +// MustParse must parse string to time or will panic +func MustParse(strs ...string) time.Time { + return With(time.Now()).MustParse(strs...) +} + +// MustParseInLocation must parse string to time in location or will panic +func MustParseInLocation(loc *time.Location, strs ...string) time.Time { + return With(time.Now().In(loc)).MustParse(strs...) +} + +// Between check now between the begin, end time or not +func Between(time1, time2 string) bool { + return With(time.Now()).Between(time1, time2) +} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/now.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/now.go new file mode 100644 index 000000000..2f524cc8d --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/now.go @@ -0,0 +1,245 @@ +package now + +import ( + "errors" + "regexp" + "time" +) + +// BeginningOfMinute beginning of minute +func (now *Now) BeginningOfMinute() time.Time { + return now.Truncate(time.Minute) +} + +// BeginningOfHour beginning of hour +func (now *Now) BeginningOfHour() time.Time { + y, m, d := now.Date() + return time.Date(y, m, d, now.Time.Hour(), 0, 0, 0, now.Time.Location()) +} + +// BeginningOfDay beginning of day +func (now *Now) BeginningOfDay() time.Time { + y, m, d := now.Date() + return time.Date(y, m, d, 0, 0, 0, 0, now.Time.Location()) +} + +// BeginningOfWeek beginning of week +func (now *Now) BeginningOfWeek() time.Time { + t := now.BeginningOfDay() + weekday := int(t.Weekday()) + + if now.WeekStartDay != time.Sunday { + weekStartDayInt := int(now.WeekStartDay) + + if weekday < weekStartDayInt { + weekday = weekday + 7 - weekStartDayInt + } else { + weekday = weekday - weekStartDayInt + } + } + return t.AddDate(0, 0, -weekday) +} + +// BeginningOfMonth beginning of month +func (now *Now) BeginningOfMonth() time.Time { + y, m, _ := now.Date() + return time.Date(y, m, 1, 0, 0, 0, 0, now.Location()) +} + +// BeginningOfQuarter beginning of quarter +func (now *Now) BeginningOfQuarter() time.Time { + month := now.BeginningOfMonth() + offset := (int(month.Month()) - 1) % 3 + return month.AddDate(0, -offset, 0) +} + +// BeginningOfHalf beginning of half year +func (now *Now) BeginningOfHalf() time.Time { + month := now.BeginningOfMonth() + offset := (int(month.Month()) - 1) % 6 + return month.AddDate(0, -offset, 0) +} + +// BeginningOfYear BeginningOfYear beginning of year +func (now *Now) BeginningOfYear() time.Time { + y, _, _ := now.Date() + return time.Date(y, time.January, 1, 0, 0, 0, 0, now.Location()) +} + +// EndOfMinute end of minute +func (now *Now) EndOfMinute() time.Time { + return now.BeginningOfMinute().Add(time.Minute - time.Nanosecond) +} + +// EndOfHour end of hour +func (now *Now) EndOfHour() time.Time { + return now.BeginningOfHour().Add(time.Hour - time.Nanosecond) +} + +// EndOfDay end of day +func (now *Now) EndOfDay() time.Time { + y, m, d := now.Date() + return time.Date(y, m, d, 23, 59, 59, int(time.Second-time.Nanosecond), now.Location()) +} + +// EndOfWeek end of week +func (now *Now) EndOfWeek() time.Time { + return now.BeginningOfWeek().AddDate(0, 0, 7).Add(-time.Nanosecond) +} + +// EndOfMonth end of month +func (now *Now) EndOfMonth() time.Time { + return now.BeginningOfMonth().AddDate(0, 1, 0).Add(-time.Nanosecond) +} + +// EndOfQuarter end of quarter +func (now *Now) EndOfQuarter() time.Time { + return now.BeginningOfQuarter().AddDate(0, 3, 0).Add(-time.Nanosecond) +} + +// EndOfHalf end of half year +func (now *Now) EndOfHalf() time.Time { + return now.BeginningOfHalf().AddDate(0, 6, 0).Add(-time.Nanosecond) +} + +// EndOfYear end of year +func (now *Now) EndOfYear() time.Time { + return now.BeginningOfYear().AddDate(1, 0, 0).Add(-time.Nanosecond) +} + +// Monday monday +/* +func (now *Now) Monday() time.Time { + t := now.BeginningOfDay() + weekday := int(t.Weekday()) + if weekday == 0 { + weekday = 7 + } + return t.AddDate(0, 0, -weekday+1) +} +*/ + +func (now *Now) Monday(strs ...string) time.Time { + var parseTime time.Time + var err error + if len(strs) > 0 { + parseTime, err = now.Parse(strs...) + if err != nil { + panic(err) + } + } else { + parseTime = now.BeginningOfDay() + } + weekday := int(parseTime.Weekday()) + if weekday == 0 { + weekday = 7 + } + return parseTime.AddDate(0, 0, -weekday+1) +} + +func (now *Now) Sunday(strs ...string) time.Time { + var parseTime time.Time + var err error + if len(strs) > 0 { + parseTime, err = now.Parse(strs...) + if err != nil { + panic(err) + } + } else { + parseTime = now.BeginningOfDay() + } + weekday := int(parseTime.Weekday()) + if weekday == 0 { + weekday = 7 + } + return parseTime.AddDate(0, 0, (7 - weekday)) +} + +// EndOfSunday end of sunday +func (now *Now) EndOfSunday() time.Time { + return New(now.Sunday()).EndOfDay() +} + +// Quarter returns the yearly quarter +func (now *Now) Quarter() uint { + return (uint(now.Month())-1)/3 + 1 +} + +func (now *Now) parseWithFormat(str string, location *time.Location) (t time.Time, err error) { + for _, format := range now.TimeFormats { + t, err = time.ParseInLocation(format, str, location) + + if err == nil { + return + } + } + err = errors.New("Can't parse string as time: " + str) + return +} + +var hasTimeRegexp = regexp.MustCompile(`(\s+|^\s*|T)\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))(\s*$|[Z+-])`) // match 15:04:05, 15:04:05.000, 15:04:05.000000 15, 2017-01-01 15:04, 2021-07-20T00:59:10Z, 2021-07-20T00:59:10+08:00, 2021-07-20T00:00:10-07:00 etc +var onlyTimeRegexp = regexp.MustCompile(`^\s*\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))\s*$`) // match 15:04:05, 15, 15:04:05.000, 15:04:05.000000, etc + +// Parse parse string to time +func (now *Now) Parse(strs ...string) (t time.Time, err error) { + var ( + setCurrentTime bool + parseTime []int + currentLocation = now.Location() + onlyTimeInStr = true + currentTime = formatTimeToList(now.Time) + ) + + for _, str := range strs { + hasTimeInStr := hasTimeRegexp.MatchString(str) // match 15:04:05, 15 + onlyTimeInStr = hasTimeInStr && onlyTimeInStr && onlyTimeRegexp.MatchString(str) + if t, err = now.parseWithFormat(str, currentLocation); err == nil { + location := t.Location() + parseTime = formatTimeToList(t) + + for i, v := range parseTime { + // Don't reset hour, minute, second if current time str including time + if hasTimeInStr && i <= 3 { + continue + } + + // If value is zero, replace it with current time + if v == 0 { + if setCurrentTime { + parseTime[i] = currentTime[i] + } + } else { + setCurrentTime = true + } + + // if current time only includes time, should change day, month to current time + if onlyTimeInStr { + if i == 4 || i == 5 { + parseTime[i] = currentTime[i] + continue + } + } + } + + t = time.Date(parseTime[6], time.Month(parseTime[5]), parseTime[4], parseTime[3], parseTime[2], parseTime[1], parseTime[0], location) + currentTime = formatTimeToList(t) + } + } + return +} + +// MustParse must parse string to time or it will panic +func (now *Now) MustParse(strs ...string) (t time.Time) { + t, err := now.Parse(strs...) + if err != nil { + panic(err) + } + return t +} + +// Between check time between the begin, end time or not +func (now *Now) Between(begin, end string) bool { + beginTime := now.MustParse(begin) + endTime := now.MustParse(end) + return now.After(beginTime) && now.Before(endTime) +} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/time.go b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/time.go new file mode 100644 index 000000000..52dd8b2a0 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/github.com/jinzhu/now/time.go @@ -0,0 +1,9 @@ +package now + +import "time" + +func formatTimeToList(t time.Time) []int { + hour, min, sec := t.Clock() + year, month, day := t.Date() + return []int{t.Nanosecond(), sec, min, hour, day, int(month), year} +} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/.gitignore b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/.gitignore deleted file mode 100644 index 3243952a4..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -.db -*.test -*~ -*.swp -.idea -.vscode \ No newline at end of file diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/LICENSE.md b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/LICENSE.md deleted file mode 100644 index 5773904a3..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/LICENSE.md +++ /dev/null @@ -1,8 +0,0 @@ -Copyright (c) 2011-2013, 'pq' Contributors -Portions Copyright (C) 2011 Blake Mizerany - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/README.md b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/README.md deleted file mode 100644 index 126ee5d35..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# pq - A pure Go postgres driver for Go's database/sql package - -[![GoDoc](https://godoc.org/github.com/lib/pq?status.svg)](https://pkg.go.dev/github.com/lib/pq?tab=doc) - -## Install - - go get github.com/lib/pq - -## Features - -* SSL -* Handles bad connections for `database/sql` -* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`) -* Scan binary blobs correctly (i.e. `bytea`) -* Package for `hstore` support -* COPY FROM support -* pq.ParseURL for converting urls to connection strings for sql.Open. -* Many libpq compatible environment variables -* Unix socket support -* Notifications: `LISTEN`/`NOTIFY` -* pgpass support -* GSS (Kerberos) auth - -## Tests - -`go test` is used for testing. See [TESTS.md](TESTS.md) for more details. - -## Status - -This package is currently in maintenance mode, which means: -1. It generally does not accept new features. -2. It does accept bug fixes and version compatability changes provided by the community. -3. Maintainers usually do not resolve reported issues. -4. Community members are encouraged to help each other with reported issues. - -For users that require new features or reliable resolution of reported bugs, we recommend using [pgx](https://github.com/jackc/pgx) which is under active development. diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/TESTS.md b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/TESTS.md deleted file mode 100644 index f05021115..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/TESTS.md +++ /dev/null @@ -1,33 +0,0 @@ -# Tests - -## Running Tests - -`go test` is used for testing. A running PostgreSQL -server is required, with the ability to log in. The -database to connect to test with is "pqgotest," on -"localhost" but these can be overridden using [environment -variables](https://www.postgresql.org/docs/9.3/static/libpq-envars.html). - -Example: - - PGHOST=/run/postgresql go test - -## Benchmarks - -A benchmark suite can be run as part of the tests: - - go test -bench . - -## Example setup (Docker) - -Run a postgres container: - -``` -docker run --expose 5432:5432 postgres -``` - -Run tests: - -``` -PGHOST=localhost PGPORT=5432 PGUSER=postgres PGSSLMODE=disable PGDATABASE=postgres go test -``` diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/array.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/array.go deleted file mode 100644 index 39c8f7e2e..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/array.go +++ /dev/null @@ -1,895 +0,0 @@ -package pq - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding/hex" - "fmt" - "reflect" - "strconv" - "strings" -) - -var typeByteSlice = reflect.TypeOf([]byte{}) -var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() -var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() - -// Array returns the optimal driver.Valuer and sql.Scanner for an array or -// slice of any dimension. -// -// For example: -// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) -// -// var x []sql.NullInt64 -// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) -// -// Scanning multi-dimensional arrays is not supported. Arrays where the lower -// bound is not one (such as `[0:0]={1}') are not supported. -func Array(a interface{}) interface { - driver.Valuer - sql.Scanner -} { - switch a := a.(type) { - case []bool: - return (*BoolArray)(&a) - case []float64: - return (*Float64Array)(&a) - case []float32: - return (*Float32Array)(&a) - case []int64: - return (*Int64Array)(&a) - case []int32: - return (*Int32Array)(&a) - case []string: - return (*StringArray)(&a) - case [][]byte: - return (*ByteaArray)(&a) - - case *[]bool: - return (*BoolArray)(a) - case *[]float64: - return (*Float64Array)(a) - case *[]float32: - return (*Float32Array)(a) - case *[]int64: - return (*Int64Array)(a) - case *[]int32: - return (*Int32Array)(a) - case *[]string: - return (*StringArray)(a) - case *[][]byte: - return (*ByteaArray)(a) - } - - return GenericArray{a} -} - -// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner -// to override the array delimiter used by GenericArray. -type ArrayDelimiter interface { - // ArrayDelimiter returns the delimiter character(s) for this element's type. - ArrayDelimiter() string -} - -// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. -type BoolArray []bool - -// Scan implements the sql.Scanner interface. -func (a *BoolArray) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to BoolArray", src) -} - -func (a *BoolArray) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "BoolArray") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(BoolArray, len(elems)) - for i, v := range elems { - if len(v) != 1 { - return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) - } - switch v[0] { - case 't': - b[i] = true - case 'f': - b[i] = false - default: - return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a BoolArray) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be exactly two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1+2*n) - - for i := 0; i < n; i++ { - b[2*i] = ',' - if a[i] { - b[1+2*i] = 't' - } else { - b[1+2*i] = 'f' - } - } - - b[0] = '{' - b[2*n] = '}' - - return string(b), nil - } - - return "{}", nil -} - -// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. -type ByteaArray [][]byte - -// Scan implements the sql.Scanner interface. -func (a *ByteaArray) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) -} - -func (a *ByteaArray) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(ByteaArray, len(elems)) - for i, v := range elems { - b[i], err = parseBytea(v) - if err != nil { - return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. It uses the "hex" format which -// is only supported on PostgreSQL 9.0 or newer. -func (a ByteaArray) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, 2*N bytes of quotes, - // 3*N bytes of hex formatting, and N-1 bytes of delimiters. - size := 1 + 6*n - for _, x := range a { - size += hex.EncodedLen(len(x)) - } - - b := make([]byte, size) - - for i, s := 0, b; i < n; i++ { - o := copy(s, `,"\\x`) - o += hex.Encode(s[o:], a[i]) - s[o] = '"' - s = s[o+1:] - } - - b[0] = '{' - b[size-1] = '}' - - return string(b), nil - } - - return "{}", nil -} - -// Float64Array represents a one-dimensional array of the PostgreSQL double -// precision type. -type Float64Array []float64 - -// Scan implements the sql.Scanner interface. -func (a *Float64Array) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to Float64Array", src) -} - -func (a *Float64Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Float64Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Float64Array, len(elems)) - for i, v := range elems { - if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a Float64Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendFloat(b, a[0], 'f', -1, 64) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendFloat(b, a[i], 'f', -1, 64) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} - -// Float32Array represents a one-dimensional array of the PostgreSQL double -// precision type. -type Float32Array []float32 - -// Scan implements the sql.Scanner interface. -func (a *Float32Array) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to Float32Array", src) -} - -func (a *Float32Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Float32Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Float32Array, len(elems)) - for i, v := range elems { - var x float64 - if x, err = strconv.ParseFloat(string(v), 32); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) - } - b[i] = float32(x) - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a Float32Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} - -// GenericArray implements the driver.Valuer and sql.Scanner interfaces for -// an array or slice of any dimension. -type GenericArray struct{ A interface{} } - -func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { - var assign func([]byte, reflect.Value) error - var del = "," - - // TODO calculate the assign function for other types - // TODO repeat this section on the element type of arrays or slices (multidimensional) - { - if reflect.PtrTo(rt).Implements(typeSQLScanner) { - // dest is always addressable because it is an element of a slice. - assign = func(src []byte, dest reflect.Value) (err error) { - ss := dest.Addr().Interface().(sql.Scanner) - if src == nil { - err = ss.Scan(nil) - } else { - err = ss.Scan(src) - } - return - } - goto FoundType - } - - assign = func([]byte, reflect.Value) error { - return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) - } - } - -FoundType: - - if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { - del = ad.ArrayDelimiter() - } - - return rt, assign, del -} - -// Scan implements the sql.Scanner interface. -func (a GenericArray) Scan(src interface{}) error { - dpv := reflect.ValueOf(a.A) - switch { - case dpv.Kind() != reflect.Ptr: - return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) - case dpv.IsNil(): - return fmt.Errorf("pq: destination %T is nil", a.A) - } - - dv := dpv.Elem() - switch dv.Kind() { - case reflect.Slice: - case reflect.Array: - default: - return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) - } - - switch src := src.(type) { - case []byte: - return a.scanBytes(src, dv) - case string: - return a.scanBytes([]byte(src), dv) - case nil: - if dv.Kind() == reflect.Slice { - dv.Set(reflect.Zero(dv.Type())) - return nil - } - } - - return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) -} - -func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { - dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) - dims, elems, err := parseArray(src, []byte(del)) - if err != nil { - return err - } - - // TODO allow multidimensional - - if len(dims) > 1 { - return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", - strings.Replace(fmt.Sprint(dims), " ", "][", -1)) - } - - // Treat a zero-dimensional array like an array with a single dimension of zero. - if len(dims) == 0 { - dims = append(dims, 0) - } - - for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { - switch rt.Kind() { - case reflect.Slice: - case reflect.Array: - if rt.Len() != dims[i] { - return fmt.Errorf("pq: cannot convert ARRAY%s to %s", - strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) - } - default: - // TODO handle multidimensional - } - } - - values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) - for i, e := range elems { - if err := assign(e, values.Index(i)); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) - } - } - - // TODO handle multidimensional - - switch dv.Kind() { - case reflect.Slice: - dv.Set(values.Slice(0, dims[0])) - case reflect.Array: - for i := 0; i < dims[0]; i++ { - dv.Index(i).Set(values.Index(i)) - } - } - - return nil -} - -// Value implements the driver.Valuer interface. -func (a GenericArray) Value() (driver.Value, error) { - if a.A == nil { - return nil, nil - } - - rv := reflect.ValueOf(a.A) - - switch rv.Kind() { - case reflect.Slice: - if rv.IsNil() { - return nil, nil - } - case reflect.Array: - default: - return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) - } - - if n := rv.Len(); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 0, 1+2*n) - - b, _, err := appendArray(b, rv, n) - return string(b), err - } - - return "{}", nil -} - -// Int64Array represents a one-dimensional array of the PostgreSQL integer types. -type Int64Array []int64 - -// Scan implements the sql.Scanner interface. -func (a *Int64Array) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to Int64Array", src) -} - -func (a *Int64Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Int64Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Int64Array, len(elems)) - for i, v := range elems { - if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a Int64Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendInt(b, a[0], 10) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendInt(b, a[i], 10) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} - -// Int32Array represents a one-dimensional array of the PostgreSQL integer types. -type Int32Array []int32 - -// Scan implements the sql.Scanner interface. -func (a *Int32Array) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to Int32Array", src) -} - -func (a *Int32Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Int32Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Int32Array, len(elems)) - for i, v := range elems { - x, err := strconv.ParseInt(string(v), 10, 32) - if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) - } - b[i] = int32(x) - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a Int32Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendInt(b, int64(a[0]), 10) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendInt(b, int64(a[i]), 10) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} - -// StringArray represents a one-dimensional array of the PostgreSQL character types. -type StringArray []string - -// Scan implements the sql.Scanner interface. -func (a *StringArray) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to StringArray", src) -} - -func (a *StringArray) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "StringArray") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(StringArray, len(elems)) - for i, v := range elems { - if b[i] = string(v); v == nil { - return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a StringArray) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, 2*N bytes of quotes, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+3*n) - b[0] = '{' - - b = appendArrayQuotedBytes(b, []byte(a[0])) - for i := 1; i < n; i++ { - b = append(b, ',') - b = appendArrayQuotedBytes(b, []byte(a[i])) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} - -// appendArray appends rv to the buffer, returning the extended buffer and -// the delimiter used between elements. -// -// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. -func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { - var del string - var err error - - b = append(b, '{') - - if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { - return b, del, err - } - - for i := 1; i < n; i++ { - b = append(b, del...) - if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { - return b, del, err - } - } - - return append(b, '}'), del, nil -} - -// appendArrayElement appends rv to the buffer, returning the extended buffer -// and the delimiter to use before the next element. -// -// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted -// using driver.DefaultParameterConverter and the resulting []byte or string -// is double-quoted. -// -// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO -func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { - if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { - if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { - if n := rv.Len(); n > 0 { - return appendArray(b, rv, n) - } - - return b, "", nil - } - } - - var del = "," - var err error - var iv interface{} = rv.Interface() - - if ad, ok := iv.(ArrayDelimiter); ok { - del = ad.ArrayDelimiter() - } - - if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { - return b, del, err - } - - switch v := iv.(type) { - case nil: - return append(b, "NULL"...), del, nil - case []byte: - return appendArrayQuotedBytes(b, v), del, nil - case string: - return appendArrayQuotedBytes(b, []byte(v)), del, nil - } - - b, err = appendValue(b, iv) - return b, del, err -} - -func appendArrayQuotedBytes(b, v []byte) []byte { - b = append(b, '"') - for { - i := bytes.IndexAny(v, `"\`) - if i < 0 { - b = append(b, v...) - break - } - if i > 0 { - b = append(b, v[:i]...) - } - b = append(b, '\\', v[i]) - v = v[i+1:] - } - return append(b, '"') -} - -func appendValue(b []byte, v driver.Value) ([]byte, error) { - return append(b, encode(nil, v, 0)...), nil -} - -// parseArray extracts the dimensions and elements of an array represented in -// text format. Only representations emitted by the backend are supported. -// Notably, whitespace around brackets and delimiters is significant, and NULL -// is case-sensitive. -// -// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO -func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { - var depth, i int - - if len(src) < 1 || src[0] != '{' { - return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) - } - -Open: - for i < len(src) { - switch src[i] { - case '{': - depth++ - i++ - case '}': - elems = make([][]byte, 0) - goto Close - default: - break Open - } - } - dims = make([]int, i) - -Element: - for i < len(src) { - switch src[i] { - case '{': - if depth == len(dims) { - break Element - } - depth++ - dims[depth-1] = 0 - i++ - case '"': - var elem = []byte{} - var escape bool - for i++; i < len(src); i++ { - if escape { - elem = append(elem, src[i]) - escape = false - } else { - switch src[i] { - default: - elem = append(elem, src[i]) - case '\\': - escape = true - case '"': - elems = append(elems, elem) - i++ - break Element - } - } - } - default: - for start := i; i < len(src); i++ { - if bytes.HasPrefix(src[i:], del) || src[i] == '}' { - elem := src[start:i] - if len(elem) == 0 { - return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) - } - if bytes.Equal(elem, []byte("NULL")) { - elem = nil - } - elems = append(elems, elem) - break Element - } - } - } - } - - for i < len(src) { - if bytes.HasPrefix(src[i:], del) && depth > 0 { - dims[depth-1]++ - i += len(del) - goto Element - } else if src[i] == '}' && depth > 0 { - dims[depth-1]++ - depth-- - i++ - } else { - return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) - } - } - -Close: - for i < len(src) { - if src[i] == '}' && depth > 0 { - depth-- - i++ - } else { - return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) - } - } - if depth > 0 { - err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) - } - if err == nil { - for _, d := range dims { - if (len(elems) % d) != 0 { - err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") - } - } - } - return -} - -func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { - dims, elems, err := parseArray(src, del) - if err != nil { - return nil, err - } - if len(dims) > 1 { - return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) - } - return elems, err -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/buf.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/buf.go deleted file mode 100644 index 4b0a0a8f7..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/buf.go +++ /dev/null @@ -1,91 +0,0 @@ -package pq - -import ( - "bytes" - "encoding/binary" - - "github.com/lib/pq/oid" -) - -type readBuf []byte - -func (b *readBuf) int32() (n int) { - n = int(int32(binary.BigEndian.Uint32(*b))) - *b = (*b)[4:] - return -} - -func (b *readBuf) oid() (n oid.Oid) { - n = oid.Oid(binary.BigEndian.Uint32(*b)) - *b = (*b)[4:] - return -} - -// N.B: this is actually an unsigned 16-bit integer, unlike int32 -func (b *readBuf) int16() (n int) { - n = int(binary.BigEndian.Uint16(*b)) - *b = (*b)[2:] - return -} - -func (b *readBuf) string() string { - i := bytes.IndexByte(*b, 0) - if i < 0 { - errorf("invalid message format; expected string terminator") - } - s := (*b)[:i] - *b = (*b)[i+1:] - return string(s) -} - -func (b *readBuf) next(n int) (v []byte) { - v = (*b)[:n] - *b = (*b)[n:] - return -} - -func (b *readBuf) byte() byte { - return b.next(1)[0] -} - -type writeBuf struct { - buf []byte - pos int -} - -func (b *writeBuf) int32(n int) { - x := make([]byte, 4) - binary.BigEndian.PutUint32(x, uint32(n)) - b.buf = append(b.buf, x...) -} - -func (b *writeBuf) int16(n int) { - x := make([]byte, 2) - binary.BigEndian.PutUint16(x, uint16(n)) - b.buf = append(b.buf, x...) -} - -func (b *writeBuf) string(s string) { - b.buf = append(append(b.buf, s...), '\000') -} - -func (b *writeBuf) byte(c byte) { - b.buf = append(b.buf, c) -} - -func (b *writeBuf) bytes(v []byte) { - b.buf = append(b.buf, v...) -} - -func (b *writeBuf) wrap() []byte { - p := b.buf[b.pos:] - binary.BigEndian.PutUint32(p, uint32(len(p))) - return b.buf -} - -func (b *writeBuf) next(c byte) { - p := b.buf[b.pos:] - binary.BigEndian.PutUint32(p, uint32(len(p))) - b.pos = len(b.buf) + 1 - b.buf = append(b.buf, c, 0, 0, 0, 0) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn.go deleted file mode 100644 index da4ff9de6..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn.go +++ /dev/null @@ -1,2112 +0,0 @@ -package pq - -import ( - "bufio" - "bytes" - "context" - "crypto/md5" - "crypto/sha256" - "database/sql" - "database/sql/driver" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "os" - "os/user" - "path" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - "unicode" - - "github.com/lib/pq/oid" - "github.com/lib/pq/scram" -) - -// Common error types -var ( - ErrNotSupported = errors.New("pq: Unsupported command") - ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") - ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") - - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") - - errUnexpectedReady = errors.New("unexpected ReadyForQuery") - errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") - errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") -) - -// Compile time validation that our types implement the expected interfaces -var ( - _ driver.Driver = Driver{} -) - -// Driver is the Postgres database driver. -type Driver struct{} - -// Open opens a new connection to the database. name is a connection string. -// Most users should only use it through database/sql package from the standard -// library. -func (d Driver) Open(name string) (driver.Conn, error) { - return Open(name) -} - -func init() { - sql.Register("postgres", &Driver{}) -} - -type parameterStatus struct { - // server version in the same format as server_version_num, or 0 if - // unavailable - serverVersion int - - // the current location based on the TimeZone value of the session, if - // available - currentLocation *time.Location -} - -type transactionStatus byte - -const ( - txnStatusIdle transactionStatus = 'I' - txnStatusIdleInTransaction transactionStatus = 'T' - txnStatusInFailedTransaction transactionStatus = 'E' -) - -func (s transactionStatus) String() string { - switch s { - case txnStatusIdle: - return "idle" - case txnStatusIdleInTransaction: - return "idle in transaction" - case txnStatusInFailedTransaction: - return "in a failed transaction" - default: - errorf("unknown transactionStatus %d", s) - } - - panic("not reached") -} - -// Dialer is the dialer interface. It can be used to obtain more control over -// how pq creates network connections. -type Dialer interface { - Dial(network, address string) (net.Conn, error) - DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) -} - -// DialerContext is the context-aware dialer interface. -type DialerContext interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - -type defaultDialer struct { - d net.Dialer -} - -func (d defaultDialer) Dial(network, address string) (net.Conn, error) { - return d.d.Dial(network, address) -} -func (d defaultDialer) DialTimeout( - network, address string, timeout time.Duration, -) (net.Conn, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.DialContext(ctx, network, address) -} -func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return d.d.DialContext(ctx, network, address) -} - -type conn struct { - c net.Conn - buf *bufio.Reader - namei int - scratch [512]byte - txnStatus transactionStatus - txnFinish func() - - // Save connection arguments to use during CancelRequest. - dialer Dialer - opts values - - // Cancellation key data for use with CancelRequest messages. - processID int - secretKey int - - parameterStatus parameterStatus - - saveMessageType byte - saveMessageBuffer []byte - - // If an error is set, this connection is bad and all public-facing - // functions should return the appropriate error by calling get() - // (ErrBadConn) or getForNext(). - err syncErr - - // If set, this connection should never use the binary format when - // receiving query results from prepared statements. Only provided for - // debugging. - disablePreparedBinaryResult bool - - // Whether to always send []byte parameters over as binary. Enables single - // round-trip mode for non-prepared Query calls. - binaryParameters bool - - // If true this connection is in the middle of a COPY - inCopy bool - - // If not nil, notices will be synchronously sent here - noticeHandler func(*Error) - - // If not nil, notifications will be synchronously sent here - notificationHandler func(*Notification) - - // GSSAPI context - gss GSS -} - -type syncErr struct { - err error - sync.Mutex -} - -// Return ErrBadConn if connection is bad. -func (e *syncErr) get() error { - e.Lock() - defer e.Unlock() - if e.err != nil { - return driver.ErrBadConn - } - return nil -} - -// Return the error set on the connection. Currently only used by rows.Next. -func (e *syncErr) getForNext() error { - e.Lock() - defer e.Unlock() - return e.err -} - -// Set error, only if it isn't set yet. -func (e *syncErr) set(err error) { - if err == nil { - panic("attempt to set nil err") - } - e.Lock() - defer e.Unlock() - if e.err == nil { - e.err = err - } -} - -// Handle driver-side settings in parsed connection string. -func (cn *conn) handleDriverSettings(o values) (err error) { - boolSetting := func(key string, val *bool) error { - if value, ok := o[key]; ok { - if value == "yes" { - *val = true - } else if value == "no" { - *val = false - } else { - return fmt.Errorf("unrecognized value %q for %s", value, key) - } - } - return nil - } - - err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) - if err != nil { - return err - } - return boolSetting("binary_parameters", &cn.binaryParameters) -} - -func (cn *conn) handlePgpass(o values) { - // if a password was supplied, do not process .pgpass - if _, ok := o["password"]; ok { - return - } - filename := os.Getenv("PGPASSFILE") - if filename == "" { - // XXX this code doesn't work on Windows where the default filename is - // XXX %APPDATA%\postgresql\pgpass.conf - // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 - userHome := os.Getenv("HOME") - if userHome == "" { - user, err := user.Current() - if err != nil { - return - } - userHome = user.HomeDir - } - filename = filepath.Join(userHome, ".pgpass") - } - fileinfo, err := os.Stat(filename) - if err != nil { - return - } - mode := fileinfo.Mode() - if mode&(0x77) != 0 { - // XXX should warn about incorrect .pgpass permissions as psql does - return - } - file, err := os.Open(filename) - if err != nil { - return - } - defer file.Close() - scanner := bufio.NewScanner(io.Reader(file)) - // From: https://github.com/tg/pgpass/blob/master/reader.go - for scanner.Scan() { - if scanText(scanner.Text(), o) { - break - } - } -} - -// GetFields is a helper function for scanText. -func getFields(s string) []string { - fs := make([]string, 0, 5) - f := make([]rune, 0, len(s)) - - var esc bool - for _, c := range s { - switch { - case esc: - f = append(f, c) - esc = false - case c == '\\': - esc = true - case c == ':': - fs = append(fs, string(f)) - f = f[:0] - default: - f = append(f, c) - } - } - return append(fs, string(f)) -} - -// ScanText assists HandlePgpass in it's objective. -func scanText(line string, o values) bool { - hostname := o["host"] - ntw, _ := network(o) - port := o["port"] - db := o["dbname"] - username := o["user"] - if len(line) == 0 || line[0] == '#' { - return false - } - split := getFields(line) - if len(split) != 5 { - return false - } - if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { - o["password"] = split[4] - return true - } - return false -} - -func (cn *conn) writeBuf(b byte) *writeBuf { - cn.scratch[0] = b - return &writeBuf{ - buf: cn.scratch[:5], - pos: 1, - } -} - -// Open opens a new connection to the database. dsn is a connection string. -// Most users should only use it through database/sql package from the standard -// library. -func Open(dsn string) (_ driver.Conn, err error) { - return DialOpen(defaultDialer{}, dsn) -} - -// DialOpen opens a new connection to the database using a dialer. -func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { - c, err := NewConnector(dsn) - if err != nil { - return nil, err - } - c.Dialer(d) - return c.open(context.Background()) -} - -func (c *Connector) open(ctx context.Context) (cn *conn, err error) { - // Handle any panics during connection initialization. Note that we - // specifically do *not* want to use errRecover(), as that would turn any - // connection errors into ErrBadConns, hiding the real error message from - // the user. - defer errRecoverNoErrBadConn(&err) - - // Create a new values map (copy). This makes it so maps in different - // connections do not reference the same underlying data structure, so it - // is safe for multiple connections to concurrently write to their opts. - o := make(values) - for k, v := range c.opts { - o[k] = v - } - - cn = &conn{ - opts: o, - dialer: c.dialer, - } - err = cn.handleDriverSettings(o) - if err != nil { - return nil, err - } - cn.handlePgpass(o) - - cn.c, err = dial(ctx, c.dialer, o) - if err != nil { - return nil, err - } - - err = cn.ssl(o) - if err != nil { - if cn.c != nil { - cn.c.Close() - } - return nil, err - } - - // cn.startup panics on error. Make sure we don't leak cn.c. - panicking := true - defer func() { - if panicking { - cn.c.Close() - } - }() - - cn.buf = bufio.NewReader(cn.c) - cn.startup(o) - - // reset the deadline, in case one was set (see dial) - if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { - err = cn.c.SetDeadline(time.Time{}) - } - panicking = false - return cn, err -} - -func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { - network, address := network(o) - - // Zero or not specified means wait indefinitely. - if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { - seconds, err := strconv.ParseInt(timeout, 10, 0) - if err != nil { - return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) - } - duration := time.Duration(seconds) * time.Second - - // connect_timeout should apply to the entire connection establishment - // procedure, so we both use a timeout for the TCP connection - // establishment and set a deadline for doing the initial handshake. - // The deadline is then reset after startup() is done. - deadline := time.Now().Add(duration) - var conn net.Conn - if dctx, ok := d.(DialerContext); ok { - ctx, cancel := context.WithTimeout(ctx, duration) - defer cancel() - conn, err = dctx.DialContext(ctx, network, address) - } else { - conn, err = d.DialTimeout(network, address, duration) - } - if err != nil { - return nil, err - } - err = conn.SetDeadline(deadline) - return conn, err - } - if dctx, ok := d.(DialerContext); ok { - return dctx.DialContext(ctx, network, address) - } - return d.Dial(network, address) -} - -func network(o values) (string, string) { - host := o["host"] - - if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o["port"]) - return "unix", sockPath - } - - return "tcp", net.JoinHostPort(host, o["port"]) -} - -type values map[string]string - -// scanner implements a tokenizer for libpq-style option strings. -type scanner struct { - s []rune - i int -} - -// newScanner returns a new scanner initialized with the option string s. -func newScanner(s string) *scanner { - return &scanner{[]rune(s), 0} -} - -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) Next() (rune, bool) { - if s.i >= len(s.s) { - return 0, false - } - r := s.s[s.i] - s.i++ - return r, true -} - -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) SkipSpaces() (rune, bool) { - r, ok := s.Next() - for unicode.IsSpace(r) && ok { - r, ok = s.Next() - } - return r, ok -} - -// parseOpts parses the options from name and adds them to the values. -// -// The parsing code is based on conninfo_parse from libpq's fe-connect.c -func parseOpts(name string, o values) error { - s := newScanner(name) - - for { - var ( - keyRunes, valRunes []rune - r rune - ok bool - ) - - if r, ok = s.SkipSpaces(); !ok { - break - } - - // Scan the key - for !unicode.IsSpace(r) && r != '=' { - keyRunes = append(keyRunes, r) - if r, ok = s.Next(); !ok { - break - } - } - - // Skip any whitespace if we're not at the = yet - if r != '=' { - r, ok = s.SkipSpaces() - } - - // The current character should be = - if r != '=' || !ok { - return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) - } - - // Skip any whitespace after the = - if r, ok = s.SkipSpaces(); !ok { - // If we reach the end here, the last value is just an empty string as per libpq. - o[string(keyRunes)] = "" - break - } - - if r != '\'' { - for !unicode.IsSpace(r) { - if r == '\\' { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`missing character after backslash`) - } - } - valRunes = append(valRunes, r) - - if r, ok = s.Next(); !ok { - break - } - } - } else { - quote: - for { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`unterminated quoted string literal in connection string`) - } - switch r { - case '\'': - break quote - case '\\': - r, _ = s.Next() - fallthrough - default: - valRunes = append(valRunes, r) - } - } - } - - o[string(keyRunes)] = string(valRunes) - } - - return nil -} - -func (cn *conn) isInTransaction() bool { - return cn.txnStatus == txnStatusIdleInTransaction || - cn.txnStatus == txnStatusInFailedTransaction -} - -func (cn *conn) checkIsInTransaction(intxn bool) { - if cn.isInTransaction() != intxn { - cn.err.set(driver.ErrBadConn) - errorf("unexpected transaction status %v", cn.txnStatus) - } -} - -func (cn *conn) Begin() (_ driver.Tx, err error) { - return cn.begin("") -} - -func (cn *conn) begin(mode string) (_ driver.Tx, err error) { - if err := cn.err.get(); err != nil { - return nil, err - } - defer cn.errRecover(&err) - - cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN" + mode) - if err != nil { - return nil, err - } - if commandTag != "BEGIN" { - cn.err.set(driver.ErrBadConn) - return nil, fmt.Errorf("unexpected command tag %s", commandTag) - } - if cn.txnStatus != txnStatusIdleInTransaction { - cn.err.set(driver.ErrBadConn) - return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) - } - return cn, nil -} - -func (cn *conn) closeTxn() { - if finish := cn.txnFinish; finish != nil { - finish() - } -} - -func (cn *conn) Commit() (err error) { - defer cn.closeTxn() - if err := cn.err.get(); err != nil { - return err - } - defer cn.errRecover(&err) - - cn.checkIsInTransaction(true) - // We don't want the client to think that everything is okay if it tries - // to commit a failed transaction. However, no matter what we return, - // database/sql will release this connection back into the free connection - // pool so we have to abort the current transaction here. Note that you - // would get the same behaviour if you issued a COMMIT in a failed - // transaction, so it's also the least surprising thing to do here. - if cn.txnStatus == txnStatusInFailedTransaction { - if err := cn.rollback(); err != nil { - return err - } - return ErrInFailedTransaction - } - - _, commandTag, err := cn.simpleExec("COMMIT") - if err != nil { - if cn.isInTransaction() { - cn.err.set(driver.ErrBadConn) - } - return err - } - if commandTag != "COMMIT" { - cn.err.set(driver.ErrBadConn) - return fmt.Errorf("unexpected command tag %s", commandTag) - } - cn.checkIsInTransaction(false) - return nil -} - -func (cn *conn) Rollback() (err error) { - defer cn.closeTxn() - if err := cn.err.get(); err != nil { - return err - } - defer cn.errRecover(&err) - return cn.rollback() -} - -func (cn *conn) rollback() (err error) { - cn.checkIsInTransaction(true) - _, commandTag, err := cn.simpleExec("ROLLBACK") - if err != nil { - if cn.isInTransaction() { - cn.err.set(driver.ErrBadConn) - } - return err - } - if commandTag != "ROLLBACK" { - return fmt.Errorf("unexpected command tag %s", commandTag) - } - cn.checkIsInTransaction(false) - return nil -} - -func (cn *conn) gname() string { - cn.namei++ - return strconv.FormatInt(int64(cn.namei), 10) -} - -func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { - b := cn.writeBuf('Q') - b.string(q) - cn.send(b) - - for { - t, r := cn.recv1() - switch t { - case 'C': - res, commandTag = cn.parseComplete(r.string()) - case 'Z': - cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady - } - // done - return - case 'E': - err = parseError(r) - case 'I': - res = emptyRows - case 'T', 'D': - // ignore any results - default: - cn.err.set(driver.ErrBadConn) - errorf("unknown response for simple query: %q", t) - } - } -} - -func (cn *conn) simpleQuery(q string) (res *rows, err error) { - defer cn.errRecover(&err) - - b := cn.writeBuf('Q') - b.string(q) - cn.send(b) - - for { - t, r := cn.recv1() - switch t { - case 'C', 'I': - // We allow queries which don't return any results through Query as - // well as Exec. We still have to give database/sql a rows object - // the user can close, though, to avoid connections from being - // leaked. A "rows" with done=true works fine for that purpose. - if err != nil { - cn.err.set(driver.ErrBadConn) - errorf("unexpected message %q in simple query execution", t) - } - if res == nil { - res = &rows{ - cn: cn, - } - } - // Set the result and tag to the last command complete if there wasn't a - // query already run. Although queries usually return from here and cede - // control to Next, a query with zero results does not. - if t == 'C' { - res.result, res.tag = cn.parseComplete(r.string()) - if res.colNames != nil { - return - } - } - res.done = true - case 'Z': - cn.processReadyForQuery(r) - // done - return - case 'E': - res = nil - err = parseError(r) - case 'D': - if res == nil { - cn.err.set(driver.ErrBadConn) - errorf("unexpected DataRow in simple query execution") - } - // the query didn't fail; kick off to Next - cn.saveMessage(t, r) - return - case 'T': - // res might be non-nil here if we received a previous - // CommandComplete, but that's fine; just overwrite it - res = &rows{cn: cn} - res.rowsHeader = parsePortalRowDescribe(r) - - // To work around a bug in QueryRow in Go 1.2 and earlier, wait - // until the first DataRow has been received. - default: - cn.err.set(driver.ErrBadConn) - errorf("unknown response for simple query: %q", t) - } - } -} - -type noRows struct{} - -var emptyRows noRows - -var _ driver.Result = noRows{} - -func (noRows) LastInsertId() (int64, error) { - return 0, errNoLastInsertID -} - -func (noRows) RowsAffected() (int64, error) { - return 0, errNoRowsAffected -} - -// Decides which column formats to use for a prepared statement. The input is -// an array of type oids, one element per result column. -func decideColumnFormats( - colTyps []fieldDesc, forceText bool, -) (colFmts []format, colFmtData []byte) { - if len(colTyps) == 0 { - return nil, colFmtDataAllText - } - - colFmts = make([]format, len(colTyps)) - if forceText { - return colFmts, colFmtDataAllText - } - - allBinary := true - allText := true - for i, t := range colTyps { - switch t.OID { - // This is the list of types to use binary mode for when receiving them - // through a prepared statement. If a type appears in this list, it - // must also be implemented in binaryDecode in encode.go. - case oid.T_bytea: - fallthrough - case oid.T_int8: - fallthrough - case oid.T_int4: - fallthrough - case oid.T_int2: - fallthrough - case oid.T_uuid: - colFmts[i] = formatBinary - allText = false - - default: - allBinary = false - } - } - - if allBinary { - return colFmts, colFmtDataAllBinary - } else if allText { - return colFmts, colFmtDataAllText - } else { - colFmtData = make([]byte, 2+len(colFmts)*2) - binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) - for i, v := range colFmts { - binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) - } - return colFmts, colFmtData - } -} - -func (cn *conn) prepareTo(q, stmtName string) *stmt { - st := &stmt{cn: cn, name: stmtName} - - b := cn.writeBuf('P') - b.string(st.name) - b.string(q) - b.int16(0) - - b.next('D') - b.byte('S') - b.string(st.name) - - b.next('S') - cn.send(b) - - cn.readParseResponse() - st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() - st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) - cn.readReadyForQuery() - return st -} - -func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if err := cn.err.get(); err != nil { - return nil, err - } - defer cn.errRecover(&err) - - if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - s, err := cn.prepareCopyIn(q) - if err == nil { - cn.inCopy = true - } - return s, err - } - return cn.prepareTo(q, cn.gname()), nil -} - -func (cn *conn) Close() (err error) { - // Skip cn.bad return here because we always want to close a connection. - defer cn.errRecover(&err) - - // Ensure that cn.c.Close is always run. Since error handling is done with - // panics and cn.errRecover, the Close must be in a defer. - defer func() { - cerr := cn.c.Close() - if err == nil { - err = cerr - } - }() - - // Don't go through send(); ListenerConn relies on us not scribbling on the - // scratch buffer of this connection. - return cn.sendSimpleMessage('X') -} - -// Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { - return cn.query(query, args) -} - -func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { - if err := cn.err.get(); err != nil { - return nil, err - } - if cn.inCopy { - return nil, errCopyInProgress - } - defer cn.errRecover(&err) - - // Check to see if we can use the "simpleQuery" interface, which is - // *much* faster than going through prepare/exec - if len(args) == 0 { - return cn.simpleQuery(query) - } - - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) - - cn.readParseResponse() - cn.readBindResponse() - rows := &rows{cn: cn} - rows.rowsHeader = cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() - return rows, nil - } - st := cn.prepareTo(query, "") - st.exec(args) - return &rows{ - cn: cn, - rowsHeader: st.rowsHeader, - }, nil -} - -// Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if err := cn.err.get(); err != nil { - return nil, err - } - defer cn.errRecover(&err) - - // Check to see if we can use the "simpleExec" interface, which is - // *much* faster than going through prepare/exec - if len(args) == 0 { - // ignore commandTag, our caller doesn't care - r, _, err := cn.simpleExec(query) - return r, err - } - - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) - - cn.readParseResponse() - cn.readBindResponse() - cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() - res, _, err = cn.readExecuteResponse("Execute") - return res, err - } - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") - r, err := st.Exec(args) - if err != nil { - panic(err) - } - return r, err -} - -type safeRetryError struct { - Err error -} - -func (se *safeRetryError) Error() string { - return se.Err.Error() -} - -func (cn *conn) send(m *writeBuf) { - n, err := cn.c.Write(m.wrap()) - if err != nil { - if n == 0 { - err = &safeRetryError{Err: err} - } - panic(err) - } -} - -func (cn *conn) sendStartupPacket(m *writeBuf) error { - _, err := cn.c.Write((m.wrap())[1:]) - return err -} - -// Send a message of type typ to the server on the other end of cn. The -// message should have no payload. This method does not use the scratch -// buffer. -func (cn *conn) sendSimpleMessage(typ byte) (err error) { - _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) - return err -} - -// saveMessage memorizes a message and its buffer in the conn struct. -// recvMessage will then return these values on the next call to it. This -// method is useful in cases where you have to see what the next message is -// going to be (e.g. to see whether it's an error or not) but you can't handle -// the message yourself. -func (cn *conn) saveMessage(typ byte, buf *readBuf) { - if cn.saveMessageType != 0 { - cn.err.set(driver.ErrBadConn) - errorf("unexpected saveMessageType %d", cn.saveMessageType) - } - cn.saveMessageType = typ - cn.saveMessageBuffer = *buf -} - -// recvMessage receives any message from the backend, or returns an error if -// a problem occurred while reading the message. -func (cn *conn) recvMessage(r *readBuf) (byte, error) { - // workaround for a QueryRow bug, see exec - if cn.saveMessageType != 0 { - t := cn.saveMessageType - *r = cn.saveMessageBuffer - cn.saveMessageType = 0 - cn.saveMessageBuffer = nil - return t, nil - } - - x := cn.scratch[:5] - _, err := io.ReadFull(cn.buf, x) - if err != nil { - return 0, err - } - - // read the type and length of the message that follows - t := x[0] - n := int(binary.BigEndian.Uint32(x[1:])) - 4 - var y []byte - if n <= len(cn.scratch) { - y = cn.scratch[:n] - } else { - y = make([]byte, n) - } - _, err = io.ReadFull(cn.buf, y) - if err != nil { - return 0, err - } - *r = y - return t, nil -} - -// recv receives a message from the backend, but if an error happened while -// reading the message or the received message was an ErrorResponse, it panics. -// NoticeResponses are ignored. This function should generally be used only -// during the startup sequence. -func (cn *conn) recv() (t byte, r *readBuf) { - for { - var err error - r = &readBuf{} - t, err = cn.recvMessage(r) - if err != nil { - panic(err) - } - switch t { - case 'E': - panic(parseError(r)) - case 'N': - if n := cn.noticeHandler; n != nil { - n(parseError(r)) - } - case 'A': - if n := cn.notificationHandler; n != nil { - n(recvNotification(r)) - } - default: - return - } - } -} - -// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by -// the caller to avoid an allocation. -func (cn *conn) recv1Buf(r *readBuf) byte { - for { - t, err := cn.recvMessage(r) - if err != nil { - panic(err) - } - - switch t { - case 'A': - if n := cn.notificationHandler; n != nil { - n(recvNotification(r)) - } - case 'N': - if n := cn.noticeHandler; n != nil { - n(parseError(r)) - } - case 'S': - cn.processParameterStatus(r) - default: - return t - } - } -} - -// recv1 receives a message from the backend, panicking if an error occurs -// while attempting to read it. All asynchronous messages are ignored, with -// the exception of ErrorResponse. -func (cn *conn) recv1() (t byte, r *readBuf) { - r = &readBuf{} - t = cn.recv1Buf(r) - return t, r -} - -func (cn *conn) ssl(o values) error { - upgrade, err := ssl(o) - if err != nil { - return err - } - - if upgrade == nil { - // Nothing to do - return nil - } - - w := cn.writeBuf(0) - w.int32(80877103) - if err = cn.sendStartupPacket(w); err != nil { - return err - } - - b := cn.scratch[:1] - _, err = io.ReadFull(cn.c, b) - if err != nil { - return err - } - - if b[0] != 'S' { - return ErrSSLNotSupported - } - - cn.c, err = upgrade(cn.c) - return err -} - -// isDriverSetting returns true iff a setting is purely for configuring the -// driver's options and should not be sent to the server in the connection -// startup packet. -func isDriverSetting(key string) bool { - switch key { - case "host", "port": - return true - case "password": - return true - case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": - return true - case "fallback_application_name": - return true - case "connect_timeout": - return true - case "disable_prepared_binary_result": - return true - case "binary_parameters": - return true - case "krbsrvname": - return true - case "krbspn": - return true - default: - return false - } -} - -func (cn *conn) startup(o values) { - w := cn.writeBuf(0) - w.int32(196608) - // Send the backend the name of the database we want to connect to, and the - // user we want to connect as. Additionally, we send over any run-time - // parameters potentially included in the connection string. If the server - // doesn't recognize any of them, it will reply with an error. - for k, v := range o { - if isDriverSetting(k) { - // skip options which can't be run-time parameters - continue - } - // The protocol requires us to supply the database name as "database" - // instead of "dbname". - if k == "dbname" { - k = "database" - } - w.string(k) - w.string(v) - } - w.string("") - if err := cn.sendStartupPacket(w); err != nil { - panic(err) - } - - for { - t, r := cn.recv() - switch t { - case 'K': - cn.processBackendKeyData(r) - case 'S': - cn.processParameterStatus(r) - case 'R': - cn.auth(r, o) - case 'Z': - cn.processReadyForQuery(r) - return - default: - errorf("unknown response for startup: %q", t) - } - } -} - -func (cn *conn) auth(r *readBuf, o values) { - switch code := r.int32(); code { - case 0: - // OK - case 3: - w := cn.writeBuf('p') - w.string(o["password"]) - cn.send(w) - - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } - - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - case 5: - s := string(r.next(4)) - w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) - cn.send(w) - - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } - - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - case 7: // GSSAPI, startup - if newGss == nil { - errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") - } - cli, err := newGss() - if err != nil { - errorf("kerberos error: %s", err.Error()) - } - - var token []byte - - if spn, ok := o["krbspn"]; ok { - // Use the supplied SPN if provided.. - token, err = cli.GetInitTokenFromSpn(spn) - } else { - // Allow the kerberos service name to be overridden - service := "postgres" - if val, ok := o["krbsrvname"]; ok { - service = val - } - - token, err = cli.GetInitToken(o["host"], service) - } - - if err != nil { - errorf("failed to get Kerberos ticket: %q", err) - } - - w := cn.writeBuf('p') - w.bytes(token) - cn.send(w) - - // Store for GSSAPI continue message - cn.gss = cli - - case 8: // GSSAPI continue - - if cn.gss == nil { - errorf("GSSAPI protocol error") - } - - b := []byte(*r) - - done, tokOut, err := cn.gss.Continue(b) - if err == nil && !done { - w := cn.writeBuf('p') - w.bytes(tokOut) - cn.send(w) - } - - // Errors fall through and read the more detailed message - // from the server.. - - case 10: - sc := scram.NewClient(sha256.New, o["user"], o["password"]) - sc.Step(nil) - if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) - } - scOut := sc.Out() - - w := cn.writeBuf('p') - w.string("SCRAM-SHA-256") - w.int32(len(scOut)) - w.bytes(scOut) - cn.send(w) - - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } - - if r.int32() != 11 { - errorf("unexpected authentication response: %q", t) - } - - nextStep := r.next(len(*r)) - sc.Step(nextStep) - if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) - } - - scOut = sc.Out() - w = cn.writeBuf('p') - w.bytes(scOut) - cn.send(w) - - t, r = cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } - - if r.int32() != 12 { - errorf("unexpected authentication response: %q", t) - } - - nextStep = r.next(len(*r)) - sc.Step(nextStep) - if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) - } - - default: - errorf("unknown authentication response: %d", code) - } -} - -type format int - -const formatText format = 0 -const formatBinary format = 1 - -// One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary = []byte{0, 1, 0, 1} - -// No result-column format codes (i.e. all text). -var colFmtDataAllText = []byte{0, 0} - -type stmt struct { - cn *conn - name string - rowsHeader - colFmtData []byte - paramTyps []oid.Oid - closed bool -} - -func (st *stmt) Close() (err error) { - if st.closed { - return nil - } - if err := st.cn.err.get(); err != nil { - return err - } - defer st.cn.errRecover(&err) - - w := st.cn.writeBuf('C') - w.byte('S') - w.string(st.name) - st.cn.send(w) - - st.cn.send(st.cn.writeBuf('S')) - - t, _ := st.cn.recv1() - if t != '3' { - st.cn.err.set(driver.ErrBadConn) - errorf("unexpected close response: %q", t) - } - st.closed = true - - t, r := st.cn.recv1() - if t != 'Z' { - st.cn.err.set(driver.ErrBadConn) - errorf("expected ready for query, but got: %q", t) - } - st.cn.processReadyForQuery(r) - - return nil -} - -func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - return st.query(v) -} - -func (st *stmt) query(v []driver.Value) (r *rows, err error) { - if err := st.cn.err.get(); err != nil { - return nil, err - } - defer st.cn.errRecover(&err) - - st.exec(v) - return &rows{ - cn: st.cn, - rowsHeader: st.rowsHeader, - }, nil -} - -func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if err := st.cn.err.get(); err != nil { - return nil, err - } - defer st.cn.errRecover(&err) - - st.exec(v) - res, _, err = st.cn.readExecuteResponse("simple query") - return res, err -} - -func (st *stmt) exec(v []driver.Value) { - if len(v) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) - } - if len(v) != len(st.paramTyps) { - errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) - } - - cn := st.cn - w := cn.writeBuf('B') - w.byte(0) // unnamed portal - w.string(st.name) - - if cn.binaryParameters { - cn.sendBinaryParameters(w, v) - } else { - w.int16(0) - w.int16(len(v)) - for i, x := range v { - if x == nil { - w.int32(-1) - } else { - b := encode(&cn.parameterStatus, x, st.paramTyps[i]) - w.int32(len(b)) - w.bytes(b) - } - } - } - w.bytes(st.colFmtData) - - w.next('E') - w.byte(0) - w.int32(0) - - w.next('S') - cn.send(w) - - cn.readBindResponse() - cn.postExecuteWorkaround() - -} - -func (st *stmt) NumInput() int { - return len(st.paramTyps) -} - -// parseComplete parses the "command tag" from a CommandComplete message, and -// returns the number of rows affected (if applicable) and a string -// identifying only the command that was executed, e.g. "ALTER TABLE". If the -// command tag could not be parsed, parseComplete panics. -func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { - commandsWithAffectedRows := []string{ - "SELECT ", - // INSERT is handled below - "UPDATE ", - "DELETE ", - "FETCH ", - "MOVE ", - "COPY ", - } - - var affectedRows *string - for _, tag := range commandsWithAffectedRows { - if strings.HasPrefix(commandTag, tag) { - t := commandTag[len(tag):] - affectedRows = &t - commandTag = tag[:len(tag)-1] - break - } - } - // INSERT also includes the oid of the inserted row in its command tag. - // Oids in user tables are deprecated, and the oid is only returned when - // exactly one row is inserted, so it's unlikely to be of value to any - // real-world application and we can ignore it. - if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { - parts := strings.Split(commandTag, " ") - if len(parts) != 3 { - cn.err.set(driver.ErrBadConn) - errorf("unexpected INSERT command tag %s", commandTag) - } - affectedRows = &parts[len(parts)-1] - commandTag = "INSERT" - } - // There should be no affected rows attached to the tag, just return it - if affectedRows == nil { - return driver.RowsAffected(0), commandTag - } - n, err := strconv.ParseInt(*affectedRows, 10, 64) - if err != nil { - cn.err.set(driver.ErrBadConn) - errorf("could not parse commandTag: %s", err) - } - return driver.RowsAffected(n), commandTag -} - -type rowsHeader struct { - colNames []string - colTyps []fieldDesc - colFmts []format -} - -type rows struct { - cn *conn - finish func() - rowsHeader - done bool - rb readBuf - result driver.Result - tag string - - next *rowsHeader -} - -func (rs *rows) Close() error { - if finish := rs.finish; finish != nil { - defer finish() - } - // no need to look at cn.bad as Next() will - for { - err := rs.Next(nil) - switch err { - case nil: - case io.EOF: - // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row - // description, used with HasNextResultSet). We need to fetch messages until - // we hit a 'Z', which is done by waiting for done to be set. - if rs.done { - return nil - } - default: - return err - } - } -} - -func (rs *rows) Columns() []string { - return rs.colNames -} - -func (rs *rows) Result() driver.Result { - if rs.result == nil { - return emptyRows - } - return rs.result -} - -func (rs *rows) Tag() string { - return rs.tag -} - -func (rs *rows) Next(dest []driver.Value) (err error) { - if rs.done { - return io.EOF - } - - conn := rs.cn - if err := conn.err.getForNext(); err != nil { - return err - } - defer conn.errRecover(&err) - - for { - t := conn.recv1Buf(&rs.rb) - switch t { - case 'E': - err = parseError(&rs.rb) - case 'C', 'I': - if t == 'C' { - rs.result, rs.tag = conn.parseComplete(rs.rb.string()) - } - continue - case 'Z': - conn.processReadyForQuery(&rs.rb) - rs.done = true - if err != nil { - return err - } - return io.EOF - case 'D': - n := rs.rb.int16() - if err != nil { - conn.err.set(driver.ErrBadConn) - errorf("unexpected DataRow after error %s", err) - } - if n < len(dest) { - dest = dest[:n] - } - for i := range dest { - l := rs.rb.int32() - if l == -1 { - dest[i] = nil - continue - } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) - } - return - case 'T': - next := parsePortalRowDescribe(&rs.rb) - rs.next = &next - return io.EOF - default: - errorf("unexpected message after execute: %q", t) - } - } -} - -func (rs *rows) HasNextResultSet() bool { - hasNext := rs.next != nil && !rs.done - return hasNext -} - -func (rs *rows) NextResultSet() error { - if rs.next == nil { - return io.EOF - } - rs.rowsHeader = *rs.next - rs.next = nil - return nil -} - -// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be -// used as part of an SQL statement. For example: -// -// tblname := "my_table" -// data := "my_data" -// quoted := pq.QuoteIdentifier(tblname) -// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) -// -// Any double quotes in name will be escaped. The quoted identifier will be -// case sensitive when used in a query. If the input string contains a zero -// byte, the result will be truncated immediately before it. -func QuoteIdentifier(name string) string { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return `"` + strings.Replace(name, `"`, `""`, -1) + `"` -} - -// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a -// byte buffer. -func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - buffer.WriteRune('"') - buffer.WriteString(strings.Replace(name, `"`, `""`, -1)) - buffer.WriteRune('"') -} - -// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal -// to DDL and other statements that do not accept parameters) to be used as part -// of an SQL statement. For example: -// -// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") -// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) -// -// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be -// replaced by two backslashes (i.e. "\\") and the C-style escape identifier -// that PostgreSQL provides ('E') will be prepended to the string. -func QuoteLiteral(literal string) string { - // This follows the PostgreSQL internal algorithm for handling quoted literals - // from libpq, which can be found in the "PQEscapeStringInternal" function, - // which is found in the libpq/fe-exec.c source file: - // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c - // - // substitute any single-quotes (') with two single-quotes ('') - literal = strings.Replace(literal, `'`, `''`, -1) - // determine if the string has any backslashes (\) in it. - // if it does, replace any backslashes (\) with two backslashes (\\) - // then, we need to wrap the entire string with a PostgreSQL - // C-style escape. Per how "PQEscapeStringInternal" handles this case, we - // also add a space before the "E" - if strings.Contains(literal, `\`) { - literal = strings.Replace(literal, `\`, `\\`, -1) - literal = ` E'` + literal + `'` - } else { - // otherwise, we can just wrap the literal with a pair of single quotes - literal = `'` + literal + `'` - } - return literal -} - -func md5s(s string) string { - h := md5.New() - h.Write([]byte(s)) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { - // Do one pass over the parameters to see if we're going to send any of - // them over in binary. If we are, create a paramFormats array at the - // same time. - var paramFormats []int - for i, x := range args { - _, ok := x.([]byte) - if ok { - if paramFormats == nil { - paramFormats = make([]int, len(args)) - } - paramFormats[i] = 1 - } - } - if paramFormats == nil { - b.int16(0) - } else { - b.int16(len(paramFormats)) - for _, x := range paramFormats { - b.int16(x) - } - } - - b.int16(len(args)) - for _, x := range args { - if x == nil { - b.int32(-1) - } else { - datum := binaryEncode(&cn.parameterStatus, x) - b.int32(len(datum)) - b.bytes(datum) - } - } -} - -func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { - if len(args) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) - } - - b := cn.writeBuf('P') - b.byte(0) // unnamed statement - b.string(query) - b.int16(0) - - b.next('B') - b.int16(0) // unnamed portal and statement - cn.sendBinaryParameters(b, args) - b.bytes(colFmtDataAllText) - - b.next('D') - b.byte('P') - b.byte(0) // unnamed portal - - b.next('E') - b.byte(0) - b.int32(0) - - b.next('S') - cn.send(b) -} - -func (cn *conn) processParameterStatus(r *readBuf) { - var err error - - param := r.string() - switch param { - case "server_version": - var major1 int - var major2 int - _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) - if err == nil { - cn.parameterStatus.serverVersion = major1*10000 + major2*100 - } - - case "TimeZone": - cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) - if err != nil { - cn.parameterStatus.currentLocation = nil - } - - default: - // ignore - } -} - -func (cn *conn) processReadyForQuery(r *readBuf) { - cn.txnStatus = transactionStatus(r.byte()) -} - -func (cn *conn) readReadyForQuery() { - t, r := cn.recv1() - switch t { - case 'Z': - cn.processReadyForQuery(r) - return - default: - cn.err.set(driver.ErrBadConn) - errorf("unexpected message %q; expected ReadyForQuery", t) - } -} - -func (cn *conn) processBackendKeyData(r *readBuf) { - cn.processID = r.int32() - cn.secretKey = r.int32() -} - -func (cn *conn) readParseResponse() { - t, r := cn.recv1() - switch t { - case '1': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.err.set(driver.ErrBadConn) - errorf("unexpected Parse response %q", t) - } -} - -func (cn *conn) readStatementDescribeResponse() ( - paramTyps []oid.Oid, - colNames []string, - colTyps []fieldDesc, -) { - for { - t, r := cn.recv1() - switch t { - case 't': - nparams := r.int16() - paramTyps = make([]oid.Oid, nparams) - for i := range paramTyps { - paramTyps[i] = r.oid() - } - case 'n': - return paramTyps, nil, nil - case 'T': - colNames, colTyps = parseStatementRowDescribe(r) - return paramTyps, colNames, colTyps - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.err.set(driver.ErrBadConn) - errorf("unexpected Describe statement response %q", t) - } - } -} - -func (cn *conn) readPortalDescribeResponse() rowsHeader { - t, r := cn.recv1() - switch t { - case 'T': - return parsePortalRowDescribe(r) - case 'n': - return rowsHeader{} - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.err.set(driver.ErrBadConn) - errorf("unexpected Describe response %q", t) - } - panic("not reached") -} - -func (cn *conn) readBindResponse() { - t, r := cn.recv1() - switch t { - case '2': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.err.set(driver.ErrBadConn) - errorf("unexpected Bind response %q", t) - } -} - -func (cn *conn) postExecuteWorkaround() { - // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores - // any errors from rows.Next, which masks errors that happened during the - // execution of the query. To avoid the problem in common cases, we wait - // here for one more message from the database. If it's not an error the - // query will likely succeed (or perhaps has already, if it's a - // CommandComplete), so we push the message into the conn struct; recv1 - // will return it as the next message for rows.Next or rows.Close. - // However, if it's an error, we wait until ReadyForQuery and then return - // the error to our caller. - for { - t, r := cn.recv1() - switch t { - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - case 'C', 'D', 'I': - // the query didn't fail, but we can't process this message - cn.saveMessage(t, r) - return - default: - cn.err.set(driver.ErrBadConn) - errorf("unexpected message during extended query execution: %q", t) - } - } -} - -// Only for Exec(), since we ignore the returned data -func (cn *conn) readExecuteResponse( - protocolState string, -) (res driver.Result, commandTag string, err error) { - for { - t, r := cn.recv1() - switch t { - case 'C': - if err != nil { - cn.err.set(driver.ErrBadConn) - errorf("unexpected CommandComplete after error %s", err) - } - res, commandTag = cn.parseComplete(r.string()) - case 'Z': - cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady - } - return res, commandTag, err - case 'E': - err = parseError(r) - case 'T', 'D', 'I': - if err != nil { - cn.err.set(driver.ErrBadConn) - errorf("unexpected %q after error %s", t, err) - } - if t == 'I' { - res = emptyRows - } - // ignore any results - default: - cn.err.set(driver.ErrBadConn) - errorf("unknown %s response: %q", protocolState, t) - } - } -} - -func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { - n := r.int16() - colNames = make([]string, n) - colTyps = make([]fieldDesc, n) - for i := range colNames { - colNames[i] = r.string() - r.next(6) - colTyps[i].OID = r.oid() - colTyps[i].Len = r.int16() - colTyps[i].Mod = r.int32() - // format code not known when describing a statement; always 0 - r.next(2) - } - return -} - -func parsePortalRowDescribe(r *readBuf) rowsHeader { - n := r.int16() - colNames := make([]string, n) - colFmts := make([]format, n) - colTyps := make([]fieldDesc, n) - for i := range colNames { - colNames[i] = r.string() - r.next(6) - colTyps[i].OID = r.oid() - colTyps[i].Len = r.int16() - colTyps[i].Mod = r.int32() - colFmts[i] = format(r.int16()) - } - return rowsHeader{ - colNames: colNames, - colFmts: colFmts, - colTyps: colTyps, - } -} - -// parseEnviron tries to mimic some of libpq's environment handling -// -// To ease testing, it does not directly reference os.Environ, but is -// designed to accept its output. -// -// Environment-set connection information is intended to have a higher -// precedence than a library default but lower than any explicitly -// passed information (such as in the URL or connection string). -func parseEnviron(env []string) (out map[string]string) { - out = make(map[string]string) - - for _, v := range env { - parts := strings.SplitN(v, "=", 2) - - accrue := func(keyname string) { - out[keyname] = parts[1] - } - unsupported := func() { - panic(fmt.Sprintf("setting %v not supported", parts[0])) - } - - // The order of these is the same as is seen in the - // PostgreSQL 9.1 manual. Unsupported but well-defined - // keys cause a panic; these should be unset prior to - // execution. Options which pq expects to be set to a - // certain value are allowed, but must be set to that - // value if present (they can, of course, be absent). - switch parts[0] { - case "PGHOST": - accrue("host") - case "PGHOSTADDR": - unsupported() - case "PGPORT": - accrue("port") - case "PGDATABASE": - accrue("dbname") - case "PGUSER": - accrue("user") - case "PGPASSWORD": - accrue("password") - case "PGSERVICE", "PGSERVICEFILE", "PGREALM": - unsupported() - case "PGOPTIONS": - accrue("options") - case "PGAPPNAME": - accrue("application_name") - case "PGSSLMODE": - accrue("sslmode") - case "PGSSLCERT": - accrue("sslcert") - case "PGSSLKEY": - accrue("sslkey") - case "PGSSLROOTCERT": - accrue("sslrootcert") - case "PGSSLSNI": - accrue("sslsni") - case "PGREQUIRESSL", "PGSSLCRL": - unsupported() - case "PGREQUIREPEER": - unsupported() - case "PGKRBSRVNAME", "PGGSSLIB": - unsupported() - case "PGCONNECT_TIMEOUT": - accrue("connect_timeout") - case "PGCLIENTENCODING": - accrue("client_encoding") - case "PGDATESTYLE": - accrue("datestyle") - case "PGTZ": - accrue("timezone") - case "PGGEQO": - accrue("geqo") - case "PGSYSCONFDIR", "PGLOCALEDIR": - unsupported() - } - } - - return out -} - -// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". -func isUTF8(name string) bool { - // Recognize all sorts of silly things as "UTF-8", like Postgres does - s := strings.Map(alnumLowerASCII, name) - return s == "utf8" || s == "unicode" -} - -func alnumLowerASCII(ch rune) rune { - if 'A' <= ch && ch <= 'Z' { - return ch + ('a' - 'A') - } - if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { - return ch - } - return -1 // discard -} - -// The database/sql/driver package says: -// All Conn implementations should implement the following interfaces: Pinger, SessionResetter, and Validator. -var _ driver.Pinger = &conn{} -var _ driver.SessionResetter = &conn{} - -func (cn *conn) ResetSession(ctx context.Context) error { - // Ensure bad connections are reported: From database/sql/driver: - // If a connection is never returned to the connection pool but immediately reused, then - // ResetSession is called prior to reuse but IsValid is not called. - return cn.err.get() -} - -func (cn *conn) IsValid() bool { - return cn.err.get() == nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go115.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go115.go deleted file mode 100644 index f4ef030f9..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go115.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build go1.15 -// +build go1.15 - -package pq - -import "database/sql/driver" - -var _ driver.Validator = &conn{} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go18.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go18.go deleted file mode 100644 index 63d4ca6aa..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/conn_go18.go +++ /dev/null @@ -1,247 +0,0 @@ -package pq - -import ( - "context" - "database/sql" - "database/sql/driver" - "fmt" - "io" - "io/ioutil" - "time" -) - -const ( - watchCancelDialContextTimeout = time.Second * 10 -) - -// Implement the "QueryerContext" interface -func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - finish := cn.watchCancel(ctx) - r, err := cn.query(query, list) - if err != nil { - if finish != nil { - finish() - } - return nil, err - } - r.finish = finish - return r, nil -} - -// Implement the "ExecerContext" interface -func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - - return cn.Exec(query, list) -} - -// Implement the "ConnPrepareContext" interface -func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - return cn.Prepare(query) -} - -// Implement the "ConnBeginTx" interface -func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - var mode string - - switch sql.IsolationLevel(opts.Isolation) { - case sql.LevelDefault: - // Don't touch mode: use the server's default - case sql.LevelReadUncommitted: - mode = " ISOLATION LEVEL READ UNCOMMITTED" - case sql.LevelReadCommitted: - mode = " ISOLATION LEVEL READ COMMITTED" - case sql.LevelRepeatableRead: - mode = " ISOLATION LEVEL REPEATABLE READ" - case sql.LevelSerializable: - mode = " ISOLATION LEVEL SERIALIZABLE" - default: - return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) - } - - if opts.ReadOnly { - mode += " READ ONLY" - } else { - mode += " READ WRITE" - } - - tx, err := cn.begin(mode) - if err != nil { - return nil, err - } - cn.txnFinish = cn.watchCancel(ctx) - return tx, nil -} - -func (cn *conn) Ping(ctx context.Context) error { - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - rows, err := cn.simpleQuery(";") - if err != nil { - return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger - } - rows.Close() - return nil -} - -func (cn *conn) watchCancel(ctx context.Context) func() { - if done := ctx.Done(); done != nil { - finished := make(chan struct{}, 1) - go func() { - select { - case <-done: - select { - case finished <- struct{}{}: - default: - // We raced with the finish func, let the next query handle this with the - // context. - return - } - - // Set the connection state to bad so it does not get reused. - cn.err.set(ctx.Err()) - - // At this point the function level context is canceled, - // so it must not be used for the additional network - // request to cancel the query. - // Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) - defer cancel() - - _ = cn.cancel(ctxCancel) - case <-finished: - } - }() - return func() { - select { - case <-finished: - cn.err.set(ctx.Err()) - cn.Close() - case finished <- struct{}{}: - } - } - } - return nil -} - -func (cn *conn) cancel(ctx context.Context) error { - // Create a new values map (copy). This makes sure the connection created - // in this method cannot write to the same underlying data, which could - // cause a concurrent map write panic. This is necessary because cancel - // is called from a goroutine in watchCancel. - o := make(values) - for k, v := range cn.opts { - o[k] = v - } - - c, err := dial(ctx, cn.dialer, o) - if err != nil { - return err - } - defer c.Close() - - { - can := conn{ - c: c, - } - err = can.ssl(o) - if err != nil { - return err - } - - w := can.writeBuf(0) - w.int32(80877102) // cancel request code - w.int32(cn.processID) - w.int32(cn.secretKey) - - if err := can.sendStartupPacket(w); err != nil { - return err - } - } - - // Read until EOF to ensure that the server received the cancel. - { - _, err := io.Copy(ioutil.Discard, c) - return err - } -} - -// Implement the "StmtQueryContext" interface -func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - finish := st.watchCancel(ctx) - r, err := st.query(list) - if err != nil { - if finish != nil { - finish() - } - return nil, err - } - r.finish = finish - return r, nil -} - -// Implement the "StmtExecContext" interface -func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - - if finish := st.watchCancel(ctx); finish != nil { - defer finish() - } - - return st.Exec(list) -} - -// watchCancel is implemented on stmt in order to not mark the parent conn as bad -func (st *stmt) watchCancel(ctx context.Context) func() { - if done := ctx.Done(); done != nil { - finished := make(chan struct{}) - go func() { - select { - case <-done: - // At this point the function level context is canceled, - // so it must not be used for the additional network - // request to cancel the query. - // Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) - defer cancel() - - _ = st.cancel(ctxCancel) - finished <- struct{}{} - case <-finished: - } - }() - return func() { - select { - case <-finished: - case finished <- struct{}{}: - } - } - } - return nil -} - -func (st *stmt) cancel(ctx context.Context) error { - return st.cn.cancel(ctx) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/connector.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/connector.go deleted file mode 100644 index 1145e1225..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/connector.go +++ /dev/null @@ -1,120 +0,0 @@ -package pq - -import ( - "context" - "database/sql/driver" - "errors" - "fmt" - "os" - "strings" -) - -// Connector represents a fixed configuration for the pq driver with a given -// name. Connector satisfies the database/sql/driver Connector interface and -// can be used to create any number of DB Conn's via the database/sql OpenDB -// function. -// -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. -type Connector struct { - opts values - dialer Dialer -} - -// Connect returns a connection to the database using the fixed configuration -// of this Connector. Context is not used. -func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { - return c.open(ctx) -} - -// Dialer allows change the dialer used to open connections. -func (c *Connector) Dialer(dialer Dialer) { - c.dialer = dialer -} - -// Driver returns the underlying driver of this Connector. -func (c *Connector) Driver() driver.Driver { - return &Driver{} -} - -// NewConnector returns a connector for the pq driver in a fixed configuration -// with the given dsn. The returned connector can be used to create any number -// of equivalent Conn's. The returned connector is intended to be used with -// database/sql.OpenDB. -// -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. -func NewConnector(dsn string) (*Connector, error) { - var err error - o := make(values) - - // A number of defaults are applied here, in this order: - // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o["host"] = "localhost" - o["port"] = "5432" - // N.B.: Extra float digits should be set to 3, but that breaks - // Postgres 8.4 and older, where the max is 2. - o["extra_float_digits"] = "2" - for k, v := range parseEnviron(os.Environ()) { - o[k] = v - } - - if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { - dsn, err = ParseURL(dsn) - if err != nil { - return nil, err - } - } - - if err := parseOpts(dsn, o); err != nil { - return nil, err - } - - // Use the "fallback" application name if necessary - if fallback, ok := o["fallback_application_name"]; ok { - if _, ok := o["application_name"]; !ok { - o["application_name"] = fallback - } - } - - // We can't work with any client_encoding other than UTF-8 currently. - // However, we have historically allowed the user to set it to UTF-8 - // explicitly, and there's no reason to break such programs, so allow that. - // Note that the "options" setting could also set client_encoding, but - // parsing its value is not worth it. Instead, we always explicitly send - // client_encoding as a separate run-time parameter, which should override - // anything set in options. - if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") - } - o["client_encoding"] = "UTF8" - // DateStyle needs a similar treatment. - if datestyle, ok := o["datestyle"]; ok { - if datestyle != "ISO, MDY" { - return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) - } - } else { - o["datestyle"] = "ISO, MDY" - } - - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if _, ok := o["user"]; !ok { - u, err := userCurrent() - if err != nil { - return nil, err - } - o["user"] = u - } - - // SSL is not necessary or supported over UNIX domain sockets - if network, _ := network(o); network == "unix" { - o["sslmode"] = "disable" - } - - return &Connector{opts: o, dialer: defaultDialer{}}, nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/copy.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/copy.go deleted file mode 100644 index a8f16b2b2..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/copy.go +++ /dev/null @@ -1,348 +0,0 @@ -package pq - -import ( - "bytes" - "context" - "database/sql/driver" - "encoding/binary" - "errors" - "fmt" - "sync" -) - -var ( - errCopyInClosed = errors.New("pq: copyin statement has already been closed") - errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") - errCopyToNotSupported = errors.New("pq: COPY TO is not supported") - errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") - errCopyInProgress = errors.New("pq: COPY in progress") -) - -// CopyIn creates a COPY FROM statement which can be prepared with -// Tx.Prepare(). The target table should be visible in search_path. -func CopyIn(table string, columns ...string) string { - buffer := bytes.NewBufferString("COPY ") - BufferQuoteIdentifier(table, buffer) - buffer.WriteString(" (") - makeStmt(buffer, columns...) - return buffer.String() -} - -// MakeStmt makes the stmt string for CopyIn and CopyInSchema. -func makeStmt(buffer *bytes.Buffer, columns ...string) { - //s := bytes.NewBufferString() - for i, col := range columns { - if i != 0 { - buffer.WriteString(", ") - } - BufferQuoteIdentifier(col, buffer) - } - buffer.WriteString(") FROM STDIN") -} - -// CopyInSchema creates a COPY FROM statement which can be prepared with -// Tx.Prepare(). -func CopyInSchema(schema, table string, columns ...string) string { - buffer := bytes.NewBufferString("COPY ") - BufferQuoteIdentifier(schema, buffer) - buffer.WriteRune('.') - BufferQuoteIdentifier(table, buffer) - buffer.WriteString(" (") - makeStmt(buffer, columns...) - return buffer.String() -} - -type copyin struct { - cn *conn - buffer []byte - rowData chan []byte - done chan bool - - closed bool - - mu struct { - sync.Mutex - err error - driver.Result - } -} - -const ciBufferSize = 64 * 1024 - -// flush buffer before the buffer is filled up and needs reallocation -const ciBufferFlushSize = 63 * 1024 - -func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { - if !cn.isInTransaction() { - return nil, errCopyNotSupportedOutsideTxn - } - - ci := ©in{ - cn: cn, - buffer: make([]byte, 0, ciBufferSize), - rowData: make(chan []byte), - done: make(chan bool, 1), - } - // add CopyData identifier + 4 bytes for message length - ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0) - - b := cn.writeBuf('Q') - b.string(q) - cn.send(b) - -awaitCopyInResponse: - for { - t, r := cn.recv1() - switch t { - case 'G': - if r.byte() != 0 { - err = errBinaryCopyNotSupported - break awaitCopyInResponse - } - go ci.resploop() - return ci, nil - case 'H': - err = errCopyToNotSupported - break awaitCopyInResponse - case 'E': - err = parseError(r) - case 'Z': - if err == nil { - ci.setBad(driver.ErrBadConn) - errorf("unexpected ReadyForQuery in response to COPY") - } - cn.processReadyForQuery(r) - return nil, err - default: - ci.setBad(driver.ErrBadConn) - errorf("unknown response for copy query: %q", t) - } - } - - // something went wrong, abort COPY before we return - b = cn.writeBuf('f') - b.string(err.Error()) - cn.send(b) - - for { - t, r := cn.recv1() - switch t { - case 'c', 'C', 'E': - case 'Z': - // correctly aborted, we're done - cn.processReadyForQuery(r) - return nil, err - default: - ci.setBad(driver.ErrBadConn) - errorf("unknown response for CopyFail: %q", t) - } - } -} - -func (ci *copyin) flush(buf []byte) { - // set message length (without message identifier) - binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) - - _, err := ci.cn.c.Write(buf) - if err != nil { - panic(err) - } -} - -func (ci *copyin) resploop() { - for { - var r readBuf - t, err := ci.cn.recvMessage(&r) - if err != nil { - ci.setBad(driver.ErrBadConn) - ci.setError(err) - ci.done <- true - return - } - switch t { - case 'C': - // complete - res, _ := ci.cn.parseComplete(r.string()) - ci.setResult(res) - case 'N': - if n := ci.cn.noticeHandler; n != nil { - n(parseError(&r)) - } - case 'Z': - ci.cn.processReadyForQuery(&r) - ci.done <- true - return - case 'E': - err := parseError(&r) - ci.setError(err) - default: - ci.setBad(driver.ErrBadConn) - ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) - ci.done <- true - return - } - } -} - -func (ci *copyin) setBad(err error) { - ci.cn.err.set(err) -} - -func (ci *copyin) getBad() error { - return ci.cn.err.get() -} - -func (ci *copyin) err() error { - ci.mu.Lock() - err := ci.mu.err - ci.mu.Unlock() - return err -} - -// setError() sets ci.err if one has not been set already. Caller must not be -// holding ci.Mutex. -func (ci *copyin) setError(err error) { - ci.mu.Lock() - if ci.mu.err == nil { - ci.mu.err = err - } - ci.mu.Unlock() -} - -func (ci *copyin) setResult(result driver.Result) { - ci.mu.Lock() - ci.mu.Result = result - ci.mu.Unlock() -} - -func (ci *copyin) getResult() driver.Result { - ci.mu.Lock() - result := ci.mu.Result - ci.mu.Unlock() - if result == nil { - return driver.RowsAffected(0) - } - return result -} - -func (ci *copyin) NumInput() int { - return -1 -} - -func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { - return nil, ErrNotSupported -} - -// Exec inserts values into the COPY stream. The insert is asynchronous -// and Exec can return errors from previous Exec calls to the same -// COPY stmt. -// -// You need to call Exec(nil) to sync the COPY stream and to get any -// errors from pending data, since Stmt.Close() doesn't return errors -// to the user. -func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { - if ci.closed { - return nil, errCopyInClosed - } - - if err := ci.getBad(); err != nil { - return nil, err - } - defer ci.cn.errRecover(&err) - - if err := ci.err(); err != nil { - return nil, err - } - - if len(v) == 0 { - if err := ci.Close(); err != nil { - return driver.RowsAffected(0), err - } - - return ci.getResult(), nil - } - - numValues := len(v) - for i, value := range v { - ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value) - if i < numValues-1 { - ci.buffer = append(ci.buffer, '\t') - } - } - - ci.buffer = append(ci.buffer, '\n') - - if len(ci.buffer) > ciBufferFlushSize { - ci.flush(ci.buffer) - // reset buffer, keep bytes for message identifier and length - ci.buffer = ci.buffer[:5] - } - - return driver.RowsAffected(0), nil -} - -// CopyData inserts a raw string into the COPY stream. The insert is -// asynchronous and CopyData can return errors from previous CopyData calls to -// the same COPY stmt. -// -// You need to call Exec(nil) to sync the COPY stream and to get any -// errors from pending data, since Stmt.Close() doesn't return errors -// to the user. -func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) { - if ci.closed { - return nil, errCopyInClosed - } - - if finish := ci.cn.watchCancel(ctx); finish != nil { - defer finish() - } - - if err := ci.getBad(); err != nil { - return nil, err - } - defer ci.cn.errRecover(&err) - - if err := ci.err(); err != nil { - return nil, err - } - - ci.buffer = append(ci.buffer, []byte(line)...) - ci.buffer = append(ci.buffer, '\n') - - if len(ci.buffer) > ciBufferFlushSize { - ci.flush(ci.buffer) - // reset buffer, keep bytes for message identifier and length - ci.buffer = ci.buffer[:5] - } - - return driver.RowsAffected(0), nil -} - -func (ci *copyin) Close() (err error) { - if ci.closed { // Don't do anything, we're already closed - return nil - } - ci.closed = true - - if err := ci.getBad(); err != nil { - return err - } - defer ci.cn.errRecover(&err) - - if len(ci.buffer) > 0 { - ci.flush(ci.buffer) - } - // Avoid touching the scratch buffer as resploop could be using it. - err = ci.cn.sendSimpleMessage('c') - if err != nil { - return err - } - - <-ci.done - ci.cn.inCopy = false - - if err := ci.err(); err != nil { - return err - } - return nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/doc.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/doc.go deleted file mode 100644 index b57184801..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/doc.go +++ /dev/null @@ -1,268 +0,0 @@ -/* -Package pq is a pure Go Postgres driver for the database/sql package. - -In most cases clients will use the database/sql package instead of -using this package directly. For example: - - import ( - "database/sql" - - _ "github.com/lib/pq" - ) - - func main() { - connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full" - db, err := sql.Open("postgres", connStr) - if err != nil { - log.Fatal(err) - } - - age := 21 - rows, err := db.Query("SELECT name FROM users WHERE age = $1", age) - … - } - -You can also connect to a database using a URL. For example: - - connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full" - db, err := sql.Open("postgres", connStr) - - -Connection String Parameters - - -Similarly to libpq, when establishing a connection using pq you are expected to -supply a connection string containing zero or more parameters. -A subset of the connection parameters supported by libpq are also supported by pq. -Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem) -directly in the connection string. This is different from libpq, which does not allow -run-time parameters in the connection string, instead requiring you to supply -them in the options parameter. - -For compatibility with libpq, the following special connection parameters are -supported: - - * dbname - The name of the database to connect to - * user - The user to sign in as - * password - The user's password - * host - The host to connect to. Values that start with / are for unix - domain sockets. (default is localhost) - * port - The port to bind to. (default is 5432) - * sslmode - Whether or not to use SSL (default is require, this is not - the default for libpq) - * fallback_application_name - An application_name to fall back to if one isn't provided. - * connect_timeout - Maximum wait for connection, in seconds. Zero or - not specified means wait indefinitely. - * sslcert - Cert file location. The file must contain PEM encoded data. - * sslkey - Key file location. The file must contain PEM encoded data. - * sslrootcert - The location of the root certificate file. The file - must contain PEM encoded data. - -Valid values for sslmode are: - - * disable - No SSL - * require - Always SSL (skip verification) - * verify-ca - Always SSL (verify that the certificate presented by the - server was signed by a trusted CA) - * verify-full - Always SSL (verify that the certification presented by - the server was signed by a trusted CA and the server host name - matches the one in the certificate) - -See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING -for more information about connection string parameters. - -Use single quotes for values that contain whitespace: - - "user=pqgotest password='with spaces'" - -A backslash will escape the next character in values: - - "user=space\ man password='it\'s valid'" - -Note that the connection parameter client_encoding (which sets the -text encoding for the connection) may be set but must be "UTF8", -matching with the same rules as Postgres. It is an error to provide -any other value. - -In addition to the parameters listed above, any run-time parameter that can be -set at backend start time can be set in the connection string. For more -information, see -http://www.postgresql.org/docs/current/static/runtime-config.html. - -Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html -supported by libpq are also supported by pq. If any of the environment -variables not supported by pq are set, pq will panic during connection -establishment. Environment variables have a lower precedence than explicitly -provided connection parameters. - -The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html -is supported, but on Windows PGPASSFILE must be specified explicitly. - - -Queries - - -database/sql does not dictate any specific format for parameter -markers in query strings, and pq uses the Postgres-native ordinal markers, -as shown above. The same marker can be reused for the same parameter: - - rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1 - OR age BETWEEN $2 AND $2 + 3`, "orange", 64) - -pq does not support the LastInsertId() method of the Result type in database/sql. -To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres -RETURNING clause with a standard Query or QueryRow call: - - var userid int - err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age) - VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid) - -For more details on RETURNING, see the Postgres documentation: - - http://www.postgresql.org/docs/current/static/sql-insert.html - http://www.postgresql.org/docs/current/static/sql-update.html - http://www.postgresql.org/docs/current/static/sql-delete.html - -For additional instructions on querying see the documentation for the database/sql package. - - -Data Types - - -Parameters pass through driver.DefaultParameterConverter before they are handled -by this package. When the binary_parameters connection option is enabled, -[]byte values are sent directly to the backend as data in binary format. - -This package returns the following types for values from the PostgreSQL backend: - - - integer types smallint, integer, and bigint are returned as int64 - - floating-point types real and double precision are returned as float64 - - character types char, varchar, and text are returned as string - - temporal types date, time, timetz, timestamp, and timestamptz are - returned as time.Time - - the boolean type is returned as bool - - the bytea type is returned as []byte - -All other types are returned directly from the backend as []byte values in text format. - - -Errors - - -pq may return errors of type *pq.Error which can be interrogated for error details: - - if err, ok := err.(*pq.Error); ok { - fmt.Println("pq error:", err.Code.Name()) - } - -See the pq.Error type for details. - - -Bulk imports - -You can perform bulk imports by preparing a statement returned by pq.CopyIn (or -pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement -handle can then be repeatedly "executed" to copy data into the target table. -After all data has been processed you should call Exec() once with no arguments -to flush all buffered data. Any call to Exec() might return an error which -should be handled appropriately, but because of the internal buffering an error -returned by Exec() might not be related to the data passed in the call that -failed. - -CopyIn uses COPY FROM internally. It is not possible to COPY outside of an -explicit transaction in pq. - -Usage example: - - txn, err := db.Begin() - if err != nil { - log.Fatal(err) - } - - stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age")) - if err != nil { - log.Fatal(err) - } - - for _, user := range users { - _, err = stmt.Exec(user.Name, int64(user.Age)) - if err != nil { - log.Fatal(err) - } - } - - _, err = stmt.Exec() - if err != nil { - log.Fatal(err) - } - - err = stmt.Close() - if err != nil { - log.Fatal(err) - } - - err = txn.Commit() - if err != nil { - log.Fatal(err) - } - - -Notifications - - -PostgreSQL supports a simple publish/subscribe model over database -connections. See http://www.postgresql.org/docs/current/static/sql-notify.html -for more information about the general mechanism. - -To start listening for notifications, you first have to open a new connection -to the database by calling NewListener. This connection can not be used for -anything other than LISTEN / NOTIFY. Calling Listen will open a "notification -channel"; once a notification channel is open, a notification generated on that -channel will effect a send on the Listener.Notify channel. A notification -channel will remain open until Unlisten is called, though connection loss might -result in some notifications being lost. To solve this problem, Listener sends -a nil pointer over the Notify channel any time the connection is re-established -following a connection loss. The application can get information about the -state of the underlying connection by setting an event callback in the call to -NewListener. - -A single Listener can safely be used from concurrent goroutines, which means -that there is often no need to create more than one Listener in your -application. However, a Listener is always connected to a single database, so -you will need to create a new Listener instance for every database you want to -receive notifications in. - -The channel name in both Listen and Unlisten is case sensitive, and can contain -any characters legal in an identifier (see -http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -for more information). Note that the channel name will be truncated to 63 -bytes by the PostgreSQL server. - -You can find a complete, working example of Listener usage at -https://godoc.org/github.com/lib/pq/example/listen. - - -Kerberos Support - - -If you need support for Kerberos authentication, add the following to your main -package: - - import "github.com/lib/pq/auth/kerberos" - - func init() { - pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) - } - -This package is in a separate module so that users who don't need Kerberos -don't have to download unnecessary dependencies. - -When imported, additional connection string parameters are supported: - - * krbsrvname - GSS (Kerberos) service name when constructing the - SPN (default is `postgres`). This will be combined with the host - to form the full SPN: `krbsrvname/host`. - * krbspn - GSS (Kerberos) SPN. This takes priority over - `krbsrvname` if present. -*/ -package pq diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/encode.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/encode.go deleted file mode 100644 index bffe6096a..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/encode.go +++ /dev/null @@ -1,632 +0,0 @@ -package pq - -import ( - "bytes" - "database/sql/driver" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "math" - "regexp" - "strconv" - "strings" - "sync" - "time" - - "github.com/lib/pq/oid" -) - -var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`) - -func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { - switch v := x.(type) { - case []byte: - return v - default: - return encode(parameterStatus, x, oid.T_unknown) - } -} - -func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { - switch v := x.(type) { - case int64: - return strconv.AppendInt(nil, v, 10) - case float64: - return strconv.AppendFloat(nil, v, 'f', -1, 64) - case []byte: - if pgtypOid == oid.T_bytea { - return encodeBytea(parameterStatus.serverVersion, v) - } - - return v - case string: - if pgtypOid == oid.T_bytea { - return encodeBytea(parameterStatus.serverVersion, []byte(v)) - } - - return []byte(v) - case bool: - return strconv.AppendBool(nil, v) - case time.Time: - return formatTs(v) - - default: - errorf("encode: unknown type for %T", v) - } - - panic("not reached") -} - -func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { - switch f { - case formatBinary: - return binaryDecode(parameterStatus, s, typ) - case formatText: - return textDecode(parameterStatus, s, typ) - default: - panic("not reached") - } -} - -func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { - switch typ { - case oid.T_bytea: - return s - case oid.T_int8: - return int64(binary.BigEndian.Uint64(s)) - case oid.T_int4: - return int64(int32(binary.BigEndian.Uint32(s))) - case oid.T_int2: - return int64(int16(binary.BigEndian.Uint16(s))) - case oid.T_uuid: - b, err := decodeUUIDBinary(s) - if err != nil { - panic(err) - } - return b - - default: - errorf("don't know how to decode binary parameter of type %d", uint32(typ)) - } - - panic("not reached") -} - -func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { - switch typ { - case oid.T_char, oid.T_varchar, oid.T_text: - return string(s) - case oid.T_bytea: - b, err := parseBytea(s) - if err != nil { - errorf("%s", err) - } - return b - case oid.T_timestamptz: - return parseTs(parameterStatus.currentLocation, string(s)) - case oid.T_timestamp, oid.T_date: - return parseTs(nil, string(s)) - case oid.T_time: - return mustParse("15:04:05", typ, s) - case oid.T_timetz: - return mustParse("15:04:05-07", typ, s) - case oid.T_bool: - return s[0] == 't' - case oid.T_int8, oid.T_int4, oid.T_int2: - i, err := strconv.ParseInt(string(s), 10, 64) - if err != nil { - errorf("%s", err) - } - return i - case oid.T_float4, oid.T_float8: - // We always use 64 bit parsing, regardless of whether the input text is for - // a float4 or float8, because clients expect float64s for all float datatypes - // and returning a 32-bit parsed float64 produces lossy results. - f, err := strconv.ParseFloat(string(s), 64) - if err != nil { - errorf("%s", err) - } - return f - } - - return s -} - -// appendEncodedText encodes item in text format as required by COPY -// and appends to buf -func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte { - switch v := x.(type) { - case int64: - return strconv.AppendInt(buf, v, 10) - case float64: - return strconv.AppendFloat(buf, v, 'f', -1, 64) - case []byte: - encodedBytea := encodeBytea(parameterStatus.serverVersion, v) - return appendEscapedText(buf, string(encodedBytea)) - case string: - return appendEscapedText(buf, v) - case bool: - return strconv.AppendBool(buf, v) - case time.Time: - return append(buf, formatTs(v)...) - case nil: - return append(buf, "\\N"...) - default: - errorf("encode: unknown type for %T", v) - } - - panic("not reached") -} - -func appendEscapedText(buf []byte, text string) []byte { - escapeNeeded := false - startPos := 0 - var c byte - - // check if we need to escape - for i := 0; i < len(text); i++ { - c = text[i] - if c == '\\' || c == '\n' || c == '\r' || c == '\t' { - escapeNeeded = true - startPos = i - break - } - } - if !escapeNeeded { - return append(buf, text...) - } - - // copy till first char to escape, iterate the rest - result := append(buf, text[:startPos]...) - for i := startPos; i < len(text); i++ { - c = text[i] - switch c { - case '\\': - result = append(result, '\\', '\\') - case '\n': - result = append(result, '\\', 'n') - case '\r': - result = append(result, '\\', 'r') - case '\t': - result = append(result, '\\', 't') - default: - result = append(result, c) - } - } - return result -} - -func mustParse(f string, typ oid.Oid, s []byte) time.Time { - str := string(s) - - // Check for a minute and second offset in the timezone. - if typ == oid.T_timestamptz || typ == oid.T_timetz { - for i := 3; i <= 6; i += 3 { - if str[len(str)-i] == ':' { - f += ":00" - continue - } - break - } - } - - // Special case for 24:00 time. - // Unfortunately, golang does not parse 24:00 as a proper time. - // In this case, we want to try "round to the next day", to differentiate. - // As such, we find if the 24:00 time matches at the beginning; if so, - // we default it back to 00:00 but add a day later. - var is2400Time bool - switch typ { - case oid.T_timetz, oid.T_time: - if matches := time2400Regex.FindStringSubmatch(str); matches != nil { - // Concatenate timezone information at the back. - str = "00:00:00" + str[len(matches[1]):] - is2400Time = true - } - } - t, err := time.Parse(f, str) - if err != nil { - errorf("decode: %s", err) - } - if is2400Time { - t = t.Add(24 * time.Hour) - } - return t -} - -var errInvalidTimestamp = errors.New("invalid timestamp") - -type timestampParser struct { - err error -} - -func (p *timestampParser) expect(str string, char byte, pos int) { - if p.err != nil { - return - } - if pos+1 > len(str) { - p.err = errInvalidTimestamp - return - } - if c := str[pos]; c != char && p.err == nil { - p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) - } -} - -func (p *timestampParser) mustAtoi(str string, begin int, end int) int { - if p.err != nil { - return 0 - } - if begin < 0 || end < 0 || begin > end || end > len(str) { - p.err = errInvalidTimestamp - return 0 - } - result, err := strconv.Atoi(str[begin:end]) - if err != nil { - if p.err == nil { - p.err = fmt.Errorf("expected number; got '%v'", str) - } - return 0 - } - return result -} - -// The location cache caches the time zones typically used by the client. -type locationCache struct { - cache map[int]*time.Location - lock sync.Mutex -} - -// All connections share the same list of timezones. Benchmarking shows that -// about 5% speed could be gained by putting the cache in the connection and -// losing the mutex, at the cost of a small amount of memory and a somewhat -// significant increase in code complexity. -var globalLocationCache = newLocationCache() - -func newLocationCache() *locationCache { - return &locationCache{cache: make(map[int]*time.Location)} -} - -// Returns the cached timezone for the specified offset, creating and caching -// it if necessary. -func (c *locationCache) getLocation(offset int) *time.Location { - c.lock.Lock() - defer c.lock.Unlock() - - location, ok := c.cache[offset] - if !ok { - location = time.FixedZone("", offset) - c.cache[offset] = location - } - - return location -} - -var infinityTsEnabled = false -var infinityTsNegative time.Time -var infinityTsPositive time.Time - -const ( - infinityTsEnabledAlready = "pq: infinity timestamp enabled already" - infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" -) - -// EnableInfinityTs controls the handling of Postgres' "-infinity" and -// "infinity" "timestamp"s. -// -// If EnableInfinityTs is not called, "-infinity" and "infinity" will return -// []byte("-infinity") and []byte("infinity") respectively, and potentially -// cause error "sql: Scan error on column index 0: unsupported driver -> Scan -// pair: []uint8 -> *time.Time", when scanning into a time.Time value. -// -// Once EnableInfinityTs has been called, all connections created using this -// driver will decode Postgres' "-infinity" and "infinity" for "timestamp", -// "timestamp with time zone" and "date" types to the predefined minimum and -// maximum times, respectively. When encoding time.Time values, any time which -// equals or precedes the predefined minimum time will be encoded to -// "-infinity". Any values at or past the maximum time will similarly be -// encoded to "infinity". -// -// If EnableInfinityTs is called with negative >= positive, it will panic. -// Calling EnableInfinityTs after a connection has been established results in -// undefined behavior. If EnableInfinityTs is called more than once, it will -// panic. -func EnableInfinityTs(negative time.Time, positive time.Time) { - if infinityTsEnabled { - panic(infinityTsEnabledAlready) - } - if !negative.Before(positive) { - panic(infinityTsNegativeMustBeSmaller) - } - infinityTsEnabled = true - infinityTsNegative = negative - infinityTsPositive = positive -} - -/* - * Testing might want to toggle infinityTsEnabled - */ -func disableInfinityTs() { - infinityTsEnabled = false -} - -// This is a time function specific to the Postgres default DateStyle -// setting ("ISO, MDY"), the only one we currently support. This -// accounts for the discrepancies between the parsing available with -// time.Parse and the Postgres date formatting quirks. -func parseTs(currentLocation *time.Location, str string) interface{} { - switch str { - case "-infinity": - if infinityTsEnabled { - return infinityTsNegative - } - return []byte(str) - case "infinity": - if infinityTsEnabled { - return infinityTsPositive - } - return []byte(str) - } - t, err := ParseTimestamp(currentLocation, str) - if err != nil { - panic(err) - } - return t -} - -// ParseTimestamp parses Postgres' text format. It returns a time.Time in -// currentLocation iff that time's offset agrees with the offset sent from the -// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the -// fixed offset offset provided by the Postgres server. -func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { - p := timestampParser{} - - monSep := strings.IndexRune(str, '-') - // this is Gregorian year, not ISO Year - // In Gregorian system, the year 1 BC is followed by AD 1 - year := p.mustAtoi(str, 0, monSep) - daySep := monSep + 3 - month := p.mustAtoi(str, monSep+1, daySep) - p.expect(str, '-', daySep) - timeSep := daySep + 3 - day := p.mustAtoi(str, daySep+1, timeSep) - - minLen := monSep + len("01-01") + 1 - - isBC := strings.HasSuffix(str, " BC") - if isBC { - minLen += 3 - } - - var hour, minute, second int - if len(str) > minLen { - p.expect(str, ' ', timeSep) - minSep := timeSep + 3 - p.expect(str, ':', minSep) - hour = p.mustAtoi(str, timeSep+1, minSep) - secSep := minSep + 3 - p.expect(str, ':', secSep) - minute = p.mustAtoi(str, minSep+1, secSep) - secEnd := secSep + 3 - second = p.mustAtoi(str, secSep+1, secEnd) - } - remainderIdx := monSep + len("01-01 00:00:00") + 1 - // Three optional (but ordered) sections follow: the - // fractional seconds, the time zone offset, and the BC - // designation. We set them up here and adjust the other - // offsets if the preceding sections exist. - - nanoSec := 0 - tzOff := 0 - - if remainderIdx < len(str) && str[remainderIdx] == '.' { - fracStart := remainderIdx + 1 - fracOff := strings.IndexAny(str[fracStart:], "-+Z ") - if fracOff < 0 { - fracOff = len(str) - fracStart - } - fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) - nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) - - remainderIdx += fracOff + 1 - } - if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { - // time zone separator is always '-' or '+' or 'Z' (UTC is +00) - var tzSign int - switch c := str[tzStart]; c { - case '-': - tzSign = -1 - case '+': - tzSign = +1 - default: - return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) - } - tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) - remainderIdx += 3 - var tzMin, tzSec int - if remainderIdx < len(str) && str[remainderIdx] == ':' { - tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) - remainderIdx += 3 - } - if remainderIdx < len(str) && str[remainderIdx] == ':' { - tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) - remainderIdx += 3 - } - tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) - } else if tzStart < len(str) && str[tzStart] == 'Z' { - // time zone Z separator indicates UTC is +00 - remainderIdx += 1 - } - - var isoYear int - - if isBC { - isoYear = 1 - year - remainderIdx += 3 - } else { - isoYear = year - } - if remainderIdx < len(str) { - return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) - } - t := time.Date(isoYear, time.Month(month), day, - hour, minute, second, nanoSec, - globalLocationCache.getLocation(tzOff)) - - if currentLocation != nil { - // Set the location of the returned Time based on the session's - // TimeZone value, but only if the local time zone database agrees with - // the remote database on the offset. - lt := t.In(currentLocation) - _, newOff := lt.Zone() - if newOff == tzOff { - t = lt - } - } - - return t, p.err -} - -// formatTs formats t into a format postgres understands. -func formatTs(t time.Time) []byte { - if infinityTsEnabled { - // t <= -infinity : ! (t > -infinity) - if !t.After(infinityTsNegative) { - return []byte("-infinity") - } - // t >= infinity : ! (!t < infinity) - if !t.Before(infinityTsPositive) { - return []byte("infinity") - } - } - return FormatTimestamp(t) -} - -// FormatTimestamp formats t into Postgres' text format for timestamps. -func FormatTimestamp(t time.Time) []byte { - // Need to send dates before 0001 A.D. with " BC" suffix, instead of the - // minus sign preferred by Go. - // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on - bc := false - if t.Year() <= 0 { - // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" - t = t.AddDate((-t.Year())*2+1, 0, 0) - bc = true - } - b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) - - _, offset := t.Zone() - offset %= 60 - if offset != 0 { - // RFC3339Nano already printed the minus sign - if offset < 0 { - offset = -offset - } - - b = append(b, ':') - if offset < 10 { - b = append(b, '0') - } - b = strconv.AppendInt(b, int64(offset), 10) - } - - if bc { - b = append(b, " BC"...) - } - return b -} - -// Parse a bytea value received from the server. Both "hex" and the legacy -// "escape" format are supported. -func parseBytea(s []byte) (result []byte, err error) { - if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { - // bytea_output = hex - s = s[2:] // trim off leading "\\x" - result = make([]byte, hex.DecodedLen(len(s))) - _, err := hex.Decode(result, s) - if err != nil { - return nil, err - } - } else { - // bytea_output = escape - for len(s) > 0 { - if s[0] == '\\' { - // escaped '\\' - if len(s) >= 2 && s[1] == '\\' { - result = append(result, '\\') - s = s[2:] - continue - } - - // '\\' followed by an octal number - if len(s) < 4 { - return nil, fmt.Errorf("invalid bytea sequence %v", s) - } - r, err := strconv.ParseUint(string(s[1:4]), 8, 8) - if err != nil { - return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) - } - result = append(result, byte(r)) - s = s[4:] - } else { - // We hit an unescaped, raw byte. Try to read in as many as - // possible in one go. - i := bytes.IndexByte(s, '\\') - if i == -1 { - result = append(result, s...) - break - } - result = append(result, s[:i]...) - s = s[i:] - } - } - } - - return result, nil -} - -func encodeBytea(serverVersion int, v []byte) (result []byte) { - if serverVersion >= 90000 { - // Use the hex format if we know that the server supports it - result = make([]byte, 2+hex.EncodedLen(len(v))) - result[0] = '\\' - result[1] = 'x' - hex.Encode(result[2:], v) - } else { - // .. or resort to "escape" - for _, b := range v { - if b == '\\' { - result = append(result, '\\', '\\') - } else if b < 0x20 || b > 0x7e { - result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) - } else { - result = append(result, b) - } - } - } - - return result -} - -// NullTime represents a time.Time that may be null. NullTime implements the -// sql.Scanner interface so it can be used as a scan destination, similar to -// sql.NullString. -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -// Scan implements the Scanner interface. -func (nt *NullTime) Scan(value interface{}) error { - nt.Time, nt.Valid = value.(time.Time) - return nil -} - -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/error.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/error.go deleted file mode 100644 index f67c5a5fa..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/error.go +++ /dev/null @@ -1,523 +0,0 @@ -package pq - -import ( - "database/sql/driver" - "fmt" - "io" - "net" - "runtime" -) - -// Error severities -const ( - Efatal = "FATAL" - Epanic = "PANIC" - Ewarning = "WARNING" - Enotice = "NOTICE" - Edebug = "DEBUG" - Einfo = "INFO" - Elog = "LOG" -) - -// Error represents an error communicating with the server. -// -// See http://www.postgresql.org/docs/current/static/protocol-error-fields.html for details of the fields -type Error struct { - Severity string - Code ErrorCode - Message string - Detail string - Hint string - Position string - InternalPosition string - InternalQuery string - Where string - Schema string - Table string - Column string - DataTypeName string - Constraint string - File string - Line string - Routine string -} - -// ErrorCode is a five-character error code. -type ErrorCode string - -// Name returns a more human friendly rendering of the error code, namely the -// "condition name". -// -// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for -// details. -func (ec ErrorCode) Name() string { - return errorCodeNames[ec] -} - -// ErrorClass is only the class part of an error code. -type ErrorClass string - -// Name returns the condition name of an error class. It is equivalent to the -// condition name of the "standard" error code (i.e. the one having the last -// three characters "000"). -func (ec ErrorClass) Name() string { - return errorCodeNames[ErrorCode(ec+"000")] -} - -// Class returns the error class, e.g. "28". -// -// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for -// details. -func (ec ErrorCode) Class() ErrorClass { - return ErrorClass(ec[0:2]) -} - -// errorCodeNames is a mapping between the five-character error codes and the -// human readable "condition names". It is derived from the list at -// http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html -var errorCodeNames = map[ErrorCode]string{ - // Class 00 - Successful Completion - "00000": "successful_completion", - // Class 01 - Warning - "01000": "warning", - "0100C": "dynamic_result_sets_returned", - "01008": "implicit_zero_bit_padding", - "01003": "null_value_eliminated_in_set_function", - "01007": "privilege_not_granted", - "01006": "privilege_not_revoked", - "01004": "string_data_right_truncation", - "01P01": "deprecated_feature", - // Class 02 - No Data (this is also a warning class per the SQL standard) - "02000": "no_data", - "02001": "no_additional_dynamic_result_sets_returned", - // Class 03 - SQL Statement Not Yet Complete - "03000": "sql_statement_not_yet_complete", - // Class 08 - Connection Exception - "08000": "connection_exception", - "08003": "connection_does_not_exist", - "08006": "connection_failure", - "08001": "sqlclient_unable_to_establish_sqlconnection", - "08004": "sqlserver_rejected_establishment_of_sqlconnection", - "08007": "transaction_resolution_unknown", - "08P01": "protocol_violation", - // Class 09 - Triggered Action Exception - "09000": "triggered_action_exception", - // Class 0A - Feature Not Supported - "0A000": "feature_not_supported", - // Class 0B - Invalid Transaction Initiation - "0B000": "invalid_transaction_initiation", - // Class 0F - Locator Exception - "0F000": "locator_exception", - "0F001": "invalid_locator_specification", - // Class 0L - Invalid Grantor - "0L000": "invalid_grantor", - "0LP01": "invalid_grant_operation", - // Class 0P - Invalid Role Specification - "0P000": "invalid_role_specification", - // Class 0Z - Diagnostics Exception - "0Z000": "diagnostics_exception", - "0Z002": "stacked_diagnostics_accessed_without_active_handler", - // Class 20 - Case Not Found - "20000": "case_not_found", - // Class 21 - Cardinality Violation - "21000": "cardinality_violation", - // Class 22 - Data Exception - "22000": "data_exception", - "2202E": "array_subscript_error", - "22021": "character_not_in_repertoire", - "22008": "datetime_field_overflow", - "22012": "division_by_zero", - "22005": "error_in_assignment", - "2200B": "escape_character_conflict", - "22022": "indicator_overflow", - "22015": "interval_field_overflow", - "2201E": "invalid_argument_for_logarithm", - "22014": "invalid_argument_for_ntile_function", - "22016": "invalid_argument_for_nth_value_function", - "2201F": "invalid_argument_for_power_function", - "2201G": "invalid_argument_for_width_bucket_function", - "22018": "invalid_character_value_for_cast", - "22007": "invalid_datetime_format", - "22019": "invalid_escape_character", - "2200D": "invalid_escape_octet", - "22025": "invalid_escape_sequence", - "22P06": "nonstandard_use_of_escape_character", - "22010": "invalid_indicator_parameter_value", - "22023": "invalid_parameter_value", - "2201B": "invalid_regular_expression", - "2201W": "invalid_row_count_in_limit_clause", - "2201X": "invalid_row_count_in_result_offset_clause", - "22009": "invalid_time_zone_displacement_value", - "2200C": "invalid_use_of_escape_character", - "2200G": "most_specific_type_mismatch", - "22004": "null_value_not_allowed", - "22002": "null_value_no_indicator_parameter", - "22003": "numeric_value_out_of_range", - "2200H": "sequence_generator_limit_exceeded", - "22026": "string_data_length_mismatch", - "22001": "string_data_right_truncation", - "22011": "substring_error", - "22027": "trim_error", - "22024": "unterminated_c_string", - "2200F": "zero_length_character_string", - "22P01": "floating_point_exception", - "22P02": "invalid_text_representation", - "22P03": "invalid_binary_representation", - "22P04": "bad_copy_file_format", - "22P05": "untranslatable_character", - "2200L": "not_an_xml_document", - "2200M": "invalid_xml_document", - "2200N": "invalid_xml_content", - "2200S": "invalid_xml_comment", - "2200T": "invalid_xml_processing_instruction", - // Class 23 - Integrity Constraint Violation - "23000": "integrity_constraint_violation", - "23001": "restrict_violation", - "23502": "not_null_violation", - "23503": "foreign_key_violation", - "23505": "unique_violation", - "23514": "check_violation", - "23P01": "exclusion_violation", - // Class 24 - Invalid Cursor State - "24000": "invalid_cursor_state", - // Class 25 - Invalid Transaction State - "25000": "invalid_transaction_state", - "25001": "active_sql_transaction", - "25002": "branch_transaction_already_active", - "25008": "held_cursor_requires_same_isolation_level", - "25003": "inappropriate_access_mode_for_branch_transaction", - "25004": "inappropriate_isolation_level_for_branch_transaction", - "25005": "no_active_sql_transaction_for_branch_transaction", - "25006": "read_only_sql_transaction", - "25007": "schema_and_data_statement_mixing_not_supported", - "25P01": "no_active_sql_transaction", - "25P02": "in_failed_sql_transaction", - // Class 26 - Invalid SQL Statement Name - "26000": "invalid_sql_statement_name", - // Class 27 - Triggered Data Change Violation - "27000": "triggered_data_change_violation", - // Class 28 - Invalid Authorization Specification - "28000": "invalid_authorization_specification", - "28P01": "invalid_password", - // Class 2B - Dependent Privilege Descriptors Still Exist - "2B000": "dependent_privilege_descriptors_still_exist", - "2BP01": "dependent_objects_still_exist", - // Class 2D - Invalid Transaction Termination - "2D000": "invalid_transaction_termination", - // Class 2F - SQL Routine Exception - "2F000": "sql_routine_exception", - "2F005": "function_executed_no_return_statement", - "2F002": "modifying_sql_data_not_permitted", - "2F003": "prohibited_sql_statement_attempted", - "2F004": "reading_sql_data_not_permitted", - // Class 34 - Invalid Cursor Name - "34000": "invalid_cursor_name", - // Class 38 - External Routine Exception - "38000": "external_routine_exception", - "38001": "containing_sql_not_permitted", - "38002": "modifying_sql_data_not_permitted", - "38003": "prohibited_sql_statement_attempted", - "38004": "reading_sql_data_not_permitted", - // Class 39 - External Routine Invocation Exception - "39000": "external_routine_invocation_exception", - "39001": "invalid_sqlstate_returned", - "39004": "null_value_not_allowed", - "39P01": "trigger_protocol_violated", - "39P02": "srf_protocol_violated", - // Class 3B - Savepoint Exception - "3B000": "savepoint_exception", - "3B001": "invalid_savepoint_specification", - // Class 3D - Invalid Catalog Name - "3D000": "invalid_catalog_name", - // Class 3F - Invalid Schema Name - "3F000": "invalid_schema_name", - // Class 40 - Transaction Rollback - "40000": "transaction_rollback", - "40002": "transaction_integrity_constraint_violation", - "40001": "serialization_failure", - "40003": "statement_completion_unknown", - "40P01": "deadlock_detected", - // Class 42 - Syntax Error or Access Rule Violation - "42000": "syntax_error_or_access_rule_violation", - "42601": "syntax_error", - "42501": "insufficient_privilege", - "42846": "cannot_coerce", - "42803": "grouping_error", - "42P20": "windowing_error", - "42P19": "invalid_recursion", - "42830": "invalid_foreign_key", - "42602": "invalid_name", - "42622": "name_too_long", - "42939": "reserved_name", - "42804": "datatype_mismatch", - "42P18": "indeterminate_datatype", - "42P21": "collation_mismatch", - "42P22": "indeterminate_collation", - "42809": "wrong_object_type", - "42703": "undefined_column", - "42883": "undefined_function", - "42P01": "undefined_table", - "42P02": "undefined_parameter", - "42704": "undefined_object", - "42701": "duplicate_column", - "42P03": "duplicate_cursor", - "42P04": "duplicate_database", - "42723": "duplicate_function", - "42P05": "duplicate_prepared_statement", - "42P06": "duplicate_schema", - "42P07": "duplicate_table", - "42712": "duplicate_alias", - "42710": "duplicate_object", - "42702": "ambiguous_column", - "42725": "ambiguous_function", - "42P08": "ambiguous_parameter", - "42P09": "ambiguous_alias", - "42P10": "invalid_column_reference", - "42611": "invalid_column_definition", - "42P11": "invalid_cursor_definition", - "42P12": "invalid_database_definition", - "42P13": "invalid_function_definition", - "42P14": "invalid_prepared_statement_definition", - "42P15": "invalid_schema_definition", - "42P16": "invalid_table_definition", - "42P17": "invalid_object_definition", - // Class 44 - WITH CHECK OPTION Violation - "44000": "with_check_option_violation", - // Class 53 - Insufficient Resources - "53000": "insufficient_resources", - "53100": "disk_full", - "53200": "out_of_memory", - "53300": "too_many_connections", - "53400": "configuration_limit_exceeded", - // Class 54 - Program Limit Exceeded - "54000": "program_limit_exceeded", - "54001": "statement_too_complex", - "54011": "too_many_columns", - "54023": "too_many_arguments", - // Class 55 - Object Not In Prerequisite State - "55000": "object_not_in_prerequisite_state", - "55006": "object_in_use", - "55P02": "cant_change_runtime_param", - "55P03": "lock_not_available", - // Class 57 - Operator Intervention - "57000": "operator_intervention", - "57014": "query_canceled", - "57P01": "admin_shutdown", - "57P02": "crash_shutdown", - "57P03": "cannot_connect_now", - "57P04": "database_dropped", - // Class 58 - System Error (errors external to PostgreSQL itself) - "58000": "system_error", - "58030": "io_error", - "58P01": "undefined_file", - "58P02": "duplicate_file", - // Class F0 - Configuration File Error - "F0000": "config_file_error", - "F0001": "lock_file_exists", - // Class HV - Foreign Data Wrapper Error (SQL/MED) - "HV000": "fdw_error", - "HV005": "fdw_column_name_not_found", - "HV002": "fdw_dynamic_parameter_value_needed", - "HV010": "fdw_function_sequence_error", - "HV021": "fdw_inconsistent_descriptor_information", - "HV024": "fdw_invalid_attribute_value", - "HV007": "fdw_invalid_column_name", - "HV008": "fdw_invalid_column_number", - "HV004": "fdw_invalid_data_type", - "HV006": "fdw_invalid_data_type_descriptors", - "HV091": "fdw_invalid_descriptor_field_identifier", - "HV00B": "fdw_invalid_handle", - "HV00C": "fdw_invalid_option_index", - "HV00D": "fdw_invalid_option_name", - "HV090": "fdw_invalid_string_length_or_buffer_length", - "HV00A": "fdw_invalid_string_format", - "HV009": "fdw_invalid_use_of_null_pointer", - "HV014": "fdw_too_many_handles", - "HV001": "fdw_out_of_memory", - "HV00P": "fdw_no_schemas", - "HV00J": "fdw_option_name_not_found", - "HV00K": "fdw_reply_handle", - "HV00Q": "fdw_schema_not_found", - "HV00R": "fdw_table_not_found", - "HV00L": "fdw_unable_to_create_execution", - "HV00M": "fdw_unable_to_create_reply", - "HV00N": "fdw_unable_to_establish_connection", - // Class P0 - PL/pgSQL Error - "P0000": "plpgsql_error", - "P0001": "raise_exception", - "P0002": "no_data_found", - "P0003": "too_many_rows", - // Class XX - Internal Error - "XX000": "internal_error", - "XX001": "data_corrupted", - "XX002": "index_corrupted", -} - -func parseError(r *readBuf) *Error { - err := new(Error) - for t := r.byte(); t != 0; t = r.byte() { - msg := r.string() - switch t { - case 'S': - err.Severity = msg - case 'C': - err.Code = ErrorCode(msg) - case 'M': - err.Message = msg - case 'D': - err.Detail = msg - case 'H': - err.Hint = msg - case 'P': - err.Position = msg - case 'p': - err.InternalPosition = msg - case 'q': - err.InternalQuery = msg - case 'W': - err.Where = msg - case 's': - err.Schema = msg - case 't': - err.Table = msg - case 'c': - err.Column = msg - case 'd': - err.DataTypeName = msg - case 'n': - err.Constraint = msg - case 'F': - err.File = msg - case 'L': - err.Line = msg - case 'R': - err.Routine = msg - } - } - return err -} - -// Fatal returns true if the Error Severity is fatal. -func (err *Error) Fatal() bool { - return err.Severity == Efatal -} - -// SQLState returns the SQLState of the error. -func (err *Error) SQLState() string { - return string(err.Code) -} - -// Get implements the legacy PGError interface. New code should use the fields -// of the Error struct directly. -func (err *Error) Get(k byte) (v string) { - switch k { - case 'S': - return err.Severity - case 'C': - return string(err.Code) - case 'M': - return err.Message - case 'D': - return err.Detail - case 'H': - return err.Hint - case 'P': - return err.Position - case 'p': - return err.InternalPosition - case 'q': - return err.InternalQuery - case 'W': - return err.Where - case 's': - return err.Schema - case 't': - return err.Table - case 'c': - return err.Column - case 'd': - return err.DataTypeName - case 'n': - return err.Constraint - case 'F': - return err.File - case 'L': - return err.Line - case 'R': - return err.Routine - } - return "" -} - -func (err *Error) Error() string { - return "pq: " + err.Message -} - -// PGError is an interface used by previous versions of pq. It is provided -// only to support legacy code. New code should use the Error type. -type PGError interface { - Error() string - Fatal() bool - Get(k byte) (v string) -} - -func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) -} - -// TODO(ainar-g) Rename to errorf after removing panics. -func fmterrorf(s string, args ...interface{}) error { - return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)) -} - -func errRecoverNoErrBadConn(err *error) { - e := recover() - if e == nil { - // Do nothing - return - } - var ok bool - *err, ok = e.(error) - if !ok { - *err = fmt.Errorf("pq: unexpected error: %#v", e) - } -} - -func (cn *conn) errRecover(err *error) { - e := recover() - switch v := e.(type) { - case nil: - // Do nothing - case runtime.Error: - cn.err.set(driver.ErrBadConn) - panic(v) - case *Error: - if v.Fatal() { - *err = driver.ErrBadConn - } else { - *err = v - } - case *net.OpError: - cn.err.set(driver.ErrBadConn) - *err = v - case *safeRetryError: - cn.err.set(driver.ErrBadConn) - *err = driver.ErrBadConn - case error: - if v == io.EOF || v.Error() == "remote error: handshake failure" { - *err = driver.ErrBadConn - } else { - *err = v - } - - default: - cn.err.set(driver.ErrBadConn) - panic(fmt.Sprintf("unknown error: %#v", e)) - } - - // Any time we return ErrBadConn, we need to remember it since *Tx doesn't - // mark the connection bad in database/sql. - if *err == driver.ErrBadConn { - cn.err.set(driver.ErrBadConn) - } -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/hstore/hstore.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/hstore/hstore.go deleted file mode 100644 index f1470db14..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/hstore/hstore.go +++ /dev/null @@ -1,118 +0,0 @@ -package hstore - -import ( - "database/sql" - "database/sql/driver" - "strings" -) - -// Hstore is a wrapper for transferring Hstore values back and forth easily. -type Hstore struct { - Map map[string]sql.NullString -} - -// escapes and quotes hstore keys/values -// s should be a sql.NullString or string -func hQuote(s interface{}) string { - var str string - switch v := s.(type) { - case sql.NullString: - if !v.Valid { - return "NULL" - } - str = v.String - case string: - str = v - default: - panic("not a string or sql.NullString") - } - - str = strings.Replace(str, "\\", "\\\\", -1) - return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` -} - -// Scan implements the Scanner interface. -// -// Note h.Map is reallocated before the scan to clear existing values. If the -// hstore column's database value is NULL, then h.Map is set to nil instead. -func (h *Hstore) Scan(value interface{}) error { - if value == nil { - h.Map = nil - return nil - } - h.Map = make(map[string]sql.NullString) - var b byte - pair := [][]byte{{}, {}} - pi := 0 - inQuote := false - didQuote := false - sawSlash := false - bindex := 0 - for bindex, b = range value.([]byte) { - if sawSlash { - pair[pi] = append(pair[pi], b) - sawSlash = false - continue - } - - switch b { - case '\\': - sawSlash = true - continue - case '"': - inQuote = !inQuote - if !didQuote { - didQuote = true - } - continue - default: - if !inQuote { - switch b { - case ' ', '\t', '\n', '\r': - continue - case '=': - continue - case '>': - pi = 1 - didQuote = false - continue - case ',': - s := string(pair[1]) - if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { - h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} - } else { - h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} - } - pair[0] = []byte{} - pair[1] = []byte{} - pi = 0 - continue - } - } - } - pair[pi] = append(pair[pi], b) - } - if bindex > 0 { - s := string(pair[1]) - if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { - h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} - } else { - h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} - } - } - return nil -} - -// Value implements the driver Valuer interface. Note if h.Map is nil, the -// database column value will be set to NULL. -func (h Hstore) Value() (driver.Value, error) { - if h.Map == nil { - return nil, nil - } - parts := []string{} - for key, val := range h.Map { - thispart := hQuote(key) + "=>" + hQuote(val) - parts = append(parts, thispart) - } - return []byte(strings.Join(parts, ",")), nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/krb.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/krb.go deleted file mode 100644 index 408ec01f9..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/krb.go +++ /dev/null @@ -1,27 +0,0 @@ -package pq - -// NewGSSFunc creates a GSS authentication provider, for use with -// RegisterGSSProvider. -type NewGSSFunc func() (GSS, error) - -var newGss NewGSSFunc - -// RegisterGSSProvider registers a GSS authentication provider. For example, if -// you need to use Kerberos to authenticate with your server, add this to your -// main package: -// -// import "github.com/lib/pq/auth/kerberos" -// -// func init() { -// pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() }) -// } -func RegisterGSSProvider(newGssArg NewGSSFunc) { - newGss = newGssArg -} - -// GSS provides GSSAPI authentication (e.g., Kerberos). -type GSS interface { - GetInitToken(host string, service string) ([]byte, error) - GetInitTokenFromSpn(spn string) ([]byte, error) - Continue(inToken []byte) (done bool, outToken []byte, err error) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/notice.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/notice.go deleted file mode 100644 index 70ad122a7..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/notice.go +++ /dev/null @@ -1,72 +0,0 @@ -//go:build go1.10 -// +build go1.10 - -package pq - -import ( - "context" - "database/sql/driver" -) - -// NoticeHandler returns the notice handler on the given connection, if any. A -// runtime panic occurs if c is not a pq connection. This is rarely used -// directly, use ConnectorNoticeHandler and ConnectorWithNoticeHandler instead. -func NoticeHandler(c driver.Conn) func(*Error) { - return c.(*conn).noticeHandler -} - -// SetNoticeHandler sets the given notice handler on the given connection. A -// runtime panic occurs if c is not a pq connection. A nil handler may be used -// to unset it. This is rarely used directly, use ConnectorNoticeHandler and -// ConnectorWithNoticeHandler instead. -// -// Note: Notice handlers are executed synchronously by pq meaning commands -// won't continue to be processed until the handler returns. -func SetNoticeHandler(c driver.Conn, handler func(*Error)) { - c.(*conn).noticeHandler = handler -} - -// NoticeHandlerConnector wraps a regular connector and sets a notice handler -// on it. -type NoticeHandlerConnector struct { - driver.Connector - noticeHandler func(*Error) -} - -// Connect calls the underlying connector's connect method and then sets the -// notice handler. -func (n *NoticeHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { - c, err := n.Connector.Connect(ctx) - if err == nil { - SetNoticeHandler(c, n.noticeHandler) - } - return c, err -} - -// ConnectorNoticeHandler returns the currently set notice handler, if any. If -// the given connector is not a result of ConnectorWithNoticeHandler, nil is -// returned. -func ConnectorNoticeHandler(c driver.Connector) func(*Error) { - if c, ok := c.(*NoticeHandlerConnector); ok { - return c.noticeHandler - } - return nil -} - -// ConnectorWithNoticeHandler creates or sets the given handler for the given -// connector. If the given connector is a result of calling this function -// previously, it is simply set on the given connector and returned. Otherwise, -// this returns a new connector wrapping the given one and setting the notice -// handler. A nil notice handler may be used to unset it. -// -// The returned connector is intended to be used with database/sql.OpenDB. -// -// Note: Notice handlers are executed synchronously by pq meaning commands -// won't continue to be processed until the handler returns. -func ConnectorWithNoticeHandler(c driver.Connector, handler func(*Error)) *NoticeHandlerConnector { - if c, ok := c.(*NoticeHandlerConnector); ok { - c.noticeHandler = handler - return c - } - return &NoticeHandlerConnector{Connector: c, noticeHandler: handler} -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/notify.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/notify.go deleted file mode 100644 index 5c421fdb8..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/notify.go +++ /dev/null @@ -1,858 +0,0 @@ -package pq - -// Package pq is a pure Go Postgres driver for the database/sql package. -// This module contains support for Postgres LISTEN/NOTIFY. - -import ( - "context" - "database/sql/driver" - "errors" - "fmt" - "sync" - "sync/atomic" - "time" -) - -// Notification represents a single notification from the database. -type Notification struct { - // Process ID (PID) of the notifying postgres backend. - BePid int - // Name of the channel the notification was sent on. - Channel string - // Payload, or the empty string if unspecified. - Extra string -} - -func recvNotification(r *readBuf) *Notification { - bePid := r.int32() - channel := r.string() - extra := r.string() - - return &Notification{bePid, channel, extra} -} - -// SetNotificationHandler sets the given notification handler on the given -// connection. A runtime panic occurs if c is not a pq connection. A nil handler -// may be used to unset it. -// -// Note: Notification handlers are executed synchronously by pq meaning commands -// won't continue to be processed until the handler returns. -func SetNotificationHandler(c driver.Conn, handler func(*Notification)) { - c.(*conn).notificationHandler = handler -} - -// NotificationHandlerConnector wraps a regular connector and sets a notification handler -// on it. -type NotificationHandlerConnector struct { - driver.Connector - notificationHandler func(*Notification) -} - -// Connect calls the underlying connector's connect method and then sets the -// notification handler. -func (n *NotificationHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { - c, err := n.Connector.Connect(ctx) - if err == nil { - SetNotificationHandler(c, n.notificationHandler) - } - return c, err -} - -// ConnectorNotificationHandler returns the currently set notification handler, if any. If -// the given connector is not a result of ConnectorWithNotificationHandler, nil is -// returned. -func ConnectorNotificationHandler(c driver.Connector) func(*Notification) { - if c, ok := c.(*NotificationHandlerConnector); ok { - return c.notificationHandler - } - return nil -} - -// ConnectorWithNotificationHandler creates or sets the given handler for the given -// connector. If the given connector is a result of calling this function -// previously, it is simply set on the given connector and returned. Otherwise, -// this returns a new connector wrapping the given one and setting the notification -// handler. A nil notification handler may be used to unset it. -// -// The returned connector is intended to be used with database/sql.OpenDB. -// -// Note: Notification handlers are executed synchronously by pq meaning commands -// won't continue to be processed until the handler returns. -func ConnectorWithNotificationHandler(c driver.Connector, handler func(*Notification)) *NotificationHandlerConnector { - if c, ok := c.(*NotificationHandlerConnector); ok { - c.notificationHandler = handler - return c - } - return &NotificationHandlerConnector{Connector: c, notificationHandler: handler} -} - -const ( - connStateIdle int32 = iota - connStateExpectResponse - connStateExpectReadyForQuery -) - -type message struct { - typ byte - err error -} - -var errListenerConnClosed = errors.New("pq: ListenerConn has been closed") - -// ListenerConn is a low-level interface for waiting for notifications. You -// should use Listener instead. -type ListenerConn struct { - // guards cn and err - connectionLock sync.Mutex - cn *conn - err error - - connState int32 - - // the sending goroutine will be holding this lock - senderLock sync.Mutex - - notificationChan chan<- *Notification - - replyChan chan message -} - -// NewListenerConn creates a new ListenerConn. Use NewListener instead. -func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { - return newDialListenerConn(defaultDialer{}, name, notificationChan) -} - -func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) { - cn, err := DialOpen(d, name) - if err != nil { - return nil, err - } - - l := &ListenerConn{ - cn: cn.(*conn), - notificationChan: c, - connState: connStateIdle, - replyChan: make(chan message, 2), - } - - go l.listenerConnMain() - - return l, nil -} - -// We can only allow one goroutine at a time to be running a query on the -// connection for various reasons, so the goroutine sending on the connection -// must be holding senderLock. -// -// Returns an error if an unrecoverable error has occurred and the ListenerConn -// should be abandoned. -func (l *ListenerConn) acquireSenderLock() error { - // we must acquire senderLock first to avoid deadlocks; see ExecSimpleQuery - l.senderLock.Lock() - - l.connectionLock.Lock() - err := l.err - l.connectionLock.Unlock() - if err != nil { - l.senderLock.Unlock() - return err - } - return nil -} - -func (l *ListenerConn) releaseSenderLock() { - l.senderLock.Unlock() -} - -// setState advances the protocol state to newState. Returns false if moving -// to that state from the current state is not allowed. -func (l *ListenerConn) setState(newState int32) bool { - var expectedState int32 - - switch newState { - case connStateIdle: - expectedState = connStateExpectReadyForQuery - case connStateExpectResponse: - expectedState = connStateIdle - case connStateExpectReadyForQuery: - expectedState = connStateExpectResponse - default: - panic(fmt.Sprintf("unexpected listenerConnState %d", newState)) - } - - return atomic.CompareAndSwapInt32(&l.connState, expectedState, newState) -} - -// Main logic is here: receive messages from the postgres backend, forward -// notifications and query replies and keep the internal state in sync with the -// protocol state. Returns when the connection has been lost, is about to go -// away or should be discarded because we couldn't agree on the state with the -// server backend. -func (l *ListenerConn) listenerConnLoop() (err error) { - defer errRecoverNoErrBadConn(&err) - - r := &readBuf{} - for { - t, err := l.cn.recvMessage(r) - if err != nil { - return err - } - - switch t { - case 'A': - // recvNotification copies all the data so we don't need to worry - // about the scratch buffer being overwritten. - l.notificationChan <- recvNotification(r) - - case 'T', 'D': - // only used by tests; ignore - - case 'E': - // We might receive an ErrorResponse even when not in a query; it - // is expected that the server will close the connection after - // that, but we should make sure that the error we display is the - // one from the stray ErrorResponse, not io.ErrUnexpectedEOF. - if !l.setState(connStateExpectReadyForQuery) { - return parseError(r) - } - l.replyChan <- message{t, parseError(r)} - - case 'C', 'I': - if !l.setState(connStateExpectReadyForQuery) { - // protocol out of sync - return fmt.Errorf("unexpected CommandComplete") - } - // ExecSimpleQuery doesn't need to know about this message - - case 'Z': - if !l.setState(connStateIdle) { - // protocol out of sync - return fmt.Errorf("unexpected ReadyForQuery") - } - l.replyChan <- message{t, nil} - - case 'S': - // ignore - case 'N': - if n := l.cn.noticeHandler; n != nil { - n(parseError(r)) - } - default: - return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t) - } - } -} - -// This is the main routine for the goroutine receiving on the database -// connection. Most of the main logic is in listenerConnLoop. -func (l *ListenerConn) listenerConnMain() { - err := l.listenerConnLoop() - - // listenerConnLoop terminated; we're done, but we still have to clean up. - // Make sure nobody tries to start any new queries by making sure the err - // pointer is set. It is important that we do not overwrite its value; a - // connection could be closed by either this goroutine or one sending on - // the connection -- whoever closes the connection is assumed to have the - // more meaningful error message (as the other one will probably get - // net.errClosed), so that goroutine sets the error we expose while the - // other error is discarded. If the connection is lost while two - // goroutines are operating on the socket, it probably doesn't matter which - // error we expose so we don't try to do anything more complex. - l.connectionLock.Lock() - if l.err == nil { - l.err = err - } - l.cn.Close() - l.connectionLock.Unlock() - - // There might be a query in-flight; make sure nobody's waiting for a - // response to it, since there's not going to be one. - close(l.replyChan) - - // let the listener know we're done - close(l.notificationChan) - - // this ListenerConn is done -} - -// Listen sends a LISTEN query to the server. See ExecSimpleQuery. -func (l *ListenerConn) Listen(channel string) (bool, error) { - return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel)) -} - -// Unlisten sends an UNLISTEN query to the server. See ExecSimpleQuery. -func (l *ListenerConn) Unlisten(channel string) (bool, error) { - return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel)) -} - -// UnlistenAll sends an `UNLISTEN *` query to the server. See ExecSimpleQuery. -func (l *ListenerConn) UnlistenAll() (bool, error) { - return l.ExecSimpleQuery("UNLISTEN *") -} - -// Ping the remote server to make sure it's alive. Non-nil error means the -// connection has failed and should be abandoned. -func (l *ListenerConn) Ping() error { - sent, err := l.ExecSimpleQuery("") - if !sent { - return err - } - if err != nil { - // shouldn't happen - panic(err) - } - return nil -} - -// Attempt to send a query on the connection. Returns an error if sending the -// query failed, and the caller should initiate closure of this connection. -// The caller must be holding senderLock (see acquireSenderLock and -// releaseSenderLock). -func (l *ListenerConn) sendSimpleQuery(q string) (err error) { - defer errRecoverNoErrBadConn(&err) - - // must set connection state before sending the query - if !l.setState(connStateExpectResponse) { - panic("two queries running at the same time") - } - - // Can't use l.cn.writeBuf here because it uses the scratch buffer which - // might get overwritten by listenerConnLoop. - b := &writeBuf{ - buf: []byte("Q\x00\x00\x00\x00"), - pos: 1, - } - b.string(q) - l.cn.send(b) - - return nil -} - -// ExecSimpleQuery executes a "simple query" (i.e. one with no bindable -// parameters) on the connection. The possible return values are: -// 1) "executed" is true; the query was executed to completion on the -// database server. If the query failed, err will be set to the error -// returned by the database, otherwise err will be nil. -// 2) If "executed" is false, the query could not be executed on the remote -// server. err will be non-nil. -// -// After a call to ExecSimpleQuery has returned an executed=false value, the -// connection has either been closed or will be closed shortly thereafter, and -// all subsequently executed queries will return an error. -func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { - if err = l.acquireSenderLock(); err != nil { - return false, err - } - defer l.releaseSenderLock() - - err = l.sendSimpleQuery(q) - if err != nil { - // We can't know what state the protocol is in, so we need to abandon - // this connection. - l.connectionLock.Lock() - // Set the error pointer if it hasn't been set already; see - // listenerConnMain. - if l.err == nil { - l.err = err - } - l.connectionLock.Unlock() - l.cn.c.Close() - return false, err - } - - // now we just wait for a reply.. - for { - m, ok := <-l.replyChan - if !ok { - // We lost the connection to server, don't bother waiting for a - // a response. err should have been set already. - l.connectionLock.Lock() - err := l.err - l.connectionLock.Unlock() - return false, err - } - switch m.typ { - case 'Z': - // sanity check - if m.err != nil { - panic("m.err != nil") - } - // done; err might or might not be set - return true, err - - case 'E': - // sanity check - if m.err == nil { - panic("m.err == nil") - } - // server responded with an error; ReadyForQuery to follow - err = m.err - - default: - return false, fmt.Errorf("unknown response for simple query: %q", m.typ) - } - } -} - -// Close closes the connection. -func (l *ListenerConn) Close() error { - l.connectionLock.Lock() - if l.err != nil { - l.connectionLock.Unlock() - return errListenerConnClosed - } - l.err = errListenerConnClosed - l.connectionLock.Unlock() - // We can't send anything on the connection without holding senderLock. - // Simply close the net.Conn to wake up everyone operating on it. - return l.cn.c.Close() -} - -// Err returns the reason the connection was closed. It is not safe to call -// this function until l.Notify has been closed. -func (l *ListenerConn) Err() error { - return l.err -} - -var errListenerClosed = errors.New("pq: Listener has been closed") - -// ErrChannelAlreadyOpen is returned from Listen when a channel is already -// open. -var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") - -// ErrChannelNotOpen is returned from Unlisten when a channel is not open. -var ErrChannelNotOpen = errors.New("pq: channel is not open") - -// ListenerEventType is an enumeration of listener event types. -type ListenerEventType int - -const ( - // ListenerEventConnected is emitted only when the database connection - // has been initially initialized. The err argument of the callback - // will always be nil. - ListenerEventConnected ListenerEventType = iota - - // ListenerEventDisconnected is emitted after a database connection has - // been lost, either because of an error or because Close has been - // called. The err argument will be set to the reason the database - // connection was lost. - ListenerEventDisconnected - - // ListenerEventReconnected is emitted after a database connection has - // been re-established after connection loss. The err argument of the - // callback will always be nil. After this event has been emitted, a - // nil pq.Notification is sent on the Listener.Notify channel. - ListenerEventReconnected - - // ListenerEventConnectionAttemptFailed is emitted after a connection - // to the database was attempted, but failed. The err argument will be - // set to an error describing why the connection attempt did not - // succeed. - ListenerEventConnectionAttemptFailed -) - -// EventCallbackType is the event callback type. See also ListenerEventType -// constants' documentation. -type EventCallbackType func(event ListenerEventType, err error) - -// Listener provides an interface for listening to notifications from a -// PostgreSQL database. For general usage information, see section -// "Notifications". -// -// Listener can safely be used from concurrently running goroutines. -type Listener struct { - // Channel for receiving notifications from the database. In some cases a - // nil value will be sent. See section "Notifications" above. - Notify chan *Notification - - name string - minReconnectInterval time.Duration - maxReconnectInterval time.Duration - dialer Dialer - eventCallback EventCallbackType - - lock sync.Mutex - isClosed bool - reconnectCond *sync.Cond - cn *ListenerConn - connNotificationChan <-chan *Notification - channels map[string]struct{} -} - -// NewListener creates a new database connection dedicated to LISTEN / NOTIFY. -// -// name should be set to a connection string to be used to establish the -// database connection (see section "Connection String Parameters" above). -// -// minReconnectInterval controls the duration to wait before trying to -// re-establish the database connection after connection loss. After each -// consecutive failure this interval is doubled, until maxReconnectInterval is -// reached. Successfully completing the connection establishment procedure -// resets the interval back to minReconnectInterval. -// -// The last parameter eventCallback can be set to a function which will be -// called by the Listener when the state of the underlying database connection -// changes. This callback will be called by the goroutine which dispatches the -// notifications over the Notify channel, so you should try to avoid doing -// potentially time-consuming operations from the callback. -func NewListener(name string, - minReconnectInterval time.Duration, - maxReconnectInterval time.Duration, - eventCallback EventCallbackType) *Listener { - return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) -} - -// NewDialListener is like NewListener but it takes a Dialer. -func NewDialListener(d Dialer, - name string, - minReconnectInterval time.Duration, - maxReconnectInterval time.Duration, - eventCallback EventCallbackType) *Listener { - - l := &Listener{ - name: name, - minReconnectInterval: minReconnectInterval, - maxReconnectInterval: maxReconnectInterval, - dialer: d, - eventCallback: eventCallback, - - channels: make(map[string]struct{}), - - Notify: make(chan *Notification, 32), - } - l.reconnectCond = sync.NewCond(&l.lock) - - go l.listenerMain() - - return l -} - -// NotificationChannel returns the notification channel for this listener. -// This is the same channel as Notify, and will not be recreated during the -// life time of the Listener. -func (l *Listener) NotificationChannel() <-chan *Notification { - return l.Notify -} - -// Listen starts listening for notifications on a channel. Calls to this -// function will block until an acknowledgement has been received from the -// server. Note that Listener automatically re-establishes the connection -// after connection loss, so this function may block indefinitely if the -// connection can not be re-established. -// -// Listen will only fail in three conditions: -// 1) The channel is already open. The returned error will be -// ErrChannelAlreadyOpen. -// 2) The query was executed on the remote server, but PostgreSQL returned an -// error message in response to the query. The returned error will be a -// pq.Error containing the information the server supplied. -// 3) Close is called on the Listener before the request could be completed. -// -// The channel name is case-sensitive. -func (l *Listener) Listen(channel string) error { - l.lock.Lock() - defer l.lock.Unlock() - - if l.isClosed { - return errListenerClosed - } - - // The server allows you to issue a LISTEN on a channel which is already - // open, but it seems useful to be able to detect this case to spot for - // mistakes in application logic. If the application genuinely does't - // care, it can check the exported error and ignore it. - _, exists := l.channels[channel] - if exists { - return ErrChannelAlreadyOpen - } - - if l.cn != nil { - // If gotResponse is true but error is set, the query was executed on - // the remote server, but resulted in an error. This should be - // relatively rare, so it's fine if we just pass the error to our - // caller. However, if gotResponse is false, we could not complete the - // query on the remote server and our underlying connection is about - // to go away, so we only add relname to l.channels, and wait for - // resync() to take care of the rest. - gotResponse, err := l.cn.Listen(channel) - if gotResponse && err != nil { - return err - } - } - - l.channels[channel] = struct{}{} - for l.cn == nil { - l.reconnectCond.Wait() - // we let go of the mutex for a while - if l.isClosed { - return errListenerClosed - } - } - - return nil -} - -// Unlisten removes a channel from the Listener's channel list. Returns -// ErrChannelNotOpen if the Listener is not listening on the specified channel. -// Returns immediately with no error if there is no connection. Note that you -// might still get notifications for this channel even after Unlisten has -// returned. -// -// The channel name is case-sensitive. -func (l *Listener) Unlisten(channel string) error { - l.lock.Lock() - defer l.lock.Unlock() - - if l.isClosed { - return errListenerClosed - } - - // Similarly to LISTEN, this is not an error in Postgres, but it seems - // useful to distinguish from the normal conditions. - _, exists := l.channels[channel] - if !exists { - return ErrChannelNotOpen - } - - if l.cn != nil { - // Similarly to Listen (see comment in that function), the caller - // should only be bothered with an error if it came from the backend as - // a response to our query. - gotResponse, err := l.cn.Unlisten(channel) - if gotResponse && err != nil { - return err - } - } - - // Don't bother waiting for resync if there's no connection. - delete(l.channels, channel) - return nil -} - -// UnlistenAll removes all channels from the Listener's channel list. Returns -// immediately with no error if there is no connection. Note that you might -// still get notifications for any of the deleted channels even after -// UnlistenAll has returned. -func (l *Listener) UnlistenAll() error { - l.lock.Lock() - defer l.lock.Unlock() - - if l.isClosed { - return errListenerClosed - } - - if l.cn != nil { - // Similarly to Listen (see comment in that function), the caller - // should only be bothered with an error if it came from the backend as - // a response to our query. - gotResponse, err := l.cn.UnlistenAll() - if gotResponse && err != nil { - return err - } - } - - // Don't bother waiting for resync if there's no connection. - l.channels = make(map[string]struct{}) - return nil -} - -// Ping the remote server to make sure it's alive. Non-nil return value means -// that there is no active connection. -func (l *Listener) Ping() error { - l.lock.Lock() - defer l.lock.Unlock() - - if l.isClosed { - return errListenerClosed - } - if l.cn == nil { - return errors.New("no connection") - } - - return l.cn.Ping() -} - -// Clean up after losing the server connection. Returns l.cn.Err(), which -// should have the reason the connection was lost. -func (l *Listener) disconnectCleanup() error { - l.lock.Lock() - defer l.lock.Unlock() - - // sanity check; can't look at Err() until the channel has been closed - select { - case _, ok := <-l.connNotificationChan: - if ok { - panic("connNotificationChan not closed") - } - default: - panic("connNotificationChan not closed") - } - - err := l.cn.Err() - l.cn.Close() - l.cn = nil - return err -} - -// Synchronize the list of channels we want to be listening on with the server -// after the connection has been established. -func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error { - doneChan := make(chan error) - go func(notificationChan <-chan *Notification) { - for channel := range l.channels { - // If we got a response, return that error to our caller as it's - // going to be more descriptive than cn.Err(). - gotResponse, err := cn.Listen(channel) - if gotResponse && err != nil { - doneChan <- err - return - } - - // If we couldn't reach the server, wait for notificationChan to - // close and then return the error message from the connection, as - // per ListenerConn's interface. - if err != nil { - for range notificationChan { - } - doneChan <- cn.Err() - return - } - } - doneChan <- nil - }(notificationChan) - - // Ignore notifications while synchronization is going on to avoid - // deadlocks. We have to send a nil notification over Notify anyway as - // we can't possibly know which notifications (if any) were lost while - // the connection was down, so there's no reason to try and process - // these messages at all. - for { - select { - case _, ok := <-notificationChan: - if !ok { - notificationChan = nil - } - - case err := <-doneChan: - return err - } - } -} - -// caller should NOT be holding l.lock -func (l *Listener) closed() bool { - l.lock.Lock() - defer l.lock.Unlock() - - return l.isClosed -} - -func (l *Listener) connect() error { - notificationChan := make(chan *Notification, 32) - cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) - if err != nil { - return err - } - - l.lock.Lock() - defer l.lock.Unlock() - - err = l.resync(cn, notificationChan) - if err != nil { - cn.Close() - return err - } - - l.cn = cn - l.connNotificationChan = notificationChan - l.reconnectCond.Broadcast() - - return nil -} - -// Close disconnects the Listener from the database and shuts it down. -// Subsequent calls to its methods will return an error. Close returns an -// error if the connection has already been closed. -func (l *Listener) Close() error { - l.lock.Lock() - defer l.lock.Unlock() - - if l.isClosed { - return errListenerClosed - } - - if l.cn != nil { - l.cn.Close() - } - l.isClosed = true - - // Unblock calls to Listen() - l.reconnectCond.Broadcast() - - return nil -} - -func (l *Listener) emitEvent(event ListenerEventType, err error) { - if l.eventCallback != nil { - l.eventCallback(event, err) - } -} - -// Main logic here: maintain a connection to the server when possible, wait -// for notifications and emit events. -func (l *Listener) listenerConnLoop() { - var nextReconnect time.Time - - reconnectInterval := l.minReconnectInterval - for { - for { - err := l.connect() - if err == nil { - break - } - - if l.closed() { - return - } - l.emitEvent(ListenerEventConnectionAttemptFailed, err) - - time.Sleep(reconnectInterval) - reconnectInterval *= 2 - if reconnectInterval > l.maxReconnectInterval { - reconnectInterval = l.maxReconnectInterval - } - } - - if nextReconnect.IsZero() { - l.emitEvent(ListenerEventConnected, nil) - } else { - l.emitEvent(ListenerEventReconnected, nil) - l.Notify <- nil - } - - reconnectInterval = l.minReconnectInterval - nextReconnect = time.Now().Add(reconnectInterval) - - for { - notification, ok := <-l.connNotificationChan - if !ok { - // lost connection, loop again - break - } - l.Notify <- notification - } - - err := l.disconnectCleanup() - if l.closed() { - return - } - l.emitEvent(ListenerEventDisconnected, err) - - time.Sleep(time.Until(nextReconnect)) - } -} - -func (l *Listener) listenerMain() { - l.listenerConnLoop() - close(l.Notify) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/doc.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/doc.go deleted file mode 100644 index caaede248..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// Package oid contains OID constants -// as defined by the Postgres server. -package oid - -// Oid is a Postgres Object ID. -type Oid uint32 diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/types.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/types.go deleted file mode 100644 index ecc84c2c8..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/oid/types.go +++ /dev/null @@ -1,343 +0,0 @@ -// Code generated by gen.go. DO NOT EDIT. - -package oid - -const ( - T_bool Oid = 16 - T_bytea Oid = 17 - T_char Oid = 18 - T_name Oid = 19 - T_int8 Oid = 20 - T_int2 Oid = 21 - T_int2vector Oid = 22 - T_int4 Oid = 23 - T_regproc Oid = 24 - T_text Oid = 25 - T_oid Oid = 26 - T_tid Oid = 27 - T_xid Oid = 28 - T_cid Oid = 29 - T_oidvector Oid = 30 - T_pg_ddl_command Oid = 32 - T_pg_type Oid = 71 - T_pg_attribute Oid = 75 - T_pg_proc Oid = 81 - T_pg_class Oid = 83 - T_json Oid = 114 - T_xml Oid = 142 - T__xml Oid = 143 - T_pg_node_tree Oid = 194 - T__json Oid = 199 - T_smgr Oid = 210 - T_index_am_handler Oid = 325 - T_point Oid = 600 - T_lseg Oid = 601 - T_path Oid = 602 - T_box Oid = 603 - T_polygon Oid = 604 - T_line Oid = 628 - T__line Oid = 629 - T_cidr Oid = 650 - T__cidr Oid = 651 - T_float4 Oid = 700 - T_float8 Oid = 701 - T_abstime Oid = 702 - T_reltime Oid = 703 - T_tinterval Oid = 704 - T_unknown Oid = 705 - T_circle Oid = 718 - T__circle Oid = 719 - T_money Oid = 790 - T__money Oid = 791 - T_macaddr Oid = 829 - T_inet Oid = 869 - T__bool Oid = 1000 - T__bytea Oid = 1001 - T__char Oid = 1002 - T__name Oid = 1003 - T__int2 Oid = 1005 - T__int2vector Oid = 1006 - T__int4 Oid = 1007 - T__regproc Oid = 1008 - T__text Oid = 1009 - T__tid Oid = 1010 - T__xid Oid = 1011 - T__cid Oid = 1012 - T__oidvector Oid = 1013 - T__bpchar Oid = 1014 - T__varchar Oid = 1015 - T__int8 Oid = 1016 - T__point Oid = 1017 - T__lseg Oid = 1018 - T__path Oid = 1019 - T__box Oid = 1020 - T__float4 Oid = 1021 - T__float8 Oid = 1022 - T__abstime Oid = 1023 - T__reltime Oid = 1024 - T__tinterval Oid = 1025 - T__polygon Oid = 1027 - T__oid Oid = 1028 - T_aclitem Oid = 1033 - T__aclitem Oid = 1034 - T__macaddr Oid = 1040 - T__inet Oid = 1041 - T_bpchar Oid = 1042 - T_varchar Oid = 1043 - T_date Oid = 1082 - T_time Oid = 1083 - T_timestamp Oid = 1114 - T__timestamp Oid = 1115 - T__date Oid = 1182 - T__time Oid = 1183 - T_timestamptz Oid = 1184 - T__timestamptz Oid = 1185 - T_interval Oid = 1186 - T__interval Oid = 1187 - T__numeric Oid = 1231 - T_pg_database Oid = 1248 - T__cstring Oid = 1263 - T_timetz Oid = 1266 - T__timetz Oid = 1270 - T_bit Oid = 1560 - T__bit Oid = 1561 - T_varbit Oid = 1562 - T__varbit Oid = 1563 - T_numeric Oid = 1700 - T_refcursor Oid = 1790 - T__refcursor Oid = 2201 - T_regprocedure Oid = 2202 - T_regoper Oid = 2203 - T_regoperator Oid = 2204 - T_regclass Oid = 2205 - T_regtype Oid = 2206 - T__regprocedure Oid = 2207 - T__regoper Oid = 2208 - T__regoperator Oid = 2209 - T__regclass Oid = 2210 - T__regtype Oid = 2211 - T_record Oid = 2249 - T_cstring Oid = 2275 - T_any Oid = 2276 - T_anyarray Oid = 2277 - T_void Oid = 2278 - T_trigger Oid = 2279 - T_language_handler Oid = 2280 - T_internal Oid = 2281 - T_opaque Oid = 2282 - T_anyelement Oid = 2283 - T__record Oid = 2287 - T_anynonarray Oid = 2776 - T_pg_authid Oid = 2842 - T_pg_auth_members Oid = 2843 - T__txid_snapshot Oid = 2949 - T_uuid Oid = 2950 - T__uuid Oid = 2951 - T_txid_snapshot Oid = 2970 - T_fdw_handler Oid = 3115 - T_pg_lsn Oid = 3220 - T__pg_lsn Oid = 3221 - T_tsm_handler Oid = 3310 - T_anyenum Oid = 3500 - T_tsvector Oid = 3614 - T_tsquery Oid = 3615 - T_gtsvector Oid = 3642 - T__tsvector Oid = 3643 - T__gtsvector Oid = 3644 - T__tsquery Oid = 3645 - T_regconfig Oid = 3734 - T__regconfig Oid = 3735 - T_regdictionary Oid = 3769 - T__regdictionary Oid = 3770 - T_jsonb Oid = 3802 - T__jsonb Oid = 3807 - T_anyrange Oid = 3831 - T_event_trigger Oid = 3838 - T_int4range Oid = 3904 - T__int4range Oid = 3905 - T_numrange Oid = 3906 - T__numrange Oid = 3907 - T_tsrange Oid = 3908 - T__tsrange Oid = 3909 - T_tstzrange Oid = 3910 - T__tstzrange Oid = 3911 - T_daterange Oid = 3912 - T__daterange Oid = 3913 - T_int8range Oid = 3926 - T__int8range Oid = 3927 - T_pg_shseclabel Oid = 4066 - T_regnamespace Oid = 4089 - T__regnamespace Oid = 4090 - T_regrole Oid = 4096 - T__regrole Oid = 4097 -) - -var TypeName = map[Oid]string{ - T_bool: "BOOL", - T_bytea: "BYTEA", - T_char: "CHAR", - T_name: "NAME", - T_int8: "INT8", - T_int2: "INT2", - T_int2vector: "INT2VECTOR", - T_int4: "INT4", - T_regproc: "REGPROC", - T_text: "TEXT", - T_oid: "OID", - T_tid: "TID", - T_xid: "XID", - T_cid: "CID", - T_oidvector: "OIDVECTOR", - T_pg_ddl_command: "PG_DDL_COMMAND", - T_pg_type: "PG_TYPE", - T_pg_attribute: "PG_ATTRIBUTE", - T_pg_proc: "PG_PROC", - T_pg_class: "PG_CLASS", - T_json: "JSON", - T_xml: "XML", - T__xml: "_XML", - T_pg_node_tree: "PG_NODE_TREE", - T__json: "_JSON", - T_smgr: "SMGR", - T_index_am_handler: "INDEX_AM_HANDLER", - T_point: "POINT", - T_lseg: "LSEG", - T_path: "PATH", - T_box: "BOX", - T_polygon: "POLYGON", - T_line: "LINE", - T__line: "_LINE", - T_cidr: "CIDR", - T__cidr: "_CIDR", - T_float4: "FLOAT4", - T_float8: "FLOAT8", - T_abstime: "ABSTIME", - T_reltime: "RELTIME", - T_tinterval: "TINTERVAL", - T_unknown: "UNKNOWN", - T_circle: "CIRCLE", - T__circle: "_CIRCLE", - T_money: "MONEY", - T__money: "_MONEY", - T_macaddr: "MACADDR", - T_inet: "INET", - T__bool: "_BOOL", - T__bytea: "_BYTEA", - T__char: "_CHAR", - T__name: "_NAME", - T__int2: "_INT2", - T__int2vector: "_INT2VECTOR", - T__int4: "_INT4", - T__regproc: "_REGPROC", - T__text: "_TEXT", - T__tid: "_TID", - T__xid: "_XID", - T__cid: "_CID", - T__oidvector: "_OIDVECTOR", - T__bpchar: "_BPCHAR", - T__varchar: "_VARCHAR", - T__int8: "_INT8", - T__point: "_POINT", - T__lseg: "_LSEG", - T__path: "_PATH", - T__box: "_BOX", - T__float4: "_FLOAT4", - T__float8: "_FLOAT8", - T__abstime: "_ABSTIME", - T__reltime: "_RELTIME", - T__tinterval: "_TINTERVAL", - T__polygon: "_POLYGON", - T__oid: "_OID", - T_aclitem: "ACLITEM", - T__aclitem: "_ACLITEM", - T__macaddr: "_MACADDR", - T__inet: "_INET", - T_bpchar: "BPCHAR", - T_varchar: "VARCHAR", - T_date: "DATE", - T_time: "TIME", - T_timestamp: "TIMESTAMP", - T__timestamp: "_TIMESTAMP", - T__date: "_DATE", - T__time: "_TIME", - T_timestamptz: "TIMESTAMPTZ", - T__timestamptz: "_TIMESTAMPTZ", - T_interval: "INTERVAL", - T__interval: "_INTERVAL", - T__numeric: "_NUMERIC", - T_pg_database: "PG_DATABASE", - T__cstring: "_CSTRING", - T_timetz: "TIMETZ", - T__timetz: "_TIMETZ", - T_bit: "BIT", - T__bit: "_BIT", - T_varbit: "VARBIT", - T__varbit: "_VARBIT", - T_numeric: "NUMERIC", - T_refcursor: "REFCURSOR", - T__refcursor: "_REFCURSOR", - T_regprocedure: "REGPROCEDURE", - T_regoper: "REGOPER", - T_regoperator: "REGOPERATOR", - T_regclass: "REGCLASS", - T_regtype: "REGTYPE", - T__regprocedure: "_REGPROCEDURE", - T__regoper: "_REGOPER", - T__regoperator: "_REGOPERATOR", - T__regclass: "_REGCLASS", - T__regtype: "_REGTYPE", - T_record: "RECORD", - T_cstring: "CSTRING", - T_any: "ANY", - T_anyarray: "ANYARRAY", - T_void: "VOID", - T_trigger: "TRIGGER", - T_language_handler: "LANGUAGE_HANDLER", - T_internal: "INTERNAL", - T_opaque: "OPAQUE", - T_anyelement: "ANYELEMENT", - T__record: "_RECORD", - T_anynonarray: "ANYNONARRAY", - T_pg_authid: "PG_AUTHID", - T_pg_auth_members: "PG_AUTH_MEMBERS", - T__txid_snapshot: "_TXID_SNAPSHOT", - T_uuid: "UUID", - T__uuid: "_UUID", - T_txid_snapshot: "TXID_SNAPSHOT", - T_fdw_handler: "FDW_HANDLER", - T_pg_lsn: "PG_LSN", - T__pg_lsn: "_PG_LSN", - T_tsm_handler: "TSM_HANDLER", - T_anyenum: "ANYENUM", - T_tsvector: "TSVECTOR", - T_tsquery: "TSQUERY", - T_gtsvector: "GTSVECTOR", - T__tsvector: "_TSVECTOR", - T__gtsvector: "_GTSVECTOR", - T__tsquery: "_TSQUERY", - T_regconfig: "REGCONFIG", - T__regconfig: "_REGCONFIG", - T_regdictionary: "REGDICTIONARY", - T__regdictionary: "_REGDICTIONARY", - T_jsonb: "JSONB", - T__jsonb: "_JSONB", - T_anyrange: "ANYRANGE", - T_event_trigger: "EVENT_TRIGGER", - T_int4range: "INT4RANGE", - T__int4range: "_INT4RANGE", - T_numrange: "NUMRANGE", - T__numrange: "_NUMRANGE", - T_tsrange: "TSRANGE", - T__tsrange: "_TSRANGE", - T_tstzrange: "TSTZRANGE", - T__tstzrange: "_TSTZRANGE", - T_daterange: "DATERANGE", - T__daterange: "_DATERANGE", - T_int8range: "INT8RANGE", - T__int8range: "_INT8RANGE", - T_pg_shseclabel: "PG_SHSECLABEL", - T_regnamespace: "REGNAMESPACE", - T__regnamespace: "_REGNAMESPACE", - T_regrole: "REGROLE", - T__regrole: "_REGROLE", -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/rows.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/rows.go deleted file mode 100644 index c6aa5b9a3..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/rows.go +++ /dev/null @@ -1,93 +0,0 @@ -package pq - -import ( - "math" - "reflect" - "time" - - "github.com/lib/pq/oid" -) - -const headerSize = 4 - -type fieldDesc struct { - // The object ID of the data type. - OID oid.Oid - // The data type size (see pg_type.typlen). - // Note that negative values denote variable-width types. - Len int - // The type modifier (see pg_attribute.atttypmod). - // The meaning of the modifier is type-specific. - Mod int -} - -func (fd fieldDesc) Type() reflect.Type { - switch fd.OID { - case oid.T_int8: - return reflect.TypeOf(int64(0)) - case oid.T_int4: - return reflect.TypeOf(int32(0)) - case oid.T_int2: - return reflect.TypeOf(int16(0)) - case oid.T_varchar, oid.T_text: - return reflect.TypeOf("") - case oid.T_bool: - return reflect.TypeOf(false) - case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: - return reflect.TypeOf(time.Time{}) - case oid.T_bytea: - return reflect.TypeOf([]byte(nil)) - default: - return reflect.TypeOf(new(interface{})).Elem() - } -} - -func (fd fieldDesc) Name() string { - return oid.TypeName[fd.OID] -} - -func (fd fieldDesc) Length() (length int64, ok bool) { - switch fd.OID { - case oid.T_text, oid.T_bytea: - return math.MaxInt64, true - case oid.T_varchar, oid.T_bpchar: - return int64(fd.Mod - headerSize), true - default: - return 0, false - } -} - -func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { - switch fd.OID { - case oid.T_numeric, oid.T__numeric: - mod := fd.Mod - headerSize - precision = int64((mod >> 16) & 0xffff) - scale = int64(mod & 0xffff) - return precision, scale, true - default: - return 0, 0, false - } -} - -// ColumnTypeScanType returns the value type that can be used to scan types into. -func (rs *rows) ColumnTypeScanType(index int) reflect.Type { - return rs.colTyps[index].Type() -} - -// ColumnTypeDatabaseTypeName return the database system type name. -func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { - return rs.colTyps[index].Name() -} - -// ColumnTypeLength returns the length of the column type if the column is a -// variable length type. If the column is not a variable length type ok -// should return false. -func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { - return rs.colTyps[index].Length() -} - -// ColumnTypePrecisionScale should return the precision and scale for decimal -// types. If not applicable, ok should be false. -func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - return rs.colTyps[index].PrecisionScale() -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/scram/scram.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/scram/scram.go deleted file mode 100644 index 477216b60..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/scram/scram.go +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright (c) 2014 - Gustavo Niemeyer -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR -// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. -// -// http://tools.ietf.org/html/rfc5802 -// -package scram - -import ( - "bytes" - "crypto/hmac" - "crypto/rand" - "encoding/base64" - "fmt" - "hash" - "strconv" - "strings" -) - -// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). -// -// A Client may be used within a SASL conversation with logic resembling: -// -// var in []byte -// var client = scram.NewClient(sha1.New, user, pass) -// for client.Step(in) { -// out := client.Out() -// // send out to server -// in := serverOut -// } -// if client.Err() != nil { -// // auth failed -// } -// -type Client struct { - newHash func() hash.Hash - - user string - pass string - step int - out bytes.Buffer - err error - - clientNonce []byte - serverNonce []byte - saltedPass []byte - authMsg bytes.Buffer -} - -// NewClient returns a new SCRAM-* client with the provided hash algorithm. -// -// For SCRAM-SHA-256, for example, use: -// -// client := scram.NewClient(sha256.New, user, pass) -// -func NewClient(newHash func() hash.Hash, user, pass string) *Client { - c := &Client{ - newHash: newHash, - user: user, - pass: pass, - } - c.out.Grow(256) - c.authMsg.Grow(256) - return c -} - -// Out returns the data to be sent to the server in the current step. -func (c *Client) Out() []byte { - if c.out.Len() == 0 { - return nil - } - return c.out.Bytes() -} - -// Err returns the error that occurred, or nil if there were no errors. -func (c *Client) Err() error { - return c.err -} - -// SetNonce sets the client nonce to the provided value. -// If not set, the nonce is generated automatically out of crypto/rand on the first step. -func (c *Client) SetNonce(nonce []byte) { - c.clientNonce = nonce -} - -var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") - -// Step processes the incoming data from the server and makes the -// next round of data for the server available via Client.Out. -// Step returns false if there are no errors and more data is -// still expected. -func (c *Client) Step(in []byte) bool { - c.out.Reset() - if c.step > 2 || c.err != nil { - return false - } - c.step++ - switch c.step { - case 1: - c.err = c.step1(in) - case 2: - c.err = c.step2(in) - case 3: - c.err = c.step3(in) - } - return c.step > 2 || c.err != nil -} - -func (c *Client) step1(in []byte) error { - if len(c.clientNonce) == 0 { - const nonceLen = 16 - buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) - if _, err := rand.Read(buf[:nonceLen]); err != nil { - return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err) - } - c.clientNonce = buf[nonceLen:] - b64.Encode(c.clientNonce, buf[:nonceLen]) - } - c.authMsg.WriteString("n=") - escaper.WriteString(&c.authMsg, c.user) - c.authMsg.WriteString(",r=") - c.authMsg.Write(c.clientNonce) - - c.out.WriteString("n,,") - c.out.Write(c.authMsg.Bytes()) - return nil -} - -var b64 = base64.StdEncoding - -func (c *Client) step2(in []byte) error { - c.authMsg.WriteByte(',') - c.authMsg.Write(in) - - fields := bytes.Split(in, []byte(",")) - if len(fields) != 3 { - return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in) - } - if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { - return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0]) - } - if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { - return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1]) - } - if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { - return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) - } - - c.serverNonce = fields[0][2:] - if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { - return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) - } - - salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) - n, err := b64.Decode(salt, fields[1][2:]) - if err != nil { - return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1]) - } - salt = salt[:n] - iterCount, err := strconv.Atoi(string(fields[2][2:])) - if err != nil { - return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) - } - c.saltPassword(salt, iterCount) - - c.authMsg.WriteString(",c=biws,r=") - c.authMsg.Write(c.serverNonce) - - c.out.WriteString("c=biws,r=") - c.out.Write(c.serverNonce) - c.out.WriteString(",p=") - c.out.Write(c.clientProof()) - return nil -} - -func (c *Client) step3(in []byte) error { - var isv, ise bool - var fields = bytes.Split(in, []byte(",")) - if len(fields) == 1 { - isv = bytes.HasPrefix(fields[0], []byte("v=")) - ise = bytes.HasPrefix(fields[0], []byte("e=")) - } - if ise { - return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:]) - } else if !isv { - return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in) - } - if !bytes.Equal(c.serverSignature(), fields[0][2:]) { - return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:]) - } - return nil -} - -func (c *Client) saltPassword(salt []byte, iterCount int) { - mac := hmac.New(c.newHash, []byte(c.pass)) - mac.Write(salt) - mac.Write([]byte{0, 0, 0, 1}) - ui := mac.Sum(nil) - hi := make([]byte, len(ui)) - copy(hi, ui) - for i := 1; i < iterCount; i++ { - mac.Reset() - mac.Write(ui) - mac.Sum(ui[:0]) - for j, b := range ui { - hi[j] ^= b - } - } - c.saltedPass = hi -} - -func (c *Client) clientProof() []byte { - mac := hmac.New(c.newHash, c.saltedPass) - mac.Write([]byte("Client Key")) - clientKey := mac.Sum(nil) - hash := c.newHash() - hash.Write(clientKey) - storedKey := hash.Sum(nil) - mac = hmac.New(c.newHash, storedKey) - mac.Write(c.authMsg.Bytes()) - clientProof := mac.Sum(nil) - for i, b := range clientKey { - clientProof[i] ^= b - } - clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) - b64.Encode(clientProof64, clientProof) - return clientProof64 -} - -func (c *Client) serverSignature() []byte { - mac := hmac.New(c.newHash, c.saltedPass) - mac.Write([]byte("Server Key")) - serverKey := mac.Sum(nil) - - mac = hmac.New(c.newHash, serverKey) - mac.Write(c.authMsg.Bytes()) - serverSignature := mac.Sum(nil) - - encoded := make([]byte, b64.EncodedLen(len(serverSignature))) - b64.Encode(encoded, serverSignature) - return encoded -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl.go deleted file mode 100644 index 36b61ba45..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl.go +++ /dev/null @@ -1,204 +0,0 @@ -package pq - -import ( - "crypto/tls" - "crypto/x509" - "io/ioutil" - "net" - "os" - "os/user" - "path/filepath" - "strings" -) - -// ssl generates a function to upgrade a net.Conn based on the "sslmode" and -// related settings. The function is nil when no upgrade should take place. -func ssl(o values) (func(net.Conn) (net.Conn, error), error) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o["sslmode"]; mode { - // "require" is the default. - case "", "require": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - - // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: - // - // Note: For backwards compatibility with earlier versions of - // PostgreSQL, if a root CA file exists, the behavior of - // sslmode=require will be the same as that of verify-ca, meaning the - // server certificate is validated against the CA. Relying on this - // behavior is discouraged, and applications that need certificate - // validation should always use verify-ca or verify-full. - if sslrootcert, ok := o["sslrootcert"]; ok { - if _, err := os.Stat(sslrootcert); err == nil { - verifyCaOnly = true - } else { - delete(o, "sslrootcert") - } - } - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o["host"] - case "disable": - return nil, nil - default: - return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) - } - - // Set Server Name Indication (SNI), if enabled by connection parameters. - // By default SNI is on, any value which is not starting with "1" disables - // SNI -- that is the same check vanilla libpq uses. - if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") { - // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 - // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so - // just always set ServerName here and let crypto/tls do the filtering. - tlsConf.ServerName = o["host"] - } - - err := sslClientCertificates(&tlsConf, o) - if err != nil { - return nil, err - } - err = sslCertificateAuthority(&tlsConf, o) - if err != nil { - return nil, err - } - - // Accept renegotiation requests initiated by the backend. - // - // Renegotiation was deprecated then removed from PostgreSQL 9.5, but - // the default configuration of older versions has it enabled. Redshift - // also initiates renegotiations and cannot be reconfigured. - tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient - - return func(conn net.Conn) (net.Conn, error) { - client := tls.Client(conn, &tlsConf) - if verifyCaOnly { - err := sslVerifyCertificateAuthority(client, &tlsConf) - if err != nil { - return nil, err - } - } - return client, nil - }, nil -} - -// sslClientCertificates adds the certificate specified in the "sslcert" and -// "sslkey" settings, or if they aren't set, from the .postgresql directory -// in the user's home directory. The configured files must exist and have -// the correct permissions. -func sslClientCertificates(tlsConf *tls.Config, o values) error { - sslinline := o["sslinline"] - if sslinline == "true" { - cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) - if err != nil { - return err - } - tlsConf.Certificates = []tls.Certificate{cert} - return nil - } - - // user.Current() might fail when cross-compiling. We have to ignore the - // error and continue without home directory defaults, since we wouldn't - // know from where to load them. - user, _ := user.Current() - - // In libpq, the client certificate is only loaded if the setting is not blank. - // - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 - sslcert := o["sslcert"] - if len(sslcert) == 0 && user != nil { - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - } - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 - if len(sslcert) == 0 { - return nil - } - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 - if _, err := os.Stat(sslcert); os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - - // In libpq, the ssl key is only loaded if the setting is not blank. - // - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 - sslkey := o["sslkey"] - if len(sslkey) == 0 && user != nil { - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - } - - if len(sslkey) > 0 { - if err := sslKeyPermissions(sslkey); err != nil { - return err - } - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - return err - } - - tlsConf.Certificates = []tls.Certificate{cert} - return nil -} - -// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. -func sslCertificateAuthority(tlsConf *tls.Config, o values) error { - // In libpq, the root certificate is only loaded if the setting is not blank. - // - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 - if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { - tlsConf.RootCAs = x509.NewCertPool() - - sslinline := o["sslinline"] - - var cert []byte - if sslinline == "true" { - cert = []byte(sslrootcert) - } else { - var err error - cert, err = ioutil.ReadFile(sslrootcert) - if err != nil { - return err - } - } - - if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { - return fmterrorf("couldn't parse pem in sslrootcert") - } - } - - return nil -} - -// sslVerifyCertificateAuthority carries out a TLS handshake to the server and -// verifies the presented certificate against the CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { - err := client.Handshake() - if err != nil { - return err - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - return err -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_permissions.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_permissions.go deleted file mode 100644 index d587f102e..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_permissions.go +++ /dev/null @@ -1,93 +0,0 @@ -//go:build !windows -// +build !windows - -package pq - -import ( - "errors" - "os" - "syscall" -) - -const ( - rootUserID = uint32(0) - - // The maximum permissions that a private key file owned by a regular user - // is allowed to have. This translates to u=rw. - maxUserOwnedKeyPermissions os.FileMode = 0600 - - // The maximum permissions that a private key file owned by root is allowed - // to have. This translates to u=rw,g=r. - maxRootOwnedKeyPermissions os.FileMode = 0640 -) - -var ( - errSSLKeyHasUnacceptableUserPermissions = errors.New("permissions for files not owned by root should be u=rw (0600) or less") - errSSLKeyHasUnacceptableRootPermissions = errors.New("permissions for root owned files should be u=rw,g=r (0640) or less") -) - -// sslKeyPermissions checks the permissions on user-supplied ssl key files. -// The key file should have very little access. -// -// libpq does not check key file permissions on Windows. -func sslKeyPermissions(sslkey string) error { - info, err := os.Stat(sslkey) - if err != nil { - return err - } - - err = hasCorrectPermissions(info) - - // return ErrSSLKeyHasWorldPermissions for backwards compatability with - // existing code. - if err == errSSLKeyHasUnacceptableUserPermissions || err == errSSLKeyHasUnacceptableRootPermissions { - err = ErrSSLKeyHasWorldPermissions - } - return err -} - -// hasCorrectPermissions checks the file info (and the unix-specific stat_t -// output) to verify that the permissions on the file are correct. -// -// If the file is owned by the same user the process is running as, -// the file should only have 0600 (u=rw). If the file is owned by root, -// and the group matches the group that the process is running in, the -// permissions cannot be more than 0640 (u=rw,g=r). The file should -// never have world permissions. -// -// Returns an error when the permission check fails. -func hasCorrectPermissions(info os.FileInfo) error { - // if file's permission matches 0600, allow access. - userPermissionMask := (os.FileMode(0777) ^ maxUserOwnedKeyPermissions) - - // regardless of if we're running as root or not, 0600 is acceptable, - // so we return if we match the regular user permission mask. - if info.Mode().Perm()&userPermissionMask == 0 { - return nil - } - - // We need to pull the Unix file information to get the file's owner. - // If we can't access it, there's some sort of operating system level error - // and we should fail rather than attempting to use faulty information. - sysInfo := info.Sys() - if sysInfo == nil { - return ErrSSLKeyUnknownOwnership - } - - unixStat, ok := sysInfo.(*syscall.Stat_t) - if !ok { - return ErrSSLKeyUnknownOwnership - } - - // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what - // Postgres does. - if unixStat.Uid == rootUserID { - rootPermissionMask := (os.FileMode(0777) ^ maxRootOwnedKeyPermissions) - if info.Mode().Perm()&rootPermissionMask != 0 { - return errSSLKeyHasUnacceptableRootPermissions - } - return nil - } - - return errSSLKeyHasUnacceptableUserPermissions -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_windows.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_windows.go deleted file mode 100644 index 73663c8f1..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/ssl_windows.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build windows -// +build windows - -package pq - -// sslKeyPermissions checks the permissions on user-supplied ssl key files. -// The key file should have very little access. -// -// libpq does not check key file permissions on Windows. -func sslKeyPermissions(string) error { return nil } diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/url.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/url.go deleted file mode 100644 index aec6e95be..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/url.go +++ /dev/null @@ -1,76 +0,0 @@ -package pq - -import ( - "fmt" - "net" - nurl "net/url" - "sort" - "strings" -) - -// ParseURL no longer needs to be used by clients of this library since supplying a URL as a -// connection string to sql.Open() is now supported: -// -// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") -// -// It remains exported here for backwards-compatibility. -// -// ParseURL converts a url to a connection string for driver.Open. -// Example: -// -// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" -// -// converts to: -// -// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" -// -// A minimal example: -// -// "postgres://" -// -// This will be blank, causing driver.Open to use all of the defaults -func ParseURL(url string) (string, error) { - u, err := nurl.Parse(url) - if err != nil { - return "", err - } - - if u.Scheme != "postgres" && u.Scheme != "postgresql" { - return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) - } - - var kvs []string - escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") - } - } - - if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) - } - - if host, port, err := net.SplitHostPort(u.Host); err != nil { - accrue("host", u.Host) - } else { - accrue("host", host) - accrue("port", port) - } - - if u.Path != "" { - accrue("dbname", u.Path[1:]) - } - - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_other.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_other.go deleted file mode 100644 index 3dae8f557..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_other.go +++ /dev/null @@ -1,10 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. - -//go:build js || android || hurd || zos -// +build js android hurd zos - -package pq - -func userCurrent() (string, error) { - return "", ErrCouldNotDetectUsername -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_posix.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_posix.go deleted file mode 100644 index 5f2d439bc..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_posix.go +++ /dev/null @@ -1,25 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. - -//go:build aix || darwin || dragonfly || freebsd || (linux && !android) || nacl || netbsd || openbsd || plan9 || solaris || rumprun || illumos -// +build aix darwin dragonfly freebsd linux,!android nacl netbsd openbsd plan9 solaris rumprun illumos - -package pq - -import ( - "os" - "os/user" -) - -func userCurrent() (string, error) { - u, err := user.Current() - if err == nil { - return u.Username, nil - } - - name := os.Getenv("USER") - if name != "" { - return name, nil - } - - return "", ErrCouldNotDetectUsername -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_windows.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_windows.go deleted file mode 100644 index 2b691267b..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/user_windows.go +++ /dev/null @@ -1,27 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. -package pq - -import ( - "path/filepath" - "syscall" -) - -// Perform Windows user name lookup identically to libpq. -// -// The PostgreSQL code makes use of the legacy Win32 function -// GetUserName, and that function has not been imported into stock Go. -// GetUserNameEx is available though, the difference being that a -// wider range of names are available. To get the output to be the -// same as GetUserName, only the base (or last) component of the -// result is returned. -func userCurrent() (string, error) { - pw_name := make([]uint16, 128) - pwname_size := uint32(len(pw_name)) - 1 - err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size) - if err != nil { - return "", ErrCouldNotDetectUsername - } - s := syscall.UTF16ToString(pw_name) - u := filepath.Base(s) - return u, nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/uuid.go b/src/code.cloudfoundry.org/vendor/github.com/lib/pq/uuid.go deleted file mode 100644 index 9a1b9e074..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/lib/pq/uuid.go +++ /dev/null @@ -1,23 +0,0 @@ -package pq - -import ( - "encoding/hex" - "fmt" -) - -// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. -func decodeUUIDBinary(src []byte) ([]byte, error) { - if len(src) != 16 { - return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) - } - - dst := make([]byte, 36) - dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' - hex.Encode(dst[0:], src[0:4]) - hex.Encode(dst[9:], src[4:6]) - hex.Encode(dst[14:], src[6:8]) - hex.Encode(dst[19:], src[8:10]) - hex.Encode(dst[24:], src[10:16]) - - return dst, nil -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/.gitignore b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/.gitignore similarity index 61% rename from src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/.gitignore rename to src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/.gitignore index 117f92f52..45505cc93 100644 --- a/src/code.cloudfoundry.org/vendor/github.com/jinzhu/gorm/.gitignore +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/.gitignore @@ -1,3 +1,6 @@ +TODO* documents coverage.txt _book +.idea +vendor \ No newline at end of file diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/License b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/License new file mode 100644 index 000000000..037e1653e --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/License @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md new file mode 100644 index 000000000..b8f7a6c97 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md @@ -0,0 +1,51 @@ +# GORM MySQL Driver + +## Quick Start + +```go +import ( + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +// https://github.com/go-sql-driver/mysql +dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" +db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) +``` + +## Configuration + +```go +import ( + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +var datetimePrecision = 2 + +db, err := gorm.Open(mysql.New(mysql.Config{ + DSN: "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local", // data source name, refer https://github.com/go-sql-driver/mysql#dsn-data-source-name + DefaultStringSize: 256, // add default size for string fields, by default, will use db type `longtext` for fields without size, not a primary key, no index defined and don't have default values + DisableDatetimePrecision: true, // disable datetime precision support, which not supported before MySQL 5.6 + DefaultDatetimePrecision: &datetimePrecision, // default datetime precision + DontSupportRenameIndex: true, // drop & create index when rename index, rename index not supported before MySQL 5.7, MariaDB + DontSupportRenameColumn: true, // use change when rename column, rename rename not supported before MySQL 8, MariaDB + SkipInitializeWithVersion: false, // smart configure based on used version +}), &gorm.Config{}) +``` + +## Customized Driver + +```go +import ( + _ "example.com/my_mysql_driver" + "gorm.io/gorm" +) + +db, err := gorm.Open(mysql.New(mysql.Config{ + DriverName: "my_mysql_driver_name", + DSN: "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local", // data source name, refer https://github.com/go-sql-driver/mysql#dsn-data-source-name +}) +``` + +Checkout [https://gorm.io](https://gorm.io) for details. diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go new file mode 100644 index 000000000..79f6646e5 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go @@ -0,0 +1,21 @@ +package mysql + +import ( + "github.com/go-sql-driver/mysql" + + "gorm.io/gorm" +) + +var errCodes = map[string]uint16{ + "uniqueConstraint": 1062, +} + +func (dialector Dialector) Translate(err error) error { + if mysqlErr, ok := err.(*mysql.MySQLError); ok { + if mysqlErr.Number == errCodes["uniqueConstraint"] { + return gorm.ErrDuplicatedKey + } + } + + return err +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go new file mode 100644 index 000000000..d35a86e14 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go @@ -0,0 +1,408 @@ +package mysql + +import ( + "database/sql" + "fmt" + "strconv" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +const indexSql = ` +SELECT + TABLE_NAME, + COLUMN_NAME, + INDEX_NAME, + NON_UNIQUE +FROM + information_schema.STATISTICS +WHERE + TABLE_SCHEMA = ? + AND TABLE_NAME = ? +ORDER BY + INDEX_NAME, + SEQ_IN_INDEX` + +var typeAliasMap = map[string][]string{ + "bool": {"tinyint"}, + "tinyint": {"bool"}, +} + +type Migrator struct { + migrator.Migrator + Dialector +} + +func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr { + expr := m.Migrator.FullDataTypeOf(field) + + if value, ok := field.TagSettings["COMMENT"]; ok { + expr.SQL += " COMMENT " + m.Dialector.Explain("?", value) + } + + return expr +} + +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + fullDataType := m.FullDataTypeOf(field) + if m.Dialector.DontSupportRenameColumnUnique { + fullDataType.SQL = strings.Replace(fullDataType.SQL, " UNIQUE ", " ", 1) + } + + return m.DB.Exec( + "ALTER TABLE ? MODIFY COLUMN ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fullDataType, + ).Error + } + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (m Migrator) TiDBVersion() (isTiDB bool, major, minor, patch int, err error) { + // TiDB version string looks like: + // "5.7.25-TiDB-v6.5.0" or "5.7.25-TiDB-v6.4.0-serverless" + tidbVersionArray := strings.Split(m.Dialector.ServerVersion, "-") + if len(tidbVersionArray) < 3 || tidbVersionArray[1] != "TiDB" { + // It isn't TiDB + return + } + + rawVersion := strings.TrimPrefix(tidbVersionArray[2], "v") + realVersionArray := strings.Split(rawVersion, ".") + if major, err = strconv.Atoi(realVersionArray[0]); err != nil { + err = fmt.Errorf("failed to parse the version of TiDB, the major version is: %s", realVersionArray[0]) + return + } + + if minor, err = strconv.Atoi(realVersionArray[1]); err != nil { + err = fmt.Errorf("failed to parse the version of TiDB, the minor version is: %s", realVersionArray[0]) + return + } + + if patch, err = strconv.Atoi(realVersionArray[2]); err != nil { + err = fmt.Errorf("failed to parse the version of TiDB, the patch version is: %s", realVersionArray[0]) + return + } + + isTiDB = true + return +} + +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if !m.Dialector.DontSupportRenameColumn { + return m.Migrator.RenameColumn(value, oldName, newName) + } + + var field *schema.Field + if stmt.Schema != nil { + if f := stmt.Schema.LookUpField(oldName); f != nil { + oldName = f.DBName + field = f + } + + if f := stmt.Schema.LookUpField(newName); f != nil { + newName = f.DBName + field = f + } + } + + if field != nil { + return m.DB.Exec( + "ALTER TABLE ? CHANGE ? ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, + clause.Column{Name: newName}, m.FullDataTypeOf(field), + ).Error + } + + return fmt.Errorf("failed to look up field with name: %s", newName) + }) +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + if !m.Dialector.DontSupportRenameIndex { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER TABLE ? RENAME INDEX ? TO ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + err := m.DropIndex(value, oldName) + if err != nil { + return err + } + + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(newName); idx == nil { + if idx = stmt.Schema.LookIndex(oldName); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + } + } + + return m.CreateIndex(value, newName) + }) + +} + +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + return m.DB.Connection(func(tx *gorm.DB) error { + tx.Exec("SET FOREIGN_KEY_CHECKS = 0;") + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + return tx.Exec("SET FOREIGN_KEY_CHECKS = 1;").Error + }) +} + +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if chk != nil { + return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}).Error + } + if constraint != nil { + name = constraint.Name + } + + return m.DB.Exec( + "ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name}, + ).Error + }) +} + +// ColumnTypes column types return columnTypes,error +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { + columnTypes := make([]gorm.ColumnType, 0) + err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + var ( + currentDatabase, table = m.CurrentSchema(stmt, stmt.Table) + columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale " + rows, err = m.DB.Session(&gorm.Session{}).Table(table).Limit(1).Rows() + ) + + if err != nil { + return err + } + + rawColumnTypes, err := rows.ColumnTypes() + + if err != nil { + return err + } + + if err := rows.Close(); err != nil { + return err + } + + if !m.DisableDatetimePrecision { + columnTypeSQL += ", datetime_precision " + } + columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION" + + columns, rowErr := m.DB.Table(table).Raw(columnTypeSQL, currentDatabase, table).Rows() + if rowErr != nil { + return rowErr + } + + defer columns.Close() + + for columns.Next() { + var ( + column migrator.ColumnType + datetimePrecision sql.NullInt64 + extraValue sql.NullString + columnKey sql.NullString + values = []interface{}{ + &column.NameValue, &column.DefaultValueValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.ColumnTypeValue, &columnKey, &extraValue, &column.CommentValue, &column.DecimalSizeValue, &column.ScaleValue, + } + ) + + if !m.DisableDatetimePrecision { + values = append(values, &datetimePrecision) + } + + if scanErr := columns.Scan(values...); scanErr != nil { + return scanErr + } + + column.PrimaryKeyValue = sql.NullBool{Bool: false, Valid: true} + column.UniqueValue = sql.NullBool{Bool: false, Valid: true} + switch columnKey.String { + case "PRI": + column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + case "UNI": + column.UniqueValue = sql.NullBool{Bool: true, Valid: true} + } + + if strings.Contains(extraValue.String, "auto_increment") { + column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} + } + + column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'") + if m.Dialector.DontSupportNullAsDefaultValue { + // rewrite mariadb default value like other version + if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" { + column.DefaultValueValue.Valid = false + column.DefaultValueValue.String = "" + } + } + + if datetimePrecision.Valid { + column.DecimalSizeValue = datetimePrecision + } + + for _, c := range rawColumnTypes { + if c.Name() == column.NameValue.String { + column.SQLColumnType = c + break + } + } + + columnTypes = append(columnTypes, column) + } + + return nil + }) + + return columnTypes, err +} + +func (m Migrator) CurrentDatabase() (name string) { + baseName := m.Migrator.CurrentDatabase() + m.DB.Raw( + "SELECT SCHEMA_NAME from Information_schema.SCHEMATA where SCHEMA_NAME LIKE ? ORDER BY SCHEMA_NAME=? DESC,SCHEMA_NAME limit 1", + baseName+"%", baseName).Scan(&name) + return +} + +func (m Migrator) GetTables() (tableList []string, err error) { + err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). + Scan(&tableList).Error + return +} + +func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { + indexes := make([]gorm.Index, 0) + err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + + result := make([]*Index, 0) + schema, table := m.CurrentSchema(stmt, stmt.Table) + scanErr := m.DB.Table(table).Raw(indexSql, schema, table).Scan(&result).Error + if scanErr != nil { + return scanErr + } + indexMap, indexNames := groupByIndexName(result) + + for _, name := range indexNames { + idx := indexMap[name] + if len(idx) == 0 { + continue + } + tempIdx := &migrator.Index{ + TableName: idx[0].TableName, + NameValue: idx[0].IndexName, + PrimaryKeyValue: sql.NullBool{ + Bool: idx[0].IndexName == "PRIMARY", + Valid: true, + }, + UniqueValue: sql.NullBool{ + Bool: idx[0].NonUnique == 0, + Valid: true, + }, + } + for _, x := range idx { + tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName) + } + indexes = append(indexes, tempIdx) + } + return nil + }) + return indexes, err +} + +// Index table index info +type Index struct { + TableName string `gorm:"column:TABLE_NAME"` + ColumnName string `gorm:"column:COLUMN_NAME"` + IndexName string `gorm:"column:INDEX_NAME"` + NonUnique int32 `gorm:"column:NON_UNIQUE"` +} + +func groupByIndexName(indexList []*Index) (map[string][]*Index, []string) { + columnIndexMap := make(map[string][]*Index, len(indexList)) + indexNames := make([]string, 0, len(indexList)) + for _, idx := range indexList { + if _, ok := columnIndexMap[idx.IndexName]; !ok { + indexNames = append(indexNames, idx.IndexName) + } + columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx) + } + return columnIndexMap, indexNames +} + +func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) { + if tables := strings.Split(table, `.`); len(tables) == 2 { + return tables[0], tables[1] + } + m.DB = m.DB.Table(table) + return m.CurrentDatabase(), table +} + +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return typeAliasMap[databaseTypeName] +} + +// TableType table type return tableType,error +func (m Migrator) TableType(value interface{}) (tableType gorm.TableType, err error) { + var table migrator.TableType + + err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + var ( + values = []interface{}{ + &table.SchemaValue, &table.NameValue, &table.TypeValue, &table.CommentValue, + } + currentDatabase, tableName = m.CurrentSchema(stmt, stmt.Table) + tableTypeSQL = "SELECT table_schema, table_name, table_type, table_comment FROM information_schema.tables WHERE table_schema = ? AND table_name = ?" + ) + + row := m.DB.Table(tableName).Raw(tableTypeSQL, currentDatabase, tableName).Row() + + if scanErr := row.Scan(values...); scanErr != nil { + return scanErr + } + + return nil + }) + + return table, err +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go new file mode 100644 index 000000000..68d02e857 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go @@ -0,0 +1,533 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + "math" + "regexp" + "strconv" + "strings" + "time" + + "github.com/go-sql-driver/mysql" + + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +const ( + AutoRandomTag = "auto_random()" // Treated as an auto_random field for tidb +) + +type Config struct { + DriverName string + ServerVersion string + DSN string + DSNConfig *mysql.Config + Conn gorm.ConnPool + SkipInitializeWithVersion bool + DefaultStringSize uint + DefaultDatetimePrecision *int + DisableWithReturning bool + DisableDatetimePrecision bool + DontSupportRenameIndex bool + DontSupportRenameColumn bool + DontSupportForShareClause bool + DontSupportNullAsDefaultValue bool + DontSupportRenameColumnUnique bool +} + +type Dialector struct { + *Config +} + +var ( + // CreateClauses create clauses + CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} + // QueryClauses query clauses + QueryClauses = []string{} + // UpdateClauses update clauses + UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"} + // DeleteClauses delete clauses + DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"} + + defaultDatetimePrecision = 3 +) + +func Open(dsn string) gorm.Dialector { + dsnConf, _ := mysql.ParseDSN(dsn) + return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}} +} + +func New(config Config) gorm.Dialector { + switch { + case config.DSN == "" && config.DSNConfig != nil: + config.DSN = config.DSNConfig.FormatDSN() + case config.DSN != "" && config.DSNConfig == nil: + config.DSNConfig, _ = mysql.ParseDSN(config.DSN) + } + return &Dialector{Config: &config} +} + +func (dialector Dialector) Name() string { + return "mysql" +} + +// NowFunc return now func +func (dialector Dialector) NowFunc(n int) func() time.Time { + return func() time.Time { + round := time.Second / time.Duration(math.Pow10(n)) + return time.Now().Round(round) + } +} + +func (dialector Dialector) Apply(config *gorm.Config) error { + if config.NowFunc != nil { + return nil + } + + if dialector.DefaultDatetimePrecision == nil { + dialector.DefaultDatetimePrecision = &defaultDatetimePrecision + } + // while maintaining the readability of the code, separate the business logic from + // the general part and leave it to the function to do it here. + config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision) + return nil +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + if dialector.DriverName == "" { + dialector.DriverName = "mysql" + } + + if dialector.DefaultDatetimePrecision == nil { + dialector.DefaultDatetimePrecision = &defaultDatetimePrecision + } + + if dialector.Conn != nil { + db.ConnPool = dialector.Conn + } else { + db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN) + if err != nil { + return err + } + } + + withReturning := false + if !dialector.Config.SkipInitializeWithVersion { + err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion) + if err != nil { + return err + } + + if strings.Contains(dialector.ServerVersion, "MariaDB") { + dialector.Config.DontSupportRenameIndex = true + dialector.Config.DontSupportRenameColumn = true + dialector.Config.DontSupportForShareClause = true + dialector.Config.DontSupportNullAsDefaultValue = true + withReturning = checkVersion(dialector.ServerVersion, "10.5") + } else if strings.HasPrefix(dialector.ServerVersion, "5.6.") { + dialector.Config.DontSupportRenameIndex = true + dialector.Config.DontSupportRenameColumn = true + dialector.Config.DontSupportForShareClause = true + } else if strings.HasPrefix(dialector.ServerVersion, "5.7.") { + dialector.Config.DontSupportRenameColumn = true + dialector.Config.DontSupportForShareClause = true + } else if strings.HasPrefix(dialector.ServerVersion, "5.") { + dialector.Config.DisableDatetimePrecision = true + dialector.Config.DontSupportRenameIndex = true + dialector.Config.DontSupportRenameColumn = true + dialector.Config.DontSupportForShareClause = true + } + + if strings.Contains(dialector.ServerVersion, "TiDB") { + dialector.Config.DontSupportRenameColumnUnique = true + } + } + + // register callbacks + callbackConfig := &callbacks.Config{ + CreateClauses: CreateClauses, + QueryClauses: QueryClauses, + UpdateClauses: UpdateClauses, + DeleteClauses: DeleteClauses, + } + + if !dialector.Config.DisableWithReturning && withReturning { + callbackConfig.LastInsertIDReversed = true + + if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") { + callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") + } + + if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") { + callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING") + } + + if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") { + callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING") + } + } + + callbacks.RegisterDefaultCallbacks(db, callbackConfig) + + for k, v := range dialector.ClauseBuilders() { + db.ClauseBuilders[k] = v + } + return +} + +const ( + // ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key + ClauseOnConflict = "ON CONFLICT" + // ClauseValues for clause.ClauseBuilder VALUES key + ClauseValues = "VALUES" + // ClauseFor for clause.ClauseBuilder FOR key + ClauseFor = "FOR" +) + +func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + clauseBuilders := map[string]clause.ClauseBuilder{ + ClauseOnConflict: func(c clause.Clause, builder clause.Builder) { + onConflict, ok := c.Expression.(clause.OnConflict) + if !ok { + c.Build(builder) + return + } + + builder.WriteString("ON DUPLICATE KEY UPDATE ") + if len(onConflict.DoUpdates) == 0 { + if s := builder.(*gorm.Statement).Schema; s != nil { + var column clause.Column + onConflict.DoNothing = false + + if s.PrioritizedPrimaryField != nil { + column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} + } else if len(s.DBNames) > 0 { + column = clause.Column{Name: s.DBNames[0]} + } + + if column.Name != "" { + onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} + } + + builder.(*gorm.Statement).AddClause(onConflict) + } + } + + for idx, assignment := range onConflict.DoUpdates { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(assignment.Column) + builder.WriteByte('=') + if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" { + column.Table = "" + builder.WriteString("VALUES(") + builder.WriteQuoted(column) + builder.WriteByte(')') + } else { + builder.AddVar(builder, assignment.Value) + } + } + }, + ClauseValues: func(c clause.Clause, builder clause.Builder) { + if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { + builder.WriteString("VALUES()") + return + } + c.Build(builder) + }, + } + + if dialector.Config.DontSupportForShareClause { + clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) { + if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") { + builder.WriteString("LOCK IN SHARE MODE") + return + } + c.Build(builder) + } + } + + return clauseBuilders +} + +func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { + return clause.Expr{SQL: "DEFAULT"} +} + +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{ + Migrator: migrator.Migrator{ + Config: migrator.Config{ + DB: db, + Dialector: dialector, + }, + }, + Dialector: dialector, + } +} + +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') +} + +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + var ( + underQuoted, selfQuoted bool + continuousBacktick int8 + shiftDelimiter int8 + ) + + for _, v := range []byte(str) { + switch v { + case '`': + continuousBacktick++ + if continuousBacktick == 2 { + writer.WriteString("``") + continuousBacktick = 0 + } + case '.': + if continuousBacktick > 0 || !selfQuoted { + shiftDelimiter = 0 + underQuoted = false + continuousBacktick = 0 + writer.WriteByte('`') + } + writer.WriteByte(v) + continue + default: + if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { + writer.WriteByte('`') + underQuoted = true + if selfQuoted = continuousBacktick > 0; selfQuoted { + continuousBacktick -= 1 + } + } + + for ; continuousBacktick > 0; continuousBacktick -= 1 { + writer.WriteString("``") + } + + writer.WriteByte(v) + } + shiftDelimiter++ + } + + if continuousBacktick > 0 && !selfQuoted { + writer.WriteString("``") + } + writer.WriteByte('`') +} + +type localTimeInterface interface { + In(loc *time.Location) time.Time +} + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + if dialector.DSNConfig != nil && dialector.DSNConfig.Loc != nil { + for i, v := range vars { + if p, ok := v.(localTimeInterface); ok { + func(i int, t localTimeInterface) { + defer func() { + recover() + }() + vars[i] = t.In(dialector.DSNConfig.Loc) + }(i, p) + } + } + } + return logger.ExplainSQL(sql, nil, `'`, vars...) +} + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "boolean" + case schema.Int, schema.Uint: + return dialector.getSchemaIntAndUnitType(field) + case schema.Float: + return dialector.getSchemaFloatType(field) + case schema.String: + return dialector.getSchemaStringType(field) + case schema.Time: + return dialector.getSchemaTimeType(field) + case schema.Bytes: + return dialector.getSchemaBytesType(field) + default: + return dialector.getSchemaCustomType(field) + } +} + +func (dialector Dialector) getSchemaFloatType(field *schema.Field) string { + if field.Precision > 0 { + return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale) + } + + if field.Size <= 32 { + return "float" + } + + return "double" +} + +func (dialector Dialector) getSchemaStringType(field *schema.Field) string { + size := field.Size + if size == 0 { + if dialector.DefaultStringSize > 0 { + size = int(dialector.DefaultStringSize) + } else { + hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != "" + // TEXT, GEOMETRY or JSON column can't have a default value + if field.PrimaryKey || field.HasDefaultValue || hasIndex { + size = 191 // utf8mb4 + } + } + } + + if size >= 65536 && size <= int(math.Pow(2, 24)) { + return "mediumtext" + } + + if size > int(math.Pow(2, 24)) || size <= 0 { + return "longtext" + } + + return fmt.Sprintf("varchar(%d)", size) +} + +func (dialector Dialector) getSchemaTimeType(field *schema.Field) string { + if !dialector.DisableDatetimePrecision && field.Precision == 0 { + field.Precision = *dialector.DefaultDatetimePrecision + } + + var precision string + if field.Precision > 0 { + precision = fmt.Sprintf("(%d)", field.Precision) + } + + if field.NotNull || field.PrimaryKey { + return "datetime" + precision + } + return "datetime" + precision + " NULL" +} + +func (dialector Dialector) getSchemaBytesType(field *schema.Field) string { + if field.Size > 0 && field.Size < 65536 { + return fmt.Sprintf("varbinary(%d)", field.Size) + } + + if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { + return "mediumblob" + } + + return "longblob" +} + +// autoRandomType +// field.DataType MUST be `schema.Int` or `schema.Uint` +// Judgement logic: +// 1. Is PrimaryKey; +// 2. Has default value; +// 3. Default value is "auto_random()"; +// 4. IGNORE the field.Size, it MUST be bigint; +// 5. CLEAR the default tag, and return true; +// 6. Otherwise, return false. +func autoRandomType(field *schema.Field) (bool, string) { + if field.PrimaryKey && field.HasDefaultValue && + strings.ToLower(strings.TrimSpace(field.DefaultValue)) == AutoRandomTag { + field.DefaultValue = "" + + sqlType := "bigint" + if field.DataType == schema.Uint { + sqlType += " unsigned" + } + sqlType += " auto_random" + return true, sqlType + } + + return false, "" +} + +func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string { + if autoRandom, typeString := autoRandomType(field); autoRandom { + return typeString + } + + constraint := func(sqlType string) string { + if field.DataType == schema.Uint { + sqlType += " unsigned" + } + if field.AutoIncrement { + sqlType += " AUTO_INCREMENT" + } + return sqlType + } + + switch { + case field.Size <= 8: + return constraint("tinyint") + case field.Size <= 16: + return constraint("smallint") + case field.Size <= 24: + return constraint("mediumint") + case field.Size <= 32: + return constraint("int") + default: + return constraint("bigint") + } +} + +func (dialector Dialector) getSchemaCustomType(field *schema.Field) string { + sqlType := string(field.DataType) + + if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), " auto_increment") { + sqlType += " AUTO_INCREMENT" + } + + return sqlType +} + +func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error { + return tx.Exec("SAVEPOINT " + name).Error +} + +func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { + return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error +} + +// checkVersion newer or equal returns true, old returns false +func checkVersion(newVersion, oldVersion string) bool { + if newVersion == oldVersion { + return true + } + + var ( + versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`) + + newVersions = strings.Split(newVersion, ".") + oldVersions = strings.Split(oldVersion, ".") + ) + for idx, nv := range newVersions { + if len(oldVersions) <= idx { + return true + } + + nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1")) + ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1")) + if nvi == ovi { + continue + } + return nvi > ovi + } + + return false +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/.gitignore b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/.gitignore new file mode 100644 index 000000000..485dee64b --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/License b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/License new file mode 100644 index 000000000..037e1653e --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/License @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/README.md b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/README.md new file mode 100644 index 000000000..01ba443e5 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/README.md @@ -0,0 +1,31 @@ +# GORM PostgreSQL Driver + +## Quick Start + +```go +import ( + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +// https://github.com/jackc/pgx +dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" +db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) +``` + +## Configuration + +```go +import ( + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +db, err := gorm.Open(postgres.New(postgres.Config{ + DSN: "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai", // data source name, refer https://github.com/jackc/pgx + PreferSimpleProtocol: true, // disables implicit prepared statement usage. By default pgx automatically uses the extended protocol +}), &gorm.Config{}) +``` + + +Checkout [https://gorm.io](https://gorm.io) for details. diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go new file mode 100644 index 000000000..285494c2d --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go @@ -0,0 +1,44 @@ +package postgres + +import ( + "encoding/json" + "github.com/jackc/pgx/v5/pgconn" + "gorm.io/gorm" +) + +var errCodes = map[string]string{ + "uniqueConstraint": "23505", +} + +type ErrMessage struct { + Code string `json:"Code"` + Severity string `json:"Severity"` + Message string `json:"Message"` +} + +// Translate it will translate the error to native gorm errors. +// Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback. +func (dialector Dialector) Translate(err error) error { + if pgErr, ok := err.(*pgconn.PgError); ok { + if pgErr.Code == errCodes["uniqueConstraint"] { + return gorm.ErrDuplicatedKey + } + return err + } + + parsedErr, marshalErr := json.Marshal(err) + if marshalErr != nil { + return err + } + + var errMsg ErrMessage + unmarshalErr := json.Unmarshal(parsedErr, &errMsg) + if unmarshalErr != nil { + return err + } + + if errMsg.Code == errCodes["uniqueConstraint"] { + return gorm.ErrDuplicatedKey + } + return err +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go new file mode 100644 index 000000000..e4d8e9260 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go @@ -0,0 +1,771 @@ +package postgres + +import ( + "database/sql" + "fmt" + "regexp" + "strings" + + "github.com/jackc/pgx/v5" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +const indexSql = ` +select + t.relname as table_name, + i.relname as index_name, + a.attname as column_name, + ix.indisunique as non_unique, + ix.indisprimary as primary +from + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a +where + t.oid = ix.indrelid + and i.oid = ix.indexrelid + and a.attrelid = t.oid + and a.attnum = ANY(ix.indkey) + and t.relkind = 'r' + and t.relname = ? +` + +var typeAliasMap = map[string][]string{ + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, +} + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + } + currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) + return m.DB.Raw( + "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, + ).Scan(&count).Error + }) + + return count > 0 +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX " + + if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { + createIndexSQL += "CONCURRENTLY " + } + + createIndexSQL += "IF NOT EXISTS ? ON ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + "(?)" + } else { + createIndexSQL += " ?" + } + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + } + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER INDEX ? RENAME TO ?", + clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error + }) +} + +func (m Migrator) GetTables() (tableList []string, err error) { + currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") + return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error +} + +func (m Migrator) CreateTable(values ...interface{}) (err error) { + if err = m.Migrator.CreateTable(values...); err != nil { + return + } + for _, value := range m.ReorderModels(values, false) { + if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + for _, field := range stmt.Schema.FieldsByDBName { + if field.Comment != "" { + if err := m.DB.Exec( + "COMMENT ON COLUMN ?.? IS ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), + ).Error; err != nil { + return err + } + } + } + } + return nil + }); err != nil { + return + } + } + return +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error + }) + return count > 0 +} + +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error + }); err != nil { + return err + } + } + return nil +} + +func (m Migrator) AddColumn(value interface{}, field string) error { + if err := m.Migrator.AddColumn(value, field); err != nil { + return err + } + m.resetPreparedStmts() + + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + if field.Comment != "" { + if err := m.DB.Exec( + "COMMENT ON COLUMN ?.? IS ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), + ).Error; err != nil { + return err + } + } + } + } + return nil + }) +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + name := field + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + } + + currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", + currentSchema, curTable, name, + ).Scan(&count).Error + }) + + return count > 0 +} + +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + // skip primary field + if !field.PrimaryKey { + if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil { + return err + } + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + var description string + currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) + values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema} + checkSQL := "SELECT description FROM pg_catalog.pg_description " + checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " + checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " + checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" + m.DB.Raw(checkSQL, values...).Scan(&description) + + comment := strings.Trim(field.Comment, "'") + comment = strings.Trim(comment, `"`) + if field.Comment != "" && comment != description { + if err := m.DB.Exec( + "COMMENT ON COLUMN ?.? IS ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), + ).Error; err != nil { + return err + } + } + return nil + }) +} + +// AlterColumn alter value's `field` column' type based on schema definition +func (m Migrator) AlterColumn(value interface{}, field string) error { + err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + var ( + columnTypes, _ = m.DB.Migrator().ColumnTypes(value) + fieldColumnType *migrator.ColumnType + ) + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + fieldColumnType, _ = columnType.(*migrator.ColumnType) + } + } + + fileType := clause.Expr{SQL: m.DataTypeOf(field)} + // check for typeName and SQL name + isSameType := true + if fieldColumnType.DatabaseTypeName() != fileType.SQL { + isSameType = false + // if different, also check for aliases + aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) + for _, alias := range aliases { + if strings.HasPrefix(fileType.SQL, alias) { + isSameType = true + break + } + } + } + + // not same, migrate + if !isSameType { + filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() + if field.AutoIncrement && filedColumnAutoIncrement { // update + serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) + if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { + if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { + return err + } + } + } else if field.AutoIncrement && !filedColumnAutoIncrement { // create + serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) + if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { + return err + } + } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete + if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil { + return err + } + } else { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?"+m.genUsingExpression(fileType.SQL, fieldColumnType.DatabaseTypeName()), + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil { + return err + } + } + } + + if null, _ := fieldColumnType.Nullable(); null == field.NotNull { + if field.NotNull { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } + } else { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } + } + } + + if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique { + idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)} + // Not a unique constraint but a unique index + if !m.HasIndex(stmt.Table, idxName.Name) { + if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil { + return err + } + } + } + + if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue { + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil { + return err + } + } else if field.DefaultValue != "(-)" { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + return err + } + } else { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + return err + } + } + } + } + return nil + } + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) + + if err != nil { + return err + } + m.resetPreparedStmts() + return nil +} + +func (m Migrator) genUsingExpression(targetType, sourceType string) string { + if targetType == "boolean" { + switch sourceType { + case "int2", "int8", "numeric": + return " USING ?::INT::?" + } + } + return " USING ?::?" +} + +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + currentSchema, curTable := m.CurrentSchema(stmt, table) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", + currentSchema, curTable, name, + ).Scan(&count).Error + }) + + return count > 0 +} + +func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { + columnTypes = make([]gorm.ColumnType, 0) + err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + var ( + currentDatabase = m.DB.Migrator().CurrentDatabase() + currentSchema, table = m.CurrentSchema(stmt, stmt.Table) + columns, err = m.DB.Raw( + "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", + currentDatabase, currentSchema, table).Rows() + ) + + if err != nil { + return err + } + + for columns.Next() { + var ( + column = &migrator.ColumnType{ + PrimaryKeyValue: sql.NullBool{Valid: true}, + UniqueValue: sql.NullBool{Valid: true}, + } + datetimePrecision sql.NullInt64 + radixValue sql.NullInt64 + typeLenValue sql.NullInt64 + identityIncrement sql.NullString + ) + + err = columns.Scan( + &column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue, + &radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, &identityIncrement, + ) + if err != nil { + return err + } + + if typeLenValue.Valid && typeLenValue.Int64 > 0 { + column.LengthValue = typeLenValue + } + + if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && + strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") { + column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} + column.DefaultValueValue = sql.NullString{} + } + + if column.DefaultValueValue.Valid { + column.DefaultValueValue.String = regexp.MustCompile(`'?(.*)\b'?:+[\w\s]+$`).ReplaceAllString(column.DefaultValueValue.String, "$1") + } + + if datetimePrecision.Valid { + column.DecimalSizeValue = datetimePrecision + } + + columnTypes = append(columnTypes, column) + } + columns.Close() + + // assign sql column type + { + rows, rowsErr := m.GetRows(currentSchema, table) + if rowsErr != nil { + return rowsErr + } + rawColumnTypes, err := rows.ColumnTypes() + if err != nil { + return err + } + for _, columnType := range columnTypes { + for _, c := range rawColumnTypes { + if c.Name() == columnType.Name() { + columnType.(*migrator.ColumnType).SQLColumnType = c + break + } + } + } + rows.Close() + } + + // check primary, unique field + { + columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() + if err != nil { + return err + } + uniqueContraints := map[string]int{} + for columnTypeRows.Next() { + var constraintName string + columnTypeRows.Scan(&constraintName) + uniqueContraints[constraintName]++ + } + columnTypeRows.Close() + + columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() + if err != nil { + return err + } + for columnTypeRows.Next() { + var name, constraintName, columnType string + columnTypeRows.Scan(&name, &constraintName, &columnType) + for _, c := range columnTypes { + mc := c.(*migrator.ColumnType) + if mc.NameValue.String == name { + switch columnType { + case "PRIMARY KEY": + mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + case "UNIQUE": + if uniqueContraints[constraintName] == 1 { + mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} + } + } + break + } + } + } + columnTypeRows.Close() + } + + // check column type + { + dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type + FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) + WHERE a.attnum > 0 -- hide internal columns + AND NOT a.attisdropped -- hide deleted columns + AND b.relname = ?`, currentSchema, table).Rows() + if err != nil { + return err + } + + for dataTypeRows.Next() { + var name, dataType string + dataTypeRows.Scan(&name, &dataType) + for _, c := range columnTypes { + mc := c.(*migrator.ColumnType) + if mc.NameValue.String == name { + mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true} + // Handle array type: _text -> text[] , _int4 -> integer[] + // Not support array size limits and array size limits because: + // https://www.postgresql.org/docs/current/arrays.html#ARRAYS-DECLARATION + if strings.HasPrefix(mc.DataTypeValue.String, "_") { + mc.DataTypeValue = sql.NullString{String: dataType, Valid: true} + } + break + } + } + } + dataTypeRows.Close() + } + + return err + }) + return +} + +func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) { + name := table.(string) + if _, ok := currentSchema.(string); ok { + name = fmt.Sprintf("%v.%v", currentSchema, table) + } + + return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Scopes(func(d *gorm.DB) *gorm.DB { + dialector, _ := m.Dialector.(Dialector) + // use simple protocol + if !m.DB.PrepareStmt && (dialector.Config != nil && (dialector.Config.DriverName == "" || dialector.Config.DriverName == "pgx")) { + d.Statement.Vars = append([]interface{}{pgx.QueryExecModeSimpleProtocol}, d.Statement.Vars...) + } + return d + }).Rows() +} + +func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) { + if strings.Contains(table, ".") { + if tables := strings.Split(table, `.`); len(tables) == 2 { + return tables[0], tables[1] + } + } + + if stmt.TableExpr != nil { + if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 { + return strings.TrimPrefix(tables[0], `"`), table + } + } + return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table +} + +func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, + serialDatabaseType string) (err error) { + + _, table := m.CurrentSchema(stmt, stmt.Table) + tableName := table.(string) + + sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_") + if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, + clause.Expr{SQL: serialDatabaseType}).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')", + clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?", + clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil { + return err + } + return +} + +func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, + serialDatabaseType string) (err error) { + + sequenceName, err := m.getColumnSequenceName(tx, stmt, field) + if err != nil { + return err + } + + if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { + return err + } + return +} + +func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, + fileType clause.Expr) (err error) { + + sequenceName, err := m.getColumnSequenceName(tx, stmt, field) + if err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", + m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil { + return err + } + + if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil { + return err + } + + return +} + +func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) ( + sequenceName string, err error) { + _, table := m.CurrentSchema(stmt, stmt.Table) + + // DefaultValueValue is reset by ColumnTypes, search again. + var columnDefault string + err = tx.Raw( + `SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`, + table, field.DBName).Scan(&columnDefault).Error + + if err != nil { + return + } + + sequenceName = strings.TrimSuffix( + strings.TrimPrefix(columnDefault, `nextval('`), + `'::regclass)`, + ) + return +} + +func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { + indexes := make([]gorm.Index, 0) + + err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + result := make([]*Index, 0) + scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error + if scanErr != nil { + return scanErr + } + indexMap := groupByIndexName(result) + for _, idx := range indexMap { + tempIdx := &migrator.Index{ + TableName: idx[0].TableName, + NameValue: idx[0].IndexName, + PrimaryKeyValue: sql.NullBool{ + Bool: idx[0].Primary, + Valid: true, + }, + UniqueValue: sql.NullBool{ + Bool: idx[0].NonUnique, + Valid: true, + }, + } + for _, x := range idx { + tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName) + } + indexes = append(indexes, tempIdx) + } + return nil + }) + return indexes, err +} + +// Index table index info +type Index struct { + TableName string `gorm:"column:table_name"` + ColumnName string `gorm:"column:column_name"` + IndexName string `gorm:"column:index_name"` + NonUnique bool `gorm:"column:non_unique"` + Primary bool `gorm:"column:primary"` +} + +func groupByIndexName(indexList []*Index) map[string][]*Index { + columnIndexMap := make(map[string][]*Index, len(indexList)) + for _, idx := range indexList { + columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx) + } + return columnIndexMap +} + +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return typeAliasMap[databaseTypeName] +} + +// should reset prepared stmts when table changed +func (m Migrator) resetPreparedStmts() { + if m.DB.PrepareStmt { + if pdb, ok := m.DB.ConnPool.(*gorm.PreparedStmtDB); ok { + pdb.Reset() + } + } +} + +func (m Migrator) DropColumn(dst interface{}, field string) error { + if err := m.Migrator.DropColumn(dst, field); err != nil { + return err + } + + m.resetPreparedStmts() + return nil +} + +func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { + if err := m.Migrator.RenameColumn(dst, oldName, field); err != nil { + return err + } + + m.resetPreparedStmts() + return nil +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go new file mode 100644 index 000000000..dbeabf561 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go @@ -0,0 +1,249 @@ +package postgres + +import ( + "database/sql" + "fmt" + "github.com/jackc/pgx/v5" + "regexp" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/stdlib" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +type Dialector struct { + *Config +} + +type Config struct { + DriverName string + DSN string + PreferSimpleProtocol bool + WithoutReturning bool + Conn gorm.ConnPool +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{&Config{DSN: dsn}} +} + +func New(config Config) gorm.Dialector { + return &Dialector{Config: &config} +} + +func (dialector Dialector) Name() string { + return "postgres" +} + +var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + callbackConfig := &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, + } + // register callbacks + if !dialector.WithoutReturning { + callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") + callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING") + callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING") + } + callbacks.RegisterDefaultCallbacks(db, callbackConfig) + + if dialector.Conn != nil { + db.ConnPool = dialector.Conn + } else if dialector.DriverName != "" { + db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN) + } else { + var config *pgx.ConnConfig + + config, err = pgx.ParseConfig(dialector.Config.DSN) + if err != nil { + return + } + if dialector.Config.PreferSimpleProtocol { + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + } + result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) + if len(result) > 2 { + config.RuntimeParams["timezone"] = result[2] + } + db.ConnPool = stdlib.OpenDB(*config) + } + return +} + +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, + }}} +} + +func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { + return clause.Expr{SQL: "DEFAULT"} +} + +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('$') + writer.WriteString(strconv.Itoa(len(stmt.Vars))) +} + +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + var ( + underQuoted, selfQuoted bool + continuousBacktick int8 + shiftDelimiter int8 + ) + + for _, v := range []byte(str) { + switch v { + case '"': + continuousBacktick++ + if continuousBacktick == 2 { + writer.WriteString(`""`) + continuousBacktick = 0 + } + case '.': + if continuousBacktick > 0 || !selfQuoted { + shiftDelimiter = 0 + underQuoted = false + continuousBacktick = 0 + writer.WriteByte('"') + } + writer.WriteByte(v) + continue + default: + if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { + writer.WriteByte('"') + underQuoted = true + if selfQuoted = continuousBacktick > 0; selfQuoted { + continuousBacktick -= 1 + } + } + + for ; continuousBacktick > 0; continuousBacktick -= 1 { + writer.WriteString(`""`) + } + + writer.WriteByte(v) + } + shiftDelimiter++ + } + + if continuousBacktick > 0 && !selfQuoted { + writer.WriteString(`""`) + } + writer.WriteByte('"') +} + +var numericPlaceholder = regexp.MustCompile(`\$(\d+)`) + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) +} + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "boolean" + case schema.Int, schema.Uint: + size := field.Size + if field.DataType == schema.Uint { + size++ + } + if field.AutoIncrement { + switch { + case size <= 16: + return "smallserial" + case size <= 32: + return "serial" + default: + return "bigserial" + } + } else { + switch { + case size <= 16: + return "smallint" + case size <= 32: + return "integer" + default: + return "bigint" + } + } + case schema.Float: + if field.Precision > 0 { + if field.Scale > 0 { + return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale) + } + return fmt.Sprintf("numeric(%d)", field.Precision) + } + return "decimal" + case schema.String: + if field.Size > 0 { + return fmt.Sprintf("varchar(%d)", field.Size) + } + return "text" + case schema.Time: + if field.Precision > 0 { + return fmt.Sprintf("timestamptz(%d)", field.Precision) + } + return "timestamptz" + case schema.Bytes: + return "bytea" + default: + return dialector.getSchemaCustomType(field) + } +} + +func (dialector Dialector) getSchemaCustomType(field *schema.Field) string { + sqlType := string(field.DataType) + + if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") { + size := field.Size + if field.GORMDataType == schema.Uint { + size++ + } + switch { + case size <= 16: + sqlType = "smallserial" + case size <= 32: + sqlType = "serial" + default: + sqlType = "bigserial" + } + } + + return sqlType +} + +func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error { + tx.Exec("SAVEPOINT " + name) + return nil +} + +func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { + tx.Exec("ROLLBACK TO SAVEPOINT " + name) + return nil +} + +func getSerialDatabaseType(s string) (dbType string, ok bool) { + switch s { + case "smallserial": + return "smallint", true + case "serial": + return "integer", true + case "bigserial": + return "bigint", true + default: + return "", false + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/.gitignore b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/.gitignore new file mode 100644 index 000000000..727333260 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/.gitignore @@ -0,0 +1,7 @@ +TODO* +documents +coverage.txt +_book +.idea +vendor +.vscode diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/.golangci.yml b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/.golangci.yml new file mode 100644 index 000000000..b88bf6722 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/.golangci.yml @@ -0,0 +1,20 @@ +linters: + enable: + - cyclop + - exportloopref + - gocritic + - gosec + - ineffassign + - misspell + - prealloc + - unconvert + - unparam + - goimports + - whitespace + +linters-settings: + whitespace: + multi-func: true + goimports: + local-prefixes: gorm.io/gorm + diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/LICENSE b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/LICENSE new file mode 100644 index 000000000..037e1653e --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md new file mode 100644 index 000000000..85ad3050c --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md @@ -0,0 +1,44 @@ +# GORM + +The fantastic ORM library for Golang, aims to be developer friendly. + +[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) +[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) +[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) + +## Overview + +* Full-Featured ORM +* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance) +* Hooks (Before/After Create/Save/Update/Delete/Find) +* Eager loading with `Preload`, `Joins` +* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point +* Context, Prepared Statement Mode, DryRun Mode +* Batch Insert, FindInBatches, Find To Map +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr +* Composite Primary Key +* Auto Migrations +* Logger +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… +* Every feature comes with tests +* Developer Friendly + +## Getting Started + +* GORM Guides [https://gorm.io](https://gorm.io) +* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) + +## Contributing + +[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) + +## Contributors + +[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework! + +## License + +© Jinzhu, 2013~time.Now + +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/association.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/association.go new file mode 100644 index 000000000..7c93ebea0 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/association.go @@ -0,0 +1,579 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Association Mode contains some helper methods to handle relationship things easily. +type Association struct { + DB *DB + Relationship *schema.Relationship + Unscope bool + Error error +} + +func (db *DB) Association(column string) *Association { + association := &Association{DB: db} + table := db.Statement.Table + + if err := db.Statement.Parse(db.Statement.Model); err == nil { + db.Statement.Table = table + association.Relationship = db.Statement.Schema.Relationships.Relations[column] + + if association.Relationship == nil { + association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) + } + + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + } else { + association.Error = err + } + + return association +} + +func (association *Association) Unscoped() *Association { + return &Association{ + DB: association.DB, + Relationship: association.Relationship, + Error: association.Error, + Unscope: true, + } +} + +func (association *Association) Find(out interface{}, conds ...interface{}) error { + if association.Error == nil { + association.Error = association.buildCondition().Find(out, conds...).Error + } + return association.Error +} + +func (association *Association) Append(values ...interface{}) error { + if association.Error == nil { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + if len(values) > 0 { + association.Error = association.Replace(values...) + } + default: + association.saveAssociation( /*clear*/ false, values...) + } + } + + return association.Error +} + +func (association *Association) Replace(values ...interface{}) error { + if association.Error == nil { + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + var oldBelongsToExpr clause.Expression + // we have to record the old BelongsTo value + if association.Unscope && rel.Type == schema.BelongsTo { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + oldBelongsToExpr = clause.IN{Column: column, Values: values} + } + } + + // save associations + if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { + return association.Error + } + + // set old associations's foreign key to null + switch rel.Type { + case schema.BelongsTo: + if len(values) == 0 { + updateMap := map[string]interface{}{} + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + } + case reflect.Struct: + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + } + + for _, ref := range rel.References { + updateMap[ref.ForeignKey.DBName] = nil + } + + association.Error = association.DB.UpdateColumns(updateMap).Error + } + if association.Unscope && oldBelongsToExpr != nil { + association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + case schema.HasOne, schema.HasMany: + var ( + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { + tx.Not(clause.IN{Column: column, Values: values}) + } + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateMap[ref.ForeignKey.DBName] = nil + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) + if association.Unscope { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error + } else { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + } + } + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { + tx.Where(clause.IN{Column: column, Values: values}) + } else { + return ErrPrimaryKeyRequired + } + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) + if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { + tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) + } + + association.Error = tx.Delete(modelValue).Error + } + } + return association.Error +} + +func (association *Association) Delete(values ...interface{}) error { + if association.Error == nil { + var ( + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + primaryFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + conds []clause.Expression + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + switch rel.Type { + case schema.BelongsTo: + associationDB := association.DB.Session(&Session{}) + tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) + + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) + relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + } + case schema.HasOne, schema.HasMany: + model := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := association.DB.Model(model) + + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + if association.Unscope { + association.Error = tx.Clauses(conds...).Delete(model).Error + } else { + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + } + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error + } + + if association.Error == nil { + // clean up deleted values's foreign key + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) + + cleanUpDeletedRelations := func(data reflect.Value) { + if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) + primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) + + switch fieldValue.Kind() { + case reflect.Slice, reflect.Array: + validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) + for i := 0; i < fieldValue.Len(); i++ { + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { + validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) + } + } + + association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) + case reflect.Struct: + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { + if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + break + } + + if rel.JoinTable == nil { + for _, ref := range rel.References { + if ref.OwnPrimaryKey || ref.PrimaryValue != "" { + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } else { + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } + } + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) + } + case reflect.Struct: + cleanUpDeletedRelations(reflectValue) + } + } + } + + return association.Error +} + +func (association *Association) Clear() error { + return association.Replace() +} + +func (association *Association) Count() (count int64) { + if association.Error == nil { + association.Error = association.buildCondition().Count(&count).Error + } + return +} + +type assignBack struct { + Source reflect.Value + Index int + Dest reflect.Value +} + +func (association *Association) saveAssociation(clear bool, values ...interface{}) { + var ( + reflectValue = association.DB.Statement.ReflectValue + assignBacks []assignBack // assign association values back to arguments after save + ) + + appendToRelations := func(source, rv reflect.Value, clear bool) { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() > 0 { + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) + } + } + case reflect.Struct: + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) + } + } + case schema.HasMany, schema.Many2Many: + elemType := association.Relationship.Field.IndirectFieldType.Elem() + oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + var fieldValue reflect.Value + if clear { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) + } else { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap()) + reflect.Copy(fieldValue, oldFieldValue) + } + + appendToFieldValues := func(ev reflect.Value) { + if ev.Type().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev) + } else if ev.Type().Elem().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev.Elem()) + } else { + association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) + } + + if elemType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) + } + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + } + case reflect.Struct: + appendToFieldValues(rv.Addr()) + } + + if association.Error == nil { + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) + } + } + } + + selectedSaveColumns := []string{association.Relationship.Name} + omitColumns := []string{} + selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, association.Relationship.Name) { + if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { + columnName = name + } + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selectedSaveColumns = append(selectedSaveColumns, columnName) + } else { + omitColumns = append(omitColumns, columnName) + } + } + } + + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) + } + } + + associationDB := association.DB.Session(&Session{}).Model(nil) + if !association.DB.FullSaveAssociations { + associationDB.Select(selectedSaveColumns) + } + if len(omitColumns) > 0 { + associationDB.Omit(omitColumns...) + } + associationDB = associationDB.Session(&Session{}) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(values) != reflectValue.Len() { + // clear old data + if clear && len(values) == 0 { + for i := 0; i < reflectValue.Len(); i++ { + if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + association.Error = err + break + } + + if association.Relationship.JoinTable == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + association.Error = err + break + } + } + } + } + } + break + } + + association.Error = ErrInvalidValueOfLength + return + } + + for i := 0; i < reflectValue.Len(); i++ { + appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + + // TODO support save slice data, sql with case? + association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error + } + case reflect.Struct: + // clear old data + if clear && len(values) == 0 { + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + + if association.Relationship.JoinTable == nil && association.Error == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } + } + } + + for idx, value := range values { + rv := reflect.Indirect(reflect.ValueOf(value)) + appendToRelations(reflectValue, rv, clear && idx == 0) + } + + if len(values) > 0 { + association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error + } + } + + for _, assignBack := range assignBacks { + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) + if assignBack.Index > 0 { + reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) + } else { + reflect.Indirect(assignBack.Dest).Set(fieldValue) + } + } +} + +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + if len(joinStmt.SQL.String()) > 0 { + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + } + + tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + return tx +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go new file mode 100644 index 000000000..195d17203 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go @@ -0,0 +1,341 @@ +package gorm + +import ( + "context" + "errors" + "fmt" + "reflect" + "sort" + "time" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func initializeCallbacks(db *DB) *callbacks { + return &callbacks{ + processors: map[string]*processor{ + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + }, + } +} + +// callbacks gorm callbacks manager +type callbacks struct { + processors map[string]*processor +} + +type processor struct { + db *DB + Clauses []string + fns []func(*DB) + callbacks []*callback +} + +type callback struct { + name string + before string + after string + remove bool + replace bool + match func(*DB) bool + handler func(*DB) + processor *processor +} + +func (cs *callbacks) Create() *processor { + return cs.processors["create"] +} + +func (cs *callbacks) Query() *processor { + return cs.processors["query"] +} + +func (cs *callbacks) Update() *processor { + return cs.processors["update"] +} + +func (cs *callbacks) Delete() *processor { + return cs.processors["delete"] +} + +func (cs *callbacks) Row() *processor { + return cs.processors["row"] +} + +func (cs *callbacks) Raw() *processor { + return cs.processors["raw"] +} + +func (p *processor) Execute(db *DB) *DB { + // call scopes + for len(db.Statement.scopes) > 0 { + db = db.executeScopes() + } + + var ( + curTime = time.Now() + stmt = db.Statement + resetBuildClauses bool + ) + + if len(stmt.BuildClauses) == 0 { + stmt.BuildClauses = p.Clauses + resetBuildClauses = true + } + + if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } + + // assign model values + if stmt.Model == nil { + stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model + } + + // parse model values + if stmt.Model != nil { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { + db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) + } else { + db.AddError(err) + } + } + } + + // assign stmt.ReflectValue + if stmt.Dest != nil { + stmt.ReflectValue = reflect.ValueOf(stmt.Dest) + for stmt.ReflectValue.Kind() == reflect.Ptr { + if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { + stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) + } + + stmt.ReflectValue = stmt.ReflectValue.Elem() + } + if !stmt.ReflectValue.IsValid() { + db.AddError(ErrInvalidValue) + } + } + + for _, f := range p.fns { + f(db) + } + + if stmt.SQL.Len() > 0 { + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + sql, vars := stmt.SQL.String(), stmt.Vars + if filter, ok := db.Logger.(ParamsFilter); ok { + sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...) + } + return db.Dialector.Explain(sql, vars...), db.RowsAffected + }, db.Error) + } + + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + } + + if resetBuildClauses { + stmt.BuildClauses = nil + } + + return db +} + +func (p *processor) Get(name string) func(*DB) { + for i := len(p.callbacks) - 1; i >= 0; i-- { + if v := p.callbacks[i]; v.name == name && !v.remove { + return v.handler + } + } + return nil +} + +func (p *processor) Before(name string) *callback { + return &callback{before: name, processor: p} +} + +func (p *processor) After(name string) *callback { + return &callback{after: name, processor: p} +} + +func (p *processor) Match(fc func(*DB) bool) *callback { + return &callback{match: fc, processor: p} +} + +func (p *processor) Register(name string, fn func(*DB)) error { + return (&callback{processor: p}).Register(name, fn) +} + +func (p *processor) Remove(name string) error { + return (&callback{processor: p}).Remove(name) +} + +func (p *processor) Replace(name string, fn func(*DB)) error { + return (&callback{processor: p}).Replace(name, fn) +} + +func (p *processor) compile() (err error) { + var callbacks []*callback + for _, callback := range p.callbacks { + if callback.match == nil || callback.match(p.db) { + callbacks = append(callbacks, callback) + } + } + p.callbacks = callbacks + + if p.fns, err = sortCallbacks(p.callbacks); err != nil { + p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) + } + return +} + +func (c *callback) Before(name string) *callback { + c.before = name + return c +} + +func (c *callback) After(name string) *callback { + c.after = name + return c +} + +func (c *callback) Register(name string, fn func(*DB)) error { + c.name = name + c.handler = fn + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Remove(name string) error { + c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) + c.name = name + c.remove = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Replace(name string, fn func(*DB)) error { + c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) + c.name = name + c.handler = fn + c.replace = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +// getRIndex get right index from string slice +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i + } + } + return -1 +} + +func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { + var ( + names, sorted []string + sortCallback func(*callback) error + ) + sort.SliceStable(cs, func(i, j int) bool { + if cs[j].before == "*" && cs[i].before != "*" { + return true + } + if cs[j].after == "*" && cs[i].after != "*" { + return true + } + return false + }) + + for _, c := range cs { + // show warning message the callback name already exists + if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) + } + names = append(names, c.name) + } + + sortCallback = func(c *callback) error { + if c.before != "" { // if defined before callback + if c.before == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append([]string{c.name}, sorted...) + } + } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if before callback already sorted, append current callback just after it + sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) + } + } else if idx := getRIndex(names, c.before); idx != -1 { + // if before callback exists + cs[idx].after = c.name + } + } + + if c.after != "" { // if defined after callback + if c.after == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append(sorted, c.name) + } + } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if after callback sorted, append current callback to last + sorted = append(sorted, c.name) + } else if curIdx < sortedIdx { + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) + } + } else if idx := getRIndex(names, c.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + after := cs[idx] + + if after.before == "" { + after.before = c.name + } + + if err := sortCallback(after); err != nil { + return err + } + + if err := sortCallback(c); err != nil { + return err + } + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sorted, c.name) == -1 { + sorted = append(sorted, c.name) + } + + return nil + } + + for _, c := range cs { + if err = sortCallback(c); err != nil { + return + } + } + + for _, name := range sorted { + if idx := getRIndex(names, name); !cs[idx].remove { + fns = append(fns, cs[idx].handler) + } + } + + return +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/associations.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/associations.go new file mode 100644 index 000000000..f3cd464ae --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/associations.go @@ -0,0 +1,453 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func SaveBeforeAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) + + // Save Belongs To associations + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + rValLen = db.Statement.ReflectValue.Len() + objs = make([]reflect.Value, 0, rValLen) + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} + for i := 0; i < rValLen; i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() != reflect.Struct { + break + } + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value + if !isPtr { + rv = rv.Addr() + } + objs = append(objs, obj) + elems = reflect.Append(elems, rv) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, rv) + } + } + } + + if elems.Len() > 0 { + if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { + for i := 0; i < elems.Len(); i++ { + setupReferences(objs[i], elems.Index(i)) + } + } + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { + setupReferences(db.Statement.ReflectValue, rv) + } + } + } + } + } + } +} + +func SaveAfterAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) + + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } + } + } + + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } + + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) + } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) + } + } + } + + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) + } + } + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) + } + } + + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) + objs := []reflect.Value{} + + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) + } else { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) + } + } + joins = reflect.Append(joins, joinValue) + } + + identityMap := map[string]bool{} + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + if !isPtr { + elem = elem.Addr() + } + objs = append(objs, v) + elems = reflect.Append(elems, elem) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, elem) + } + + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + // optimize elems of reflect value length + if elemLen := elems.Len(); elemLen > 0 { + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) + } + + for i := 0; i < elemLen; i++ { + appendToJoins(objs[i], elems.Index(i)) + } + } + + if joins.Len() > 0 { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }).Create(joins.Interface()).Error) + } + } + } + } +} + +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { + if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { + onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) + for _, dbName := range s.PrimaryFieldDBNames { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName}) + } + + onConflict.UpdateAll = stmt.DB.FullSaveAssociations + if !onConflict.UpdateAll { + onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns) + } + } else { + onConflict.DoNothing = true + } + + return +} + +func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + // stop save association loop + if checkAssociationsSaved(db, rValues) { + return nil + } + + var ( + selects, omits []string + onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns) + refName = rel.Name + "." + values = rValues.Interface() + ) + + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, refName) { + columnName = strings.TrimPrefix(name, refName) + } + + if columnName != "" { + if ok { + selects = append(selects, columnName) + } else { + omits = append(omits, columnName) + } + } + } + + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ + FullSaveAssociations: db.FullSaveAssociations, + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }) + + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if tx.Statement.FullSaveAssociations { + tx = tx.Set("gorm:update_track_time", true) + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } else if restricted && len(omits) == 0 { + tx = tx.Omit(clause.Associations) + } + + if len(omits) > 0 { + tx = tx.Omit(omits...) + } + + return db.AddError(tx.Create(values).Error) +} + +// check association values has been saved +// if values kind is Struct, check it has been saved +// if values kind is Slice/Array, check all items have been saved +var visitMapStoreKey = "gorm:saved_association_map" + +func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { + if visit, ok := db.Get(visitMapStoreKey); ok { + if v, ok := visit.(*visitMap); ok { + if loadOrStoreVisitMap(v, values) { + return true + } + } + } else { + vistMap := make(visitMap) + loadOrStoreVisitMap(&vistMap, values) + db.Set(visitMapStoreKey, &vistMap) + } + + return false +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callbacks.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callbacks.go new file mode 100644 index 000000000..d681aef36 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callbacks.go @@ -0,0 +1,83 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +var ( + createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} + queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + updateClauses = []string{"UPDATE", "SET", "WHERE"} + deleteClauses = []string{"DELETE", "FROM", "WHERE"} +) + +type Config struct { + LastInsertIDReversed bool + CreateClauses []string + QueryClauses []string + UpdateClauses []string + DeleteClauses []string +} + +func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { + enableTransaction := func(db *gorm.DB) bool { + return !db.SkipDefaultTransaction + } + + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + + createCallback := db.Callback().Create() + createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + createCallback.Register("gorm:before_create", BeforeCreate) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) + createCallback.Register("gorm:create", Create(config)) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) + createCallback.Register("gorm:after_create", AfterCreate) + createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + createCallback.Clauses = config.CreateClauses + + queryCallback := db.Callback().Query() + queryCallback.Register("gorm:query", Query) + queryCallback.Register("gorm:preload", Preload) + queryCallback.Register("gorm:after_query", AfterQuery) + queryCallback.Clauses = config.QueryClauses + + deleteCallback := db.Callback().Delete() + deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) + deleteCallback.Register("gorm:delete", Delete(config)) + deleteCallback.Register("gorm:after_delete", AfterDelete) + deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + deleteCallback.Clauses = config.DeleteClauses + + updateCallback := db.Callback().Update() + updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) + updateCallback.Register("gorm:before_update", BeforeUpdate) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) + updateCallback.Register("gorm:update", Update(config)) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) + updateCallback.Register("gorm:after_update", AfterUpdate) + updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + updateCallback.Clauses = config.UpdateClauses + + rowCallback := db.Callback().Row() + rowCallback.Register("gorm:row", RowQuery) + rowCallback.Clauses = config.QueryClauses + + rawCallback := db.Callback().Raw() + rawCallback.Register("gorm:raw", RawExec) + rawCallback.Clauses = config.QueryClauses +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callmethod.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callmethod.go new file mode 100644 index 000000000..fb9000379 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/callmethod.go @@ -0,0 +1,32 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" +) + +func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { + tx := db.Session(&gorm.Session{NewDB: true}) + if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + db.Statement.CurDestIndex = 0 + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { + fc(value.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + return + } + db.Statement.CurDestIndex++ + } + case reflect.Struct: + if db.Statement.ReflectValue.CanAddr() { + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + } + } + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go new file mode 100644 index 000000000..f0b781398 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go @@ -0,0 +1,345 @@ +package callbacks + +import ( + "fmt" + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// BeforeCreate before create hooks +func BeforeCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeCreate { + if i, ok := value.(BeforeCreateInterface); ok { + called = true + db.AddError(i.BeforeCreate(tx)) + } + } + return called + }) + } +} + +// Create create hook +func Create(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.CreateClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + + if db.Statement.Schema != nil { + if !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + if _, ok := db.Statement.Clauses["RETURNING"]; !ok { + fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) + } + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build(db.Statement.BuildClauses...) + } + + isDryRun := !db.DryRun && db.Error == nil + if !isDryRun { + return + } + + ok, mode := hasReturning(db, supportReturning) + if ok { + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing + } + } + + rows, err := db.Statement.ConnPool.QueryContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if db.AddError(err) == nil { + defer func() { + db.AddError(rows.Close()) + }() + gorm.Scan(rows, db, mode) + } + + return + } + + result, err := db.Statement.ConnPool.ExecContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if err != nil { + db.AddError(err) + return + } + + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && + db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + if isZero { + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } + case reflect.Struct: + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + if isZero { + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) + } + } + } + } +} + +// AfterCreate after create hooks +func AfterCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterCreate { + if i, ok := value.(AfterCreateInterface); ok { + called = true + db.AddError(i.AfterCreate(tx)) + } + } + + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + return called + }) + } +} + +// ConvertToCreateValues convert to create values +func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { + curTime := stmt.DB.NowFunc() + + switch value := stmt.Dest.(type) { + case map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) + case []map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) + default: + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + _, updateTrackTime = stmt.Get("gorm:update_track_time") + isZero bool + ) + stmt.Settings.Delete("gorm:update_track_time") + + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} + + for _, db := range stmt.Schema.DBNames { + if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { + if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { + values.Columns = append(values.Columns, clause.Column{Name: db}) + } + } + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + rValLen := stmt.ReflectValue.Len() + if rValLen == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + stmt.SQL.Grow(rValLen * 18) + stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) + values.Values = make([][]interface{}, rValLen) + + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + for i := 0; i < rValLen; i++ { + rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + + values.Values[i] = make([]interface{}, len(values.Columns)) + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { + if field.DefaultValueInterface != nil { + values.Values[i][idx] = field.DefaultValueInterface + stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + stmt.AddError(field.Set(stmt.Context, rv, curTime)) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) + } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + stmt.AddError(field.Set(stmt.Context, rv, curTime)) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { + if len(defaultValueFieldsHavingValue[field]) == 0 { + defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) + } + defaultValueFieldsHavingValue[field][i] = rvOfvalue + } + } + } + } + + for field, vs := range defaultValueFieldsHavingValue { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } + } + } + case reflect.Struct: + values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { + if field.DefaultValueInterface != nil { + values.Values[0][idx] = field.DefaultValueInterface + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) + } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + values.Values[0] = append(values.Values[0], rvOfvalue) + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + if c, ok := stmt.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { + if stmt.Schema != nil && len(values.Columns) >= 1 { + selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) + + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || + strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { + if field.AutoUpdateTime > 0 { + assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} + switch field.AutoUpdateTime { + case schema.UnixNanosecond: + assignment.Value = curTime.UnixNano() + case schema.UnixMillisecond: + assignment.Value = curTime.UnixNano() / 1e6 + case schema.UnixSecond: + assignment.Value = curTime.Unix() + } + + onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) + } else { + columns = append(columns, column.Name) + } + } + } + } + } + + onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) + if len(onConflict.DoUpdates) == 0 { + onConflict.DoNothing = true + } + + // use primary fields as default OnConflict columns + if len(onConflict.Columns) == 0 { + for _, field := range stmt.Schema.PrimaryFields { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName}) + } + } + stmt.AddClause(onConflict) + } + } + } + + return values +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/delete.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/delete.go new file mode 100644 index 000000000..84f446a3f --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/delete.go @@ -0,0 +1,185 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func BeforeDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(BeforeDeleteInterface); ok { + db.AddError(i.BeforeDelete(tx)) + return true + } + + return false + }) + } +} + +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + if !restricted { + return + } + + for column, v := range selectColumns { + if !v { + continue + } + + rel, ok := db.Statement.Schema.Relationships.Relations[column] + if !ok { + continue + } + + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } + + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { + selects = append(selects, strings.TrimPrefix(s, columnPrefix)) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + + } +} + +func Delete(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(100) + db.Statement.AddClauseIfNotExists(clause.Delete{}) + + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + db.Statement.AddClauseIfNotExists(clause.From{}) + + db.Statement.Build(db.Statement.BuildClauses...) + } + + checkMissingWhereConditions(db) + + if !db.DryRun && db.Error == nil { + ok, mode := hasReturning(db, supportReturning) + if !ok { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } + + return + } + + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + db.AddError(rows.Close()) + } + } + } +} + +func AfterDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterDeleteInterface); ok { + db.AddError(i.AfterDelete(tx)) + return true + } + return false + }) + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/helper.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/helper.go new file mode 100644 index 000000000..ae9fd8c56 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/helper.go @@ -0,0 +1,152 @@ +package callbacks + +import ( + "reflect" + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// ConvertMapToValuesForCreate convert map to values +func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { + values.Columns = make([]clause.Column, 0, len(mapValue)) + selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) + + keys := make([]string, 0, len(mapValue)) + for k := range mapValue { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + value := mapValue[k] + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + + values.Values[0] = append(values.Values[0], value) + } + } + return +} + +// ConvertSliceOfMapToValuesForCreate convert slice of map to values +func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { + columns := make([]string, 0, len(mapValues)) + + // when the length of mapValues is zero,return directly here + // no need to call stmt.SelectAndOmitColumns method + if len(mapValues) == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + var ( + result = make(map[string][]interface{}, len(mapValues)) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + ) + + for idx, mapValue := range mapValues { + for k, v := range mapValue { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if _, ok := result[k]; !ok { + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + result[k] = make([]interface{}, len(mapValues)) + columns = append(columns, k) + } else { + continue + } + } + + result[k][idx] = v + } + } + + sort.Strings(columns) + values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) + for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + + for i, v := range result[column] { + if len(values.Values[i]) == 0 { + values.Values[i] = make([]interface{}, len(columns)) + } + + values.Values[i][idx] = v + } + } + return +} + +func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { + if supportReturning { + if c, ok := tx.Statement.Clauses["RETURNING"]; ok { + returning, _ := c.Expression.(clause.Returning) + if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { + return true, 0 + } + return true, gorm.ScanUpdate + } + } + return false, 0 +} + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} + +type visitMap = map[reflect.Value]bool + +// Check if circular values, return true if loaded +func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice, reflect.Array: + loaded = true + for i := 0; i < v.Len(); i++ { + if !loadOrStoreVisitMap(visitMap, v.Index(i)) { + loaded = false + } + } + case reflect.Struct, reflect.Interface: + if v.CanAddr() { + p := v.Addr() + if _, ok := (*visitMap)[p]; ok { + return true + } + (*visitMap)[p] = true + } + } + + return +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/interfaces.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/interfaces.go new file mode 100644 index 000000000..2302470fc --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go new file mode 100644 index 000000000..15669c847 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go @@ -0,0 +1,266 @@ +package callbacks + +import ( + "fmt" + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// parsePreloadMap extracts nested preloads. e.g. +// +// // schema has a "k0" relation and a "k7.k8" embedded relation +// parsePreloadMap(schema, map[string][]interface{}{ +// clause.Associations: {"arg1"}, +// "k1": {"arg2"}, +// "k2.k3": {"arg3"}, +// "k4.k5.k6": {"arg4"}, +// }) +// // preloadMap is +// map[string]map[string][]interface{}{ +// "k0": {}, +// "k7": { +// "k8": {}, +// }, +// "k1": {}, +// "k2": { +// "k3": {"arg3"}, +// }, +// "k4": { +// "k5.k6": {"arg4"}, +// }, +// } +func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { + preloadMap := map[string]map[string][]interface{}{} + setPreloadMap := func(name, value string, args []interface{}) { + if _, ok := preloadMap[name]; !ok { + preloadMap[name] = map[string][]interface{}{} + } + if value != "" { + preloadMap[name][value] = args + } + } + + for name, args := range preloads { + preloadFields := strings.Split(name, ".") + value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") + if preloadFields[0] == clause.Associations { + for _, relation := range s.Relationships.Relations { + if relation.Schema == s { + setPreloadMap(relation.Name, value, args) + } + } + + for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { + for _, value := range embeddedValues(embeddedRelations) { + setPreloadMap(embedded, value, args) + } + } + } else { + setPreloadMap(preloadFields[0], value, args) + } + } + return preloadMap +} + +func embeddedValues(embeddedRelations *schema.Relationships) []string { + if embeddedRelations == nil { + return nil + } + names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) + for _, relation := range embeddedRelations.Relations { + // skip first struct name + names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + } + for _, relations := range embeddedRelations.EmbeddedRelations { + names = append(names, embeddedValues(relations)...) + } + return names +} + +func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { + if relationships == nil { + return nil + } + preloadMap := parsePreloadMap(s, preloads) + for name := range preloadMap { + if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { + if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + return err + } + } else if rel := relationships.Relations[name]; rel != nil { + if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { + return err + } + } else { + return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + } + } + return nil +} + +func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { + var ( + reflectValue = tx.Statement.ReflectValue + relForeignKeys []string + relForeignFields []*schema.Field + foreignFields []*schema.Field + foreignValues [][]interface{} + identityMap = map[string][]reflect.Value{} + inlineConds []interface{} + ) + + if rel.JoinTable != nil { + var ( + joinForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinForeignKeys = make([]string, 0, len(rel.References)) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) + joinForeignFields = append(joinForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + } + } + + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) + if len(joinForeignValues) == 0 { + return nil + } + + joinResults := rel.JoinTable.MakeSlice().Elem() + column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { + return err + } + + // convert join identity map to relation identity map + fieldValues := make([]interface{}, len(joinForeignFields)) + joinFieldValues := make([]interface{}, len(joinRelForeignFields)) + for i := 0; i < joinResults.Len(); i++ { + joinIndexValue := joinResults.Index(i) + for idx, field := range joinForeignFields { + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) + } + + for idx, field := range joinRelForeignFields { + joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) + } + + if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { + joinKey := utils.ToStringKey(joinFieldValues...) + identityMap[joinKey] = append(identityMap[joinKey], results...) + } + } + + _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + relForeignFields = append(relForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) + if len(foreignValues) == 0 { + return nil + } + } + + // nested preload + for p, pvs := range preloads { + tx = tx.Preload(p, pvs...) + } + + reflectResults := rel.FieldSchema.MakeSlice().Elem() + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) + + if len(values) != 0 { + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + tx = fc(tx) + } else { + inlineConds = append(inlineConds, cond) + } + } + + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + return err + } + } + + fieldValues := make([]interface{}, len(relForeignFields)) + + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + switch rel.Type { + case schema.HasMany, schema.Many2Many: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) + default: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + switch rel.Type { + case schema.HasMany, schema.Many2Many: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) + default: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) + } + } + } + + for i := 0; i < reflectResults.Len(); i++ { + elem := reflectResults.Index(i) + for idx, field := range relForeignFields { + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) + } + + datas, ok := identityMap[utils.ToStringKey(fieldValues...)] + if !ok { + return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) + } + + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) + } else { + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) + } + } + } + } + + return tx.Error +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go new file mode 100644 index 000000000..e89dd1996 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go @@ -0,0 +1,316 @@ +package callbacks + +import ( + "fmt" + "reflect" + "sort" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func Query(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + + if !db.DryRun && db.Error == nil { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer func() { + db.AddError(rows.Close()) + }() + gorm.Scan(rows, db, 0) + } + } +} + +func BuildQuerySQL(db *gorm.DB) { + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(100) + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} + + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) + } + } + + if len(db.Statement.Selects) > 0 { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} + } else { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } + } + } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { + selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) + clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) + for _, dbName := range db.Statement.Schema.DBNames { + if v, ok := selectColumns[dbName]; (ok && v) || !ok { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) + } + } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + queryFields := db.QueryFields + if !queryFields { + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } + } + + if queryFields { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + } + } + + // inline joins + fromClause := clause.From{} + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause = v + } + + if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + + specifiedRelationsName := make(map[string]interface{}) + for _, join := range db.Statement.Joins { + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } + } + + if isNestedJoin { + isRelations = true + relations = gussNestedRelations + } + } + } + + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } + + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), + }) + } + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } + } + } + + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } + + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } + } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } + } + + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + + if parentTableName != clause.CurrentTable { + parentTableName = utils.NestedRelationName(parentTableName, rel.Name) + } else { + parentTableName = rel.Name + } + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } + + db.Statement.AddClause(fromClause) + db.Statement.Joins = nil + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + db.Statement.AddClauseIfNotExists(clauseSelect) + + db.Statement.Build(db.Statement.BuildClauses...) + } +} + +func Preload(db *gorm.DB) { + if db.Error == nil && len(db.Statement.Preloads) > 0 { + if db.Statement.Schema == nil { + db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) + return + } + + preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) + preloadNames := make([]string, 0, len(preloadMap)) + for key := range preloadMap { + preloadNames = append(preloadNames, key) + } + sort.Strings(preloadNames) + + preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + preloadDB.Statement.Settings.Store(k, v) + return true + }) + + if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + return + } + preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + preloadDB.Statement.Unscoped = db.Statement.Unscoped + + for _, name := range preloadNames { + if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { + db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) + } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) + } else { + db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) + } + } + } +} + +func AfterQuery(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterFindInterface); ok { + db.AddError(i.AfterFind(tx)) + return true + } + return false + }) + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/raw.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/raw.go new file mode 100644 index 000000000..013e638cb --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/raw.go @@ -0,0 +1,17 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RawExec(db *gorm.DB) { + if db.Error == nil && !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + + db.RowsAffected, _ = result.RowsAffected() + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/row.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/row.go new file mode 100644 index 000000000..beaa189e1 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/row.go @@ -0,0 +1,23 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RowQuery(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + if db.DryRun || db.Error != nil { + return + } + + if isRows, ok := db.Get("rows"); ok && isRows.(bool) { + db.Statement.Settings.Delete("rows") + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } + + db.RowsAffected = -1 + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/transaction.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/transaction.go new file mode 100644 index 000000000..50887ccce --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/transaction.go @@ -0,0 +1,32 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func BeginTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction && db.Error == nil { + if tx := db.Begin(); tx.Error == nil { + db.Statement.ConnPool = tx.Statement.ConnPool + db.InstanceSet("gorm:started_transaction", true) + } else if tx.Error == gorm.ErrInvalidTransaction { + tx.Error = nil + } else { + db.Error = tx.Error + } + } +} + +func CommitOrRollbackTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction { + if _, ok := db.InstanceGet("gorm:started_transaction"); ok { + if db.Error != nil { + db.Rollback() + } else { + db.Commit() + } + + db.Statement.ConnPool = db.ConnPool + } + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go new file mode 100644 index 000000000..ff075dcf2 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go @@ -0,0 +1,304 @@ +package callbacks + +import ( + "reflect" + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func SetupUpdateReflectValue(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if _, ok := dest[rel.Name]; ok { + db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) + } + } + } + } + } +} + +// BeforeUpdate before update hooks +func BeforeUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeUpdate { + if i, ok := value.(BeforeUpdateInterface); ok { + called = true + db.AddError(i.BeforeUpdate(tx)) + } + } + + return called + }) + } +} + +// Update update hook +func Update(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if _, ok := db.Statement.Clauses["SET"]; !ok { + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + defer delete(db.Statement.Clauses, "SET") + db.Statement.AddClause(set) + } else { + return + } + } + + db.Statement.Build(db.Statement.BuildClauses...) + } + + checkMissingWhereConditions(db) + + if !db.DryRun && db.Error == nil { + if ok, mode := hasReturning(db, supportReturning); ok { + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + dest := db.Statement.Dest + db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() + gorm.Scan(rows, db, mode) + db.Statement.Dest = dest + db.AddError(rows.Close()) + } + } else { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } + } + } + } +} + +// AfterUpdate after update hooks +func AfterUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(AfterUpdateInterface); ok { + called = true + db.AddError(i.AfterUpdate(tx)) + } + } + + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + + return called + }) + } +} + +// ConvertToAssignments convert to update assignments +func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) + assignValue func(field *schema.Field, value interface{}) + ) + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + assignValue = func(field *schema.Field, value interface{}) { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } + } + } + case reflect.Struct: + assignValue = func(field *schema.Field, value interface{}) { + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue, value) + } + } + default: + assignValue = func(field *schema.Field, value interface{}) { + } + } + + updatingValue := reflect.ValueOf(stmt.Dest) + for updatingValue.Kind() == reflect.Ptr { + updatingValue = updatingValue.Elem() + } + + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if size := stmt.ReflectValue.Len(); size > 0 { + var isZero bool + for i := 0; i < size; i++ { + for _, field := range stmt.Schema.PrimaryFields { + _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) + if !isZero { + break + } + } + } + + if !isZero { + _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + + switch value := updatingValue.Interface().(type) { + case map[string]interface{}: + set = make([]clause.Assignment, 0, len(value)) + + keys := make([]string, 0, len(value)) + for k := range value { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { + assignValue(field, value[k]) + } + continue + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) + } + } + + if !stmt.SkipHooks && stmt.Schema != nil { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) + if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { + now := stmt.DB.NowFunc() + assignValue(field, now) + + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.AutoUpdateTime == schema.UnixMillisecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + } else if field.AutoUpdateTime == schema.UnixSecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } + } + } + } + } + default: + updatingSchema := stmt.Schema + var isDiffSchema bool + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + // different schema + updatingStmt := &gorm.Statement{DB: stmt.DB} + if err := updatingStmt.Parse(stmt.Dest); err == nil { + updatingSchema = updatingStmt.Schema + isDiffSchema = true + } + } + + switch updatingValue.Kind() { + case reflect.Struct: + set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + for _, dbName := range stmt.Schema.DBNames { + if field := updatingSchema.LookUpField(dbName); field != nil { + if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(stmt.Context, updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.AutoUpdateTime == schema.UnixSecond { + value = stmt.DB.NowFunc().Unix() + } else { + value = stmt.DB.NowFunc() + } + isZero = false + } + + if (ok || !isZero) && field.Updatable { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignField := field + if isDiffSchema { + if originField := stmt.Schema.LookUpField(dbName); originField != nil { + assignField = originField + } + } + assignValue(assignField, value) + } + } + } else { + if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + return +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go new file mode 100644 index 000000000..3dc7256e6 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go @@ -0,0 +1,469 @@ +package gorm + +import ( + "fmt" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +// Model specify the model you would like to run db operations +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (db *DB) Model(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Model = value + return +} + +// Clauses Add clauses +// +// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more +// advanced techniques like specifying lock strength and optimizer hints. See the +// [docs] for more depth. +// +// // add a simple limit clause +// db.Clauses(clause.Limit{Limit: 1}).Find(&User{}) +// // tell the optimizer to use the `idx_user_name` index +// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{}) +// // specify the lock strength to UPDATE +// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users) +// +// [docs]: https://gorm.io/docs/sql_builder.html#Clauses +func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { + tx = db.getInstance() + var whereConds []interface{} + + for _, cond := range conds { + if c, ok := cond.(clause.Interface); ok { + tx.Statement.AddClause(c) + } else if optimizer, ok := cond.(StatementModifier); ok { + optimizer.ModifyStatement(tx.Statement) + } else { + whereConds = append(whereConds, cond) + } + } + + if len(whereConds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)}) + } + return +} + +var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`) + +// Table specify the table you would like to run db operations +// +// // Get a user +// db.Table("users").Take(&result) +func (db *DB) Table(name string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { + tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} + if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 { + if results[1] != "" { + tx.Statement.Table = results[1] + } else { + tx.Statement.Table = results[2] + } + } + } else if tables := strings.Split(name, "."); len(tables) == 2 { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = tables[1] + } else if name != "" { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = name + } else { + tx.Statement.TableExpr = nil + tx.Statement.Table = "" + } + return +} + +// Distinct specify distinct fields that you want querying +// +// // Select distinct names of users +// db.Distinct("name").Find(&results) +// // Select distinct name/age pairs from users +// db.Distinct("name", "age").Find(&results) +func (db *DB) Distinct(args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Distinct = true + if len(args) > 0 { + tx = tx.Select(args[0], args[1:]...) + } + return +} + +// Select specify fields that you want when querying, creating, updating +// +// Use Select when you only want a subset of the fields. By default, GORM will select all fields. +// Select accepts both string arguments and arrays. +// +// // Select name and age of user using multiple arguments +// db.Select("name", "age").Find(&users) +// // Select name and age of user using an array +// db.Select([]string{"name", "age"}).Find(&users) +func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := query.(type) { + case []string: + tx.Statement.Selects = v + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + return + } + } + + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } + case string: + if strings.Count(v, "?") >= len(args) && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } else if strings.Count(v, "@") > 0 && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.NamedExpr{SQL: v, Vars: args}, + }) + } else { + tx.Statement.Selects = []string{v} + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + return + } + } + + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } + } + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + } + + return +} + +// Omit specify fields that you want to ignore when creating, updating and querying +func (db *DB) Omit(columns ...string) (tx *DB) { + tx = db.getInstance() + + if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + } else { + tx.Statement.Omits = columns + } + return +} + +// Where add conditions +// +// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. +// +// // Find the first user with name jinzhu +// db.Where("name = ?", "jinzhu").First(&user) +// // Find the first user with name jinzhu and age 20 +// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) +// // Find the first user with name jinzhu and age not equal to 20 +// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user) +// +// [docs]: https://gorm.io/docs/query.html#Conditions +func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: conds}) + } + return +} + +// Not add NOT conditions +// +// Not works similarly to where, and has the same syntax. +// +// // Find the first user with name not equal to jinzhu +// db.Not("name = ?", "jinzhu").First(&user) +func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) + } + return +} + +// Or add OR conditions +// +// Or is used to chain together queries with an OR. +// +// // Find the first user with name equal to jinzhu or john +// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user) +func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) + } + return +} + +// Joins specify Joins conditions +// +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.LeftJoin, query, args...) +} + +// InnerJoins specify inner joins conditions +// db.InnerJoins("Account").Find(&user) +func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.InnerJoin, query, args...) +} + +func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + if len(args) == 1 { + if db, ok := args[0].(*DB); ok { + j := join{ + Name: query, Conds: args, Selects: db.Statement.Selects, + Omits: db.Statement.Omits, JoinType: joinType, + } + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + j.On = &where + } + tx.Statement.Joins = append(tx.Statement.Joins, j) + return + } + } + + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) + return +} + +// Group specify the group method on the find +// +// // Select the sum age of users with given names +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results) +func (db *DB) Group(name string) (tx *DB) { + tx = db.getInstance() + + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + tx.Statement.AddClause(clause.GroupBy{ + Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, + }) + return +} + +// Having specify HAVING conditions for GROUP BY +// +// // Select the sum age of users with name jinzhu +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result) +func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.GroupBy{ + Having: tx.Statement.BuildCondition(query, args...), + }) + return +} + +// Order specify order when retrieving records from database +// +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +func (db *DB) Order(value interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := value.(type) { + case clause.OrderByColumn: + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{v}, + }) + case string: + if v != "" { + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ + Column: clause.Column{Name: v, Raw: true}, + }}, + }) + } + } + return +} + +// Limit specify the number of records to be retrieved +// +// Limit conditions can be cancelled by using `Limit(-1)`. +// +// // retrieve 3 users +// db.Limit(3).Find(&users) +// // retrieve 3 users into users1, and all users into users2 +// db.Limit(3).Find(&users1).Limit(-1).Find(&users2) +func (db *DB) Limit(limit int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Limit: &limit}) + return +} + +// Offset specify the number of records to skip before starting to return the records +// +// Offset conditions can be cancelled by using `Offset(-1)`. +// +// // select the third user +// db.Offset(2).First(&user) +// // select the first user by cancelling an earlier chained offset +// db.Offset(5).Offset(-1).First(&user) +func (db *DB) Offset(offset int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Offset: offset}) + return +} + +// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically +// +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { + tx = db.getInstance() + tx.Statement.scopes = append(tx.Statement.scopes, funcs...) + return tx +} + +func (db *DB) executeScopes() (tx *DB) { + tx = db.getInstance() + scopes := db.Statement.scopes + if len(scopes) == 0 { + return tx + } + tx.Statement.scopes = nil + + conditions := make([]clause.Interface, 0, 4) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + + for _, scope := range scopes { + tx = scope(tx) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + } + + for _, condition := range conditions { + tx.Statement.AddClause(condition) + } + return tx +} + +// Preload preload associations with given conditions +// +// // get all users, and preload all non-cancelled orders +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Preloads == nil { + tx.Statement.Preloads = map[string][]interface{}{} + } + tx.Statement.Preloads[query] = args + return +} + +// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Attrs only adds attributes if the record is not found. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign an email if the record is not found, otherwise ignore provided email +// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.attrs = attrs + return +} + +// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that +// records will be updated even if they are found. +// +// // assign an email regardless of if the record is not found +// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.assigns = attrs + return +} + +func (db *DB) Unscoped() (tx *DB) { + tx = db.getInstance() + tx.Statement.Unscoped = true + return +} + +func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + return +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/clause.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/clause.go new file mode 100644 index 000000000..1354fc057 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/clause.go @@ -0,0 +1,89 @@ +package clause + +// Interface clause interface +type Interface interface { + Name() string + Build(Builder) + MergeClause(*Clause) +} + +// ClauseBuilder clause builder, allows to customize how to build clause +type ClauseBuilder func(Clause, Builder) + +type Writer interface { + WriteByte(byte) error + WriteString(string) (int, error) +} + +// Builder builder interface +type Builder interface { + Writer + WriteQuoted(field interface{}) + AddVar(Writer, ...interface{}) + AddError(error) error +} + +// Clause +type Clause struct { + Name string // WHERE + BeforeExpression Expression + AfterNameExpression Expression + AfterExpression Expression + Expression Expression + Builder ClauseBuilder +} + +// Build build clause +func (c Clause) Build(builder Builder) { + if c.Builder != nil { + c.Builder(c, builder) + } else if c.Expression != nil { + if c.BeforeExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') + } + + if c.Name != "" { + builder.WriteString(c.Name) + builder.WriteByte(' ') + } + + if c.AfterNameExpression != nil { + c.AfterNameExpression.Build(builder) + builder.WriteByte(' ') + } + + c.Expression.Build(builder) + + if c.AfterExpression != nil { + builder.WriteByte(' ') + c.AfterExpression.Build(builder) + } + } +} + +const ( + PrimaryKey string = "~~~py~~~" // primary key + CurrentTable string = "~~~ct~~~" // current table + Associations string = "~~~as~~~" // associations +) + +var ( + currentTable = Table{Name: CurrentTable} + PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} +) + +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool +} + +// Table quote with name +type Table struct { + Name string + Alias string + Raw bool +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/delete.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/delete.go new file mode 100644 index 000000000..fc462cd7f --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/delete.go @@ -0,0 +1,23 @@ +package clause + +type Delete struct { + Modifier string +} + +func (d Delete) Name() string { + return "DELETE" +} + +func (d Delete) Build(builder Builder) { + builder.WriteString("DELETE") + + if d.Modifier != "" { + builder.WriteByte(' ') + builder.WriteString(d.Modifier) + } +} + +func (d Delete) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = d +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go new file mode 100644 index 000000000..8d010522f --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go @@ -0,0 +1,385 @@ +package clause + +import ( + "database/sql" + "database/sql/driver" + "go/ast" + "reflect" +) + +// Expression expression interface +type Expression interface { + Build(builder Builder) +} + +// NegationExpressionBuilder negation expression builder +type NegationExpressionBuilder interface { + NegationBuild(builder Builder) +} + +// Expr raw expression +type Expr struct { + SQL string + Vars []interface{} + WithoutParentheses bool +} + +// Build build raw expression +func (expr Expr) Build(builder Builder) { + var ( + afterParenthesis bool + idx int + ) + + for _, v := range []byte(expr.SQL) { + if v == '?' && len(expr.Vars) > idx { + if afterParenthesis || expr.WithoutParentheses { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + + idx++ + } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } + builder.WriteByte(v) + } + } + + if idx < len(expr.Vars) { + for _, v := range expr.Vars[idx:] { + builder.AddVar(builder, sql.NamedArg{Value: v}) + } + } +} + +// NamedExpr raw expression for named expr +type NamedExpr struct { + SQL string + Vars []interface{} +} + +// Build build raw expression +func (expr NamedExpr) Build(builder Builder) { + var ( + idx int + inName bool + afterParenthesis bool + namedMap = make(map[string]interface{}, len(expr.Vars)) + ) + + for _, v := range expr.Vars { + switch value := v.(type) { + case sql.NamedArg: + namedMap[value.Name] = value.Value + case map[string]interface{}: + for k, v := range value { + namedMap[k] = v + } + default: + var appendFieldsToMap func(reflect.Value) + appendFieldsToMap = func(reflectValue reflect.Value) { + reflectValue = reflect.Indirect(reflectValue) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + + if fieldStruct.Anonymous { + appendFieldsToMap(reflectValue.Field(i)) + } + } + } + } + } + + appendFieldsToMap(reflect.ValueOf(value)) + } + } + + name := make([]byte, 0, 10) + + for _, v := range []byte(expr.SQL) { + if v == '@' && !inName { + inName = true + name = []byte{} + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + inName = false + } + + afterParenthesis = false + builder.WriteByte(v) + } else if v == '?' && len(expr.Vars) > idx { + if afterParenthesis { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + + idx++ + } else if inName { + name = append(name, v) + } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } + builder.WriteByte(v) + } + } + + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + } +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder Builder) { + builder.WriteQuoted(in.Column) + + switch len(in.Values) { + case 0: + builder.WriteString(" IN (NULL)") + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" = ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteString(" IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +func (in IN) NegationBuild(builder Builder) { + builder.WriteQuoted(in.Column) + switch len(in.Values) { + case 0: + builder.WriteString(" IS NOT NULL") + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteString(" NOT IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder Builder) { + builder.WriteQuoted(eq.Column) + + switch eq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + rv := reflect.ValueOf(eq.Value) + if rv.Len() == 0 { + builder.WriteString(" IN (NULL)") + } else { + builder.WriteString(" IN (") + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + } + default: + if eqNil(eq.Value) { + builder.WriteString(" IS NULL") + } else { + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) + } + } +} + +func (eq Eq) NegationBuild(builder Builder) { + Neq(eq).Build(builder) +} + +// Neq not equal to for where +type Neq Eq + +func (neq Neq) Build(builder Builder) { + builder.WriteQuoted(neq.Column) + + switch neq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" NOT IN (") + rv := reflect.ValueOf(neq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + default: + if eqNil(neq.Value) { + builder.WriteString(" IS NOT NULL") + } else { + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) + } + } +} + +func (neq Neq) NegationBuild(builder Builder) { + Eq(neq).Build(builder) +} + +// Gt greater than for where +type Gt Eq + +func (gt Gt) Build(builder Builder) { + builder.WriteQuoted(gt.Column) + builder.WriteString(" > ") + builder.AddVar(builder, gt.Value) +} + +func (gt Gt) NegationBuild(builder Builder) { + Lte(gt).Build(builder) +} + +// Gte greater than or equal to for where +type Gte Eq + +func (gte Gte) Build(builder Builder) { + builder.WriteQuoted(gte.Column) + builder.WriteString(" >= ") + builder.AddVar(builder, gte.Value) +} + +func (gte Gte) NegationBuild(builder Builder) { + Lt(gte).Build(builder) +} + +// Lt less than for where +type Lt Eq + +func (lt Lt) Build(builder Builder) { + builder.WriteQuoted(lt.Column) + builder.WriteString(" < ") + builder.AddVar(builder, lt.Value) +} + +func (lt Lt) NegationBuild(builder Builder) { + Gte(lt).Build(builder) +} + +// Lte less than or equal to for where +type Lte Eq + +func (lte Lte) Build(builder Builder) { + builder.WriteQuoted(lte.Column) + builder.WriteString(" <= ") + builder.AddVar(builder, lte.Value) +} + +func (lte Lte) NegationBuild(builder Builder) { + Gt(lte).Build(builder) +} + +// Like whether string matches regular expression +type Like Eq + +func (like Like) Build(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" LIKE ") + builder.AddVar(builder, like.Value) +} + +func (like Like) NegationBuild(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" NOT LIKE ") + builder.AddVar(builder, like.Value) +} + +func eqNil(value interface{}) bool { + if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) { + value, _ = valuer.Value() + } + + return value == nil || eqNilReflect(value) +} + +func eqNilReflect(value interface{}) bool { + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/from.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/from.go new file mode 100644 index 000000000..1ea2d5951 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/from.go @@ -0,0 +1,37 @@ +package clause + +// From from clause +type From struct { + Tables []Table + Joins []Join +} + +// Name from clause name +func (from From) Name() string { + return "FROM" +} + +// Build build from clause +func (from From) Build(builder Builder) { + if len(from.Tables) > 0 { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(table) + } + } else { + builder.WriteQuoted(currentTable) + } + + for _, join := range from.Joins { + builder.WriteByte(' ') + join.Build(builder) + } +} + +// MergeClause merge from clause +func (from From) MergeClause(clause *Clause) { + clause.Expression = from +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/group_by.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/group_by.go new file mode 100644 index 000000000..84242fb8a --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/group_by.go @@ -0,0 +1,48 @@ +package clause + +// GroupBy group by clause +type GroupBy struct { + Columns []Column + Having []Expression +} + +// Name from clause name +func (groupBy GroupBy) Name() string { + return "GROUP BY" +} + +// Build build group by clause +func (groupBy GroupBy) Build(builder Builder) { + for idx, column := range groupBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + + if len(groupBy.Having) > 0 { + builder.WriteString(" HAVING ") + Where{Exprs: groupBy.Having}.Build(builder) + } +} + +// MergeClause merge group by clause +func (groupBy GroupBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(GroupBy); ok { + copiedColumns := make([]Column, len(v.Columns)) + copy(copiedColumns, v.Columns) + groupBy.Columns = append(copiedColumns, groupBy.Columns...) + + copiedHaving := make([]Expression, len(v.Having)) + copy(copiedHaving, v.Having) + groupBy.Having = append(copiedHaving, groupBy.Having...) + } + clause.Expression = groupBy + + if len(groupBy.Columns) == 0 { + clause.Name = "" + } else { + clause.Name = groupBy.Name() + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/insert.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/insert.go new file mode 100644 index 000000000..8efaa0352 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/insert.go @@ -0,0 +1,39 @@ +package clause + +type Insert struct { + Table Table + Modifier string +} + +// Name insert clause name +func (insert Insert) Name() string { + return "INSERT" +} + +// Build build insert clause +func (insert Insert) Build(builder Builder) { + if insert.Modifier != "" { + builder.WriteString(insert.Modifier) + builder.WriteByte(' ') + } + + builder.WriteString("INTO ") + if insert.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(insert.Table) + } +} + +// MergeClause merge insert clause +func (insert Insert) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Insert); ok { + if insert.Modifier == "" { + insert.Modifier = v.Modifier + } + if insert.Table.Name == "" { + insert.Table = v.Table + } + } + clause.Expression = insert +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/joins.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/joins.go new file mode 100644 index 000000000..879892be4 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/joins.go @@ -0,0 +1,47 @@ +package clause + +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin JoinType = "INNER" + LeftJoin JoinType = "LEFT" + RightJoin JoinType = "RIGHT" +) + +// Join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string + Expression Expression +} + +func (join Join) Build(builder Builder) { + if join.Expression != nil { + join.Expression.Build(builder) + } else { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go new file mode 100644 index 000000000..abda00551 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go @@ -0,0 +1,48 @@ +package clause + +import "strconv" + +// Limit limit clause +type Limit struct { + Limit *int + Offset int +} + +// Name where clause name +func (limit Limit) Name() string { + return "LIMIT" +} + +// Build build where clause +func (limit Limit) Build(builder Builder) { + if limit.Limit != nil && *limit.Limit >= 0 { + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(*limit.Limit)) + } + if limit.Offset > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { + builder.WriteByte(' ') + } + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + } +} + +// MergeClause merge order by clauses +func (limit Limit) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(Limit); ok { + if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { + limit.Limit = v.Limit + } + + if limit.Offset == 0 && v.Offset > 0 { + limit.Offset = v.Offset + } else if limit.Offset < 0 { + limit.Offset = 0 + } + } + + clause.Expression = limit +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go new file mode 100644 index 000000000..290aac92b --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go @@ -0,0 +1,31 @@ +package clause + +type Locking struct { + Strength string + Table Table + Options string +} + +// Name where clause name +func (locking Locking) Name() string { + return "FOR" +} + +// Build build where clause +func (locking Locking) Build(builder Builder) { + builder.WriteString(locking.Strength) + if locking.Table.Name != "" { + builder.WriteString(" OF ") + builder.WriteQuoted(locking.Table) + } + + if locking.Options != "" { + builder.WriteByte(' ') + builder.WriteString(locking.Options) + } +} + +// MergeClause merge order by clauses +func (locking Locking) MergeClause(clause *Clause) { + clause.Expression = locking +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/on_conflict.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/on_conflict.go new file mode 100644 index 000000000..032bf4a1c --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/on_conflict.go @@ -0,0 +1,59 @@ +package clause + +type OnConflict struct { + Columns []Column + Where Where + TargetWhere Where + OnConstraint string + DoNothing bool + DoUpdates Set + UpdateAll bool +} + +func (OnConflict) Name() string { + return "ON CONFLICT" +} + +// Build build onConflict clause +func (onConflict OnConflict) Build(builder Builder) { + if onConflict.OnConstraint != "" { + builder.WriteString("ON CONSTRAINT ") + builder.WriteString(onConflict.OnConstraint) + builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } + } + + if onConflict.DoNothing { + builder.WriteString("DO NOTHING") + } else { + builder.WriteString("DO UPDATE SET ") + onConflict.DoUpdates.Build(builder) + } + + if len(onConflict.Where.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.Where.Build(builder) + builder.WriteByte(' ') + } +} + +// MergeClause merge onConflict clauses +func (onConflict OnConflict) MergeClause(clause *Clause) { + clause.Expression = onConflict +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/order_by.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/order_by.go new file mode 100644 index 000000000..412180255 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/order_by.go @@ -0,0 +1,54 @@ +package clause + +type OrderByColumn struct { + Column Column + Desc bool + Reorder bool +} + +type OrderBy struct { + Columns []OrderByColumn + Expression Expression +} + +// Name where clause name +func (orderBy OrderBy) Name() string { + return "ORDER BY" +} + +// Build build where clause +func (orderBy OrderBy) Build(builder Builder) { + if orderBy.Expression != nil { + orderBy.Expression.Build(builder) + } else { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column.Column) + if column.Desc { + builder.WriteString(" DESC") + } + } + } +} + +// MergeClause merge order by clauses +func (orderBy OrderBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(OrderBy); ok { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + if orderBy.Columns[i].Reorder { + orderBy.Columns = orderBy.Columns[i:] + clause.Expression = orderBy + return + } + } + + copiedColumns := make([]OrderByColumn, len(v.Columns)) + copy(copiedColumns, v.Columns) + orderBy.Columns = append(copiedColumns, orderBy.Columns...) + } + + clause.Expression = orderBy +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/returning.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/returning.go new file mode 100644 index 000000000..d94b7a4ca --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/returning.go @@ -0,0 +1,34 @@ +package clause + +type Returning struct { + Columns []Column +} + +// Name where clause name +func (returning Returning) Name() string { + return "RETURNING" +} + +// Build build where clause +func (returning Returning) Build(builder Builder) { + if len(returning.Columns) > 0 { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +// MergeClause merge order by clauses +func (returning Returning) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Returning); ok { + returning.Columns = append(v.Columns, returning.Columns...) + } + + clause.Expression = returning +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/select.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/select.go new file mode 100644 index 000000000..d8e9f8015 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/select.go @@ -0,0 +1,59 @@ +package clause + +// Select select attrs when querying, updating, creating +type Select struct { + Distinct bool + Columns []Column + Expression Expression +} + +func (s Select) Name() string { + return "SELECT" +} + +func (s Select) Build(builder Builder) { + if len(s.Columns) > 0 { + if s.Distinct { + builder.WriteString("DISTINCT ") + } + + for idx, column := range s.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +func (s Select) MergeClause(clause *Clause) { + if s.Expression != nil { + if s.Distinct { + if expr, ok := s.Expression.(Expr); ok { + expr.SQL = "DISTINCT " + expr.SQL + clause.Expression = expr + return + } + } + + clause.Expression = s.Expression + } else { + clause.Expression = s + } +} + +// CommaExpression represents a group of expressions separated by commas. +type CommaExpression struct { + Exprs []Expression +} + +func (comma CommaExpression) Build(builder Builder) { + for idx, expr := range comma.Exprs { + if idx > 0 { + _, _ = builder.WriteString(", ") + } + expr.Build(builder) + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/set.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/set.go new file mode 100644 index 000000000..75eb6bdda --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/set.go @@ -0,0 +1,60 @@ +package clause + +import "sort" + +type Set []Assignment + +type Assignment struct { + Column Column + Value interface{} +} + +func (set Set) Name() string { + return "SET" +} + +func (set Set) Build(builder Builder) { + if len(set) > 0 { + for idx, assignment := range set { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(assignment.Column) + builder.WriteByte('=') + builder.AddVar(builder, assignment.Value) + } + } else { + builder.WriteQuoted(Column{Name: PrimaryKey}) + builder.WriteByte('=') + builder.WriteQuoted(Column{Name: PrimaryKey}) + } +} + +// MergeClause merge assignments clauses +func (set Set) MergeClause(clause *Clause) { + copiedAssignments := make([]Assignment, len(set)) + copy(copiedAssignments, set) + clause.Expression = Set(copiedAssignments) +} + +func Assignments(values map[string]interface{}) Set { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + + assignments := make([]Assignment, len(keys)) + for idx, key := range keys { + assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} + } + return assignments +} + +func AssignmentColumns(values []string) Set { + assignments := make([]Assignment, len(values)) + for idx, value := range values { + assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} + } + return assignments +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/update.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/update.go new file mode 100644 index 000000000..f9d68ac67 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/update.go @@ -0,0 +1,38 @@ +package clause + +type Update struct { + Modifier string + Table Table +} + +// Name update clause name +func (update Update) Name() string { + return "UPDATE" +} + +// Build build update clause +func (update Update) Build(builder Builder) { + if update.Modifier != "" { + builder.WriteString(update.Modifier) + builder.WriteByte(' ') + } + + if update.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(update.Table) + } +} + +// MergeClause merge update clause +func (update Update) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Update); ok { + if update.Modifier == "" { + update.Modifier = v.Modifier + } + if update.Table.Name == "" { + update.Table = v.Table + } + } + clause.Expression = update +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/values.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/values.go new file mode 100644 index 000000000..b2f5421be --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/values.go @@ -0,0 +1,45 @@ +package clause + +type Values struct { + Columns []Column + Values [][]interface{} +} + +// Name from clause name +func (Values) Name() string { + return "VALUES" +} + +// Build build from clause +func (values Values) Build(builder Builder) { + if len(values.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteByte(')') + + builder.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteByte('(') + builder.AddVar(builder, value...) + builder.WriteByte(')') + } + } else { + builder.WriteString("DEFAULT VALUES") + } +} + +// MergeClause merge values clauses +func (values Values) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = values +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go new file mode 100644 index 000000000..a29401cfe --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go @@ -0,0 +1,190 @@ +package clause + +import ( + "strings" +) + +const ( + AndWithSpace = " AND " + OrWithSpace = " OR " +) + +// Where where clause +type Where struct { + Exprs []Expression +} + +// Name where clause name +func (where Where) Name() string { + return "WHERE" +} + +// Build build where clause +func (where Where) Build(builder Builder) { + // Switch position if the first query expression is a single Or condition + for idx, expr := range where.Exprs { + if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { + if idx != 0 { + where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] + } + break + } + } + + buildExprs(where.Exprs, builder, AndWithSpace) +} + +func buildExprs(exprs []Expression, builder Builder, joinCond string) { + wrapInParentheses := false + + for idx, expr := range exprs { + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.WriteString(OrWithSpace) + } else { + builder.WriteString(joinCond) + } + } + + if len(exprs) > 1 { + switch v := expr.(type) { + case OrConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + } + } + case AndConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + } + } + case Expr: + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + case NamedExpr: + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + } + } + + if wrapInParentheses { + builder.WriteByte('(') + expr.Build(builder) + builder.WriteByte(')') + wrapInParentheses = false + } else { + expr.Build(builder) + } + } +} + +// MergeClause merge where clauses +func (where Where) MergeClause(clause *Clause) { + if w, ok := clause.Expression.(Where); ok { + exprs := make([]Expression, len(w.Exprs)+len(where.Exprs)) + copy(exprs, w.Exprs) + copy(exprs[len(w.Exprs):], where.Exprs) + where.Exprs = exprs + } + + clause.Expression = where +} + +func And(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + + if len(exprs) == 1 { + if _, ok := exprs[0].(OrConditions); !ok { + return exprs[0] + } + } + + return AndConditions{Exprs: exprs} +} + +type AndConditions struct { + Exprs []Expression +} + +func (and AndConditions) Build(builder Builder) { + if len(and.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(and.Exprs, builder, AndWithSpace) + builder.WriteByte(')') + } else { + buildExprs(and.Exprs, builder, AndWithSpace) + } +} + +func Or(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return OrConditions{Exprs: exprs} +} + +type OrConditions struct { + Exprs []Expression +} + +func (or OrConditions) Build(builder Builder) { + if len(or.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(or.Exprs, builder, OrWithSpace) + builder.WriteByte(')') + } else { + buildExprs(or.Exprs, builder, OrWithSpace) + } +} + +func Not(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return NotConditions{Exprs: exprs} +} + +type NotConditions struct { + Exprs []Expression +} + +func (not NotConditions) Build(builder Builder) { + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/with.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/with.go new file mode 100644 index 000000000..0768488e5 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/with.go @@ -0,0 +1,3 @@ +package clause + +type With struct{} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/errors.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/errors.go new file mode 100644 index 000000000..cd76f1f52 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/errors.go @@ -0,0 +1,52 @@ +package gorm + +import ( + "errors" + + "gorm.io/gorm/logger" +) + +var ( + // ErrRecordNotFound record not found error + ErrRecordNotFound = logger.ErrRecordNotFound + // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + ErrInvalidTransaction = errors.New("invalid transaction") + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("not implemented") + // ErrMissingWhereClause missing where clause + ErrMissingWhereClause = errors.New("WHERE conditions required") + // ErrUnsupportedRelation unsupported relations + ErrUnsupportedRelation = errors.New("unsupported relations") + // ErrPrimaryKeyRequired primary keys required + ErrPrimaryKeyRequired = errors.New("primary key required") + // ErrModelValueRequired model value required + ErrModelValueRequired = errors.New("model value required") + // ErrModelAccessibleFieldsRequired model accessible fields required + ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") + // ErrSubQueryRequired sub query required + ErrSubQueryRequired = errors.New("sub query required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") + // ErrUnsupportedDriver unsupported driver + ErrUnsupportedDriver = errors.New("unsupported driver") + // ErrRegistered registered + ErrRegistered = errors.New("registered") + // ErrInvalidField invalid field + ErrInvalidField = errors.New("invalid field") + // ErrEmptySlice empty slice found + ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") + // ErrInvalidDB invalid db + ErrInvalidDB = errors.New("invalid db") + // ErrInvalidValue invalid value + ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") + // ErrInvalidValueOfLength invalid values do not match length + ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") + // ErrPreloadNotAllowed preload is not allowed when count is used + ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") + // ErrDuplicatedKey occurs when there is a unique key constraint violation + ErrDuplicatedKey = errors.New("duplicated key not allowed") + // ErrForeignKeyViolated occurs when there is a foreign key constraint violation + ErrForeignKeyViolated = errors.New("violates foreign key constraint") +) diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go new file mode 100644 index 000000000..f80aa6c04 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go @@ -0,0 +1,766 @@ +package gorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Create inserts value, returning the inserted data's primary key in value's id +func (db *DB) Create(value interface{}) (tx *DB) { + if db.CreateBatchSize > 0 { + return db.CreateInBatches(value, db.CreateBatchSize) + } + + tx = db.getInstance() + tx.Statement.Dest = value + return tx.callbacks.Create().Execute(tx) +} + +// CreateInBatches inserts value in batches of batchSize +func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var rowsAffected int64 + tx = db.getInstance() + + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + + callFc := func(tx *DB) error { + for i := 0; i < reflectLen; i += batchSize { + ends := i + batchSize + if ends > reflectLen { + ends = reflectLen + } + + subtx := tx.getInstance() + subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() + subtx.callbacks.Create().Execute(subtx) + if subtx.Error != nil { + return subtx.Error + } + rowsAffected += subtx.RowsAffected + } + return nil + } + + if tx.SkipDefaultTransaction || reflectLen <= batchSize { + tx.AddError(callFc(tx.Session(&Session{}))) + } else { + tx.AddError(tx.Transaction(callFc)) + } + + tx.RowsAffected = rowsAffected + default: + tx = db.getInstance() + tx.Statement.Dest = value + tx = tx.callbacks.Create().Execute(tx) + } + return +} + +// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. +func (db *DB) Save(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = value + + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { + reflectValue = reflect.Indirect(reflectValue) + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { + tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) + } + tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) + case reflect.Struct: + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { + for _, pf := range tx.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { + return tx.callbacks.Create().Execute(tx) + } + } + } + + fallthrough + default: + selectedUpdate := len(tx.Statement.Selects) != 0 + // when updating, use all fields including those zero-value fields + if !selectedUpdate { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } + + updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) + + if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { + return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) + } + + return updateTx + } + + return +} + +// First finds the first record ordered by primary key, matching given conditions conds +func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// Take finds the first record returned by the database in no specified order, matching given conditions conds +func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// Last finds the last record ordered by primary key, matching given conditions conds +func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, + }) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// Find finds all records matching given conditions conds +func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// FindInBatches finds all records in batches of batchSize +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + var ( + tx = db.Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }).Session(&Session{}) + queryDB = tx + rowsAffected int64 + batch int + ) + + // user specified offset or limit + var totalSize int + if c, ok := tx.Statement.Clauses["LIMIT"]; ok { + if limit, ok := c.Expression.(clause.Limit); ok { + if limit.Limit != nil { + totalSize = *limit.Limit + } + + if totalSize > 0 && batchSize > totalSize { + batchSize = totalSize + } + + // reset to offset to 0 in next batch + tx = tx.Offset(-1).Session(&Session{}) + } + } + + for { + result := queryDB.Limit(batchSize).Find(dest) + rowsAffected += result.RowsAffected + batch++ + + if result.Error == nil && result.RowsAffected != 0 { + fcTx := result.Session(&Session{NewDB: true}) + fcTx.RowsAffected = result.RowsAffected + tx.AddError(fc(fcTx, batch)) + } else if result.Error != nil { + tx.AddError(result.Error) + } + + if tx.Error != nil || int(result.RowsAffected) < batchSize { + break + } + + if totalSize > 0 { + if totalSize <= int(rowsAffected) { + break + } + if totalSize/batchSize == batch { + batchSize = totalSize % batchSize + } + } + + // Optimize for-break + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + if result.Statement.Schema.PrioritizedPrimaryField == nil { + tx.AddError(ErrPrimaryKeyRequired) + break + } + + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + } + + tx.RowsAffected = rowsAffected + return tx +} + +func (db *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := db.Statement.Schema.LookUpField(column); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := db.Statement.Schema.LookUpField(column.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) + } + } + } else if andCond, ok := expr.(clause.AndConditions); ok { + db.assignInterfacesToValue(andCond.Exprs) + } + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) + } + default: + if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero { + if field := db.Statement.Schema.LookUpField(f.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v)) + } + } + } + } + } + } else if len(values) > 0 { + if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) + } + return + } + } + } +} + +// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. +// Each conds must be a struct or map. +// +// FirstOrInit never modifies the database. It is often used with Assign and Attrs. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } + return +} + +// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. +// Each conds must be a struct or map. +// +// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists. +// +// // assign an email if the record is not found +// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +// +// // assign email regardless of if record is found +// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + result := queryTx.Find(dest, conds...) + if result.Error != nil { + tx.Error = result.Error + return tx + } + + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + result.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) + } + + // initialize with attrs, conds + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + } + } + } + + return tx.Model(dest).Updates(assigns) + } + + return tx +} + +// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Update(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + return tx.callbacks.Update().Execute(tx) +} + +// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Updates(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + return tx.callbacks.Update().Execute(tx) +} + +func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.Statement.SkipHooks = true + return tx.callbacks.Update().Execute(tx) +} + +func (db *DB) UpdateColumns(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + tx.Statement.SkipHooks = true + return tx.callbacks.Update().Execute(tx) +} + +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If +// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// time if null. +func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.Dest = value + return tx.callbacks.Delete().Execute(tx) +} + +func (db *DB) Count(count *int64) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model == nil { + tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() + } + + if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { + defer func() { + tx.Statement.Clauses["SELECT"] = selectClause + }() + } else { + defer delete(tx.Statement.Clauses, "SELECT") + } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}}) + } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { + expr := clause.Expr{SQL: "count(*)"} + + if len(tx.Statement.Selects) == 1 { + dbName := tx.Statement.Selects[0] + fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName + } + } + + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else if dbName != "*" { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } + } + } + + tx.Statement.AddClause(clause.Select{Expression: expr}) + } + + if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { + if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { + delete(tx.Statement.Clauses, "ORDER BY") + defer func() { + tx.Statement.Clauses["ORDER BY"] = orderByClause + }() + } + } + + tx.Statement.Dest = count + tx = tx.callbacks.Query().Execute(tx) + + if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { + *count = tx.RowsAffected + } + + return +} + +func (db *DB) Row() *sql.Row { + tx := db.getInstance().Set("rows", false) + tx = tx.callbacks.Row().Execute(tx) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row +} + +func (db *DB) Rows() (*sql.Rows, error) { + tx := db.getInstance().Set("rows", true) + tx = tx.callbacks.Row().Execute(tx) + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error +} + +// Scan scans selected value to the struct dest +func (db *DB) Scan(dest interface{}) (tx *DB) { + config := *db.Config + currentLogger, newLogger := config.Logger, logger.Recorder.New() + config.Logger = newLogger + + tx = db.getInstance() + tx.Config = &config + + if rows, err := tx.Rows(); err == nil { + if rows.Next() { + tx.ScanRows(rows, dest) + } else { + tx.RowsAffected = 0 + tx.AddError(rows.Err()) + } + tx.AddError(rows.Close()) + } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger + return +} + +// Pluck queries a single column from a model, returning in the slice dest. E.g.: +// +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) +func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model != nil { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + } + + if len(tx.Statement.Selects) != 1 { + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, + }) + } + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { + tx := db.getInstance() + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } + tx.Statement.Dest = dest + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + elem := tx.Statement.ReflectValue.Elem() + if !elem.IsValid() { + elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) + tx.Statement.ReflectValue.Set(elem) + } + tx.Statement.ReflectValue = elem + } + Scan(rows, tx, ScanInitialized) + return tx.Error +} + +// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is +// returned to the connection pool. +func (db *DB) Connection(fc func(tx *DB) error) (err error) { + if db.Error != nil { + return db.Error + } + + tx := db.getInstance() + sqlDB, err := tx.DB() + if err != nil { + return + } + + conn, err := sqlDB.Conn(tx.Statement.Context) + if err != nil { + return + } + + defer conn.Close() + tx.Statement.ConnPool = conn + return fc(tx) +} + +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an +// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs +// they are rolled back. +func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { + panicked := true + + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + // nested transaction + if !db.DisableNestedTransaction { + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + if err != nil { + return + } + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() + } + err = fc(db.Session(&Session{NewDB: db.clone == 1})) + } else { + tx := db.Begin(opts...) + if tx.Error != nil { + return tx.Error + } + + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + if err = fc(tx); err == nil { + panicked = false + return tx.Commit().Error + } + } + + panicked = false + return +} + +// Begin begins a transaction with any transaction options opts +func (db *DB) Begin(opts ...*sql.TxOptions) *DB { + var ( + // clone statement + tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) + opt *sql.TxOptions + err error + ) + + if len(opts) > 0 { + opt = opts[0] + } + + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx +} + +// Commit commits the changes in a transaction +func (db *DB) Commit() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Commit()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +// Rollback rollbacks the changes in a transaction +func (db *DB) Rollback() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +func (db *DB) SavePoint(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because SavePoint not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } + db.AddError(savePointer.SavePoint(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +func (db *DB) RollbackTo(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because RollbackTo not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } + db.AddError(savePointer.RollbackTo(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +// Exec executes raw sql +func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + + return tx.callbacks.Raw().Execute(tx) +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go new file mode 100644 index 000000000..203527af3 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go @@ -0,0 +1,503 @@ +package gorm + +import ( + "context" + "database/sql" + "fmt" + "sort" + "sync" + "time" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +// for Config.cacheStore store PreparedStmtDB key +const preparedStmtDBKey = "preparedStmt" + +// Config GORM config +type Config struct { + // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity + // You can disable it by setting `SkipDefaultTransaction` to true + SkipDefaultTransaction bool + // NamingStrategy tables, columns naming strategy + NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool + // Logger + Logger logger.Interface + // NowFunc the function to be used when creating a new timestamp + NowFunc func() time.Time + // DryRun generate sql without execute + DryRun bool + // PrepareStmt executes the given query in cached statement + PrepareStmt bool + // DisableAutomaticPing + DisableAutomaticPing bool + // DisableForeignKeyConstraintWhenMigrating + DisableForeignKeyConstraintWhenMigrating bool + // IgnoreRelationshipsWhenMigrating + IgnoreRelationshipsWhenMigrating bool + // DisableNestedTransaction disable nested transaction + DisableNestedTransaction bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool + // QueryFields executes the SQL query with all fields of the table + QueryFields bool + // CreateBatchSize default create batch size + CreateBatchSize int + // TranslateError enabling error translation + TranslateError bool + + // ClauseBuilders clause builder + ClauseBuilders map[string]clause.ClauseBuilder + // ConnPool db conn pool + ConnPool ConnPool + // Dialector database dialector + Dialector + // Plugins registered plugins + Plugins map[string]Plugin + + callbacks *callbacks + cacheStore *sync.Map +} + +// Apply update config to new config +func (c *Config) Apply(config *Config) error { + if config != c { + *config = *c + } + return nil +} + +// AfterInitialize initialize plugins after db connected +func (c *Config) AfterInitialize(db *DB) error { + if db != nil { + for _, plugin := range c.Plugins { + if err := plugin.Initialize(db); err != nil { + return err + } + } + } + return nil +} + +// Option gorm option interface +type Option interface { + Apply(*Config) error + AfterInitialize(*DB) error +} + +// DB GORM DB definition +type DB struct { + *Config + Error error + RowsAffected int64 + Statement *Statement + clone int +} + +// Session session config when create session with Session() method +type Session struct { + DryRun bool + PrepareStmt bool + NewDB bool + Initialized bool + SkipHooks bool + SkipDefaultTransaction bool + DisableNestedTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + QueryFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int +} + +// Open initialize db session based on dialector +func Open(dialector Dialector, opts ...Option) (db *DB, err error) { + config := &Config{} + + sort.Slice(opts, func(i, j int) bool { + _, isConfig := opts[i].(*Config) + _, isConfig2 := opts[j].(*Config) + return isConfig && !isConfig2 + }) + + for _, opt := range opts { + if opt != nil { + if applyErr := opt.Apply(config); applyErr != nil { + return nil, applyErr + } + defer func(opt Option) { + if errr := opt.AfterInitialize(db); errr != nil { + err = errr + } + }(opt) + } + } + + if d, ok := dialector.(interface{ Apply(*Config) error }); ok { + if err = d.Apply(config); err != nil { + return + } + } + + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 + } + + if config.Logger == nil { + config.Logger = logger.Default + } + + if config.NowFunc == nil { + config.NowFunc = func() time.Time { return time.Now().Local() } + } + + if dialector != nil { + config.Dialector = dialector + } + + if config.Plugins == nil { + config.Plugins = map[string]Plugin{} + } + + if config.cacheStore == nil { + config.cacheStore = &sync.Map{} + } + + db = &DB{Config: config, clone: 1} + + db.callbacks = initializeCallbacks(db) + + if config.ClauseBuilders == nil { + config.ClauseBuilders = map[string]clause.ClauseBuilder{} + } + + if config.Dialector != nil { + err = config.Dialector.Initialize(db) + + if err != nil { + if db, err := db.DB(); err == nil { + _ = db.Close() + } + } + } + + if config.PrepareStmt { + preparedStmt := NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) + db.ConnPool = preparedStmt + } + + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + } + + if err == nil && !config.DisableAutomaticPing { + if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { + err = pinger.Ping() + } + } + + if err != nil { + config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) + } + + return +} + +// Session create new db session +func (db *DB) Session(config *Session) *DB { + var ( + txConfig = *db.Config + tx = &DB{ + Config: &txConfig, + Statement: db.Statement, + Error: db.Error, + clone: 1, + } + ) + if config.CreateBatchSize > 0 { + tx.Config.CreateBatchSize = config.CreateBatchSize + } + + if config.SkipDefaultTransaction { + tx.Config.SkipDefaultTransaction = true + } + + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + + if config.Context != nil || config.PrepareStmt || config.SkipHooks { + tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx + } + + if config.Context != nil { + tx.Statement.Context = config.Context + } + + if config.PrepareStmt { + var preparedStmt *PreparedStmtDB + + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { + preparedStmt = v.(*PreparedStmtDB) + } else { + preparedStmt = NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) + } + + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true + } + + if config.SkipHooks { + tx.Statement.SkipHooks = true + } + + if config.DisableNestedTransaction { + txConfig.DisableNestedTransaction = true + } + + if !config.NewDB { + tx.clone = 2 + } + + if config.DryRun { + tx.Config.DryRun = true + } + + if config.QueryFields { + tx.Config.QueryFields = true + } + + if config.Logger != nil { + tx.Config.Logger = config.Logger + } + + if config.NowFunc != nil { + tx.Config.NowFunc = config.NowFunc + } + + if config.Initialized { + tx = tx.getInstance() + } + + return tx +} + +// WithContext change current instance db's context to ctx +func (db *DB) WithContext(ctx context.Context) *DB { + return db.Session(&Session{Context: ctx}) +} + +// Debug start debug mode +func (db *DB) Debug() (tx *DB) { + tx = db.getInstance() + return tx.Session(&Session{ + Logger: db.Logger.LogMode(logger.Info), + }) +} + +// Set store value with key into current db instance's context +func (db *DB) Set(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(key, value) + return tx +} + +// Get get value with key from current db instance's context +func (db *DB) Get(key string) (interface{}, bool) { + return db.Statement.Settings.Load(key) +} + +// InstanceSet store value with key into current db instance's context +func (db *DB) InstanceSet(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) + return tx +} + +// InstanceGet get value with key from current db instance's context +func (db *DB) InstanceGet(key string) (interface{}, bool) { + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) +} + +// Callback returns callback manager +func (db *DB) Callback() *callbacks { + return db.callbacks +} + +// AddError add error to db +func (db *DB) AddError(err error) error { + if err != nil { + if db.Config.TranslateError { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } + } + + if db.Error == nil { + db.Error = err + } else { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } + } + return db.Error +} + +// DB returns `*sql.DB` +func (db *DB) DB() (*sql.DB, error) { + connPool := db.ConnPool + + if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(db) + } + + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { + return sqldb, err + } + } + + if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { + return sqldb, nil + } + + return nil, ErrInvalidDB +} + +func (db *DB) getInstance() *DB { + if db.clone > 0 { + tx := &DB{Config: db.Config, Error: db.Error} + + if db.clone == 1 { + // clone with new statement + tx.Statement = &Statement{ + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + SkipHooks: db.Statement.SkipHooks, + } + } else { + // with clone statement + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx + } + + return tx + } + + return db +} + +// Expr returns clause.Expr, which can be used to pass SQL expression as params +func Expr(expr string, args ...interface{}) clause.Expr { + return clause.Expr{SQL: expr, Vars: args} +} + +// SetupJoinTable setup join table schema +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + err := stmt.Parse(model) + if err != nil { + return err + } + modelSchema = stmt.Schema + + err = stmt.Parse(joinTable) + if err != nil { + return err + } + joinSchema = stmt.Schema + + relation, ok := modelSchema.Relationships.Relations[field] + isRelation := ok && relation.JoinTable != nil + if !isRelation { + return fmt.Errorf("failed to find relation: %s", field) + } + + for _, ref := range relation.References { + f := joinSchema.LookUpField(ref.ForeignKey.DBName) + if f == nil { + return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) + } + + f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } + ref.ForeignKey = f + } + + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } + } + relation.JoinTable = joinSchema + + return nil +} + +// Use use plugin +func (db *DB) Use(plugin Plugin) error { + name := plugin.Name() + if _, ok := db.Plugins[name]; ok { + return ErrRegistered + } + if err := plugin.Initialize(db); err != nil { + return err + } + db.Plugins[name] = plugin + return nil +} + +// ToSQL for generate SQL string. +// +// db.ToSQL(func(tx *gorm.DB) *gorm.DB { +// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) +// .Limit(10).Offset(5) +// .Order("name ASC") +// .First(&User{}) +// }) +func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) + stmt := tx.Statement + + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go new file mode 100644 index 000000000..1950d7400 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go @@ -0,0 +1,98 @@ +package gorm + +import ( + "context" + "database/sql" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Dialector GORM database dialector +type Dialector interface { + Name() string + Initialize(*DB) error + Migrator(db *DB) Migrator + DataTypeOf(*schema.Field) string + DefaultValueOf(*schema.Field) clause.Expression + BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) + QuoteTo(clause.Writer, string) + Explain(sql string, vars ...interface{}) string +} + +// Plugin GORM plugin interface +type Plugin interface { + Name() string + Initialize(*DB) error +} + +type ParamsFilter interface { + ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) +} + +// ConnPool db conns pool interface +type ConnPool interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// SavePointerDialectorInterface save pointer interface +type SavePointerDialectorInterface interface { + SavePoint(tx *DB, name string) error + RollbackTo(tx *DB, name string) error +} + +// TxBeginner tx beginner +type TxBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// ConnPoolBeginner conn pool beginner +type ConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) +} + +// TxCommitter tx committer +type TxCommitter interface { + Commit() error + Rollback() error +} + +// Tx sql.Tx interface +type Tx interface { + ConnPool + TxCommitter + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +} + +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr +} + +// GetDBConnector SQL db connector +type GetDBConnector interface { + GetDBConn() (*sql.DB, error) +} + +// GetDBConnectorWithContext represents SQL db connector which takes into +// account the current database context +type GetDBConnectorWithContext interface { + GetDBConnWithContext(db *DB) (*sql.DB, error) +} + +// Rows rows interface +type Rows interface { + Columns() ([]string, error) + ColumnTypes() ([]*sql.ColumnType, error) + Next() bool + Scan(dest ...interface{}) error + Err() error + Close() error +} + +type ErrorTranslator interface { + Translate(err error) error +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go new file mode 100644 index 000000000..aa0060bc5 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go @@ -0,0 +1,211 @@ +package logger + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "os" + "time" + + "gorm.io/gorm/utils" +) + +// ErrRecordNotFound record not found error +var ErrRecordNotFound = errors.New("record not found") + +// Colors +const ( + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + BlueBold = "\033[34;1m" + MagentaBold = "\033[35;1m" + RedBold = "\033[31;1m" + YellowBold = "\033[33;1m" +) + +// LogLevel log level +type LogLevel int + +const ( + // Silent silent log level + Silent LogLevel = iota + 1 + // Error error log level + Error + // Warn warn log level + Warn + // Info info log level + Info +) + +// Writer log writer interface +type Writer interface { + Printf(string, ...interface{}) +} + +// Config logger config +type Config struct { + SlowThreshold time.Duration + Colorful bool + IgnoreRecordNotFoundError bool + ParameterizedQueries bool + LogLevel LogLevel +} + +// Interface logger interface +type Interface interface { + LogMode(LogLevel) Interface + Info(context.Context, string, ...interface{}) + Warn(context.Context, string, ...interface{}) + Error(context.Context, string, ...interface{}) + Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +var ( + // Discard Discard logger will print any log to io.Discard + Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) + // Default Default logger + Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: Warn, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + // Recorder Recorder logger records running SQL into a recorder instance + Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} +) + +// New initialize logger +func New(writer Writer, config Config) Interface { + var ( + infoStr = "%s\n[info] " + warnStr = "%s\n[warn] " + errStr = "%s\n[error] " + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" + ) + + if config.Colorful { + infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset + errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + } + + return &logger{ + Writer: writer, + Config: config, + infoStr: infoStr, + warnStr: warnStr, + errStr: errStr, + traceStr: traceStr, + traceWarnStr: traceWarnStr, + traceErrStr: traceErrStr, + } +} + +type logger struct { + Writer + Config + infoStr, warnStr, errStr string + traceStr, traceErrStr, traceWarnStr string +} + +// LogMode log mode +func (l *logger) LogMode(level LogLevel) Interface { + newlogger := *l + newlogger.LogLevel = level + return &newlogger +} + +// Info print info +func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Warn print warn messages +func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Error print error messages +func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Trace print sql message +func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel == Info: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + } +} + +// Trace print sql message +func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Config.ParameterizedQueries { + return sql, nil + } + return sql, params +} + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +// New new trace recorder +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} +} + +// Trace implement logger interface +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go new file mode 100644 index 000000000..13e5d957d --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go @@ -0,0 +1,162 @@ +package logger + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" + "time" + "unicode" + + "gorm.io/gorm/utils" +) + +const ( + tmFmtWithMS = "2006-01-02 15:04:05.999" + tmFmtZero = "0000-00-00 00:00:00" + nullStr = "NULL" +) + +func isPrintable(s string) bool { + for _, r := range s { + if !unicode.IsPrint(r) { + return false + } + } + return true +} + +// A list of Go types that should be converted to SQL primitives +var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} + +// RegEx matches only numeric values +var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) + +// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { + var ( + convertParams func(interface{}, int) + vars = make([]string, len(avars)) + ) + + convertParams = func(v interface{}, idx int) { + switch v := v.(type) { + case bool: + vars[idx] = strconv.FormatBool(v) + case time.Time: + if v.IsZero() { + vars[idx] = escaper + tmFmtZero + escaper + } else { + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper + } + case *time.Time: + if v != nil { + if v.IsZero() { + vars[idx] = escaper + tmFmtZero + escaper + } else { + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper + } + } else { + vars[idx] = nullStr + } + case driver.Valuer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + r, _ := v.Value() + convertParams(r, idx) + } else { + vars[idx] = nullStr + } + case fmt.Stringer: + reflectValue := reflect.ValueOf(v) + switch reflectValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) + case reflect.Float32, reflect.Float64: + vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) + case reflect.Bool: + vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) + case reflect.String: + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + default: + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + } else { + vars[idx] = nullStr + } + } + case []byte: + if s := string(v); isPrintable(s) { + vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper + } else { + vars[idx] = escaper + "" + escaper + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + vars[idx] = utils.ToString(v) + case float32: + vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) + case string: + vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper + default: + rv := reflect.ValueOf(v) + if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { + vars[idx] = nullStr + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) + } else { + for _, t := range convertibleTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } + } + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper + } + } + } + + for idx, v := range avars { + convertParams(v, idx) + } + + if numericPlaceholder == nil { + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) + } + + sql = newSQL.String() + } else { + sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") + + sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string { + num := v[1 : len(v)-1] + n, _ := strconv.Atoi(num) + + // position var start from 1 ($1, $2) + n -= 1 + if n >= 0 && n <= len(vars)-1 { + return vars[n] + } + return v + }) + } + + return sql +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go new file mode 100644 index 000000000..0e01f567d --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go @@ -0,0 +1,109 @@ +package gorm + +import ( + "reflect" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Migrator returns migrator +func (db *DB) Migrator() Migrator { + tx := db.getInstance() + + // apply scopes to migrator + for len(tx.Statement.scopes) > 0 { + tx = tx.executeScopes() + } + + return tx.Dialector.Migrator(tx.Session(&Session{})) +} + +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + +// ViewOption view option +type ViewOption struct { + Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` + CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` + Query *DB // required subquery. +} + +// ColumnType column type interface +type ColumnType interface { + Name() string + DatabaseTypeName() string // varchar + ColumnType() (columnType string, ok bool) // varchar(64) + PrimaryKey() (isPrimaryKey bool, ok bool) + AutoIncrement() (isAutoIncrement bool, ok bool) + Length() (length int64, ok bool) + DecimalSize() (precision int64, scale int64, ok bool) + Nullable() (nullable bool, ok bool) + Unique() (unique bool, ok bool) + ScanType() reflect.Type + Comment() (value string, ok bool) + DefaultValue() (value string, ok bool) +} + +type Index interface { + Table() string + Name() string + Columns() []string + PrimaryKey() (isPrimaryKey bool, ok bool) + Unique() (unique bool, ok bool) + Option() string +} + +// TableType table type interface +type TableType interface { + Schema() string + Name() string + Type() string + Comment() (comment string, ok bool) +} + +// Migrator migrator interface +type Migrator interface { + // AutoMigrate + AutoMigrate(dst ...interface{}) error + + // Database + CurrentDatabase() string + FullDataTypeOf(*schema.Field) clause.Expr + GetTypeAliases(databaseTypeName string) []string + + // Tables + CreateTable(dst ...interface{}) error + DropTable(dst ...interface{}) error + HasTable(dst interface{}) bool + RenameTable(oldName, newName interface{}) error + GetTables() (tableList []string, err error) + TableType(dst interface{}) (TableType, error) + + // Columns + AddColumn(dst interface{}, field string) error + DropColumn(dst interface{}, field string) error + AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + HasColumn(dst interface{}, field string) bool + RenameColumn(dst interface{}, oldName, field string) error + ColumnTypes(dst interface{}) ([]ColumnType, error) + + // Views + CreateView(name string, option ViewOption) error + DropView(name string) error + + // Constraints + CreateConstraint(dst interface{}, name string) error + DropConstraint(dst interface{}, name string) error + HasConstraint(dst interface{}, name string) bool + + // Indexes + CreateIndex(dst interface{}, name string) error + DropIndex(dst interface{}, name string) error + HasIndex(dst interface{}, name string) bool + RenameIndex(dst interface{}, oldName, newName string) error + GetIndexes(dst interface{}) ([]Index, error) +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/column_type.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/column_type.go new file mode 100644 index 000000000..c6fdd6b2d --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/column_type.go @@ -0,0 +1,107 @@ +package migrator + +import ( + "database/sql" + "reflect" +) + +// ColumnType column type implements ColumnType interface +type ColumnType struct { + SQLColumnType *sql.ColumnType + NameValue sql.NullString + DataTypeValue sql.NullString + ColumnTypeValue sql.NullString + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + AutoIncrementValue sql.NullBool + LengthValue sql.NullInt64 + DecimalSizeValue sql.NullInt64 + ScaleValue sql.NullInt64 + NullableValue sql.NullBool + ScanTypeValue reflect.Type + CommentValue sql.NullString + DefaultValueValue sql.NullString +} + +// Name returns the name or alias of the column. +func (ct ColumnType) Name() string { + if ct.NameValue.Valid { + return ct.NameValue.String + } + return ct.SQLColumnType.Name() +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ct ColumnType) DatabaseTypeName() string { + if ct.DataTypeValue.Valid { + return ct.DataTypeValue.String + } + return ct.SQLColumnType.DatabaseTypeName() +} + +// ColumnType returns the database type of the column. like `varchar(16)` +func (ct ColumnType) ColumnType() (columnType string, ok bool) { + return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid +} + +// PrimaryKey returns the column is primary key or not. +func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { + return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid +} + +// AutoIncrement returns the column is auto increment or not. +func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { + return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid +} + +// Length returns the column type length for variable length column types +func (ct ColumnType) Length() (length int64, ok bool) { + if ct.LengthValue.Valid { + return ct.LengthValue.Int64, true + } + return ct.SQLColumnType.Length() +} + +// DecimalSize returns the scale and precision of a decimal type. +func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { + if ct.DecimalSizeValue.Valid { + return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true + } + return ct.SQLColumnType.DecimalSize() +} + +// Nullable reports whether the column may be null. +func (ct ColumnType) Nullable() (nullable bool, ok bool) { + if ct.NullableValue.Valid { + return ct.NullableValue.Bool, true + } + return ct.SQLColumnType.Nullable() +} + +// Unique reports whether the column may be unique. +func (ct ColumnType) Unique() (unique bool, ok bool) { + return ct.UniqueValue.Bool, ct.UniqueValue.Valid +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +func (ct ColumnType) ScanType() reflect.Type { + if ct.ScanTypeValue != nil { + return ct.ScanTypeValue + } + return ct.SQLColumnType.ScanType() +} + +// Comment returns the comment of current column. +func (ct ColumnType) Comment() (value string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} + +// DefaultValue returns the default value of current column. +func (ct ColumnType) DefaultValue() (value string, ok bool) { + return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/index.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/index.go new file mode 100644 index 000000000..8845da95b --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/index.go @@ -0,0 +1,43 @@ +package migrator + +import "database/sql" + +// Index implements gorm.Index interface +type Index struct { + TableName string + NameValue string + ColumnList []string + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + OptionValue string +} + +// Table return the table name of the index. +func (idx Index) Table() string { + return idx.TableName +} + +// Name return the name of the index. +func (idx Index) Name() string { + return idx.NameValue +} + +// Columns return the columns of the index +func (idx Index) Columns() []string { + return idx.ColumnList +} + +// PrimaryKey returns the index is primary key or not. +func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) { + return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid +} + +// Unique returns whether the index is unique or not. +func (idx Index) Unique() (unique bool, ok bool) { + return idx.UniqueValue.Bool, idx.UniqueValue.Valid +} + +// Option return the optional attribute of the index +func (idx Index) Option() string { + return idx.OptionValue +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go new file mode 100644 index 000000000..b15a43ef2 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go @@ -0,0 +1,965 @@ +package migrator + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "regexp" + "strings" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), +// with a possible trailing non-digit character (\D?). + +// For example, values that can pass this regular expression are: +// - "123" +// - "abc456" +// -"%$#@789" +var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) + +// TODO:? Create const vars for raw sql queries ? + +// Migrator m struct +type Migrator struct { + Config +} + +// Config schema config +type Config struct { + CreateIndexAfterCreateTable bool + DB *gorm.DB + gorm.Dialector +} + +type printSQLLogger struct { + logger.Interface +} + +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + fmt.Println(sql + ";") + l.Interface.Trace(ctx, begin, fc, err) +} + +// GormDataTypeInterface gorm data type interface +type GormDataTypeInterface interface { + GormDBDataType(*gorm.DB, *schema.Field) string +} + +// RunWithValue run migration with statement value +func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := &gorm.Statement{DB: m.DB} + if m.DB.Statement != nil { + stmt.Table = m.DB.Statement.Table + stmt.TableExpr = m.DB.Statement.TableExpr + } + + if table, ok := value.(string); ok { + stmt.Table = table + } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { + return err + } + + return fc(stmt) +} + +// DataTypeOf return field's db data type +func (m Migrator) DataTypeOf(field *schema.Field) string { + fieldValue := reflect.New(field.IndirectFieldType) + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { + return dataType + } + } + + return m.Dialector.DataTypeOf(field) +} + +// FullDataTypeOf returns field's db full data type +func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { + expr.SQL = m.DataTypeOf(field) + + if field.NotNull { + expr.SQL += " NOT NULL" + } + + if field.Unique { + expr.SQL += " UNIQUE" + } + + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) + } else if field.DefaultValue != "(-)" { + expr.SQL += " DEFAULT " + field.DefaultValue + } + } + + return +} + +// AutoMigrate auto migrate values +func (m Migrator) AutoMigrate(values ...interface{}) error { + for _, value := range m.ReorderModels(values, true) { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + if !queryTx.Migrator().HasTable(value) { + if err := execTx.Migrator().CreateTable(value); err != nil { + return err + } + } else { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + columnTypes, err := queryTx.Migrator().ColumnTypes(value) + if err != nil { + return err + } + var ( + parseIndexes = stmt.Schema.ParseIndexes() + parseCheckConstraints = stmt.Schema.ParseCheckConstraints() + ) + for _, dbName := range stmt.Schema.DBNames { + var foundColumn gorm.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == dbName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column + if err = execTx.Migrator().AddColumn(value, dbName); err != nil { + return err + } + } else { + // found, smartly migrate + field := stmt.Schema.FieldsByDBName[dbName] + if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + return err + } + } + } + + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if constraint := rel.ParseConstraint(); constraint != nil && + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } + } + } + } + + for _, chk := range parseCheckConstraints { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + + for _, idx := range parseIndexes { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + + return nil + }); err != nil { + return err + } + } + } + + return nil +} + +// GetTables returns tables +func (m Migrator) GetTables() (tableList []string, err error) { + err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). + Scan(&tableList).Error + return +} + +// CreateTable create table in database for values +func (m Migrator) CreateTable(values ...interface{}) error { + for _, value := range m.ReorderModels(values, false) { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + var ( + createTableSQL = "CREATE TABLE ? (" + values = []interface{}{m.CurrentTable(stmt)} + hasPrimaryKeyInDataType bool + ) + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + if !field.IgnoreMigration { + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) + createTableSQL += "," + } + } + + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) + } + + values = append(values, primaryKeys) + } + + for _, idx := range stmt.Schema.ParseIndexes() { + if m.CreateIndexAfterCreateTable { + defer func(value interface{}, name string) { + if err == nil { + err = tx.Migrator().CreateIndex(value, name) + } + }(value, idx.Name) + } else { + if idx.Class != "" { + createTableSQL += idx.Class + " " + } + createTableSQL += "INDEX ? ?" + + if idx.Comment != "" { + createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + + if idx.Option != "" { + createTableSQL += " " + idx.Option + } + + createTableSQL += "," + values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + } + } + + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + createTableSQL += "CONSTRAINT ? CHECK (?)," + values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } + + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + + createTableSQL += ")" + + if tableOption, ok := m.DB.Get("gorm:table_options"); ok { + createTableSQL += fmt.Sprint(tableOption) + } + + err = tx.Exec(createTableSQL, values...).Error + return err + }); err != nil { + return err + } + } + return nil +} + +// DropTable drop table for values +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + for i := len(values) - 1; i >= 0; i-- { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error + }); err != nil { + return err + } + } + return nil +} + +// HasTable returns table exists or not for value, value could be a struct or string +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) + }) + + return count > 0 +} + +// RenameTable rename table from oldName to newName +func (m Migrator) RenameTable(oldName, newName interface{}) error { + var oldTable, newTable interface{} + if v, ok := oldName.(string); ok { + oldTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(oldName); err == nil { + oldTable = m.CurrentTable(stmt) + } else { + return err + } + } + + if v, ok := newName.(string); ok { + newTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(newName); err == nil { + newTable = m.CurrentTable(stmt) + } else { + return err + } + } + + return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error +} + +// AddColumn create `name` column for value +func (m Migrator) AddColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // avoid using the same name field + f := stmt.Schema.LookUpField(name) + if f == nil { + return fmt.Errorf("failed to look up field with name: %s", name) + } + + if !f.IgnoreMigration { + return m.DB.Exec( + "ALTER TABLE ? ADD ? ?", + m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f), + ).Error + } + + return nil + }) +} + +// DropColumn drop value's `name` column +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name}, + ).Error + }) +} + +// AlterColumn alter value's `field` column' type based on schema definition +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + fileType := m.FullDataTypeOf(field) + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, + ).Error + + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +// HasColumn check has column `field` for value or not +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +// RenameColumn rename value's field name from oldName to newName +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +// MigrateColumn migrate column +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + // found, smart migrate + fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + var ( + alterColumn bool + isSameType = fullDataType == realDataType + ) + + if !field.PrimaryKey { + // check type + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } + + if !isSameType { + alterColumn = true + } + } + } + + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + // check unique + if unique, ok := columnType.Unique(); ok && unique != field.Unique { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check default value + if !field.PrimaryKey { + currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) + dv, dvNotNull := columnType.DefaultValue() + if dvNotNull && !currentDefaultNotNull { + // default value -> null + alterColumn = true + } else if !dvNotNull && currentDefaultNotNull { + // null -> default value + alterColumn = true + } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || + (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { + // default value not equal + // not both null + if currentDefaultNotNull || dvNotNull { + alterColumn = true + } + } + } + + // check comment + if comment, ok := columnType.Comment(); ok && comment != field.Comment { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + if alterColumn && !field.IgnoreMigration { + return m.DB.Migrator().AlterColumn(value, field.DBName) + } + + return nil +} + +// ColumnTypes return columnTypes []gorm.ColumnType and execErr error +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { + columnTypes := make([]gorm.ColumnType, 0) + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + if err != nil { + return err + } + + defer func() { + err = rows.Close() + }() + + var rawColumnTypes []*sql.ColumnType + rawColumnTypes, err = rows.ColumnTypes() + if err != nil { + return err + } + + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) + } + + return + }) + + return columnTypes, execErr +} + +// CreateView create view from Query in gorm.ViewOption. +// Query in gorm.ViewOption is a [subquery] +// +// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20 +// q := DB.Model(&User{}).Where("age > ?", 20) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q}) +// +// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION +// q := DB.Model(&User{}) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"}) +// +// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery +func (m Migrator) CreateView(name string, option gorm.ViewOption) error { + if option.Query == nil { + return gorm.ErrSubQueryRequired + } + + sql := new(strings.Builder) + sql.WriteString("CREATE ") + if option.Replace { + sql.WriteString("OR REPLACE ") + } + sql.WriteString("VIEW ") + m.QuoteTo(sql, name) + sql.WriteString(" AS ") + + m.DB.Statement.AddVar(sql, option.Query) + + if option.CheckOption != "" { + sql.WriteString(" ") + sql.WriteString(option.CheckOption) + } + return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error +} + +// DropView drop view +func (m Migrator) DropView(name string) error { + return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error +} + +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + +// GuessConstraintAndTable guess statement's constraint and it's table based on name +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { + if stmt.Schema == nil { + return nil, nil, stmt.Table + } + + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return nil, &chk, stmt.Table + } + + getTable := func(rel *schema.Relationship) string { + switch rel.Type { + case schema.HasOne, schema.HasMany: + return rel.FieldSchema.Table + case schema.Many2Many: + return rel.JoinTable.Table + } + return stmt.Table + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + return constraint, nil, getTable(rel) + } + } + + if field := stmt.Schema.LookUpField(name); field != nil { + for k := range checkConstraints { + if checkConstraints[k].Field == field { + v := checkConstraints[k] + return nil, &v, stmt.Table + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { + return constraint, nil, getTable(rel) + } + } + } + + return nil, nil, stmt.Schema.Table +} + +// CreateConstraint create constraint +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if chk != nil { + return m.DB.Exec( + "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", + m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + ).Error + } + + if constraint != nil { + vars := []interface{}{clause.Table{Name: table}} + if stmt.TableExpr != nil { + vars[0] = stmt.TableExpr + } + sql, values := buildConstraint(constraint) + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error + } + + return nil + }) +} + +// DropConstraint drop constraint +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error + }) +} + +// HasConstraint check has constraint or not +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + currentDatabase, table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +// BuildIndexOptions build index options +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } else if opt.Length > 0 { + str += fmt.Sprintf("(%d)", opt.Length) + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +// BuildIndexOptionsInterface build index options interface +type BuildIndexOptionsInterface interface { + BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} +} + +// CreateIndex create index `name` +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + if idx.Comment != "" { + createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + + if idx.Option != "" { + createIndexSQL += " " + idx.Option + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %s", name) + }) +} + +// DropIndex drop index `name` +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error + }) +} + +// HasIndex check has index `name` or not +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +// RenameIndex rename index from oldName to newName +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER TABLE ? RENAME INDEX ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +// CurrentDatabase returns current database name +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) + return +} + +// ReorderModels reorder models according to constraint dependencies +func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { + type Dependency struct { + *gorm.Statement + Depends []*schema.Schema + } + + var ( + modelNames, orderedModelNames []string + orderedModelNamesMap = map[string]bool{} + parsedSchemas = map[*schema.Schema]bool{} + valuesMap = map[string]Dependency{} + insertIntoOrderedList func(name string) + parseDependence func(value interface{}, addToList bool) + ) + + parseDependence = func(value interface{}, addToList bool) { + dep := Dependency{ + Statement: &gorm.Statement{DB: m.DB, Dest: value}, + } + beDependedOn := map[*schema.Schema]bool{} + // support for special table name + if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { + m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) + } + if _, ok := parsedSchemas[dep.Statement.Schema]; ok { + return + } + parsedSchemas[dep.Statement.Schema] = true + + if !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range dep.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { + dep.Depends = append(dep.Depends, c.ReferenceSchema) + } + + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } + + if rel.JoinTable != nil { + // append join value + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } + parseDependence(joinValue, autoAdd) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + } + } + } + + valuesMap[dep.Schema.Table] = dep + + if addToList { + modelNames = append(modelNames, dep.Schema.Table) + } + } + + insertIntoOrderedList = func(name string) { + if _, ok := orderedModelNamesMap[name]; ok { + return // avoid loop + } + orderedModelNamesMap[name] = true + + if autoAdd { + dep := valuesMap[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + insertIntoOrderedList(d.Table) + } else { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedList(d.Table) + } + } + } + + orderedModelNames = append(orderedModelNames, name) + } + + for _, value := range values { + if v, ok := value.(string); ok { + results = append(results, v) + } else { + parseDependence(value, true) + } + } + + for _, name := range modelNames { + insertIntoOrderedList(name) + } + + for _, name := range orderedModelNames { + results = append(results, valuesMap[name].Statement.Dest) + } + return +} + +// CurrentTable returns current statement's table expression +func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { + if stmt.TableExpr != nil { + return *stmt.TableExpr + } + return clause.Table{Name: stmt.Table} +} + +// GetIndexes return Indexes []gorm.Index and execErr error +func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { + return nil, errors.New("not support") +} + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} + +// TableType return tableType gorm.TableType and execErr error +func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) { + return nil, errors.New("not support") +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/table_type.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/table_type.go new file mode 100644 index 000000000..ed6e42a0e --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/table_type.go @@ -0,0 +1,33 @@ +package migrator + +import ( + "database/sql" +) + +// TableType table type implements TableType interface +type TableType struct { + SchemaValue string + NameValue string + TypeValue string + CommentValue sql.NullString +} + +// Schema returns the schema of the table. +func (ct TableType) Schema() string { + return ct.SchemaValue +} + +// Name returns the name of the table. +func (ct TableType) Name() string { + return ct.NameValue +} + +// Type returns the type of the table. +func (ct TableType) Type() string { + return ct.TypeValue +} + +// Comment returns the comment of current table. +func (ct TableType) Comment() (comment string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/model.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/model.go new file mode 100644 index 000000000..fa705df1c --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/model.go @@ -0,0 +1,16 @@ +package gorm + +import "time" + +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embedded into your model or you may build your own model without it +// +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt DeletedAt `gorm:"index"` +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go new file mode 100644 index 000000000..9d98c86e0 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go @@ -0,0 +1,229 @@ +package gorm + +import ( + "context" + "database/sql" + "reflect" + "sync" +) + +type Stmt struct { + *sql.Stmt + Transaction bool + prepared chan struct{} + prepareErr error +} + +type PreparedStmtDB struct { + Stmts map[string]*Stmt + PreparedSQL []string + Mux *sync.RWMutex + ConnPool +} + +func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { + return &PreparedStmtDB{ + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } +} + +func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { + if sqldb, ok := db.ConnPool.(*sql.DB); ok { + return sqldb, nil + } + + if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(gormdb) + } + + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + return nil, ErrInvalidDB +} + +func (db *PreparedStmtDB) Close() { + db.Mux.Lock() + defer db.Mux.Unlock() + + for _, query := range db.PreparedSQL { + if stmt, ok := db.Stmts[query]; ok { + delete(db.Stmts, query) + go stmt.Close() + } + } +} + +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() + + for _, stmt := range sdb.Stmts { + go stmt.Close() + } + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) +} + +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { + db.Mux.RLock() + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + db.Mux.RUnlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil + } + db.Mux.RUnlock() + + db.Mux.Lock() + // double check + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + db.Mux.Unlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil + } + + // cache preparing stmt first + cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} + db.Stmts[query] = &cacheStmt + db.Mux.Unlock() + + // prepare completed + defer close(cacheStmt.prepared) + + // Reason why cannot lock conn.PrepareContext + // suppose the maxopen is 1, g1 is creating record and g2 is querying record. + // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. + // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. + stmt, err := conn.PrepareContext(ctx, query) + if err != nil { + cacheStmt.prepareErr = err + db.Mux.Lock() + delete(db.Stmts, query) + db.Mux.Unlock() + return Stmt{}, err + } + + db.Mux.Lock() + cacheStmt.Stmt = stmt + db.PreparedSQL = append(db.PreparedSQL, query) + db.Mux.Unlock() + + return cacheStmt, nil +} + +func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { + if beginner, ok := db.ConnPool.(TxBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } + return nil, ErrInvalidTransaction +} + +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + result, err = stmt.ExecContext(ctx, args...) + if err != nil { + db.Mux.Lock() + defer db.Mux.Unlock() + go stmt.Close() + delete(db.Stmts, query) + } + } + return result, err +} + +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + rows, err = stmt.QueryContext(ctx, args...) + if err != nil { + db.Mux.Lock() + defer db.Mux.Unlock() + + go stmt.Close() + delete(db.Stmts, query) + } + } + return rows, err +} + +func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + return stmt.QueryRowContext(ctx, args...) + } + return &sql.Row{} +} + +type PreparedStmtTX struct { + Tx + PreparedStmtDB *PreparedStmtDB +} + +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + + go stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + } + } + return result, err +} + +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + + go stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + } + } + return rows, err +} + +func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) + } + return &sql.Row{} +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go new file mode 100644 index 000000000..736db4d3a --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go @@ -0,0 +1,342 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "reflect" + "time" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// prepareValues prepare values slice +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { + for idx, field := range fields { + if field != nil { + values[idx] = field.NewValuePool.Get() + } else if len(fields) == 1 { + if reflectValue.CanAddr() { + values[idx] = reflectValue.Addr().Interface() + } else { + values[idx] = reflectValue.Interface() + } + } + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + joinedNestedSchemaMap := make(map[string]interface{}) + for idx, field := range fields { + if field == nil { + continue + } + + if len(joinFields) == 0 || len(joinFields[idx]) == 0 { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { // joinFields count is larger than 2 when using join + var isNilPtrValue bool + var relValue reflect.Value + // does not contain raw dbname + nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] + // current reflect value + currentReflectValue := reflectValue + fullRels := make([]string, 0, len(nestedJoinSchemas)) + for _, joinSchema := range nestedJoinSchemas { + fullRels = append(fullRels, joinSchema.Name) + relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) + if relValue.Kind() == reflect.Ptr { + fullRelsName := utils.JoinNestedRelationNames(fullRels) + // same nested structure + if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + isNilPtrValue = true + break + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedNestedSchemaMap[fullRelsName] = nil + } + } + currentReflectValue = relValue + } + + if !isNilPtrValue { // ignore if value is nil + f := joinFields[idx][len(joinFields[idx])-1] + db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) + } + } + + // release data to pool + field.NewValuePool.Put(values[idx]) + } +} + +// ScanMode scan data mode +type ScanMode uint8 + +// scan modes +const ( + ScanInitialized ScanMode = 1 << 0 // 1 + ScanUpdate ScanMode = 1 << 1 // 2 + ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 +) + +// Scan scan rows into db statement +func Scan(rows Rows, db *DB, mode ScanMode) { + var ( + columns, _ = rows.Columns() + values = make([]interface{}, len(columns)) + initialized = mode&ScanInitialized != 0 + update = mode&ScanUpdate != 0 + onConflictDonothing = mode&ScanOnConflictDoNothing != 0 + ) + + db.RowsAffected = 0 + + switch dest := db.Statement.Dest.(type) { + case map[string]interface{}, *map[string]interface{}: + if initialized || rows.Next() { + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue, ok := dest.(map[string]interface{}) + if !ok { + if v, ok := dest.(*map[string]interface{}); ok { + if *v == nil { + *v = map[string]interface{}{} + } + mapValue = *v + } + } + scanIntoMap(mapValue, values, columns) + } + case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() + for initialized || rows.Next() { + prepareValues(values, db, columnTypes, columns) + + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue := map[string]interface{}{} + scanIntoMap(mapValue, values, columns) + *dest = append(*dest, mapValue) + } + case *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, + *float32, *float64, + *bool, *string, *time.Time, + *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, + *sql.NullBool, *sql.NullString, *sql.NullTime: + for initialized || rows.Next() { + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(dest)) + } + default: + var ( + fields = make([]*schema.Field, len(columns)) + joinFields [][]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue + ) + + if reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + + reflectValueType := reflectValue.Type() + switch reflectValueType.Kind() { + case reflect.Array, reflect.Slice: + reflectValueType = reflectValueType.Elem() + } + isPtr := reflectValueType.Kind() == reflect.Ptr + if isPtr { + reflectValueType = reflectValueType.Elem() + } + + if sch != nil { + if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { + sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + + if len(columns) == 1 { + // Is Pluck + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner + reflectValueType.Kind() != reflect.Struct || // is not struct + sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + sch = nil + } + } + + // Not Pluck + if sch != nil { + matchedFieldCount := make(map[string]int, len(columns)) + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + if count, ok := matchedFieldCount[column]; ok { + // handle duplicate fields + for _, selectField := range sch.Fields { + if selectField.DBName == column && selectField.Readable { + if count == 0 { + matchedFieldCount[column]++ + fields[idx] = selectField + break + } + count-- + } + } + } else { + matchedFieldCount[column] = 1 + } + } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + subNameCount := len(names) + // nested relation fields + relFields := make([]*schema.Field, 0, subNameCount-1) + relFields = append(relFields, rel.Field) + for _, name := range names[1 : subNameCount-1] { + rel = rel.FieldSchema.Relationships.Relations[name] + relFields = append(relFields, rel.Field) + } + // lastest name is raw dbname + dbName := names[subNameCount-1] + if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][]*schema.Field, len(columns)) + } + relFields = append(relFields, field) + joinFields[idx] = relFields + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + elem reflect.Value + isArrayKind = reflectValue.Kind() == reflect.Array + ) + + if !update || reflectValue.Len() == 0 { + update = false + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else if !isArrayKind { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } + } + + for initialized || rows.Next() { + BEGIN: + initialized = false + + if update { + if int(db.RowsAffected) >= reflectValue.Len() { + return + } + elem = reflectValue.Index(int(db.RowsAffected)) + if onConflictDonothing { + for _, field := range fields { + if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { + db.RowsAffected++ + goto BEGIN + } + } + } + } else { + elem = reflect.New(reflectValueType) + } + + db.scanIntoStruct(rows, elem, values, fields, joinFields) + + if !update { + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } + } else { + reflectValue = reflect.Append(reflectValue, elem) + } + } + } + + if !update { + db.Statement.ReflectValue.Set(reflectValue) + } + case reflect.Struct, reflect.Ptr: + if initialized || rows.Next() { + db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) + } + default: + db.AddError(rows.Scan(dest)) + } + } + + if err := rows.Err(); err != nil && err != db.Error { + db.AddError(err) + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { + db.AddError(ErrRecordNotFound) + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go new file mode 100644 index 000000000..89e732d36 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go @@ -0,0 +1,35 @@ +package schema + +import ( + "regexp" + "strings" +) + +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") + +type Check struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]Check { + checks := map[string]Check{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { + checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = Check{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go new file mode 100644 index 000000000..dd08e056b --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go @@ -0,0 +1,988 @@ +package schema + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/jinzhu/now" + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +// special types' reflect type +var ( + TimeReflectType = reflect.TypeOf(time.Time{}) + TimePtrReflectType = reflect.TypeOf(&time.Time{}) + ByteReflectType = reflect.TypeOf(uint8(0)) +) + +type ( + // DataType GORM data type + DataType string + // TimeType GORM time type + TimeType int64 +) + +// GORM time types +const ( + UnixTime TimeType = 1 + UnixSecond TimeType = 2 + UnixMillisecond TimeType = 3 + UnixNanosecond TimeType = 4 +) + +// GORM fields types +const ( + Bool DataType = "bool" + Int DataType = "int" + Uint DataType = "uint" + Float DataType = "float" + String DataType = "string" + Time DataType = "time" + Bytes DataType = "bytes" +) + +// Field is the representation of model schema's field +type Field struct { + Name string + DBName string + BindNames []string + DataType DataType + GORMDataType DataType + PrimaryKey bool + AutoIncrement bool + AutoIncrementIncrement int64 + Creatable bool + Updatable bool + Readable bool + AutoCreateTime TimeType + AutoUpdateTime TimeType + HasDefaultValue bool + DefaultValue string + DefaultValueInterface interface{} + NotNull bool + Unique bool + Comment string + Size int + Precision int + Scale int + IgnoreMigration bool + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + OwnerSchema *Schema + ReflectValueOf func(context.Context, reflect.Value) reflect.Value + ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) + Set func(context.Context, reflect.Value, interface{}) error + Serializer SerializerInterface + NewValuePool FieldNewValuePool +} + +func (field *Field) BindName() string { + return strings.Join(field.BindNames, ".") +} + +// ParseField parses reflect.StructField to Field +func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + var ( + err error + tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") + ) + + field := &Field{ + Name: fieldStruct.Name, + DBName: tagSetting["COLUMN"], + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Tag: fieldStruct.Tag, + TagSettings: tagSetting, + Schema: schema, + Creatable: true, + Updatable: true, + Readable: true, + PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), + Unique: utils.CheckTruth(tagSetting["UNIQUE"]), + Comment: tagSetting["COMMENT"], + AutoIncrementIncrement: 1, + } + + for field.IndirectFieldType.Kind() == reflect.Ptr { + field.IndirectFieldType = field.IndirectFieldType.Elem() + } + + fieldValue := reflect.New(field.IndirectFieldType) + // if field is valuer, used its value or first field as data type + valuer, isValuer := fieldValue.Interface().(driver.Valuer) + if isValuer { + if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { + fieldValue = reflect.ValueOf(v) + } + + // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + var ( + rv = reflect.Indirect(v) + rvType = rv.Type() + ) + + if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { + for i := 0; i < rvType.NumField(); i++ { + for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + + for i := 0; i < rvType.NumField(); i++ { + newFieldType := rvType.Field(i).Type + for newFieldType.Kind() == reflect.Ptr { + newFieldType = newFieldType.Elem() + } + + fieldValue = reflect.New(newFieldType) + if rvType != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + } + } + } + + getRealFieldValue(fieldValue) + } + } + + if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + field.DataType = String + field.Serializer = v + } else { + serializerName := field.TagSettings["JSON"] + if serializerName == "" { + serializerName = field.TagSettings["SERIALIZER"] + } + if serializerName != "" { + if serializer, ok := GetSerializer(serializerName); ok { + // Set default data type to string for serializer + field.DataType = String + field.Serializer = serializer + } else { + schema.err = fmt.Errorf("invalid serializer type %v", serializerName) + } + } + } + + if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { + field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) + } + + if v, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + field.DefaultValue = v + } + + if num, ok := field.TagSettings["SIZE"]; ok { + if field.Size, err = strconv.Atoi(num); err != nil { + field.Size = -1 + } + } + + if p, ok := field.TagSettings["PRECISION"]; ok { + field.Precision, _ = strconv.Atoi(p) + } + + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + + // default value is function or null or blank (primary keys) + field.DefaultValue = strings.TrimSpace(field.DefaultValue) + skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Bool: + field.DataType = Bool + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err) + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.DataType = Int + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.DataType = Uint + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err) + } + } + case reflect.Float32, reflect.Float64: + field.DataType = Float + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err) + } + } + case reflect.String: + field.DataType = String + if field.HasDefaultValue && !skipParseDefaultValue { + field.DefaultValue = strings.Trim(field.DefaultValue, "'") + field.DefaultValue = strings.Trim(field.DefaultValue, `"`) + field.DefaultValueInterface = field.DefaultValue + } + case reflect.Struct: + if _, ok := fieldValue.Interface().(*time.Time); ok { + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { + field.DataType = Time + } + if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { + if t, err := now.Parse(field.DefaultValue); err == nil { + field.DefaultValueInterface = t + } + } + case reflect.Array, reflect.Slice: + if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { + field.DataType = Bytes + } + } + + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + field.DataType = DataType(dataTyper.GormDataType()) + } + + if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if field.DataType == Time { + field.AutoCreateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { + field.AutoCreateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoCreateTime = UnixMillisecond + } else { + field.AutoCreateTime = UnixSecond + } + } + + if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if field.DataType == Time { + field.AutoUpdateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { + field.AutoUpdateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoUpdateTime = UnixMillisecond + } else { + field.AutoUpdateTime = UnixSecond + } + } + + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + + if val, ok := field.TagSettings["TYPE"]; ok { + switch DataType(strings.ToLower(val)) { + case Bool, Int, Uint, Float, String, Time, Bytes: + field.DataType = DataType(strings.ToLower(val)) + default: + field.DataType = DataType(val) + } + } + + if field.Size == 0 { + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: + field.Size = 64 + case reflect.Int8, reflect.Uint8: + field.Size = 8 + case reflect.Int16, reflect.Uint16: + field.Size = 16 + case reflect.Int32, reflect.Uint32, reflect.Float32: + field.Size = 32 + } + } + + // setup permission + if val, ok := field.TagSettings["-"]; ok { + val = strings.ToLower(strings.TrimSpace(val)) + switch val { + case "-": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + case "all": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + field.IgnoreMigration = true + case "migration": + field.IgnoreMigration = true + } + } + + if v, ok := field.TagSettings["->"]; ok { + field.Creatable = false + field.Updatable = false + if strings.ToLower(v) == "false" { + field.Readable = false + } else { + field.Readable = true + } + } + + if v, ok := field.TagSettings["<-"]; ok { + field.Creatable = true + field.Updatable = true + + if v != "<-" { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } + } + } + + // Normal anonymous field or having `EMBEDDED` tag + if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && + fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { + kind := reflect.Indirect(fieldValue).Kind() + switch kind { + case reflect.Struct: + var err error + field.Creatable = false + field.Updatable = false + field.Readable = false + + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + schema.err = err + } + + for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema + ef.OwnerSchema = field.EmbeddedSchema + ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) + } + + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { + ef.DBName = prefix + ef.DBName + } + + if ef.PrimaryKey { + if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) { + ef.PrimaryKey = false + + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + if !ef.AutoIncrement && ef.DefaultValue == "" { + ef.HasDefaultValue = false + } + } + } + + for k, v := range field.TagSettings { + ef.TagSettings[k] = v + } + } + case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128: + schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) + } + } + + return field +} + +// create valuer, setter when parse struct +func (field *Field) setupValuerAndSetter() { + // Setup NewValuePool + field.setupNewValuePool() + + // ValueOf returns field's value and if it is zero + fieldIndex := field.StructField.Index[0] + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(fieldIndex) + return fieldValue.Interface(), fieldValue.IsZero() + } + default: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } + } + } + + fv, zero := v.Interface(), v.IsZero() + return fv, zero + } + } + + if field.Serializer != nil { + oldValuerOf := field.ValueOf + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + value, zero := oldValuerOf(ctx, v) + + s, ok := value.(SerializerValuerInterface) + if !ok { + s = field.Serializer + } + + return &serializer{ + Field: field, + SerializeValuer: s, + Destination: v, + Context: ctx, + fieldValue: value, + }, zero + } + } + + // ReflectValueOf returns field's reflect value + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(fieldIndex) + } + default: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } + } + } + return v + } + } + + fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { + if v == nil { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + // Optimal value type acquisition for v + reflectValType := reflectV.Type() + + if reflectValType.AssignableTo(field.FieldType) { + if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr { + reflectV = reflect.Indirect(reflectV) + } + field.ReflectValueOf(ctx, value).Set(reflectV) + return + } else if reflectValType.ConvertibleTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) + return + } else if field.FieldType.Kind() == reflect.Ptr { + fieldValue := field.ReflectValueOf(ctx, value) + fieldType := field.FieldType.Elem() + + if reflectValType.AssignableTo(fieldType) { + if !fieldValue.IsValid() { + fieldValue = reflect.New(fieldType) + } else if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldType)) + } + fieldValue.Elem().Set(reflectV) + return + } else if reflectValType.ConvertibleTo(fieldType) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldType)) + } + + fieldValue.Elem().Set(reflectV.Convert(fieldType)) + return + } + } + + if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) + return + } else { + err = setter(ctx, value, reflectV.Elem().Interface()) + } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(ctx, value, v) + } + } else if _, ok := v.(clause.Expr); !ok { + return fmt.Errorf("failed to set value %#v to field %s", v, field.Name) + } + } + + return + } + + // Set + switch field.FieldType.Kind() { + case reflect.Bool: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + switch data := v.(type) { + case **bool: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetBool(**data) + } + case bool: + field.ReflectValueOf(ctx, value).SetBool(data) + case int64: + field.ReflectValueOf(ctx, value).SetBool(data > 0) + case string: + b, _ := strconv.ParseBool(data) + field.ReflectValueOf(ctx, value).SetBool(b) + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **int64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(**data) + } + case **int: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case int64: + field.ReflectValueOf(ctx, value).SetInt(data) + case int: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case int8: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case int16: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case int32: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint8: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint16: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint32: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint64: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case float32: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case float64: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case []byte: + return field.Set(ctx, value, string(data)) + case string: + if i, err := strconv.ParseInt(data, 0, 64); err == nil { + field.ReflectValueOf(ctx, value).SetInt(i) + } else { + return err + } + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) + } + case *time.Time: + if data != nil { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) + } + } else { + field.ReflectValueOf(ctx, value).SetInt(0) + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **uint64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(**data) + } + case **uint: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case uint64: + field.ReflectValueOf(ctx, value).SetUint(data) + case uint: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case uint8: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case uint16: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case uint32: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int64: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int8: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int16: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int32: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case float32: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case float64: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case []byte: + return field.Set(ctx, value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) + } + case string: + if i, err := strconv.ParseUint(data, 0, 64); err == nil { + field.ReflectValueOf(ctx, value).SetUint(i) + } else { + return err + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + case reflect.Float32, reflect.Float64: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **float64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(**data) + } + case **float32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(float64(**data)) + } + case float64: + field.ReflectValueOf(ctx, value).SetFloat(data) + case float32: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int64: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int8: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int16: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int32: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint8: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint16: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint32: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint64: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case []byte: + return field.Set(ctx, value, string(data)) + case string: + if i, err := strconv.ParseFloat(data, 64); err == nil { + field.ReflectValueOf(ctx, value).SetFloat(i) + } else { + return err + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + case reflect.String: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **string: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetString(**data) + } + case string: + field.ReflectValueOf(ctx, value).SetString(data) + case []byte: + field.ReflectValueOf(ctx, value).SetString(string(data)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) + case float64, float32: + field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + default: + fieldValue := reflect.New(field.FieldType) + switch fieldValue.Elem().Interface().(type) { + case time.Time: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.Set(ctx, value, *data) + } + case time.Time: + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) + case *time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) + } else { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) + } + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return nil + } + case *time.Time: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } + case time.Time: + fieldValue := field.ReflectValueOf(ctx, value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) + case string: + if t, err := now.Parse(data); err == nil { + fieldValue := field.ReflectValueOf(ctx, value) + if fieldValue.IsNil() { + if v == "" { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return nil + } + default: + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + // pointer scanner + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + return field.Set(ctx, value, reflectV.Elem().Interface()) + } else { + fieldValue := field.ReflectValueOf(ctx, value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = fieldValue.Interface().(sql.Scanner).Scan(v) + } + return + } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + return field.Set(ctx, value, reflectV.Elem().Interface()) + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + } else { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + return fallbackSetter(ctx, value, v, field.Set) + } + } + } + } + + if field.Serializer != nil { + var ( + oldFieldSetter = field.Set + sameElemType bool + sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() + ) + + if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { + sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() + } + + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + if s, ok := v.(*serializer); ok { + if s.fieldValue != nil { + err = oldFieldSetter(ctx, value, s.fieldValue) + } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if sameElemType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) + } else if sameType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) + } + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + s.Serializer = si.Interface().(SerializerInterface) + } + } else { + err = oldFieldSetter(ctx, value, v) + } + return + } + } +} + +func (field *Field) setupNewValuePool() { + if field.Serializer != nil { + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + return &serializer{ + Field: field, + Serializer: si.Interface().(SerializerInterface), + } + }, + } + } + + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go new file mode 100644 index 000000000..f5ac5dd21 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go @@ -0,0 +1,166 @@ +package schema + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +type Index struct { + Name string + Class string // UNIQUE | FULLTEXT | SPATIAL + Type string // btree, hash, gist, spgist, gin, and brin + Where string + Comment string + Option string // WITH PARSER parser_name + Fields []IndexOption +} + +type IndexOption struct { + *Field + Expression string + Sort string // DESC, ASC + Collate string + Length int + priority int +} + +// ParseIndexes parse schema indexes +func (schema *Schema) ParseIndexes() map[string]Index { + indexes := map[string]Index{} + + for _, field := range schema.Fields { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { + fieldIndexes, err := parseFieldIndexes(field) + if err != nil { + schema.err = err + break + } + for _, index := range fieldIndexes { + idx := indexes[index.Name] + idx.Name = index.Name + if idx.Class == "" { + idx.Class = index.Class + } + if idx.Type == "" { + idx.Type = index.Type + } + if idx.Where == "" { + idx.Where = index.Where + } + if idx.Comment == "" { + idx.Comment = index.Comment + } + if idx.Option == "" { + idx.Option = index.Option + } + + idx.Fields = append(idx.Fields, index.Fields...) + sort.Slice(idx.Fields, func(i, j int) bool { + return idx.Fields[i].priority < idx.Fields[j].priority + }) + + indexes[index.Name] = idx + } + } + } + for _, index := range indexes { + if index.Class == "UNIQUE" && len(index.Fields) == 1 { + index.Fields[0].Field.Unique = true + } + } + return indexes +} + +func (schema *Schema) LookIndex(name string) *Index { + if schema != nil { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { + return &index + } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } + } + } + + return nil +} + +func parseFieldIndexes(field *Field) (indexes []Index, err error) { + for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if k == "INDEX" || k == "UNIQUEINDEX" { + var ( + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") + settings = ParseTagSetting(tagSetting, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) + ) + + if idx == -1 { + idx = len(tag) + } + + if idx != -1 { + name = tag[0:idx] + } + + if name == "" { + subName := field.Name + const key = "COMPOSITE" + if composite, found := settings[key]; found { + if len(composite) == 0 || composite == key { + err = fmt.Errorf( + "The composite tag of %s.%s cannot be empty", + field.Schema.Name, + field.Name) + return + } + subName = composite + } + name = field.Schema.namer.IndexName( + field.Schema.Table, subName) + } + + if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { + settings["CLASS"] = "UNIQUE" + } + + priority, err := strconv.Atoi(settings["PRIORITY"]) + if err != nil { + priority = 10 + } + + indexes = append(indexes, Index{ + Name: name, + Class: settings["CLASS"], + Type: settings["TYPE"], + Where: settings["WHERE"], + Comment: settings["COMMENT"], + Option: settings["OPTION"], + Fields: []IndexOption{{ + Field: field, + Expression: settings["EXPRESSION"], + Sort: settings["SORT"], + Collate: settings["COLLATE"], + Length: length, + priority: priority, + }}, + }) + } + } + } + + err = nil + return +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go new file mode 100644 index 000000000..a75a33c0d --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go @@ -0,0 +1,36 @@ +package schema + +import ( + "gorm.io/gorm/clause" +) + +// GormDataTypeInterface gorm data type interface +type GormDataTypeInterface interface { + GormDataType() string +} + +// FieldNewValuePool field new scan value pool +type FieldNewValuePool interface { + Get() interface{} + Put(interface{}) +} + +// CreateClausesInterface create clauses interface +type CreateClausesInterface interface { + CreateClauses(*Field) []clause.Interface +} + +// QueryClausesInterface query clauses interface +type QueryClausesInterface interface { + QueryClauses(*Field) []clause.Interface +} + +// UpdateClausesInterface update clauses interface +type UpdateClausesInterface interface { + UpdateClauses(*Field) []clause.Interface +} + +// DeleteClausesInterface delete clauses interface +type DeleteClausesInterface interface { + DeleteClauses(*Field) []clause.Interface +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go new file mode 100644 index 000000000..a2a0150a3 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go @@ -0,0 +1,186 @@ +package schema + +import ( + "crypto/sha1" + "encoding/hex" + "regexp" + "strings" + "unicode/utf8" + + "github.com/jinzhu/inflection" +) + +// Namer namer interface +type Namer interface { + TableName(table string) string + SchemaName(table string) string + ColumnName(table, column string) string + JoinTableName(joinTable string) string + RelationshipFKName(Relationship) string + CheckerName(table, column string) string + IndexName(table, column string) string +} + +// Replacer replacer interface like strings.Replacer +type Replacer interface { + Replace(name string) string +} + +// NamingStrategy tables, columns naming strategy +type NamingStrategy struct { + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool + IdentifierMaxLength int +} + +// TableName convert string to table name +func (ns NamingStrategy) TableName(str string) string { + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName +func (ns NamingStrategy) SchemaName(table string) string { + table = strings.TrimPrefix(table, ns.TablePrefix) + + if ns.SingularTable { + return ns.toSchemaName(table) + } + return ns.toSchemaName(inflection.Singular(table)) +} + +// ColumnName convert string to column name +func (ns NamingStrategy) ColumnName(table, column string) string { + return ns.toDBName(column) +} + +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + if !ns.NoLowerCase && strings.ToLower(str) == str { + return ns.TablePrefix + str + } + + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// RelationshipFKName generate fk name for relation +func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { + return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name)) +} + +// CheckerName generate checker name +func (ns NamingStrategy) CheckerName(table, column string) string { + return ns.formatName("chk", table, column) +} + +// IndexName generate index name +func (ns NamingStrategy) IndexName(table, column string) string { + return ns.formatName("idx", table, ns.toDBName(column)) +} + +func (ns NamingStrategy) formatName(prefix, table, name string) string { + formattedName := strings.ReplaceAll(strings.Join([]string{ + prefix, table, name, + }, "_"), ".", "_") + + if ns.IdentifierMaxLength == 0 { + ns.IdentifierMaxLength = 64 + } + + if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { + h := sha1.New() + h.Write([]byte(formattedName)) + bs := h.Sum(nil) + + formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] + } + return formattedName +} + +var ( + // https://github.com/golang/lint/blob/master/lint.go#L770 + commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} + commonInitialismsReplacer *strings.Replacer +) + +func init() { + commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +func (ns NamingStrategy) toDBName(name string) string { + if name == "" { + return "" + } + + if ns.NameReplacer != nil { + tmpName := ns.NameReplacer.Replace(name) + + if tmpName == "" { + return name + } + + name = tmpName + } + + if ns.NoLowerCase { + return name + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf strings.Builder + lastCase, nextCase, nextNumber bool // upper case == true + curCase = value[0] <= 'Z' && value[0] >= 'A' + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' + nextNumber = value[i+1] >= '0' && value[i+1] <= '9' + + if curCase { + if lastCase && (nextCase || nextNumber) { + buf.WriteRune(v + 32) + } else { + if i > 0 && value[i-1] != '_' && value[i+1] != '_' { + buf.WriteByte('_') + } + buf.WriteRune(v + 32) + } + } else { + buf.WriteRune(v) + } + + lastCase = curCase + curCase = nextCase + } + + if curCase { + if !lastCase && len(value) > 1 { + buf.WriteByte('_') + } + buf.WriteByte(value[len(value)-1] + 32) + } else { + buf.WriteByte(value[len(value)-1]) + } + ret := buf.String() + return ret +} + +func (ns NamingStrategy) toSchemaName(name string) string { + result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") + for _, initialism := range commonInitialisms { + result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") + } + return result +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/pool.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/pool.go new file mode 100644 index 000000000..fa62fe223 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/pool.go @@ -0,0 +1,19 @@ +package schema + +import ( + "reflect" + "sync" +) + +// sync pools +var ( + normalPool sync.Map + poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { + v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ + New: func() interface{} { + return reflect.New(reflectType).Interface() + }, + }) + return v.(FieldNewValuePool) + } +) diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go new file mode 100644 index 000000000..e03dcc520 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go @@ -0,0 +1,699 @@ +package schema + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/jinzhu/inflection" + "gorm.io/gorm/clause" +) + +// RelationshipType relationship type +type RelationshipType string + +const ( + HasOne RelationshipType = "has_one" // HasOneRel has one relationship + HasMany RelationshipType = "has_many" // HasManyRel has many relationship + BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship + has RelationshipType = "has" +) + +type Relationships struct { + HasOne []*Relationship + BelongsTo []*Relationship + HasMany []*Relationship + Many2Many []*Relationship + Relations map[string]*Relationship + + EmbeddedRelations map[string]*Relationships +} + +type Relationship struct { + Name string + Type RelationshipType + Field *Field + Polymorphic *Polymorphic + References []*Reference + Schema *Schema + FieldSchema *Schema + JoinTable *Schema + foreignKeys, primaryKeys []string +} + +type Polymorphic struct { + PolymorphicID *Field + PolymorphicType *Field + Value string +} + +type Reference struct { + PrimaryKey *Field + PrimaryValue string + ForeignKey *Field + OwnPrimaryKey bool +} + +func (schema *Schema) parseRelation(field *Field) *Relationship { + var ( + err error + fieldValue = reflect.New(field.IndirectFieldType).Interface() + relation = &Relationship{ + Name: field.Name, + Field: field, + Schema: schema, + foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + primaryKeys: toColumns(field.TagSettings["REFERENCES"]), + } + ) + + cacheStore := schema.cacheStore + + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return nil + } + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + schema.buildPolymorphicRelation(relation, field, polymorphic) + } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + schema.buildMany2ManyRelation(relation, field, many2many) + } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { + schema.guessRelation(relation, field, guessBelongs) + } else { + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + schema.guessRelation(relation, field, guessGuess) + case reflect.Slice: + schema.guessRelation(relation, field, guessHas) + default: + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) + } + } + + if relation.Type == has { + // don't add relations to embedded schema, which might be shared + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { + relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation + } + + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + relation.Type = HasOne + case reflect.Slice: + relation.Type = HasMany + } + } + + if schema.err == nil { + schema.setRelation(relation) + switch relation.Type { + case HasOne: + schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) + case HasMany: + schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation) + case BelongsTo: + schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation) + case Many2Many: + schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) + } + } + + return relation +} + +func (schema *Schema) setRelation(relation *Relationship) { + // set non-embedded relation + if rel := schema.Relationships.Relations[relation.Name]; rel != nil { + if len(rel.Field.BindNames) > 1 { + schema.Relationships.Relations[relation.Name] = relation + } + } else { + schema.Relationships.Relations[relation.Name] = relation + } + + // set embedded relation + if len(relation.Field.BindNames) <= 1 { + return + } + relationships := &schema.Relationships + for i, name := range relation.Field.BindNames { + if i < len(relation.Field.BindNames)-1 { + if relationships.EmbeddedRelations == nil { + relationships.EmbeddedRelations = map[string]*Relationships{} + } + if r := relationships.EmbeddedRelations[name]; r == nil { + relationships.EmbeddedRelations[name] = &Relationships{} + } + relationships = relationships.EmbeddedRelations[name] + } else { + if relationships.Relations == nil { + relationships.Relations = map[string]*Relationship{} + } + relationships.Relations[relation.Name] = relation + } + } +} + +// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` +// +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, &Reference{ + PrimaryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.foreignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) + } + } + + if primaryKeyField == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + return + } + + // use same data type for foreign keys + if copyableDataType(primaryKeyField.DataType) { + relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + } + relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType + if relation.Polymorphic.PolymorphicID.Size == 0 { + relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicID, + OwnPrimaryKey: true, + }) + } + + relation.Type = has +} + +func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { + relation.Type = Many2Many + + var ( + err error + joinTableFields []reflect.StructField + fieldsMap = map[string]*Field{} + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} + joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) + joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) + ) + + ownForeignFields := schema.PrimaryFields + refForeignFields := relation.FieldSchema.PrimaryFields + + if len(relation.foreignKeys) > 0 { + ownForeignFields = []*Field{} + for _, foreignKey := range relation.foreignKeys { + if field := schema.LookUpField(foreignKey); field != nil { + ownForeignFields = append(ownForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) + return + } + } + } + + if len(relation.primaryKeys) > 0 { + refForeignFields = []*Field{} + for _, foreignKey := range relation.primaryKeys { + if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { + refForeignFields = append(refForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) + return + } + } + } + + for idx, ownField := range ownForeignFields { + joinFieldName := strings.Title(schema.Name) + ownField.Name + if len(joinForeignKeys) > idx { + joinFieldName = strings.Title(joinForeignKeys[idx]) + } + + ownFieldsMap[joinFieldName] = ownField + fieldsMap[joinFieldName] = ownField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: ownField.StructField.PkgPath, + Type: ownField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + + for idx, relField := range refForeignFields { + joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name + + if _, ok := ownFieldsMap[joinFieldName]; ok { + if field.Name != relation.FieldSchema.Name { + joinFieldName = inflection.Singular(field.Name) + relField.Name + } else { + joinFieldName += "Reference" + } + } + + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + } + + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: strings.Title(schema.Name) + field.Name, + Type: schema.ModelType, + Tag: `gorm:"-"`, + }) + + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + schema.err = err + } + relation.JoinTable.Name = many2many + relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) + + relName := relation.Schema.Name + relRefName := relation.FieldSchema.Name + if relName == relRefName { + relRefName = relation.Field.Name + } + + if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok { + relation.JoinTable.Relationships.Relations[relName] = &Relationship{ + Name: relName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.Schema, + } + } else { + relation.JoinTable.Relationships.Relations[relName].References = []*Reference{} + } + + if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok { + relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{ + Name: relRefName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.FieldSchema, + } + } else { + relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{} + } + + // build references + for _, f := range relation.JoinTable.Fields { + if f.Creatable || f.Readable || f.Updatable { + // use same data type for foreign keys + if copyableDataType(fieldsMap[f.Name].DataType) { + f.DataType = fieldsMap[f.Name].DataType + } + f.GORMDataType = fieldsMap[f.Name].GORMDataType + if f.Size == 0 { + f.Size = fieldsMap[f.Name].Size + } + relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) + + if of, ok := ownFieldsMap[f.Name]; ok { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, + ForeignKey: f, + }) + } + } + } +} + +type guessLevel int + +const ( + guessGuess guessLevel = iota + guessBelongs + guessEmbeddedBelongs + guessHas + guessEmbeddedHas +) + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) { + var ( + primaryFields, foreignFields []*Field + primarySchema, foreignSchema = schema, relation.FieldSchema + gl = cgl + ) + + if gl == guessGuess { + if field.Schema == relation.FieldSchema { + gl = guessBelongs + } else { + gl = guessHas + } + } + + reguessOrErr := func() { + switch cgl { + case guessGuess: + schema.guessRelation(relation, field, guessBelongs) + case guessBelongs: + schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: + default: + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + } + } + + switch gl { + case guessBelongs: + primarySchema, foreignSchema = relation.FieldSchema, schema + case guessEmbeddedBelongs: + if field.OwnerSchema == nil { + reguessOrErr() + return + } + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema == nil { + reguessOrErr() + return + } + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } + + if len(relation.foreignKeys) > 0 { + for _, foreignKey := range relation.foreignKeys { + f := foreignSchema.LookUpField(foreignKey) + if f == nil { + reguessOrErr() + return + } + foreignFields = append(foreignFields, f) + } + } else { + primarySchemaName := primarySchema.Name + if primarySchemaName == "" { + primarySchemaName = relation.FieldSchema.Name + } + + if len(relation.primaryKeys) > 0 { + for _, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + primaryFields = append(primaryFields, f) + } + } + } else { + primaryFields = primarySchema.PrimaryFields + } + + primaryFieldLoop: + for _, primaryField := range primaryFields { + lookUpName := primarySchemaName + primaryField.Name + if gl == guessBelongs { + lookUpName = field.Name + primaryField.Name + } + + lookUpNames := []string{lookUpName} + if len(primaryFields) == 1 { + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + } + + for _, name := range lookUpNames { + if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } + for _, name := range lookUpNames { + if f := foreignSchema.LookUpField(name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } + } + } + + switch { + case len(foreignFields) == 0: + reguessOrErr() + return + case len(relation.primaryKeys) > 0: + for idx, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + if len(primaryFields) < idx+1 { + primaryFields = append(primaryFields, f) + } else if f != primaryFields[idx] { + reguessOrErr() + return + } + } else { + reguessOrErr() + return + } + } + case len(primaryFields) == 0: + if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { + primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) + } else if len(primarySchema.PrimaryFields) == len(foreignFields) { + primaryFields = append(primaryFields, primarySchema.PrimaryFields...) + } else { + reguessOrErr() + return + } + } + + // build references + for idx, foreignField := range foreignFields { + // use same data type for foreign keys + if copyableDataType(primaryFields[idx].DataType) { + foreignField.DataType = primaryFields[idx].DataType + } + foreignField.GORMDataType = primaryFields[idx].GORMDataType + if foreignField.Size == 0 { + foreignField.Size = primaryFields[idx].Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryFields[idx], + ForeignKey: foreignField, + OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), + }) + } + + if gl == guessHas || gl == guessEmbeddedHas { + relation.Type = has + } else { + relation.Type = BelongsTo + } +} + +type Constraint struct { + Name string + Field *Field + Schema *Schema + ForeignKeys []*Field + ReferenceSchema *Schema + References []*Field + OnDelete string + OnUpdate string +} + +func (rel *Relationship) ParseConstraint() *Constraint { + str := rel.Field.TagSettings["CONSTRAINT"] + if str == "-" { + return nil + } + + if rel.Type == BelongsTo { + for _, r := range rel.FieldSchema.Relationships.Relations { + if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { + matched := true + for idx, ref := range r.References { + if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && + rel.References[idx].PrimaryValue == ref.PrimaryValue) { + matched = false + } + } + + if matched { + return nil + } + } + } + } + + var ( + name string + idx = strings.Index(str, ",") + settings = ParseTagSetting(str, ",") + ) + + // optimize match english letters and midline + // The following code is basically called in for. + // In order to avoid the performance problems caused by repeated compilation of regular expressions, + // it only needs to be done once outside, so optimization is done here. + if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) { + name = str[0:idx] + } else { + name = rel.Schema.namer.RelationshipFKName(*rel) + } + + constraint := Constraint{ + Name: name, + Field: rel.Field, + OnUpdate: settings["ONUPDATE"], + OnDelete: settings["ONDELETE"], + } + + for _, ref := range rel.References { + if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) { + constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) + constraint.References = append(constraint.References, ref.PrimaryKey) + + if ref.OwnPrimaryKey { + constraint.Schema = ref.ForeignKey.Schema + constraint.ReferenceSchema = rel.Schema + } else { + constraint.Schema = rel.Schema + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } + } + } + + return &constraint +} + +func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { + table := rel.FieldSchema.Table + foreignFields := []*Field{} + relForeignKeys := []string{} + + if rel.JoinTable != nil { + table = rel.JoinTable.Table + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + } + + _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) + column, values := ToQueryValues(table, relForeignKeys, foreignValues) + + conds = append(conds, clause.IN{Column: column, Values: values}) + return +} + +func copyableDataType(str DataType) bool { + for _, s := range []string{"auto_increment", "primary key"} { + if strings.Contains(strings.ToLower(string(str)), s) { + return false + } + } + return true +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go new file mode 100644 index 000000000..e13a5ed13 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go @@ -0,0 +1,370 @@ +package schema + +import ( + "context" + "errors" + "fmt" + "go/ast" + "reflect" + "strings" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +// ErrUnsupportedDataType unsupported data type +var ErrUnsupportedDataType = errors.New("unsupported data type") + +type Schema struct { + Name string + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + DBNames []string + PrimaryFields []*Field + PrimaryFieldDBNames []string + Fields []*Field + FieldsByName map[string]*Field + FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' + FieldsByDBName map[string]*Field + FieldsWithDefaultDBValue []*Field // fields with default value assigned by database + Relationships Relationships + CreateClauses []clause.Interface + QueryClauses []clause.Interface + UpdateClauses []clause.Interface + DeleteClauses []clause.Interface + BeforeCreate, AfterCreate bool + BeforeUpdate, AfterUpdate bool + BeforeDelete, AfterDelete bool + BeforeSave, AfterSave bool + AfterFind bool + err error + initialized chan struct{} + namer Namer + cacheStore *sync.Map +} + +func (schema Schema) String() string { + if schema.ModelType.Name() == "" { + return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) + } + return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) +} + +func (schema Schema) MakeSlice() reflect.Value { + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) + results := reflect.New(slice.Type()) + results.Elem().Set(slice) + return results +} + +func (schema Schema) LookUpField(name string) *Field { + if field, ok := schema.FieldsByDBName[name]; ok { + return field + } + if field, ok := schema.FieldsByName[name]; ok { + return field + } + return nil +} + +// LookUpFieldByBindName looks for the closest field in the embedded struct. +// +// type Struct struct { +// Embedded struct { +// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID") +// } +// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") +// } +func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { + if len(bindNames) == 0 { + return nil + } + for i := len(bindNames) - 1; i >= 0; i-- { + find := strings.Join(bindNames[:i], ".") + "." + name + if field, ok := schema.FieldsByBindName[find]; ok { + return field + } + } + return nil +} + +type Tabler interface { + TableName() string +} + +type TablerWithNamer interface { + TableName(Namer) string +} + +// Parse get data type from dialector +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + return ParseWithSpecialTableName(dest, cacheStore, namer, "") +} + +// ParseWithSpecialTableName get data type from dialector with extra schema table +func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { + if dest == nil { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + + value := reflect.ValueOf(dest) + if value.Kind() == reflect.Ptr && value.IsNil() { + value = reflect.New(value.Type().Elem()) + } + modelType := reflect.Indirect(value).Type() + + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + // Cache the Schema for performance, + // Use the modelType or modelType + schemaTable (if it present) as cache key. + var schemaCacheKey interface{} + if specialTableName != "" { + schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) + } else { + schemaCacheKey = modelType + } + + // Load exist schema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + modelValue := reflect.New(modelType) + tableName := namer.TableName(modelType.Name()) + if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } + if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } + if specialTableName != "" && specialTableName != tableName { + tableName = specialTableName + } + + schema := &Schema{ + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByBindName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), + } + // When the schema initialization is completed, the channel will be closed + defer close(schema.initialized) + + // Load exist schema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { + schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) + } else { + schema.Fields = append(schema.Fields, field) + } + } + } + + for _, field := range schema.Fields { + if field.DBName == "" && field.DataType != "" { + field.DBName = namer.ColumnName(schema.Table, field.Name) + } + + bindName := field.BindName() + if field.DBName != "" { + // nonexistence or shortest path or first appear prioritized if has permission + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { + if _, ok := schema.FieldsByDBName[field.DBName]; !ok { + schema.DBNames = append(schema.DBNames, field.DBName) + } + schema.FieldsByDBName[field.DBName] = field + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[bindName] = field + + if v != nil && v.PrimaryKey { + for idx, f := range schema.PrimaryFields { + if f == v { + schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) + } + } + } + + if field.PrimaryKey { + schema.PrimaryFields = append(schema.PrimaryFields, field) + } + } + } + + if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByName[field.Name] = field + } + if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByBindName[bindName] = field + } + + field.setupValuerAndSetter() + } + + prioritizedPrimaryField := schema.LookUpField("id") + if prioritizedPrimaryField == nil { + prioritizedPrimaryField = schema.LookUpField("ID") + } + + if prioritizedPrimaryField != nil { + if prioritizedPrimaryField.PrimaryKey { + schema.PrioritizedPrimaryField = prioritizedPrimaryField + } else if len(schema.PrimaryFields) == 0 { + prioritizedPrimaryField.PrimaryKey = true + schema.PrioritizedPrimaryField = prioritizedPrimaryField + schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) + } + } + + if schema.PrioritizedPrimaryField == nil { + if len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } else if len(schema.PrimaryFields) > 1 { + // If there are multiple primary keys, the AUTOINCREMENT field is prioritized + for _, field := range schema.PrimaryFields { + if field.AutoIncrement { + schema.PrioritizedPrimaryField = field + break + } + } + } + } + + for _, field := range schema.PrimaryFields { + schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) + } + + for _, field := range schema.Fields { + if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + } + + if field := schema.PrioritizedPrimaryField; field != nil { + switch field.GORMDataType { + case Int, Uint: + if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + field.AutoIncrement = true + } + } + } + + callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} + for _, name := range callbacks { + if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": // TODO hack + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) + } + } + } + + // Cache the schema + if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() + + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } else { + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + } + + fieldValue := reflect.New(field.IndirectFieldType) + fieldInterface := fieldValue.Interface() + if fc, ok := fieldInterface.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldInterface.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldInterface.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldInterface.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } + } + + return schema, schema.err +} + +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema), nil + } + + return Parse(dest, cacheStore, namer) +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go new file mode 100644 index 000000000..397edff03 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go @@ -0,0 +1,170 @@ +package schema + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "encoding/gob" + "encoding/json" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var serializerMap = sync.Map{} + +// RegisterSerializer register serializer +func RegisterSerializer(name string, serializer SerializerInterface) { + serializerMap.Store(strings.ToLower(name), serializer) +} + +// GetSerializer get serializer +func GetSerializer(name string) (serializer SerializerInterface, ok bool) { + v, ok := serializerMap.Load(strings.ToLower(name)) + if ok { + serializer, ok = v.(SerializerInterface) + } + return serializer, ok +} + +func init() { + RegisterSerializer("json", JSONSerializer{}) + RegisterSerializer("unixtime", UnixSecondSerializer{}) + RegisterSerializer("gob", GobSerializer{}) +} + +// Serializer field value serializer +type serializer struct { + Field *Field + Serializer SerializerInterface + SerializeValuer SerializerValuerInterface + Destination reflect.Value + Context context.Context + value interface{} + fieldValue interface{} +} + +// Scan implements sql.Scanner interface +func (s *serializer) Scan(value interface{}) error { + s.value = value + return nil +} + +// Value implements driver.Valuer interface +func (s serializer) Value() (driver.Value, error) { + return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { + Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { + Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) +} + +// JSONSerializer json serializer +type JSONSerializer struct{} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) + } + + if len(bytes) > 0 { + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } + return nil, err + } + return string(result), err +} + +// UnixSecondSerializer json serializer +type UnixSecondSerializer struct{} + +// Scan implements serializer interface +func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + t := sql.NullTime{} + if err = t.Scan(dbValue); err == nil && t.Valid { + err = field.Set(ctx, dst, t.Time.Unix()) + } + + return +} + +// Value implements serializer interface +func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + rv := reflect.ValueOf(fieldValue) + switch v := fieldValue.(type) { + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.Indirect(rv).Int(), 0) + case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + if rv.IsZero() { + return nil, nil + } + result = time.Unix(reflect.Indirect(rv).Int(), 0) + default: + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + } + return +} + +// GobSerializer gob serializer +type GobSerializer struct{} + +// Scan implements serializer interface +func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytesValue []byte + switch v := dbValue.(type) { + case []byte: + bytesValue = v + default: + return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) + } + if len(bytesValue) > 0 { + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } + } + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(fieldValue) + return buf.Bytes(), err +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go new file mode 100644 index 000000000..65d012e54 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go @@ -0,0 +1,208 @@ +package schema + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +var embeddedCacheKey = "embedded_cache_store" + +func ParseTagSetting(str string, sep string) map[string]string { + settings := map[string]string{} + names := strings.Split(str, sep) + + for i := 0; i < len(names); i++ { + j := i + if len(names[j]) > 0 { + for { + if names[j][len(names[j])-1] == '\\' { + i++ + names[j] = names[j][0:len(names[j])-1] + sep + names[i] + names[i] = "" + } else { + break + } + } + } + + values := strings.Split(names[j], ":") + k := strings.TrimSpace(strings.ToUpper(values[0])) + + if len(values) >= 2 { + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k + } + } + + return settings +} + +func toColumns(val string) (results []string) { + if val != "" { + for _, v := range strings.Split(val, ",") { + results = append(results, strings.TrimSpace(v)) + } + } + return +} + +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag +} + +func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { + t := tag.Get("gorm") + if strings.Contains(t, value) { + return tag + } + return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) +} + +// GetRelationsValues get relations's values from a reflect value +func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result.Addr()) + case reflect.Slice, reflect.Array: + for i := 0; i < result.Len(); i++ { + if elem := result.Index(i); elem.Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, elem) + } else { + reflectResults = reflect.Append(reflectResults, elem.Addr()) + } + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +// GetIdentityFieldValuesMap get identity map from fields +func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + loaded = map[interface{}]bool{} + notZero, zero bool + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + results[0][idx], zero = field.ValueOf(ctx, reflectValue) + notZero = notZero || !zero + } + + if !notZero { + return nil, nil + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + elem := reflectValue.Index(i) + elemKey := elem.Interface() + if elem.Kind() != reflect.Ptr && elem.CanAddr() { + elemKey = elem.Addr().Interface() + } + + if _, ok := loaded[elemKey]; ok { + continue + } + loaded[elemKey] = true + + fieldValues := make([]interface{}, len(fields)) + notZero = false + for idx, field := range fields { + fieldValues[idx], zero = field.ValueOf(ctx, elem) + notZero = notZero || !zero + } + + if notZero { + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + results = append(results, fieldValues) + dataResults[dataKey] = []reflect.Value{elem} + } else { + dataResults[dataKey] = append(dataResults[dataKey], elem) + } + } + } + } + + return dataResults, results +} + +// GetIdentityFieldValuesMapFromValues get identity map from fields +func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + resultsMap := map[string][]reflect.Value{} + results := [][]interface{}{} + + for _, v := range values { + rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) + for k, v := range rm { + resultsMap[k] = append(resultsMap[k], v...) + } + results = append(results, rs...) + } + return resultsMap, results +} + +// ToQueryValues to query values +func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + + return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues + } + + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + + for idx, r := range foreignValues { + queryValues[idx] = r + } + + return columns, queryValues +} + +type embeddedNamer struct { + Table string + Namer +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/soft_delete.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/soft_delete.go new file mode 100644 index 000000000..5673d3b85 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/soft_delete.go @@ -0,0 +1,170 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + + "github.com/jinzhu/now" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +type DeletedAt sql.NullTime + +// Scan implements the Scanner interface. +func (n *DeletedAt) Scan(value interface{}) error { + return (*sql.NullTime)(n).Scan(value) +} + +// Value implements the driver Valuer interface. +func (n DeletedAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +func (n DeletedAt) MarshalJSON() ([]byte, error) { + if n.Valid { + return json.Marshal(n.Time) + } + return json.Marshal(nil) +} + +func (n *DeletedAt) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.Valid = false + return nil + } + err := json.Unmarshal(b, &n.Time) + if err == nil { + n.Valid = true + } + return err +} + +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +func parseZeroValueTag(f *schema.Field) sql.NullString { + if v, ok := f.TagSettings["ZEROVALUE"]; ok { + if _, err := now.Parse(v); err == nil { + return sql.NullString{String: v, Valid: true} + } + } + return sql.NullString{Valid: false} +} + +type SoftDeleteQueryClause struct { + ZeroValue sql.NullString + Field *schema.Field +} + +func (sd SoftDeleteQueryClause) Name() string { + return "" +} + +func (sd SoftDeleteQueryClause) Build(clause.Builder) { +} + +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { + if c, ok := stmt.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 { + for _, expr := range where.Exprs { + if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { + where.Exprs = []clause.Expression{clause.And(where.Exprs...)} + c.Expression = where + stmt.Clauses["WHERE"] = c + break + } + } + } + } + + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +type SoftDeleteUpdateClause struct { + ZeroValue sql.NullString + Field *schema.Field +} + +func (sd SoftDeleteUpdateClause) Name() string { + return "" +} + +func (sd SoftDeleteUpdateClause) Build(clause.Builder) { +} + +func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { + SoftDeleteQueryClause(sd).ModifyStatement(stmt) + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +type SoftDeleteDeleteClause struct { + ZeroValue sql.NullString + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { + curTime := stmt.DB.NowFunc() + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.SetColumn(sd.Field.DBName, curTime, true) + + if stmt.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + SoftDeleteQueryClause(sd).ModifyStatement(stmt) + stmt.AddClauseIfNotExists(clause.Update{}) + stmt.Build(stmt.DB.Callback().Update().Clauses...) + } +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go new file mode 100644 index 000000000..59c0b772c --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go @@ -0,0 +1,728 @@ +package gorm + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Statement statement +type Statement struct { + *DB + TableExpr *clause.Expr + Table string + Model interface{} + Unscoped bool + Dest interface{} + ReflectValue reflect.Value + Clauses map[string]clause.Clause + BuildClauses []string + Distinct bool + Selects []string // selected columns + Omits []string // omit columns + Joins []join + Preloads map[string][]interface{} + Settings sync.Map + ConnPool ConnPool + Schema *schema.Schema + Context context.Context + RaiseErrorOnNotFound bool + SkipHooks bool + SQL strings.Builder + Vars []interface{} + CurDestIndex int + attrs []interface{} + assigns []interface{} + scopes []func(*DB) *DB +} + +type join struct { + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + JoinType clause.JoinType +} + +// StatementModifier statement modifier interface +type StatementModifier interface { + ModifyStatement(*Statement) +} + +// WriteString write string +func (stmt *Statement) WriteString(str string) (int, error) { + return stmt.SQL.WriteString(str) +} + +// WriteByte write byte +func (stmt *Statement) WriteByte(c byte) error { + return stmt.SQL.WriteByte(c) +} + +// WriteQuoted write quoted value +func (stmt *Statement) WriteQuoted(value interface{}) { + stmt.QuoteTo(&stmt.SQL, value) +} + +// QuoteTo write quoted value to writer +func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { + write := func(raw bool, str string) { + if raw { + writer.WriteString(str) + } else { + stmt.DB.Dialector.QuoteTo(writer, str) + } + } + + switch v := field.(type) { + case clause.Table: + if v.Name == clause.CurrentTable { + if stmt.TableExpr != nil { + stmt.TableExpr.Build(stmt) + } else { + write(v.Raw, stmt.Table) + } + } else { + write(v.Raw, v.Name) + } + + if v.Alias != "" { + writer.WriteByte(' ') + write(v.Raw, v.Alias) + } + case clause.Column: + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } + + if v.Name == clause.PrimaryKey { + if stmt.Schema == nil { + stmt.DB.AddError(ErrModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { + write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if len(stmt.Schema.DBNames) > 0 { + write(v.Raw, stmt.Schema.DBNames[0]) + } else { + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck + } + } else { + write(v.Raw, v.Name) + } + + if v.Alias != "" { + writer.WriteString(" AS ") + write(v.Raw, v.Alias) + } + case []clause.Column: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteByte(',') + } + stmt.QuoteTo(writer, d) + } + writer.WriteByte(')') + case clause.Expr: + v.Build(stmt) + case string: + stmt.DB.Dialector.QuoteTo(writer, v) + case []string: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteByte(',') + } + stmt.DB.Dialector.QuoteTo(writer, d) + } + writer.WriteByte(')') + default: + stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) + } +} + +// Quote returns quoted value +func (stmt *Statement) Quote(field interface{}) string { + var builder strings.Builder + stmt.QuoteTo(&builder, field) + return builder.String() +} + +// AddVar add var +func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { + for idx, v := range vars { + if idx > 0 { + writer.WriteByte(',') + } + + switch v := v.(type) { + case sql.NamedArg: + stmt.Vars = append(stmt.Vars, v.Value) + case clause.Column, clause.Table: + stmt.QuoteTo(writer, v) + case Valuer: + reflectValue := reflect.ValueOf(v) + if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() { + stmt.AddVar(writer, nil) + } else { + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) + } + case clause.Interface: + c := clause.Clause{Name: v.Name()} + v.MergeClause(&c) + c.Build(stmt) + case clause.Expression: + v.Build(stmt) + case driver.Valuer: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []byte: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []interface{}: + if len(v) > 0 { + writer.WriteByte('(') + stmt.AddVar(writer, v...) + writer.WriteByte(')') + } else { + writer.WriteString("(NULL)") + } + case *DB: + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + if v.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = v.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + + writer.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars + default: + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + writer.WriteString("(NULL)") + } else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) { + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + } else { + writer.WriteByte('(') + for i := 0; i < rv.Len(); i++ { + if i > 0 { + writer.WriteByte(',') + } + stmt.AddVar(writer, rv.Index(i).Interface()) + } + writer.WriteByte(')') + } + default: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + } + } + } +} + +// AddClause add clause +func (stmt *Statement) AddClause(v clause.Interface) { + if optimizer, ok := v.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } else { + name := v.Name() + c := stmt.Clauses[name] + c.Name = name + v.MergeClause(&c) + stmt.Clauses[name] = c + } +} + +// AddClauseIfNotExists add clause if not exists +func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { + if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { + stmt.AddClause(v) + } +} + +// BuildCondition build condition +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { + if s, ok := query.(string); ok { + // if it is a number, then treats it as primary key + if _, err := strconv.Atoi(s); err != nil { + if s == "" && len(args) == 0 { + return nil + } + + if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } + + if len(args) > 0 && strings.Contains(s, "@") { + // looks like a named query + return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} + } + + if strings.Contains(strings.TrimSpace(s), " ") { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } + + if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} + } + } + } + + conds := make([]clause.Expression, 0, 4) + args = append([]interface{}{query}, args...) + for idx, arg := range args { + if arg == nil { + continue + } + if valuer, ok := arg.(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + switch v := arg.(type) { + case clause.Expression: + conds = append(conds, v) + case *DB: + v.executeScopes() + + if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + if where, ok := cs.Expression.(clause.Where); ok { + if len(where.Exprs) == 1 { + if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { + where.Exprs[0] = clause.AndConditions(orConds) + } + } + conds = append(conds, clause.And(where.Exprs...)) + } else { + conds = append(conds, cs.Expression) + } + if v.Statement == stmt { + cs.Expression = nil + stmt.Statement.Clauses["WHERE"] = cs + } + } + case map[interface{}]interface{}: + for i, j := range v { + conds = append(conds, clause.Eq{Column: i, Value: j}) + } + case map[string]string: + keys := make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + case map[string]interface{}: + keys := make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { + values[i] = reflectValue.Index(i).Interface() + } + + conds = append(conds, clause.IN{Column: key, Values: values}) + } + default: + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + } + } + + if restricted { + break + } + } else if !reflectValue.IsValid() { + stmt.AddError(ErrInvalidData) + } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return conds + } + } + + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + } + } + } + + return conds +} + +// Build build sql with clauses names +func (stmt *Statement) Build(clauses ...string) { + var firstClauseWritten bool + + for _, name := range clauses { + if c, ok := stmt.Clauses[name]; ok { + if firstClauseWritten { + stmt.WriteByte(' ') + } + + firstClauseWritten = true + if b, ok := stmt.DB.ClauseBuilders[name]; ok { + b(c, stmt) + } else { + c.Build(stmt) + } + } + } +} + +func (stmt *Statement) Parse(value interface{}) (err error) { + return stmt.ParseWithSpecialTableName(value, "") +} + +func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { + if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { + if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { + stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} + stmt.Table = tables[1] + return + } + + stmt.Table = stmt.Schema.Table + } + return err +} + +func (stmt *Statement) clone() *Statement { + newStmt := &Statement{ + TableExpr: stmt.TableExpr, + Table: stmt.Table, + Model: stmt.Model, + Unscoped: stmt.Unscoped, + Dest: stmt.Dest, + ReflectValue: stmt.ReflectValue, + Clauses: map[string]clause.Clause{}, + Distinct: stmt.Distinct, + Selects: stmt.Selects, + Omits: stmt.Omits, + Preloads: map[string][]interface{}{}, + ConnPool: stmt.ConnPool, + Schema: stmt.Schema, + Context: stmt.Context, + RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, + SkipHooks: stmt.SkipHooks, + } + + if stmt.SQL.Len() > 0 { + newStmt.SQL.WriteString(stmt.SQL.String()) + newStmt.Vars = make([]interface{}, 0, len(stmt.Vars)) + newStmt.Vars = append(newStmt.Vars, stmt.Vars...) + } + + for k, c := range stmt.Clauses { + newStmt.Clauses[k] = c + } + + for k, p := range stmt.Preloads { + newStmt.Preloads[k] = p + } + + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) + } + + if len(stmt.scopes) > 0 { + newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes)) + copy(newStmt.scopes, stmt.scopes) + } + + stmt.Settings.Range(func(k, v interface{}) bool { + newStmt.Settings.Store(k, v) + return true + }) + + return newStmt +} + +// SetColumn set column's value +// +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + v[name] = value + } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { + for _, m := range v { + m[name] = value + } + } else if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + if stmt.ReflectValue != destValue { + if !destValue.CanAddr() { + destValueCanAddr := reflect.New(destValue.Type()) + destValueCanAddr.Elem().Set(destValue) + stmt.Dest = destValueCanAddr.Interface() + destValue = destValueCanAddr.Elem() + } + + switch destValue.Kind() { + case reflect.Struct: + stmt.AddError(field.Set(stmt.Context, destValue, value)) + default: + stmt.AddError(ErrInvalidData) + } + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(fromCallbacks) > 0 { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) + } + } else { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) + } + case reflect.Struct: + if !stmt.ReflectValue.CanAddr() { + stmt.AddError(ErrInvalidValue) + return + } + + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) + } + } else { + stmt.AddError(ErrInvalidField) + } + } else { + stmt.AddError(ErrInvalidField) + } +} + +// Changed check model changed or not when updating +func (stmt *Statement) Changed(fields ...string) bool { + modelValue := stmt.ReflectValue + switch modelValue.Kind() { + case reflect.Slice, reflect.Array: + modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) + } + + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) + changed := func(field *schema.Field) bool { + fieldValue, _ := field.ValueOf(stmt.Context, modelValue) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if mv, mok := stmt.Dest.(map[string]interface{}); mok { + if fv, ok := mv[field.Name]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if fv, ok := mv[field.DBName]; ok { + return !utils.AssertEqual(fv, fieldValue) + } + } else { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + changedValue, zero := field.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } + } + return false + } + + if len(fields) == 0 { + for _, field := range stmt.Schema.FieldsByDBName { + if changed(field) { + return true + } + } + } else { + for _, name := range fields { + if field := stmt.Schema.LookUpField(name); field != nil { + if changed(field) { + return true + } + } + } + } + + return false +} + +var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { + results := map[string]bool{} + notRestricted := false + + processColumn := func(column string, result bool) { + if stmt.Schema == nil { + results[column] = result + } else if column == "*" { + notRestricted = result + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = result + } + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = result + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + results[field.DBName] = result + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { + if matches[2] == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = result + } + } else { + results[matches[2]] = result + } + } else { + results[column] = result + } + } + + // select columns + for _, column := range stmt.Selects { + processColumn(column, true) + } + + // omit columns + for _, column := range stmt.Omits { + processColumn(column, false) + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.FieldsByName { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results, !notRestricted && len(stmt.Selects) > 0 +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go new file mode 100644 index 000000000..ddbca60a8 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go @@ -0,0 +1,150 @@ +package utils + +import ( + "database/sql/driver" + "fmt" + "path/filepath" + "reflect" + "runtime" + "strconv" + "strings" + "unicode" +) + +var gormSourceDir string + +func init() { + _, file, _, _ := runtime.Caller(0) + // compatible solution to get gorm source directory with various operating systems + gormSourceDir = sourceDir(file) +} + +func sourceDir(file string) string { + dir := filepath.Dir(file) + dir = filepath.Dir(dir) + + s := filepath.Dir(dir) + if filepath.Base(s) != "gorm.io" { + s = dir + } + return filepath.ToSlash(s) + "/" +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + // the second caller usually from gorm internal, so set i start from 2 + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + return file + ":" + strconv.FormatInt(int64(line), 10) + } + } + + return "" +} + +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' +} + +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if val != "" && !strings.EqualFold(val, "false") { + return true + } + } + return false +} + +func ToStringKey(values ...interface{}) string { + results := make([]string, len(values)) + + for idx, value := range values { + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() + } + + switch v := value.(type) { + case string: + results[idx] = v + case []byte: + results[idx] = string(v) + case uint: + results[idx] = strconv.FormatUint(uint64(v), 10) + default: + results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + } + } + + return strings.Join(results, "_") +} + +func Contains(elems []string, elem string) bool { + for _, e := range elems { + if elem == e { + return true + } + } + return false +} + +func AssertEqual(src, dst interface{}) bool { + if !reflect.DeepEqual(src, dst) { + if valuer, ok := src.(driver.Valuer); ok { + src, _ = valuer.Value() + } + + if valuer, ok := dst.(driver.Valuer); ok { + dst, _ = valuer.Value() + } + + return reflect.DeepEqual(src, dst) + } + return true +} + +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case int: + return strconv.FormatInt(int64(v), 10) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + } + return "" +} + +const nestedRelationSplit = "__" + +// NestedRelationName nested relationships like `Manager__Company` +func NestedRelationName(prefix, name string) string { + return prefix + nestedRelationSplit + name +} + +// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` +func SplitNestedRelationName(name string) []string { + return strings.Split(name, nestedRelationSplit) +} + +// JoinNestedRelationNames nested relationships like `Manager__Company` +func JoinNestedRelationNames(relationNames []string) string { + return strings.Join(relationNames, nestedRelationSplit) +} From 2349bb765e75b88ec78865594c1cf6c97a82809a Mon Sep 17 00:00:00 2001 From: I539231 Date: Wed, 3 Apr 2024 23:21:25 +0200 Subject: [PATCH 2/2] feat: Update the vendors --- src/code.cloudfoundry.org/go.mod | 8 +- src/code.cloudfoundry.org/go.sum | 31 +- .../jackc/pgservicefile/.travis.yml | 9 - .../pgx/v5/internal/nbconn/bufferqueue.go | 70 --- .../jackc/pgx/v5/internal/nbconn/nbconn.go | 520 ------------------ .../internal/nbconn/nbconn_fake_non_block.go | 11 - .../internal/nbconn/nbconn_real_non_block.go | 81 --- .../vendor/gorm.io/driver/mysql/README.md | 1 + .../gorm.io/driver/mysql/error_translator.go | 12 +- .../vendor/gorm.io/driver/mysql/migrator.go | 148 ++++- .../vendor/gorm.io/driver/mysql/mysql.go | 11 +- .../driver/postgres/error_translator.go | 24 +- .../gorm.io/driver/postgres/migrator.go | 147 ++--- .../gorm.io/driver/postgres/postgres.go | 48 +- .../vendor/gorm.io/gorm/README.md | 2 +- .../vendor/gorm.io/gorm/callbacks.go | 19 + .../vendor/gorm.io/gorm/callbacks/create.go | 95 +++- .../vendor/gorm.io/gorm/callbacks/preload.go | 87 ++- .../vendor/gorm.io/gorm/callbacks/query.go | 35 +- .../vendor/gorm.io/gorm/callbacks/update.go | 4 +- .../vendor/gorm.io/gorm/chainable_api.go | 27 +- .../vendor/gorm.io/gorm/clause/expression.go | 2 +- .../vendor/gorm.io/gorm/clause/limit.go | 6 +- .../vendor/gorm.io/gorm/clause/locking.go | 7 + .../vendor/gorm.io/gorm/clause/where.go | 74 ++- .../vendor/gorm.io/gorm/finisher_api.go | 8 +- .../vendor/gorm.io/gorm/gorm.go | 11 +- .../vendor/gorm.io/gorm/interfaces.go | 6 - .../vendor/gorm.io/gorm/logger/logger.go | 22 +- .../vendor/gorm.io/gorm/logger/sql.go | 29 +- .../vendor/gorm.io/gorm/migrator.go | 2 + .../vendor/gorm.io/gorm/migrator/migrator.go | 181 +++--- .../vendor/gorm.io/gorm/prepare_stmt.go | 33 +- .../vendor/gorm.io/gorm/scan.go | 16 +- .../vendor/gorm.io/gorm/schema/check.go | 35 -- .../vendor/gorm.io/gorm/schema/constraint.go | 66 +++ .../vendor/gorm.io/gorm/schema/field.go | 16 +- .../vendor/gorm.io/gorm/schema/index.go | 6 +- .../vendor/gorm.io/gorm/schema/interfaces.go | 6 + .../vendor/gorm.io/gorm/schema/naming.go | 8 + .../gorm.io/gorm/schema/relationship.go | 93 +++- .../vendor/gorm.io/gorm/schema/schema.go | 63 ++- .../vendor/gorm.io/gorm/schema/serializer.go | 4 +- .../vendor/gorm.io/gorm/schema/utils.go | 5 + .../vendor/gorm.io/gorm/statement.go | 38 +- .../vendor/gorm.io/gorm/utils/utils.go | 38 +- src/code.cloudfoundry.org/vendor/modules.txt | 29 +- 47 files changed, 1060 insertions(+), 1134 deletions(-) delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go delete mode 100644 src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go delete mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go create mode 100644 src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/constraint.go diff --git a/src/code.cloudfoundry.org/go.mod b/src/code.cloudfoundry.org/go.mod index 05bbe158e..4e33071db 100644 --- a/src/code.cloudfoundry.org/go.mod +++ b/src/code.cloudfoundry.org/go.mod @@ -32,9 +32,8 @@ require ( github.com/codegangsta/cli v1.22.14 github.com/go-sql-driver/mysql v1.8.1 github.com/golang-jwt/jwt/v4 v4.5.0 - github.com/jinzhu/gorm v1.9.16 + github.com/jackc/pgx/v5 v5.5.5 github.com/kisielk/errcheck v1.7.0 - github.com/lib/pq v1.10.9 github.com/nats-io/nats-server/v2 v2.10.12 github.com/nats-io/nats.go v1.34.1 github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d @@ -55,6 +54,9 @@ require ( google.golang.org/grpc v1.62.1 google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v2 v2.4.0 + gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 + gorm.io/gorm v1.25.9 ) require ( @@ -80,9 +82,9 @@ require ( github.com/honeycombio/libhoney-go v1.22.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.17.7 // indirect github.com/minio/highwayhash v1.0.2 // indirect github.com/nats-io/jwt/v2 v2.5.5 // indirect diff --git a/src/code.cloudfoundry.org/go.sum b/src/code.cloudfoundry.org/go.sum index 88633429d..5651f3a31 100644 --- a/src/code.cloudfoundry.org/go.sum +++ b/src/code.cloudfoundry.org/go.sum @@ -635,13 +635,11 @@ github.com/DataDog/zstd v1.5.5 h1:oWf5W7GtOLgp6bciQYDmhHHjdhYkALu6S/5Ni9ZgSvQ= github.com/DataDog/zstd v1.5.5/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/v10 v10.0.1/go.mod h1:YvhnlEePVnBS4+0z3fhPfUy7W1Ikj0Ih0vcRo/gZ1M0= github.com/apache/arrow/go/v11 v11.0.0/go.mod h1:Eg5OsL5H+e299f7u5ssuXsuHQVEGC4xei5aX110hRiI= @@ -704,8 +702,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -725,8 +721,6 @@ github.com/envoyproxy/protoc-gen-validate v0.6.7/go.mod h1:dyJXwwfPK2VSqiB9Klm1J github.com/envoyproxy/protoc-gen-validate v0.9.1/go.mod h1:OKNgG7TCp5pF4d6XftA0++PMirau2/yoOwVac3AbF2w= github.com/envoyproxy/protoc-gen-validate v0.10.0/go.mod h1:DRjgyB0I43LtJapqN6NiRwroiAU2PaFuvk/vjgh61ss= github.com/envoyproxy/protoc-gen-validate v0.10.1/go.mod h1:DRjgyB0I43LtJapqN6NiRwroiAU2PaFuvk/vjgh61ss= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a h1:yDWHCSQ40h88yih2JAcL6Ls/kVkSE8GFACTGVnMPruw= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a/go.mod h1:7Ga40egUymuWXxAe151lTNnCv97MddSOVsjpPPkityA= github.com/facebookgo/ensure v0.0.0-20200202191622-63f1cf65ac4c h1:8ISkoahWXwZR41ois5lSJBSVw4D0OV19Ht/JSTzvSv0= @@ -761,7 +755,7 @@ github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= @@ -775,8 +769,6 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= @@ -899,12 +891,10 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jinzhu/gorm v1.9.16 h1:+IyIjPEABKRpsu/F8OvDPy9fyQlgsg2luMV2ZIH5i5o= -github.com/jinzhu/gorm v1.9.16/go.mod h1:G3LB3wezTOWM2ITLzPxEXgSkOXAntiLHS7UdBefADcs= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= -github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= @@ -931,9 +921,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/loggregator/go-bindata v0.0.0-20190422223605-5f11cfb2d7d9/go.mod h1:PvsJfK9t/8OdGvSanpYlwJ1EPoJ/hwT3c52txAzqooY= github.com/lyft/protoc-gen-star v0.6.0/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA= github.com/lyft/protoc-gen-star v0.6.1/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA= @@ -941,8 +928,6 @@ github.com/lyft/protoc-gen-star/v2 v2.0.1/go.mod h1:RcCdONR2ScXaYnQC5tUzxzlpA3WV github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= -github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= @@ -1115,12 +1100,10 @@ go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -1190,7 +1173,6 @@ golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1808,6 +1790,13 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8= +gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml deleted file mode 100644 index e176228e8..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgservicefile/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: go - -go: - - 1.x - - tip - -matrix: - allow_failures: - - go: tip diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go deleted file mode 100644 index 4bf25481c..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go +++ /dev/null @@ -1,70 +0,0 @@ -package nbconn - -import ( - "sync" -) - -const minBufferQueueLen = 8 - -type bufferQueue struct { - lock sync.Mutex - queue []*[]byte - r, w int -} - -func (bq *bufferQueue) pushBack(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - bq.queue[bq.w] = buf - bq.w++ -} - -func (bq *bufferQueue) pushFront(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) - bq.queue[bq.r] = buf - bq.w++ -} - -func (bq *bufferQueue) popFront() *[]byte { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.r == bq.w { - return nil - } - - buf := bq.queue[bq.r] - bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. - bq.r++ - - if bq.r == bq.w { - bq.r = 0 - bq.w = 0 - if len(bq.queue) > minBufferQueueLen { - bq.queue = make([]*[]byte, minBufferQueueLen) - } - } - - return buf -} - -func (bq *bufferQueue) growQueue() { - desiredLen := (len(bq.queue) + 1) * 3 / 2 - if desiredLen < minBufferQueueLen { - desiredLen = minBufferQueueLen - } - - newQueue := make([]*[]byte, desiredLen) - copy(newQueue, bq.queue) - bq.queue = newQueue -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go deleted file mode 100644 index 7a38383f0..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go +++ /dev/null @@ -1,520 +0,0 @@ -// Package nbconn implements a non-blocking net.Conn wrapper. -// -// It is designed to solve three problems. -// -// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all -// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. -// -// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. -// -// The third is to efficiently check if a connection has been closed via a non-blocking read. -package nbconn - -import ( - "crypto/tls" - "errors" - "net" - "os" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/jackc/pgx/v5/internal/iobufpool" -) - -var errClosed = errors.New("closed") -var ErrWouldBlock = new(wouldBlockError) - -const fakeNonblockingWriteWaitDuration = 100 * time.Millisecond -const minNonblockingReadWaitDuration = time.Microsecond -const maxNonblockingReadWaitDuration = 100 * time.Millisecond - -// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read -// mode. -var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) - -// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to -// ignore all future calls. -var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) - -// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. -type wouldBlockError struct{} - -func (*wouldBlockError) Error() string { - return "would block" -} - -func (*wouldBlockError) Timeout() bool { return true } -func (*wouldBlockError) Temporary() bool { return true } - -// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to -// the underlying connection. -type Conn interface { - net.Conn - - // Flush flushes any buffered writes. - Flush() error - - // BufferReadUntilBlock reads and buffers any successfully read bytes until the read would block. - BufferReadUntilBlock() error -} - -// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. -type NetConn struct { - // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit - // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and - // https://github.com/jackc/pgx/issues/1307. Only access with atomics - closed int64 // 0 = not closed, 1 = closed - - conn net.Conn - rawConn syscall.RawConn - - readQueue bufferQueue - writeQueue bufferQueue - - readFlushLock sync.Mutex - // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockWriteFunc func(fd uintptr) (done bool) - nonblockWriteBuf []byte - nonblockWriteErr error - nonblockWriteN int - - // non-blocking reads with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockReadFunc func(fd uintptr) (done bool) - nonblockReadBuf []byte - nonblockReadErr error - nonblockReadN int - - readDeadlineLock sync.Mutex - readDeadline time.Time - readNonblocking bool - fakeNonBlockingShortReadCount int - fakeNonblockingReadWaitDuration time.Duration - - writeDeadlineLock sync.Mutex - writeDeadline time.Time -} - -func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { - nc := &NetConn{ - conn: conn, - fakeNonblockingReadWaitDuration: maxNonblockingReadWaitDuration, - } - - if !fakeNonBlockingIO { - if sc, ok := conn.(syscall.Conn); ok { - if rawConn, err := sc.SyscallConn(); err == nil { - nc.rawConn = rawConn - } - } - } - - return nc -} - -// Read implements io.Reader. -func (c *NetConn) Read(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - - err = c.flush() - if err != nil { - return 0, err - } - - for n < len(b) { - buf := c.readQueue.popFront() - if buf == nil { - break - } - copiedN := copy(b[n:], *buf) - if copiedN < len(*buf) { - *buf = (*buf)[copiedN:] - c.readQueue.pushFront(buf) - } else { - iobufpool.Put(buf) - } - n += copiedN - } - - // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to - // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. - if n > 0 { - return n, nil - } - - var readNonblocking bool - c.readDeadlineLock.Lock() - readNonblocking = c.readNonblocking - c.readDeadlineLock.Unlock() - - var readN int - if readNonblocking { - readN, err = c.nonblockingRead(b[n:]) - } else { - readN, err = c.conn.Read(b[n:]) - } - n += readN - return n, err -} - -// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is -// closed. Call Flush to actually write to the underlying connection. -func (c *NetConn) Write(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - buf := iobufpool.Get(len(b)) - copy(*buf, b) - c.writeQueue.pushBack(buf) - return len(b), nil -} - -func (c *NetConn) Close() (err error) { - swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) - if !swapped { - return errClosed - } - - defer func() { - closeErr := c.conn.Close() - if err == nil { - err = closeErr - } - }() - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - err = c.flush() - if err != nil { - return err - } - - return nil -} - -func (c *NetConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *NetConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). -func (c *NetConn) SetDeadline(t time.Time) error { - err := c.SetReadDeadline(t) - if err != nil { - return err - } - return c.SetWriteDeadline(t) -} - -// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. -func (c *NetConn) SetReadDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - if c.readDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.readDeadline = t - return nil - } - - if t == NonBlockingDeadline { - c.readNonblocking = true - t = time.Time{} - } else { - c.readNonblocking = false - } - - c.readDeadline = t - - return c.conn.SetReadDeadline(t) -} - -func (c *NetConn) SetWriteDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - if c.writeDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.writeDeadline = t - return nil - } - - c.writeDeadline = t - - return c.conn.SetWriteDeadline(t) -} - -func (c *NetConn) Flush() error { - if c.isClosed() { - return errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - return c.flush() -} - -// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. -func (c *NetConn) flush() error { - var stopChan chan struct{} - var errChan chan error - - defer func() { - if stopChan != nil { - select { - case stopChan <- struct{}{}: - case <-errChan: - } - } - }() - - for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { - remainingBuf := *buf - for len(remainingBuf) > 0 { - n, err := c.nonblockingWrite(remainingBuf) - remainingBuf = remainingBuf[n:] - if err != nil { - if !errors.Is(err, ErrWouldBlock) { - *buf = (*buf)[:len(remainingBuf)] - copy(*buf, remainingBuf) - c.writeQueue.pushFront(buf) - return err - } - - // Writing was blocked. Reading might unblock it. - if stopChan == nil { - stopChan, errChan = c.bufferNonblockingRead() - } - - select { - case err := <-errChan: - stopChan = nil - return err - default: - } - - } - } - iobufpool.Put(buf) - } - - return nil -} - -func (c *NetConn) BufferReadUntilBlock() error { - for { - buf := iobufpool.Get(8 * 1024) - n, err := c.nonblockingRead(*buf) - if n > 0 { - *buf = (*buf)[:n] - c.readQueue.pushBack(buf) - } else if n == 0 { - iobufpool.Put(buf) - } - - if err != nil { - if errors.Is(err, ErrWouldBlock) { - return nil - } else { - return err - } - } - } -} - -func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { - stopChan = make(chan struct{}) - errChan = make(chan error, 1) - - go func() { - for { - err := c.BufferReadUntilBlock() - if err != nil { - errChan <- err - return - } - - select { - case <-stopChan: - return - default: - } - } - }() - - return stopChan, errChan -} - -func (c *NetConn) isClosed() bool { - closed := atomic.LoadInt64(&c.closed) - return closed == 1 -} - -func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingWrite(b) - } else { - return c.realNonblockingWrite(b) - } -} - -func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - - deadline := time.Now().Add(fakeNonblockingWriteWaitDuration) - if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { - err = c.conn.SetWriteDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetWriteDeadline(c.writeDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Write(b) -} - -func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingRead(b) - } else { - return c.realNonblockingRead(b) - } -} - -func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - - // The first 5 reads only read 1 byte at a time. This should give us 4 chances to read when we are sure the bytes are - // already in Go or the OS's receive buffer. - if c.fakeNonBlockingShortReadCount < 5 && len(b) > 0 && c.fakeNonblockingReadWaitDuration < minNonblockingReadWaitDuration { - b = b[:1] - } - - startTime := time.Now() - deadline := startTime.Add(c.fakeNonblockingReadWaitDuration) - if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { - err = c.conn.SetReadDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // If the read was successful and the wait duration is not already the minimum - if err == nil && c.fakeNonblockingReadWaitDuration > minNonblockingReadWaitDuration { - endTime := time.Now() - - if n > 0 && c.fakeNonBlockingShortReadCount < 5 { - c.fakeNonBlockingShortReadCount++ - } - - // The wait duration should be 2x the fastest read that has occurred. This should give reasonable assurance that - // a Read deadline will not block a read before it has a chance to read data already in Go or the OS's receive - // buffer. - proposedWait := endTime.Sub(startTime) * 2 - if proposedWait < minNonblockingReadWaitDuration { - proposedWait = minNonblockingReadWaitDuration - } - if proposedWait < c.fakeNonblockingReadWaitDuration { - c.fakeNonblockingReadWaitDuration = proposedWait - } - } - - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetReadDeadline(c.readDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Read(b) -} - -// syscall.Conn is interface - -// TLSClient establishes a TLS connection as a client over conn using config. -// -// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby -// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the -// *TLSConn is returned. -func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { - tc := tls.Client(conn, config) - err := tc.Handshake() - if err != nil { - return nil, err - } - - // Ensure last written part of Handshake is actually sent. - err = conn.Flush() - if err != nil { - return nil, err - } - - return &TLSConn{ - tlsConn: tc, - nbConn: conn, - }, nil -} - -// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a -// tls.Conn. -type TLSConn struct { - tlsConn *tls.Conn - nbConn *NetConn -} - -func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } -func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } -func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } -func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } -func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } -func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } - -func (tc *TLSConn) Close() error { - // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then - // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our - // own 5 second deadline then make all set deadlines no-op. - tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) - tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) - - return tc.tlsConn.Close() -} - -func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } -func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } -func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go deleted file mode 100644 index 4915c6219..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !unix - -package nbconn - -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - return c.fakeNonblockingWrite(b) -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - return c.fakeNonblockingRead(b) -} diff --git a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go b/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go deleted file mode 100644 index e93372f25..000000000 --- a/src/code.cloudfoundry.org/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go +++ /dev/null @@ -1,81 +0,0 @@ -//go:build unix - -package nbconn - -import ( - "errors" - "io" - "syscall" -) - -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - if c.nonblockWriteFunc == nil { - c.nonblockWriteFunc = func(fd uintptr) (done bool) { - c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) - return true - } - } - c.nonblockWriteBuf = b - c.nonblockWriteN = 0 - c.nonblockWriteErr = nil - - err = c.rawConn.Write(c.nonblockWriteFunc) - n = c.nonblockWriteN - c.nonblockWriteBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockWriteErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - return n, nil -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - if c.nonblockReadFunc == nil { - c.nonblockReadFunc = func(fd uintptr) (done bool) { - c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf) - return true - } - } - c.nonblockReadBuf = b - c.nonblockReadN = 0 - c.nonblockReadErr = nil - - err = c.rawConn.Read(c.nonblockReadFunc) - n = c.nonblockReadN - c.nonblockReadBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockReadErr != nil { - if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockReadErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - // syscall read did not return an error and 0 bytes were read means EOF. - if n == 0 { - return 0, io.EOF - } - - return n, nil -} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md index b8f7a6c97..9d75d5c12 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/README.md @@ -40,6 +40,7 @@ db, err := gorm.Open(mysql.New(mysql.Config{ import ( _ "example.com/my_mysql_driver" "gorm.io/gorm" + "gorm.io/driver/mysql" ) db, err := gorm.Open(mysql.New(mysql.Config{ diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go index 79f6646e5..44ab27753 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/error_translator.go @@ -6,15 +6,19 @@ import ( "gorm.io/gorm" ) -var errCodes = map[string]uint16{ - "uniqueConstraint": 1062, +// The error codes to map mysql errors to gorm errors, here is the mysql error codes reference https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html. +var errCodes = map[uint16]error{ + 1062: gorm.ErrDuplicatedKey, + 1451: gorm.ErrForeignKeyViolated, + 1452: gorm.ErrForeignKeyViolated, } func (dialector Dialector) Translate(err error) error { if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[mysqlErr.Number]; found { + return translatedErr } + return mysqlErr } return err diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go index d35a86e14..fcdb63c5a 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/migrator.go @@ -47,6 +47,103 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr { return expr } +// MigrateColumnUnique migrate column's UNIQUE constraint. +// In MySQL, ColumnType's Unique is affected by UniqueIndex, so we have to take care of the UniqueIndex. +func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + unique, ok := columnType.Unique() + if !ok || field.PrimaryKey { + return nil // skip primary key + } + + queryTx, execTx := m.GetQueryAndExecTx() + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // We're currently only receiving boolean values on `Unique` tag, + // so the UniqueConstraint name is fixed + constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + if unique { + // Clean up redundant unique indexes + indexes, _ := queryTx.Migrator().GetIndexes(value) + for _, index := range indexes { + if uni, ok := index.Unique(); !ok || !uni { + continue + } + if columns := index.Columns(); len(columns) != 1 || columns[0] != field.DBName { + continue + } + if name := index.Name(); name == constraint || name == field.UniqueIndex { + continue + } + if err := execTx.Migrator().DropIndex(value, index.Name()); err != nil { + return err + } + } + + hasConstraint := queryTx.Migrator().HasConstraint(value, constraint) + switch { + case field.Unique && !hasConstraint: + if field.Unique { + if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil { + return err + } + } + // field isn't Unique but ColumnType's Unique is reported by UniqueConstraint. + case !field.Unique && hasConstraint: + if err := execTx.Migrator().DropConstraint(value, constraint); err != nil { + return err + } + if field.UniqueIndex != "" { + if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil { + return err + } + } + } + + if field.UniqueIndex != "" && !queryTx.Migrator().HasIndex(value, field.UniqueIndex) { + if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil { + return err + } + } + } else { + if field.Unique { + if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil { + return err + } + } + if field.UniqueIndex != "" { + if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil { + return err + } + } + } + return nil + }) +} + +func (m Migrator) AddColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // avoid using the same name field + f := stmt.Schema.LookUpField(name) + if f == nil { + return fmt.Errorf("failed to look up field with name: %s", name) + } + + if !f.IgnoreMigration { + fieldType := m.FullDataTypeOf(f) + columnName := clause.Column{Name: f.DBName} + values := []interface{}{m.CurrentTable(stmt), columnName, fieldType} + var alterSql strings.Builder + alterSql.WriteString("ALTER TABLE ? ADD ? ?") + if f.PrimaryKey || strings.Contains(strings.ToLower(fieldType.SQL), "auto_increment") { + alterSql.WriteString(", ADD PRIMARY KEY (?)") + values = append(values, columnName) + } + return m.DB.Exec(alterSql.String(), values...).Error + } + + return nil + }) +} + func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { @@ -83,12 +180,12 @@ func (m Migrator) TiDBVersion() (isTiDB bool, major, minor, patch int, err error } if minor, err = strconv.Atoi(realVersionArray[1]); err != nil { - err = fmt.Errorf("failed to parse the version of TiDB, the minor version is: %s", realVersionArray[0]) + err = fmt.Errorf("failed to parse the version of TiDB, the minor version is: %s", realVersionArray[1]) return } if patch, err = strconv.Atoi(realVersionArray[2]); err != nil { - err = fmt.Errorf("failed to parse the version of TiDB, the patch version is: %s", realVersionArray[0]) + err = fmt.Errorf("failed to parse the version of TiDB, the patch version is: %s", realVersionArray[2]) return } @@ -127,6 +224,29 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) DropConstraint(value interface{}, name string) error { + if !m.Dialector.Config.DontSupportDropConstraint { + return m.Migrator.DropConstraint(value, name) + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) + if constraint != nil { + name = constraint.GetName() + switch constraint.(type) { + case *schema.Constraint: + return m.DB.Exec("ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name}).Error + case *schema.CheckConstraint: + return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: table}, clause.Column{Name: name}).Error + } + } + if m.HasIndex(value, name) { + return m.DB.Exec("ALTER TABLE ? DROP INDEX ?", clause.Table{Name: table}, clause.Column{Name: name}).Error + } + return nil + }) +} + func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { if !m.Dialector.DontSupportRenameIndex { return m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -184,22 +304,6 @@ func (m Migrator) DropTable(values ...interface{}) error { }) } -func (m Migrator) DropConstraint(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - if chk != nil { - return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}).Error - } - if constraint != nil { - name = constraint.Name - } - - return m.DB.Exec( - "ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name}, - ).Error - }) -} - // ColumnTypes column types return columnTypes,error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) @@ -268,7 +372,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} } - column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'") + // only trim paired single-quotes + s := column.DefaultValueValue.String + for (len(s) >= 3 && s[0] == '\'' && s[len(s)-1] == '\'' && s[len(s)-2] != '\\') || + (len(s) == 2 && s == "''") { + s = s[1 : len(s)-1] + } + column.DefaultValueValue.String = s if m.Dialector.DontSupportNullAsDefaultValue { // rewrite mariadb default value like other version if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" { diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go index 68d02e857..bdef77bbf 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/mysql/mysql.go @@ -41,6 +41,10 @@ type Config struct { DontSupportForShareClause bool DontSupportNullAsDefaultValue bool DontSupportRenameColumnUnique bool + // As of MySQL 8.0.19, ALTER TABLE permits more general (and SQL standard) syntax + // for dropping and altering existing constraints of any type. + // see https://dev.mysql.com/doc/refman/8.0/en/alter-table.html + DontSupportDropConstraint bool } type Dialector struct { @@ -136,14 +140,17 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { dialector.Config.DontSupportRenameIndex = true dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportForShareClause = true + dialector.Config.DontSupportDropConstraint = true } else if strings.HasPrefix(dialector.ServerVersion, "5.7.") { dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportForShareClause = true + dialector.Config.DontSupportDropConstraint = true } else if strings.HasPrefix(dialector.ServerVersion, "5.") { dialector.Config.DisableDatetimePrecision = true dialector.Config.DontSupportRenameIndex = true dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportForShareClause = true + dialector.Config.DontSupportDropConstraint = true } if strings.Contains(dialector.ServerVersion, "TiDB") { @@ -160,8 +167,6 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } if !dialector.Config.DisableWithReturning && withReturning { - callbackConfig.LastInsertIDReversed = true - if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") { callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") } @@ -405,7 +410,7 @@ func (dialector Dialector) getSchemaStringType(field *schema.Field) string { } func (dialector Dialector) getSchemaTimeType(field *schema.Field) string { - if !dialector.DisableDatetimePrecision && field.Precision == 0 { + if !dialector.DisableDatetimePrecision && field.Precision == 0 && field.TagSettings["PRECISION"] == "" { field.Precision = *dialector.DefaultDatetimePrecision } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go index 285494c2d..9c0ef2534 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/error_translator.go @@ -2,26 +2,30 @@ package postgres import ( "encoding/json" - "github.com/jackc/pgx/v5/pgconn" + "gorm.io/gorm" + + "github.com/jackc/pgx/v5/pgconn" ) -var errCodes = map[string]string{ - "uniqueConstraint": "23505", +var errCodes = map[string]error{ + "23505": gorm.ErrDuplicatedKey, + "23503": gorm.ErrForeignKeyViolated, + "42703": gorm.ErrInvalidField, } type ErrMessage struct { - Code string `json:"Code"` - Severity string `json:"Severity"` - Message string `json:"Message"` + Code string + Severity string + Message string } // Translate it will translate the error to native gorm errors. // Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback. func (dialector Dialector) Translate(err error) error { if pgErr, ok := err.(*pgconn.PgError); ok { - if pgErr.Code == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[pgErr.Code]; found { + return translatedErr } return err } @@ -37,8 +41,8 @@ func (dialector Dialector) Translate(err error) error { return err } - if errMsg.Code == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[errMsg.Code]; found { + return translatedErr } return err } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go index e4d8e9260..6174e1c1b 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/migrator.go @@ -13,44 +13,61 @@ import ( "gorm.io/gorm/schema" ) +// See https://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql +// Here are some changes: +// - use `LEFT JOIN` instead of `CROSS JOIN` +// - exclude indexes used to support constraints (they are auto-generated) const indexSql = ` -select - t.relname as table_name, - i.relname as index_name, - a.attname as column_name, - ix.indisunique as non_unique, - ix.indisprimary as primary -from - pg_class t, - pg_class i, - pg_index ix, - pg_attribute a -where - t.oid = ix.indrelid - and i.oid = ix.indexrelid - and a.attrelid = t.oid - and a.attnum = ANY(ix.indkey) - and t.relkind = 'r' - and t.relname = ? +SELECT + ct.relname AS table_name, + ci.relname AS index_name, + i.indisunique AS non_unique, + i.indisprimary AS primary, + a.attname AS column_name +FROM + pg_index i + LEFT JOIN pg_class ct ON ct.oid = i.indrelid + LEFT JOIN pg_class ci ON ci.oid = i.indexrelid + LEFT JOIN pg_attribute a ON a.attrelid = ct.oid + LEFT JOIN pg_constraint con ON con.conindid = i.indexrelid +WHERE + a.attnum = ANY(i.indkey) + AND con.oid IS NULL + AND ct.relkind = 'r' + AND ct.relname = ? ` var typeAliasMap = map[string][]string{ - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp with time zone": {"timestamptz"}, + "bool": {"boolean"}, + "boolean": {"bool"}, } type Migrator struct { migrator.Migrator } +// select querys ignore dryrun +func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) { + queryTx := m.DB + if m.DB.DryRun { + queryTx = m.DB.Session(&gorm.Session{}) + queryTx.DryRun = false + } + return queryTx.Raw(sql, values...) +} + func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) + m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name) return } @@ -82,7 +99,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { } } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, ).Scan(&count).Error }) @@ -150,7 +167,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) GetTables() (tableList []string, err error) { currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") - return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error + return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error } func (m Migrator) CreateTable(values ...interface{}) (err error) { @@ -160,7 +177,8 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) { for _, value := range m.ReorderModels(values, false) { if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, fieldName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[fieldName] if field.Comment != "" { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", @@ -183,7 +201,7 @@ func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error + return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error }) return count > 0 } @@ -235,7 +253,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentSchema, curTable, name, ).Scan(&count).Error @@ -260,7 +278,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" - m.DB.Raw(checkSQL, values...).Scan(&description) + m.queryRaw(checkSQL, values...).Scan(&description) comment := strings.Trim(field.Comment, "'") comment = strings.Trim(comment, `"`) @@ -326,8 +344,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?"+m.genUsingExpression(fileType.SQL, fieldColumnType.DatabaseTypeName()), - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil { + if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil { return err } } @@ -345,16 +362,6 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { } } - if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique { - idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)} - // Not a unique constraint but a unique index - if !m.HasIndex(stmt.Table, idxName.Name) { - if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil { - return err - } - } - } - if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue { if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { @@ -387,28 +394,39 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return nil } -func (m Migrator) genUsingExpression(targetType, sourceType string) string { - if targetType == "boolean" { - switch sourceType { +func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error { + alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?" + isUncastableDefaultValue := false + + if targetType.SQL == "boolean" { + switch existingColumn.DatabaseTypeName() { case "int2", "int8", "numeric": - return " USING ?::INT::?" + alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?" + } + isUncastableDefaultValue = true + } + + if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err } } - return " USING ?::?" + if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil { + return err + } + return nil } func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - currentSchema, curTable := m.CurrentSchema(stmt, table) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } + currentSchema, curTable := m.CurrentSchema(stmt, table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", currentSchema, curTable, name, ).Scan(&count).Error @@ -423,7 +441,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, var ( currentDatabase = m.DB.Migrator().CurrentDatabase() currentSchema, table = m.CurrentSchema(stmt, stmt.Table) - columns, err = m.DB.Raw( + columns, err = m.queryRaw( "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", currentDatabase, currentSchema, table).Rows() ) @@ -463,7 +481,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } if column.DefaultValueValue.Valid { - column.DefaultValueValue.String = regexp.MustCompile(`'?(.*)\b'?:+[\w\s]+$`).ReplaceAllString(column.DefaultValueValue.String, "$1") + column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String) } if datetimePrecision.Valid { @@ -497,7 +515,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, // check primary, unique field { - columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() + columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() if err != nil { return err } @@ -509,7 +527,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } columnTypeRows.Close() - columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() + columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() if err != nil { return err } @@ -536,7 +554,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, // check column type { - dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type + dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) WHERE a.attnum > 0 -- hide internal columns AND NOT a.attisdropped -- hide deleted columns @@ -694,7 +712,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { err := m.RunWithValue(value, func(stmt *gorm.Statement) error { result := make([]*Index, 0) - scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error + scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error if scanErr != nil { return scanErr } @@ -769,3 +787,8 @@ func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { m.resetPreparedStmts() return nil } + +func parseDefaultValueValue(defaultValue string) string { + value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1") + return strings.Trim(value, "'") +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go index dbeabf561..e865b0f85 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/driver/postgres/postgres.go @@ -3,11 +3,11 @@ package postgres import ( "database/sql" "fmt" - "github.com/jackc/pgx/v5" "regexp" "strconv" "strings" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -24,11 +24,17 @@ type Dialector struct { type Config struct { DriverName string DSN string + WithoutQuotingCheck bool PreferSimpleProtocol bool WithoutReturning bool Conn gorm.ConnPool } +var ( + timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + defaultIdentifierLength = 63 //maximum identifier length for postgres +) + func Open(dsn string) gorm.Dialector { return &Dialector{&Config{DSN: dsn}} } @@ -41,12 +47,33 @@ func (dialector Dialector) Name() string { return "postgres" } -var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") +func (dialector Dialector) Apply(config *gorm.Config) error { + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{ + IdentifierMaxLength: defaultIdentifierLength, + } + return nil + } + + switch v := config.NamingStrategy.(type) { + case *schema.NamingStrategy: + if v.IdentifierMaxLength <= 0 { + v.IdentifierMaxLength = defaultIdentifierLength + } + case schema.NamingStrategy: + if v.IdentifierMaxLength <= 0 { + v.IdentifierMaxLength = defaultIdentifierLength + config.NamingStrategy = v + } + } + + return nil +} func (dialector Dialector) Initialize(db *gorm.DB) (err error) { callbackConfig := &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, - UpdateClauses: []string{"UPDATE", "SET", "WHERE"}, + UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, } // register callbacks @@ -94,10 +121,23 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('$') - writer.WriteString(strconv.Itoa(len(stmt.Vars))) + index := 0 + varLen := len(stmt.Vars) + if varLen > 0 { + switch stmt.Vars[0].(type) { + case pgx.QueryExecMode: + index++ + } + } + writer.WriteString(strconv.Itoa(varLen - index)) } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + if dialector.WithoutQuotingCheck { + writer.WriteString(str) + return + } + var ( underQuoted, selfQuoted bool continuousBacktick int8 diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md index 85ad3050c..745dad60b 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/README.md @@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now -Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go index 195d17203..50b5b0e93 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks.go @@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error { func (p *processor) compile() (err error) { var callbacks []*callback + removedMap := map[string]bool{} for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } + if callback.remove { + removedMap[callback.name] = true + } + } + + if len(removedMap) > 0 { + callbacks = removeCallbacks(callbacks, removedMap) } p.callbacks = callbacks @@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { return } + +func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { + callbacks := make([]*callback, 0, len(cs)) + for _, callback := range cs { + if nameMap[callback.name] { + continue + } + callbacks = append(callbacks, callback) + } + return callbacks +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go index f0b781398..afea2ccac 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/create.go @@ -103,13 +103,62 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && - db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { + if db.RowsAffected == 0 { + return + } + + var ( + pkField *schema.Field + pkFieldName = "@id" + ) + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + + if !insertOk { + if !supportReturning { db.AddError(err) + } + return + } + + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + + // append @id column with value for auto-increment primary key + // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 + switch values := db.Statement.Dest.(type) { + case map[string]interface{}: + values[pkFieldName] = insertID + case *map[string]interface{}: + (*values)[pkFieldName] = insertID + case []map[string]interface{}, *[]map[string]interface{}: + mapValues, ok := values.([]map[string]interface{}) + if !ok { + if v, ok := values.(*[]map[string]interface{}); ok { + if *v != nil { + mapValues = *v + } + } + } + + if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + } + + for _, mapValue := range mapValues { + if mapValue != nil { + mapValue[pkFieldName] = insertID + } + insertID += schema.DefaultAutoIncrementIncrement + } + default: + if pkField == nil { return } @@ -122,10 +171,10 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + _, isZero := pkField.ValueOf(db.Statement.Context, rv) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID -= pkField.AutoIncrementIncrement } } } else { @@ -135,16 +184,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID += pkField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) + db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } @@ -253,13 +302,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - for field, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - for idx := range values.Values { - if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) - } else { - values.Values[idx] = append(values.Values[idx], vs[idx]) + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if vs, ok := defaultValueFieldsHavingValue[field]; ok { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } } } } @@ -311,7 +362,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { case schema.UnixNanosecond: assignment.Value = curTime.UnixNano() case schema.UnixMillisecond: - assignment.Value = curTime.UnixNano() / 1e6 + assignment.Value = curTime.UnixMilli() case schema.UnixSecond: assignment.Value = curTime.Unix() } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go index 15669c847..cf7a0d2ba 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/preload.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "sort" "strings" "gorm.io/gorm" @@ -82,27 +83,93 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { return names } -func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { - if relationships == nil { - return nil +// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point. +// If the current relationship is embedded or joined, current query will be ignored. +// +//nolint:cyclop +func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { + preloadMap := parsePreloadMap(db.Statement.Schema, preloads) + + // avoid random traversal of the map + preloadNames := make([]string, 0, len(preloadMap)) + for key := range preloadMap { + preloadNames = append(preloadNames, key) + } + sort.Strings(preloadNames) + + isJoined := func(name string) (joined bool, nestedJoins []string) { + for _, join := range joins { + if _, ok := relationships.Relations[join]; ok && name == join { + joined = true + continue + } + joinNames := strings.SplitN(join, ".", 2) + if len(joinNames) == 2 { + if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { + joined = true + nestedJoins = append(nestedJoins, joinNames[1]) + } + } + } + return joined, nestedJoins } - preloadMap := parsePreloadMap(s, preloads) - for name := range preloadMap { - if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { - if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + + for _, name := range preloadNames { + if relations := relationships.EmbeddedRelations[name]; relations != nil { + if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { - if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { - return err + if joined, nestedJoins := isJoined(name); joined { + switch rv := db.Statement.ReflectValue; rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } + case reflect.Struct: + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + default: + return gorm.ErrInvalidData + } + } else { + tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) + tx.Statement.ReflectValue = db.Statement.ReflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { + return err + } } } else { - return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name) } } return nil } +func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB { + tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if err := tx.Statement.Parse(dest); err != nil { + tx.AddError(err) + return tx + } + tx.Statement.ReflectValue = reflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + return tx +} + func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go index e89dd1996..2a82eaba1 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/query.go @@ -3,7 +3,6 @@ package callbacks import ( "fmt" "reflect" - "sort" "strings" "gorm.io/gorm" @@ -254,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) { } db.Statement.AddClause(fromClause) - db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) } @@ -272,38 +270,23 @@ func Preload(db *gorm.DB) { return } - preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) - preloadNames := make([]string, 0, len(preloadMap)) - for key := range preloadMap { - preloadNames = append(preloadNames, key) + joins := make([]string, 0, len(db.Statement.Joins)) + for _, join := range db.Statement.Joins { + joins = append(joins, join.Name) } - sort.Strings(preloadNames) - preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) - db.Statement.Settings.Range(func(k, v interface{}) bool { - preloadDB.Statement.Settings.Store(k, v) - return true - }) - - if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest) + if tx.Error != nil { return } - preloadDB.Statement.ReflectValue = db.Statement.ReflectValue - preloadDB.Statement.Unscoped = db.Statement.Unscoped - - for _, name := range preloadNames { - if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { - db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) - } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { - db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) - } else { - db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) - } - } + + db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) } } func AfterQuery(db *gorm.DB) { + // clear the joins after query because preload need it + db.Statement.Joins = nil if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go index ff075dcf2..7cde7f619 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/callbacks/update.go @@ -234,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()}) } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) } else { @@ -268,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 + value = stmt.DB.NowFunc().UnixMilli() } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() } else { diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go index 3dc7256e6..1ec9b865f 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/chainable_api.go @@ -367,33 +367,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } func (db *DB) executeScopes() (tx *DB) { - tx = db.getInstance() scopes := db.Statement.scopes - if len(scopes) == 0 { - return tx - } - tx.Statement.scopes = nil - - conditions := make([]clause.Interface, 0, 4) - if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { - conditions = append(conditions, cs.Expression.(clause.Interface)) - cs.Expression = nil - tx.Statement.Clauses["WHERE"] = cs - } - + db.Statement.scopes = nil for _, scope := range scopes { - tx = scope(tx) - if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { - conditions = append(conditions, cs.Expression.(clause.Interface)) - cs.Expression = nil - tx.Statement.Clauses["WHERE"] = cs - } - } - - for _, condition := range conditions { - tx.Statement.AddClause(condition) + db = scope(db) } - return tx + return db } // Preload preload associations with given conditions diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go index 8d010522f..3140846ef 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/expression.go @@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '@' && !inName { inName = true - name = []byte{} + name = name[:0] } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go index abda00551..3edde4346 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/limit.go @@ -1,7 +1,5 @@ package clause -import "strconv" - // Limit limit clause type Limit struct { Limit *int @@ -17,14 +15,14 @@ func (limit Limit) Name() string { func (limit Limit) Build(builder Builder) { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(*limit.Limit)) + builder.AddVar(builder, *limit.Limit) } if limit.Offset > 0 { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") - builder.WriteString(strconv.Itoa(limit.Offset)) + builder.AddVar(builder, limit.Offset) } } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go index 290aac92b..2bc48ceb4 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/locking.go @@ -1,5 +1,12 @@ package clause +const ( + LockingStrengthUpdate = "UPDATE" + LockingStrengthShare = "SHARE" + LockingOptionsSkipLocked = "SKIP LOCKED" + LockingOptionsNoWait = "NOWAIT" +) + type Locking struct { Strength string Table Table diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go index a29401cfe..9ac78578e 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/clause/where.go @@ -21,6 +21,12 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } + // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { @@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } + if len(exprs) == 1 { + if andCondition, ok := exprs[0].(AndConditions); ok { + exprs = andCondition.Exprs + } + } return NotConditions{Exprs: exprs} } @@ -155,19 +166,58 @@ type NotConditions struct { } func (not NotConditions) Build(builder Builder) { - if len(not.Exprs) > 1 { - builder.WriteByte('(') + anyNegationBuilder := false + for _, c := range not.Exprs { + if _, ok := c.(NegationExpressionBuilder); ok { + anyNegationBuilder = true + break + } } - for idx, c := range not.Exprs { - if idx > 0 { - builder.WriteString(AndWithSpace) + if anyNegationBuilder { + if len(not.Exprs) > 1 { + builder.WriteByte('(') } - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.WriteString("NOT ") + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } + } else { + builder.WriteString("NOT ") + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) @@ -182,9 +232,9 @@ func (not NotConditions) Build(builder Builder) { builder.WriteByte(')') } } - } - if len(not.Exprs) > 1 { - builder.WriteByte(')') + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } } } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go index f80aa6c04..f97571ed0 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/finisher_api.go @@ -376,8 +376,12 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } else if len(db.Statement.assigns) > 0 { exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { + for i := 0; i < len(exprs); i++ { + expr := exprs[i] + + if eq, ok := expr.(clause.AndConditions); ok { + exprs = append(exprs, eq.Exprs...) + } else if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: assigns[column] = eq.Value diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go index 203527af3..775cd3de3 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -181,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, err := db.DB(); err == nil { + if db, _ := db.DB(); db != nil { _ = db.Close() } } @@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - - if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(db) + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go index 1950d7400..3bcc3d570 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/interfaces.go @@ -77,12 +77,6 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } -// GetDBConnectorWithContext represents SQL db connector which takes into -// account the current database context -type GetDBConnectorWithContext interface { - GetDBConnWithContext(db *DB) (*sql.DB, error) -} - // Rows rows interface type Rows interface { Columns() ([]string, error) diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go index aa0060bc5..253f03252 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/logger.go @@ -69,7 +69,7 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to io.Discard + // Discard logger will print any log to io.Discard Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ @@ -78,7 +78,7 @@ var ( IgnoreRecordNotFoundError: false, Colorful: true, }) - // Recorder Recorder logger records running SQL into a recorder instance + // Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) @@ -129,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface { } // Info print info -func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages -func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages -func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message -func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { +// +//nolint:cyclop +func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel <= Silent { return } @@ -182,8 +184,8 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } } -// Trace print sql message -func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { +// ParamsFilter filter params +func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Config.ParameterizedQueries { return sql, nil } @@ -198,8 +200,8 @@ type traceRecorder struct { Err error } -// New new trace recorder -func (l traceRecorder) New() *traceRecorder { +// New trace recorder +func (l *traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go index 13e5d957d..ad4787956 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/logger/sql.go @@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO // RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) +func isNumeric(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + default: + return false + } +} + // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( @@ -79,17 +92,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper } else { vars[idx] = nullStr } } case []byte: if s := string(v); isPrintable(s) { - vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } @@ -100,7 +113,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64: vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: - vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { @@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) + } else if isNumeric(rv.Kind()) { + if rv.CanInt() || rv.CanUint() { + vars[idx] = fmt.Sprintf("%d", rv.Interface()) + } else { + vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) + } } else { for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { @@ -117,7 +136,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper } } } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go index 0e01f567d..3d2b032b0 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator.go @@ -87,6 +87,8 @@ type Migrator interface { DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + // MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn. + MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]ColumnType, error) diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go index b15a43ef2..acce5df21 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/migrator/migrator.go @@ -7,6 +7,7 @@ import ( "fmt" "reflect" "regexp" + "strconv" "strings" "time" @@ -27,6 +28,8 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) // TODO:? Create const vars for raw sql queries ? +var _ gorm.Migrator = (*Migrator)(nil) + // Migrator m struct type Migrator struct { Config @@ -91,10 +94,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " NOT NULL" } - if field.Unique { - expr.SQL += " UNIQUE" - } - if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} @@ -108,15 +107,20 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } +func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { + queryTx = m.DB.Session(&gorm.Session{}) + execTx = queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + return queryTx, execTx +} + // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - queryTx := m.DB.Session(&gorm.Session{}) - execTx := queryTx - if m.DB.DryRun { - queryTx.DryRun = false - execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) - } + queryTx, execTx := m.GetQueryAndExecTx() if !queryTx.Migrator().HasTable(value) { if err := execTx.Migrator().CreateTable(value); err != nil { return err @@ -217,7 +221,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] if !field.IgnoreMigration { createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } @@ -266,7 +270,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { } if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { - sql, vars := buildConstraint(constraint) + sql, vars := constraint.Build() createTableSQL += sql + "," values = append(values, vars...) } @@ -274,6 +278,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { } } + for _, uni := range stmt.Schema.ParseUniqueConstraints() { + createTableSQL += "CONSTRAINT ? UNIQUE (?)," + values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)}) + } + for _, chk := range stmt.Schema.ParseCheckConstraints() { createTableSQL += "CONSTRAINT ? CHECK (?)," values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) @@ -437,6 +446,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + if field.IgnoreMigration { + return nil + } + // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) @@ -496,14 +509,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - // check unique - if unique, ok := columnType.Unique(); ok && unique != field.Unique { - // not primary key - if !field.PrimaryKey { - alterColumn = true - } - } - // check default value if !field.PrimaryKey { currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) @@ -514,12 +519,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || - (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { - // default value not equal - // not both null - if currentDefaultNotNull || dvNotNull { - alterColumn = true + } else if currentDefaultNotNull || dvNotNull { + switch field.GORMDataType { + case schema.Time: + if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { + alterColumn = true + } + case schema.Bool: + v1, _ := strconv.ParseBool(dv) + v2, _ := strconv.ParseBool(field.DefaultValue) + alterColumn = v1 != v2 + default: + alterColumn = dv != field.DefaultValue } } } @@ -532,13 +543,39 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.DBName) + if alterColumn { + if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil { + return err + } + } + + if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil { + return err } return nil } +func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + unique, ok := columnType.Unique() + if !ok || field.PrimaryKey { + return nil // skip primary key + } + // By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex. + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // We're currently only receiving boolean values on `Unique` tag, + // so the UniqueConstraint name is fixed + constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + if unique && !field.Unique { + return m.DB.Migrator().DropConstraint(value, constraint) + } + if !unique && field.Unique { + return m.DB.Migrator().CreateConstraint(value, constraint) + } + return nil + }) +} + // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) @@ -608,37 +645,36 @@ func (m Migrator) DropView(name string) error { return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error } -func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { - sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" - if constraint.OnDelete != "" { - sql += " ON DELETE " + constraint.OnDelete - } - - if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate - } - - var foreignKeys, references []interface{} - for _, field := range constraint.ForeignKeys { - foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) - } - - for _, field := range constraint.References { - references = append(references, clause.Column{Name: field.DBName}) - } - results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) - return -} - // GuessConstraintAndTable guess statement's constraint and it's table based on name -func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { +// +// Deprecated: use GuessConstraintInterfaceAndTable instead. +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) { + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) + switch c := constraint.(type) { + case *schema.Constraint: + return c, nil, table + case *schema.CheckConstraint: + return nil, c, table + default: + return nil, nil, table + } +} + +// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name +// nolint:cyclop +func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { if stmt.Schema == nil { - return nil, nil, stmt.Table + return nil, stmt.Table } checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { - return nil, &chk, stmt.Table + return &chk, stmt.Table + } + + uniqueConstraints := stmt.Schema.ParseUniqueConstraints() + if uni, ok := uniqueConstraints[name]; ok { + return &uni, stmt.Table } getTable := func(rel *schema.Relationship) string { @@ -653,7 +689,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } @@ -661,40 +697,39 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for k := range checkConstraints { if checkConstraints[k].Field == field { v := checkConstraints[k] - return nil, &v, stmt.Table + return &v, stmt.Table + } + } + + for k := range uniqueConstraints { + if uniqueConstraints[k].Field == field { + v := uniqueConstraints[k] + return &v, stmt.Table } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } } - return nil, nil, stmt.Schema.Table + return nil, stmt.Schema.Table } // CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - if chk != nil { - return m.DB.Exec( - "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", - m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, - ).Error - } - + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { vars := []interface{}{clause.Table{Name: table}} if stmt.TableExpr != nil { vars[0] = stmt.TableExpr } - sql, values := buildConstraint(constraint) + sql, values := constraint.Build() return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } - return nil }) } @@ -702,11 +737,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { // DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) @@ -717,11 +750,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Raw( diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go index 9d98c86e0..c60b5db79 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/prepare_stmt.go @@ -3,6 +3,8 @@ package gorm import ( "context" "database/sql" + "database/sql/driver" + "errors" "reflect" "sync" ) @@ -30,15 +32,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } - if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(gormdb) - } - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } @@ -131,6 +129,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -138,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() go stmt.Close() @@ -152,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() @@ -176,6 +187,10 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() @@ -194,7 +209,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -209,7 +224,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go index 736db4d3a..415b9f0d7 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/scan.go @@ -274,12 +274,16 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - // if the slice cap is externally initialized, the externally initialized slice is directly used here - if reflectValue.Cap() == 0 { - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - } else if !isArrayKind { - reflectValue.SetLen(0) - db.Statement.ReflectValue.Set(reflectValue) + if isArrayKind { + db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) + } else { + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go deleted file mode 100644 index 89e732d36..000000000 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/check.go +++ /dev/null @@ -1,35 +0,0 @@ -package schema - -import ( - "regexp" - "strings" -) - -// reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") - -type Check struct { - Name string - Constraint string // length(phone) >= 10 - *Field -} - -// ParseCheckConstraints parse schema check constraints -func (schema *Schema) ParseCheckConstraints() map[string]Check { - checks := map[string]Check{} - for _, field := range schema.FieldsByDBName { - if chk := field.TagSettings["CHECK"]; chk != "" { - names := strings.Split(chk, ",") - if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { - checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} - } else { - if names[0] == "" { - chk = strings.Join(names[1:], ",") - } - name := schema.namer.CheckerName(schema.Table, field.DBName) - checks[name] = Check{Name: name, Constraint: chk, Field: field} - } - } - } - return checks -} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/constraint.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/constraint.go new file mode 100644 index 000000000..80a743a83 --- /dev/null +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/constraint.go @@ -0,0 +1,66 @@ +package schema + +import ( + "regexp" + "strings" + + "gorm.io/gorm/clause" +) + +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`) + +type CheckConstraint struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +func (chk *CheckConstraint) GetName() string { return chk.Name } + +func (chk *CheckConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint { + checks := map[string]CheckConstraint{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { + checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} + +type UniqueConstraint struct { + Name string + Field *Field +} + +func (uni *UniqueConstraint) GetName() string { return uni.Name } + +func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}} +} + +// ParseUniqueConstraints parse schema unique constraints +func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint { + uniques := make(map[string]UniqueConstraint) + for _, field := range schema.Fields { + if field.Unique { + name := schema.namer.UniqueName(schema.Table, field.DBName) + uniques[name] = UniqueConstraint{Name: name, Field: field} + } + } + return uniques +} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go index dd08e056b..ca2e11482 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/field.go @@ -49,6 +49,8 @@ const ( Bytes DataType = "bytes" ) +const DefaultAutoIncrementIncrement int64 = 1 + // Field is the representation of model schema's field type Field struct { Name string @@ -87,6 +89,12 @@ type Field struct { Set func(context.Context, reflect.Value, interface{}) error Serializer SerializerInterface NewValuePool FieldNewValuePool + + // In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable. + // When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique. + // It causes field unnecessarily migration. + // Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique. + UniqueIndex string } func (field *Field) BindName() string { @@ -119,7 +127,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Comment: tagSetting["COMMENT"], - AutoIncrementIncrement: 1, + AutoIncrementIncrement: DefaultAutoIncrementIncrement, } for field.IndirectFieldType.Kind() == reflect.Ptr { @@ -656,7 +664,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -665,7 +673,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -730,7 +738,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli())) } else { field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go index f5ac5dd21..f4f367510 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/index.go @@ -13,8 +13,8 @@ type Index struct { Type string // btree, hash, gist, spgist, gin, and brin Where string Comment string - Option string // WITH PARSER parser_name - Fields []IndexOption + Option string // WITH PARSER parser_name + Fields []IndexOption // Note: IndexOption's Field maybe the same } type IndexOption struct { @@ -67,7 +67,7 @@ func (schema *Schema) ParseIndexes() map[string]Index { } for _, index := range indexes { if index.Class == "UNIQUE" && len(index.Fields) == 1 { - index.Fields[0].Field.Unique = true + index.Fields[0].Field.UniqueIndex = index.Name } } return indexes diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go index a75a33c0d..306d4f4e0 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/interfaces.go @@ -4,6 +4,12 @@ import ( "gorm.io/gorm/clause" ) +// ConstraintInterface database constraint interface +type ConstraintInterface interface { + GetName() string + Build() (sql string, vars []interface{}) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go index a2a0150a3..e6fb81b2b 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/naming.go @@ -19,6 +19,7 @@ type Namer interface { RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string + UniqueName(table, column string) string } // Replacer replacer interface like strings.Replacer @@ -26,6 +27,8 @@ type Replacer interface { Replace(name string) string } +var _ Namer = (*NamingStrategy)(nil) + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string @@ -85,6 +88,11 @@ func (ns NamingStrategy) IndexName(table, column string) string { return ns.formatName("idx", table, ns.toDBName(column)) } +// UniqueName generate unique constraint name +func (ns NamingStrategy) UniqueName(table, column string) string { + return ns.formatName("uni", table, ns.toDBName(column)) +} + func (ns NamingStrategy) formatName(prefix, table, name string) string { formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go index e03dcc520..2e94fc2cb 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/relationship.go @@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return nil } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - schema.buildPolymorphicRelation(relation, field, polymorphic) + if hasPolymorphicRelation(field.TagSettings) { + schema.buildPolymorphicRelation(relation, field) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { @@ -89,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, + field.Name) } } @@ -124,6 +125,20 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } +// hasPolymorphicRelation check if has polymorphic relation +// 1. `POLYMORPHIC` tag +// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag +func hasPolymorphicRelation(tagSettings map[string]string) bool { + if _, ok := tagSettings["POLYMORPHIC"]; ok { + return true + } + + _, hasType := tagSettings["POLYMORPHICTYPE"] + _, hasId := tagSettings["POLYMORPHICID"] + + return hasType && hasId +} + func (schema *Schema) setRelation(relation *Relationship) { // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { @@ -169,23 +184,41 @@ func (schema *Schema) setRelation(relation *Relationship) { // OwnerID int // OwnerType string // } -func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) { + polymorphic := field.TagSettings["POLYMORPHIC"] + relation.Polymorphic = &Polymorphic{ - Value: schema.Table, - PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], - PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + Value: schema.Table, + } + + var ( + typeName = polymorphic + "Type" + typeId = polymorphic + "ID" + ) + + if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok { + typeName = strings.TrimSpace(value) + } + + if value, ok := field.TagSettings["POLYMORPHICID"]; ok { + typeId = strings.TrimSpace(value) } + relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName] + relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId] + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { relation.Polymorphic.Value = strings.TrimSpace(value) } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -197,12 +230,14 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, + schema, field.Name) } } if primaryKeyField == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", + relation.FieldSchema, schema, field.Name) return } @@ -317,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Tag: `gorm:"-"`, }) - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, + schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many @@ -436,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", + schema, field.Name) } } @@ -492,7 +529,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames := []string{lookUpName} if len(primaryFields) == 1 { - lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", + strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, + strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } for _, name := range lookUpNames { @@ -566,6 +605,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } +// Constraint is ForeignKey Constraint type Constraint struct { Name string Field *Field @@ -577,6 +617,31 @@ type Constraint struct { OnUpdate string } +func (constraint *Constraint) GetName() string { return constraint.Name } + +func (constraint *Constraint) Build() (sql string, vars []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys)) + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + references := make([]interface{}, 0, len(constraint.References)) + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + func (rel *Relationship) ParseConstraint() *Constraint { str := rel.Field.TagSettings["CONSTRAINT"] if str == "-" { diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go index e13a5ed13..3e7459ce7 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/schema.go @@ -13,6 +13,20 @@ import ( "gorm.io/gorm/logger" ) +type callbackType string + +const ( + callbackTypeBeforeCreate callbackType = "BeforeCreate" + callbackTypeBeforeUpdate callbackType = "BeforeUpdate" + callbackTypeAfterCreate callbackType = "AfterCreate" + callbackTypeAfterUpdate callbackType = "AfterUpdate" + callbackTypeBeforeSave callbackType = "BeforeSave" + callbackTypeAfterSave callbackType = "AfterSave" + callbackTypeBeforeDelete callbackType = "BeforeDelete" + callbackTypeAfterDelete callbackType = "AfterDelete" + callbackTypeAfterFind callbackType = "AfterFind" +) + // ErrUnsupportedDataType unsupported data type var ErrUnsupportedDataType = errors.New("unsupported data type") @@ -288,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} - for _, name := range callbacks { - if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + callbackTypes := []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, + } + for _, cbName := range callbackTypes { + if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } } } @@ -349,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } +// This unrolling is needed to show to the compiler the exact set of methods +// that can be used on the modelType. +// Prior to go1.22 any use of MethodByName would cause the linker to +// abandon dead code elimination for the entire binary. +// As of go1.22 the compiler supports one special case of a string constant +// being passed to MethodByName. For enterprise customers or those building +// large binaries, this gives a significant reduction in binary size. +// https://github.com/golang/go/issues/62257 +func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { + switch cbType { + case callbackTypeBeforeCreate: + return modelType.MethodByName(string(callbackTypeBeforeCreate)) + case callbackTypeAfterCreate: + return modelType.MethodByName(string(callbackTypeAfterCreate)) + case callbackTypeBeforeUpdate: + return modelType.MethodByName(string(callbackTypeBeforeUpdate)) + case callbackTypeAfterUpdate: + return modelType.MethodByName(string(callbackTypeAfterUpdate)) + case callbackTypeBeforeSave: + return modelType.MethodByName(string(callbackTypeBeforeSave)) + case callbackTypeAfterSave: + return modelType.MethodByName(string(callbackTypeAfterSave)) + case callbackTypeBeforeDelete: + return modelType.MethodByName(string(callbackTypeBeforeDelete)) + case callbackTypeAfterDelete: + return modelType.MethodByName(string(callbackTypeAfterDelete)) + case callbackTypeAfterFind: + return modelType.MethodByName(string(callbackTypeAfterFind)) + default: + return reflect.ValueOf(nil) + } +} + func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go index 397edff03..f500521ef 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/serializer.go @@ -126,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect rv := reflect.ValueOf(fieldValue) switch v := fieldValue.(type) { case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go index 65d012e54..7fdda1855 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/schema/utils.go @@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, notZero, zero bool ) + if reflectValue.Kind() == reflect.Ptr || + reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + switch reflectValue.Kind() { case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go index 59c0b772c..ae79aa321 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/statement.go @@ -326,7 +326,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case *DB: v.executeScopes() - if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { @@ -334,13 +334,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } conds = append(conds, clause.And(where.Exprs...)) - } else { + } else if cs.Expression != nil { conds = append(conds, cs.Expression) } - if v.Statement == stmt { - cs.Expression = nil - stmt.Statement.Clauses["WHERE"] = cs - } } case map[interface{}]interface{}: for i, j := range v { @@ -451,8 +447,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if len(values) > 0 { conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + return []clause.Expression{clause.And(conds...)} } - return conds + return nil } } @@ -461,7 +458,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } - return conds + if len(conds) > 0 { + return []clause.Expression{clause.And(conds...)} + } + return nil } // Build build sql with clauses names @@ -665,7 +665,21 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) +var matchName = func() func(tableColumn string) (table, column string) { + nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`) + return func(tableColumn string) (table, column string) { + if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 { + table = matches[1] + star := matches[2] + columnName := matches[3] + if star != "" { + return table, star + } + return table, columnName + } + return "", "" + } +}() // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { @@ -686,13 +700,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = result - } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { - if matches[2] == "*" { + } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") { + if col == "*" { for _, dbName := range stmt.Schema.DBNames { results[dbName] = result } } else { - results[matches[2]] = result + results[col] = result } } else { results[column] = result diff --git a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go index ddbca60a8..347a331fb 100644 --- a/src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go +++ b/src/code.cloudfoundry.org/vendor/gorm.io/gorm/utils/utils.go @@ -35,7 +35,8 @@ func FileWithLineNum() string { // the second caller usually from gorm internal, so set i start from 2 for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) && + !strings.HasSuffix(file, ".gen.go") { return file + ":" + strconv.FormatInt(int64(line), 10) } } @@ -73,7 +74,11 @@ func ToStringKey(values ...interface{}) string { case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: - results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + results[idx] = "nil" + vv := reflect.ValueOf(v) + if vv.IsValid() && !vv.IsZero() { + results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface()) + } } } @@ -89,19 +94,28 @@ func Contains(elems []string, elem string) bool { return false } -func AssertEqual(src, dst interface{}) bool { - if !reflect.DeepEqual(src, dst) { - if valuer, ok := src.(driver.Valuer); ok { - src, _ = valuer.Value() - } +func AssertEqual(x, y interface{}) bool { + if reflect.DeepEqual(x, y) { + return true + } + if x == nil || y == nil { + return false + } - if valuer, ok := dst.(driver.Valuer); ok { - dst, _ = valuer.Value() - } + xval := reflect.ValueOf(x) + yval := reflect.ValueOf(y) + if xval.Kind() == reflect.Ptr && xval.IsNil() || + yval.Kind() == reflect.Ptr && yval.IsNil() { + return false + } - return reflect.DeepEqual(src, dst) + if valuer, ok := x.(driver.Valuer); ok { + x, _ = valuer.Value() + } + if valuer, ok := y.(driver.Valuer); ok { + y, _ = valuer.Value() } - return true + return reflect.DeepEqual(x, y) } func ToString(value interface{}) string { diff --git a/src/code.cloudfoundry.org/vendor/modules.txt b/src/code.cloudfoundry.org/vendor/modules.txt index a58c6513f..7891d4f1d 100644 --- a/src/code.cloudfoundry.org/vendor/modules.txt +++ b/src/code.cloudfoundry.org/vendor/modules.txt @@ -205,14 +205,12 @@ github.com/jackc/pgx/v5/stdlib ## explicit; go 1.19 github.com/jackc/puddle/v2 github.com/jackc/puddle/v2/internal/genstack -# github.com/jinzhu/gorm v1.9.16 -## explicit; go 1.12 -github.com/jinzhu/gorm -github.com/jinzhu/gorm/dialects/mysql -github.com/jinzhu/gorm/dialects/postgres # github.com/jinzhu/inflection v1.0.0 ## explicit github.com/jinzhu/inflection +# github.com/jinzhu/now v1.1.5 +## explicit; go 1.12 +github.com/jinzhu/now # github.com/kisielk/errcheck v1.7.0 ## explicit; go 1.18 github.com/kisielk/errcheck @@ -229,12 +227,6 @@ github.com/klauspost/compress/internal/snapref github.com/klauspost/compress/s2 github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd/internal/xxhash -# github.com/lib/pq v1.10.9 -## explicit; go 1.13 -github.com/lib/pq -github.com/lib/pq/hstore -github.com/lib/pq/oid -github.com/lib/pq/scram # github.com/minio/highwayhash v1.0.2 ## explicit; go 1.15 github.com/minio/highwayhash @@ -621,6 +613,21 @@ gopkg.in/yaml.v2 # gopkg.in/yaml.v3 v3.0.1 ## explicit gopkg.in/yaml.v3 +# gorm.io/driver/mysql v1.5.6 +## explicit; go 1.14 +gorm.io/driver/mysql +# gorm.io/driver/postgres v1.5.7 +## explicit; go 1.18 +gorm.io/driver/postgres +# gorm.io/gorm v1.25.9 +## explicit; go 1.18 +gorm.io/gorm +gorm.io/gorm/callbacks +gorm.io/gorm/clause +gorm.io/gorm/logger +gorm.io/gorm/migrator +gorm.io/gorm/schema +gorm.io/gorm/utils # github.com/uber-go/zap => github.com/uber-go/zap v0.0.0-20161222040304-a5783ee4b216 # github.com/uber-go/atomic => github.com/uber-go/atomic v1.1.0 # github.com/codegangsta/cli => github.com/codegangsta/cli v1.6.0