vendor papka bolmasa boljak dal
This commit is contained in:
parent
ded68e5b60
commit
9d352458d9
|
|
@ -1,5 +1,4 @@
|
|||
.env
|
||||
db_service.exe
|
||||
/vendor
|
||||
.vscode
|
||||
db_service
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
Icon?
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
.idea
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
# This is the official list of Go-MySQL-Driver authors for copyright purposes.
|
||||
|
||||
# If you are submitting a patch, please add your name or the name of the
|
||||
# organization which holds the copyright to this list in alphabetical order.
|
||||
|
||||
# Names should be added to this file as
|
||||
# Name <email address>
|
||||
# The email address is not required for organizations.
|
||||
# Please keep the list sorted.
|
||||
|
||||
|
||||
# Individual Persons
|
||||
|
||||
Aaron Hopkins <go-sql-driver at die.net>
|
||||
Achille Roussel <achille.roussel at gmail.com>
|
||||
Alex Snast <alexsn at fb.com>
|
||||
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
|
||||
Andrew Reid <andrew.reid at tixtrack.com>
|
||||
Animesh Ray <mail.rayanimesh at gmail.com>
|
||||
Arne Hormann <arnehormann at gmail.com>
|
||||
Ariel Mashraki <ariel at mashraki.co.il>
|
||||
Asta Xie <xiemengjun at gmail.com>
|
||||
Bulat Gaifullin <gaifullinbf at gmail.com>
|
||||
Caine Jette <jette at alum.mit.edu>
|
||||
Carlos Nieto <jose.carlos at menteslibres.net>
|
||||
Chris Moos <chris at tech9computers.com>
|
||||
Craig Wilson <craiggwilson at gmail.com>
|
||||
Daniel Montoya <dsmontoyam at gmail.com>
|
||||
Daniel Nichter <nil at codenode.com>
|
||||
Daniël van Eeden <git at myname.nl>
|
||||
Dave Protasowski <dprotaso at gmail.com>
|
||||
DisposaBoy <disposaboy at dby.me>
|
||||
Egor Smolyakov <egorsmkv at gmail.com>
|
||||
Erwan Martin <hello at erwan.io>
|
||||
Evan Shaw <evan at vendhq.com>
|
||||
Frederick Mayle <frederickmayle at gmail.com>
|
||||
Gustavo Kristic <gkristic at gmail.com>
|
||||
Hajime Nakagami <nakagami at gmail.com>
|
||||
Hanno Braun <mail at hannobraun.com>
|
||||
Henri Yandell <flamefew at gmail.com>
|
||||
Hirotaka Yamamoto <ymmt2005 at gmail.com>
|
||||
Huyiguang <hyg at webterren.com>
|
||||
ICHINOSE Shogo <shogo82148 at gmail.com>
|
||||
Ilia Cimpoes <ichimpoesh at gmail.com>
|
||||
INADA Naoki <songofacandy at gmail.com>
|
||||
Jacek Szwec <szwec.jacek at gmail.com>
|
||||
James Harr <james.harr at gmail.com>
|
||||
Jeff Hodges <jeff at somethingsimilar.com>
|
||||
Jeffrey Charles <jeffreycharles at gmail.com>
|
||||
Jerome Meyer <jxmeyer at gmail.com>
|
||||
Jiajia Zhong <zhong2plus at gmail.com>
|
||||
Jian Zhen <zhenjl at gmail.com>
|
||||
Joshua Prunier <joshua.prunier at gmail.com>
|
||||
Julien Lefevre <julien.lefevr at gmail.com>
|
||||
Julien Schmidt <go-sql-driver at julienschmidt.com>
|
||||
Justin Li <jli at j-li.net>
|
||||
Justin Nuß <nuss.justin at gmail.com>
|
||||
Kamil Dziedzic <kamil at klecza.pl>
|
||||
Kei Kamikawa <x00.x7f.x86 at gmail.com>
|
||||
Kevin Malachowski <kevin at chowski.com>
|
||||
Kieron Woodhouse <kieron.woodhouse at infosum.com>
|
||||
Lennart Rudolph <lrudolph at hmc.edu>
|
||||
Leonardo YongUk Kim <dalinaum at gmail.com>
|
||||
Linh Tran Tuan <linhduonggnu at gmail.com>
|
||||
Lion Yang <lion at aosc.xyz>
|
||||
Luca Looz <luca.looz92 at gmail.com>
|
||||
Lucas Liu <extrafliu at gmail.com>
|
||||
Luke Scott <luke at webconnex.com>
|
||||
Maciej Zimnoch <maciej.zimnoch at codilime.com>
|
||||
Michael Woolnough <michael.woolnough at gmail.com>
|
||||
Nathanial Murphy <nathanial.murphy at gmail.com>
|
||||
Nicola Peduzzi <thenikso at gmail.com>
|
||||
Olivier Mengué <dolmen at cpan.org>
|
||||
oscarzhao <oscarzhaosl at gmail.com>
|
||||
Paul Bonser <misterpib at gmail.com>
|
||||
Peter Schultz <peter.schultz at classmarkets.com>
|
||||
Rebecca Chin <rchin at pivotal.io>
|
||||
Reed Allman <rdallman10 at gmail.com>
|
||||
Richard Wilkes <wilkes at me.com>
|
||||
Robert Russell <robert at rrbrussell.com>
|
||||
Runrioter Wung <runrioter at gmail.com>
|
||||
Sho Iizuka <sho.i518 at gmail.com>
|
||||
Sho Ikeda <suicaicoca at gmail.com>
|
||||
Shuode Li <elemount at qq.com>
|
||||
Simon J Mudd <sjmudd at pobox.com>
|
||||
Soroush Pour <me at soroushjp.com>
|
||||
Stan Putrya <root.vagner at gmail.com>
|
||||
Stanley Gunawan <gunawan.stanley at gmail.com>
|
||||
Steven Hartland <steven.hartland at multiplay.co.uk>
|
||||
Tan Jinhua <312841925 at qq.com>
|
||||
Thomas Wodarek <wodarekwebpage at gmail.com>
|
||||
Tim Ruffles <timruffles at gmail.com>
|
||||
Tom Jenkinson <tom at tjenkinson.me>
|
||||
Vladimir Kovpak <cn007b at gmail.com>
|
||||
Vladyslav Zhelezniak <zhvladi at gmail.com>
|
||||
Xiangyu Hu <xiangyu.hu at outlook.com>
|
||||
Xiaobing Jiang <s7v7nislands at gmail.com>
|
||||
Xiuming Chen <cc at cxm.cc>
|
||||
Xuehong Chan <chanxuehong at gmail.com>
|
||||
Zhenye Xie <xiezhenye at gmail.com>
|
||||
Zhixin Wen <john.wenzhixin at gmail.com>
|
||||
|
||||
# Organizations
|
||||
|
||||
Barracuda Networks, Inc.
|
||||
Counting Ltd.
|
||||
DigitalOcean Inc.
|
||||
Facebook Inc.
|
||||
GitHub Inc.
|
||||
Google Inc.
|
||||
InfoSum Ltd.
|
||||
Keybase Inc.
|
||||
Multiplay Ltd.
|
||||
Percona LLC
|
||||
Pivotal Inc.
|
||||
Stripe Inc.
|
||||
Zendesk Inc.
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
## Version 1.6 (2021-04-01)
|
||||
|
||||
Changes:
|
||||
|
||||
- Migrate the CI service from travis-ci to GitHub Actions (#1176, #1183, #1190)
|
||||
- `NullTime` is deprecated (#960, #1144)
|
||||
- Reduce allocations when building SET command (#1111)
|
||||
- Performance improvement for time formatting (#1118)
|
||||
- Performance improvement for time parsing (#1098, #1113)
|
||||
|
||||
New Features:
|
||||
|
||||
- Implement `driver.Validator` interface (#1106, #1174)
|
||||
- Support returning `uint64` from `Valuer` in `ConvertValue` (#1143)
|
||||
- Add `json.RawMessage` for converter and prepared statement (#1059)
|
||||
- Interpolate `json.RawMessage` as `string` (#1058)
|
||||
- Implements `CheckNamedValue` (#1090)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Stop rounding times (#1121, #1172)
|
||||
- Put zero filler into the SSL handshake packet (#1066)
|
||||
- Fix checking cancelled connections back into the connection pool (#1095)
|
||||
- Fix remove last 0 byte for mysql_old_password when password is empty (#1133)
|
||||
|
||||
|
||||
## Version 1.5 (2020-01-07)
|
||||
|
||||
Changes:
|
||||
|
||||
- Dropped support Go 1.9 and lower (#823, #829, #886, #1016, #1017)
|
||||
- Improve buffer handling (#890)
|
||||
- Document potentially insecure TLS configs (#901)
|
||||
- Use a double-buffering scheme to prevent data races (#943)
|
||||
- Pass uint64 values without converting them to string (#838, #955)
|
||||
- Update collations and make utf8mb4 default (#877, #1054)
|
||||
- Make NullTime compatible with sql.NullTime in Go 1.13+ (#995)
|
||||
- Removed CloudSQL support (#993, #1007)
|
||||
- Add Go Module support (#1003)
|
||||
|
||||
New Features:
|
||||
|
||||
- Implement support of optional TLS (#900)
|
||||
- Check connection liveness (#934, #964, #997, #1048, #1051, #1052)
|
||||
- Implement Connector Interface (#941, #958, #1020, #1035)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Mark connections as bad on error during ping (#875)
|
||||
- Mark connections as bad on error during dial (#867)
|
||||
- Fix connection leak caused by rapid context cancellation (#1024)
|
||||
- Mark connections as bad on error during Conn.Prepare (#1030)
|
||||
|
||||
|
||||
## Version 1.4.1 (2018-11-14)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Fix TIME format for binary columns (#818)
|
||||
- Fix handling of empty auth plugin names (#835)
|
||||
- Fix caching_sha2_password with empty password (#826)
|
||||
- Fix canceled context broke mysqlConn (#862)
|
||||
- Fix OldAuthSwitchRequest support (#870)
|
||||
- Fix Auth Response packet for cleartext password (#887)
|
||||
|
||||
## Version 1.4 (2018-06-03)
|
||||
|
||||
Changes:
|
||||
|
||||
- Documentation fixes (#530, #535, #567)
|
||||
- Refactoring (#575, #579, #580, #581, #603, #615, #704)
|
||||
- Cache column names (#444)
|
||||
- Sort the DSN parameters in DSNs generated from a config (#637)
|
||||
- Allow native password authentication by default (#644)
|
||||
- Use the default port if it is missing in the DSN (#668)
|
||||
- Removed the `strict` mode (#676)
|
||||
- Do not query `max_allowed_packet` by default (#680)
|
||||
- Dropped support Go 1.6 and lower (#696)
|
||||
- Updated `ConvertValue()` to match the database/sql/driver implementation (#760)
|
||||
- Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783)
|
||||
- Improved the compatibility of the authentication system (#807)
|
||||
|
||||
New Features:
|
||||
|
||||
- Multi-Results support (#537)
|
||||
- `rejectReadOnly` DSN option (#604)
|
||||
- `context.Context` support (#608, #612, #627, #761)
|
||||
- Transaction isolation level support (#619, #744)
|
||||
- Read-Only transactions support (#618, #634)
|
||||
- `NewConfig` function which initializes a config with default values (#679)
|
||||
- Implemented the `ColumnType` interfaces (#667, #724)
|
||||
- Support for custom string types in `ConvertValue` (#623)
|
||||
- Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710)
|
||||
- `caching_sha2_password` authentication plugin support (#794, #800, #801, #802)
|
||||
- Implemented `driver.SessionResetter` (#779)
|
||||
- `sha256_password` authentication plugin support (#808)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718)
|
||||
- Fixed LOAD LOCAL DATA INFILE for empty files (#590)
|
||||
- Removed columns definition cache since it sometimes cached invalid data (#592)
|
||||
- Don't mutate registered TLS configs (#600)
|
||||
- Make RegisterTLSConfig concurrency-safe (#613)
|
||||
- Handle missing auth data in the handshake packet correctly (#646)
|
||||
- Do not retry queries when data was written to avoid data corruption (#302, #736)
|
||||
- Cache the connection pointer for error handling before invalidating it (#678)
|
||||
- Fixed imports for appengine/cloudsql (#700)
|
||||
- Fix sending STMT_LONG_DATA for 0 byte data (#734)
|
||||
- Set correct capacity for []bytes read from length-encoded strings (#766)
|
||||
- Make RegisterDial concurrency-safe (#773)
|
||||
|
||||
|
||||
## Version 1.3 (2016-12-01)
|
||||
|
||||
Changes:
|
||||
|
||||
- Go 1.1 is no longer supported
|
||||
- Use decimals fields in MySQL to format time types (#249)
|
||||
- Buffer optimizations (#269)
|
||||
- TLS ServerName defaults to the host (#283)
|
||||
- Refactoring (#400, #410, #437)
|
||||
- Adjusted documentation for second generation CloudSQL (#485)
|
||||
- Documented DSN system var quoting rules (#502)
|
||||
- Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512)
|
||||
|
||||
New Features:
|
||||
|
||||
- Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249)
|
||||
- Support for returning table alias on Columns() (#289, #359, #382)
|
||||
- Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490)
|
||||
- Support for uint64 parameters with high bit set (#332, #345)
|
||||
- Cleartext authentication plugin support (#327)
|
||||
- Exported ParseDSN function and the Config struct (#403, #419, #429)
|
||||
- Read / Write timeouts (#401)
|
||||
- Support for JSON field type (#414)
|
||||
- Support for multi-statements and multi-results (#411, #431)
|
||||
- DSN parameter to set the driver-side max_allowed_packet value manually (#489)
|
||||
- Native password authentication plugin support (#494, #524)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Fixed handling of queries without columns and rows (#255)
|
||||
- Fixed a panic when SetKeepAlive() failed (#298)
|
||||
- Handle ERR packets while reading rows (#321)
|
||||
- Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349)
|
||||
- Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356)
|
||||
- Actually zero out bytes in handshake response (#378)
|
||||
- Fixed race condition in registering LOAD DATA INFILE handler (#383)
|
||||
- Fixed tests with MySQL 5.7.9+ (#380)
|
||||
- QueryUnescape TLS config names (#397)
|
||||
- Fixed "broken pipe" error by writing to closed socket (#390)
|
||||
- Fixed LOAD LOCAL DATA INFILE buffering (#424)
|
||||
- Fixed parsing of floats into float64 when placeholders are used (#434)
|
||||
- Fixed DSN tests with Go 1.7+ (#459)
|
||||
- Handle ERR packets while waiting for EOF (#473)
|
||||
- Invalidate connection on error while discarding additional results (#513)
|
||||
- Allow terminating packets of length 0 (#516)
|
||||
|
||||
|
||||
## Version 1.2 (2014-06-03)
|
||||
|
||||
Changes:
|
||||
|
||||
- We switched back to a "rolling release". `go get` installs the current master branch again
|
||||
- Version v1 of the driver will not be maintained anymore. Go 1.0 is no longer supported by this driver
|
||||
- Exported errors to allow easy checking from application code
|
||||
- Enabled TCP Keepalives on TCP connections
|
||||
- Optimized INFILE handling (better buffer size calculation, lazy init, ...)
|
||||
- The DSN parser also checks for a missing separating slash
|
||||
- Faster binary date / datetime to string formatting
|
||||
- Also exported the MySQLWarning type
|
||||
- mysqlConn.Close returns the first error encountered instead of ignoring all errors
|
||||
- writePacket() automatically writes the packet size to the header
|
||||
- readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets
|
||||
|
||||
New Features:
|
||||
|
||||
- `RegisterDial` allows the usage of a custom dial function to establish the network connection
|
||||
- Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter
|
||||
- Logging of critical errors is configurable with `SetLogger`
|
||||
- Google CloudSQL support
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Allow more than 32 parameters in prepared statements
|
||||
- Various old_password fixes
|
||||
- Fixed TestConcurrent test to pass Go's race detection
|
||||
- Fixed appendLengthEncodedInteger for large numbers
|
||||
- Renamed readLengthEnodedString to readLengthEncodedString and skipLengthEnodedString to skipLengthEncodedString (fixed typo)
|
||||
|
||||
|
||||
## Version 1.1 (2013-11-02)
|
||||
|
||||
Changes:
|
||||
|
||||
- Go-MySQL-Driver now requires Go 1.1
|
||||
- Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore
|
||||
- Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
|
||||
- `[]byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")`
|
||||
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
|
||||
- Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries
|
||||
- Optimized the buffer for reading
|
||||
- stmt.Query now caches column metadata
|
||||
- New Logo
|
||||
- Changed the copyright header to include all contributors
|
||||
- Improved the LOAD INFILE documentation
|
||||
- The driver struct is now exported to make the driver directly accessible
|
||||
- Refactored the driver tests
|
||||
- Added more benchmarks and moved all to a separate file
|
||||
- Other small refactoring
|
||||
|
||||
New Features:
|
||||
|
||||
- Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure
|
||||
- Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs
|
||||
- Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
|
||||
- Convert to DB timezone when inserting `time.Time`
|
||||
- Splitted packets (more than 16MB) are now merged correctly
|
||||
- Fixed false positive `io.EOF` errors when the data was fully read
|
||||
- Avoid panics on reuse of closed connections
|
||||
- Fixed empty string producing false nil values
|
||||
- Fixed sign byte for positive TIME fields
|
||||
|
||||
|
||||
## Version 1.0 (2013-05-14)
|
||||
|
||||
Initial Release
|
||||
|
|
@ -0,0 +1,373 @@
|
|||
Mozilla Public License Version 2.0
|
||||
==================================
|
||||
|
||||
1. Definitions
|
||||
--------------
|
||||
|
||||
1.1. "Contributor"
|
||||
means each individual or legal entity that creates, contributes to
|
||||
the creation of, or owns Covered Software.
|
||||
|
||||
1.2. "Contributor Version"
|
||||
means the combination of the Contributions of others (if any) used
|
||||
by a Contributor and that particular Contributor's Contribution.
|
||||
|
||||
1.3. "Contribution"
|
||||
means Covered Software of a particular Contributor.
|
||||
|
||||
1.4. "Covered Software"
|
||||
means Source Code Form to which the initial Contributor has attached
|
||||
the notice in Exhibit A, the Executable Form of such Source Code
|
||||
Form, and Modifications of such Source Code Form, in each case
|
||||
including portions thereof.
|
||||
|
||||
1.5. "Incompatible With Secondary Licenses"
|
||||
means
|
||||
|
||||
(a) that the initial Contributor has attached the notice described
|
||||
in Exhibit B to the Covered Software; or
|
||||
|
||||
(b) that the Covered Software was made available under the terms of
|
||||
version 1.1 or earlier of the License, but not also under the
|
||||
terms of a Secondary License.
|
||||
|
||||
1.6. "Executable Form"
|
||||
means any form of the work other than Source Code Form.
|
||||
|
||||
1.7. "Larger Work"
|
||||
means a work that combines Covered Software with other material, in
|
||||
a separate file or files, that is not Covered Software.
|
||||
|
||||
1.8. "License"
|
||||
means this document.
|
||||
|
||||
1.9. "Licensable"
|
||||
means having the right to grant, to the maximum extent possible,
|
||||
whether at the time of the initial grant or subsequently, any and
|
||||
all of the rights conveyed by this License.
|
||||
|
||||
1.10. "Modifications"
|
||||
means any of the following:
|
||||
|
||||
(a) any file in Source Code Form that results from an addition to,
|
||||
deletion from, or modification of the contents of Covered
|
||||
Software; or
|
||||
|
||||
(b) any new file in Source Code Form that contains any Covered
|
||||
Software.
|
||||
|
||||
1.11. "Patent Claims" of a Contributor
|
||||
means any patent claim(s), including without limitation, method,
|
||||
process, and apparatus claims, in any patent Licensable by such
|
||||
Contributor that would be infringed, but for the grant of the
|
||||
License, by the making, using, selling, offering for sale, having
|
||||
made, import, or transfer of either its Contributions or its
|
||||
Contributor Version.
|
||||
|
||||
1.12. "Secondary License"
|
||||
means either the GNU General Public License, Version 2.0, the GNU
|
||||
Lesser General Public License, Version 2.1, the GNU Affero General
|
||||
Public License, Version 3.0, or any later versions of those
|
||||
licenses.
|
||||
|
||||
1.13. "Source Code Form"
|
||||
means the form of the work preferred for making modifications.
|
||||
|
||||
1.14. "You" (or "Your")
|
||||
means an individual or a legal entity exercising rights under this
|
||||
License. For legal entities, "You" includes any entity that
|
||||
controls, is controlled by, or is under common control with You. For
|
||||
purposes of this definition, "control" means (a) the power, direct
|
||||
or indirect, to cause the direction or management of such entity,
|
||||
whether by contract or otherwise, or (b) ownership of more than
|
||||
fifty percent (50%) of the outstanding shares or beneficial
|
||||
ownership of such entity.
|
||||
|
||||
2. License Grants and Conditions
|
||||
--------------------------------
|
||||
|
||||
2.1. Grants
|
||||
|
||||
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||
non-exclusive license:
|
||||
|
||||
(a) under intellectual property rights (other than patent or trademark)
|
||||
Licensable by such Contributor to use, reproduce, make available,
|
||||
modify, display, perform, distribute, and otherwise exploit its
|
||||
Contributions, either on an unmodified basis, with Modifications, or
|
||||
as part of a Larger Work; and
|
||||
|
||||
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
||||
for sale, have made, import, and otherwise transfer either its
|
||||
Contributions or its Contributor Version.
|
||||
|
||||
2.2. Effective Date
|
||||
|
||||
The licenses granted in Section 2.1 with respect to any Contribution
|
||||
become effective for each Contribution on the date the Contributor first
|
||||
distributes such Contribution.
|
||||
|
||||
2.3. Limitations on Grant Scope
|
||||
|
||||
The licenses granted in this Section 2 are the only rights granted under
|
||||
this License. No additional rights or licenses will be implied from the
|
||||
distribution or licensing of Covered Software under this License.
|
||||
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
||||
Contributor:
|
||||
|
||||
(a) for any code that a Contributor has removed from Covered Software;
|
||||
or
|
||||
|
||||
(b) for infringements caused by: (i) Your and any other third party's
|
||||
modifications of Covered Software, or (ii) the combination of its
|
||||
Contributions with other software (except as part of its Contributor
|
||||
Version); or
|
||||
|
||||
(c) under Patent Claims infringed by Covered Software in the absence of
|
||||
its Contributions.
|
||||
|
||||
This License does not grant any rights in the trademarks, service marks,
|
||||
or logos of any Contributor (except as may be necessary to comply with
|
||||
the notice requirements in Section 3.4).
|
||||
|
||||
2.4. Subsequent Licenses
|
||||
|
||||
No Contributor makes additional grants as a result of Your choice to
|
||||
distribute the Covered Software under a subsequent version of this
|
||||
License (see Section 10.2) or under the terms of a Secondary License (if
|
||||
permitted under the terms of Section 3.3).
|
||||
|
||||
2.5. Representation
|
||||
|
||||
Each Contributor represents that the Contributor believes its
|
||||
Contributions are its original creation(s) or it has sufficient rights
|
||||
to grant the rights to its Contributions conveyed by this License.
|
||||
|
||||
2.6. Fair Use
|
||||
|
||||
This License is not intended to limit any rights You have under
|
||||
applicable copyright doctrines of fair use, fair dealing, or other
|
||||
equivalents.
|
||||
|
||||
2.7. Conditions
|
||||
|
||||
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
||||
in Section 2.1.
|
||||
|
||||
3. Responsibilities
|
||||
-------------------
|
||||
|
||||
3.1. Distribution of Source Form
|
||||
|
||||
All distribution of Covered Software in Source Code Form, including any
|
||||
Modifications that You create or to which You contribute, must be under
|
||||
the terms of this License. You must inform recipients that the Source
|
||||
Code Form of the Covered Software is governed by the terms of this
|
||||
License, and how they can obtain a copy of this License. You may not
|
||||
attempt to alter or restrict the recipients' rights in the Source Code
|
||||
Form.
|
||||
|
||||
3.2. Distribution of Executable Form
|
||||
|
||||
If You distribute Covered Software in Executable Form then:
|
||||
|
||||
(a) such Covered Software must also be made available in Source Code
|
||||
Form, as described in Section 3.1, and You must inform recipients of
|
||||
the Executable Form how they can obtain a copy of such Source Code
|
||||
Form by reasonable means in a timely manner, at a charge no more
|
||||
than the cost of distribution to the recipient; and
|
||||
|
||||
(b) You may distribute such Executable Form under the terms of this
|
||||
License, or sublicense it under different terms, provided that the
|
||||
license for the Executable Form does not attempt to limit or alter
|
||||
the recipients' rights in the Source Code Form under this License.
|
||||
|
||||
3.3. Distribution of a Larger Work
|
||||
|
||||
You may create and distribute a Larger Work under terms of Your choice,
|
||||
provided that You also comply with the requirements of this License for
|
||||
the Covered Software. If the Larger Work is a combination of Covered
|
||||
Software with a work governed by one or more Secondary Licenses, and the
|
||||
Covered Software is not Incompatible With Secondary Licenses, this
|
||||
License permits You to additionally distribute such Covered Software
|
||||
under the terms of such Secondary License(s), so that the recipient of
|
||||
the Larger Work may, at their option, further distribute the Covered
|
||||
Software under the terms of either this License or such Secondary
|
||||
License(s).
|
||||
|
||||
3.4. Notices
|
||||
|
||||
You may not remove or alter the substance of any license notices
|
||||
(including copyright notices, patent notices, disclaimers of warranty,
|
||||
or limitations of liability) contained within the Source Code Form of
|
||||
the Covered Software, except that You may alter any license notices to
|
||||
the extent required to remedy known factual inaccuracies.
|
||||
|
||||
3.5. Application of Additional Terms
|
||||
|
||||
You may choose to offer, and to charge a fee for, warranty, support,
|
||||
indemnity or liability obligations to one or more recipients of Covered
|
||||
Software. However, You may do so only on Your own behalf, and not on
|
||||
behalf of any Contributor. You must make it absolutely clear that any
|
||||
such warranty, support, indemnity, or liability obligation is offered by
|
||||
You alone, and You hereby agree to indemnify every Contributor for any
|
||||
liability incurred by such Contributor as a result of warranty, support,
|
||||
indemnity or liability terms You offer. You may include additional
|
||||
disclaimers of warranty and limitations of liability specific to any
|
||||
jurisdiction.
|
||||
|
||||
4. Inability to Comply Due to Statute or Regulation
|
||||
---------------------------------------------------
|
||||
|
||||
If it is impossible for You to comply with any of the terms of this
|
||||
License with respect to some or all of the Covered Software due to
|
||||
statute, judicial order, or regulation then You must: (a) comply with
|
||||
the terms of this License to the maximum extent possible; and (b)
|
||||
describe the limitations and the code they affect. Such description must
|
||||
be placed in a text file included with all distributions of the Covered
|
||||
Software under this License. Except to the extent prohibited by statute
|
||||
or regulation, such description must be sufficiently detailed for a
|
||||
recipient of ordinary skill to be able to understand it.
|
||||
|
||||
5. Termination
|
||||
--------------
|
||||
|
||||
5.1. The rights granted under this License will terminate automatically
|
||||
if You fail to comply with any of its terms. However, if You become
|
||||
compliant, then the rights granted under this License from a particular
|
||||
Contributor are reinstated (a) provisionally, unless and until such
|
||||
Contributor explicitly and finally terminates Your grants, and (b) on an
|
||||
ongoing basis, if such Contributor fails to notify You of the
|
||||
non-compliance by some reasonable means prior to 60 days after You have
|
||||
come back into compliance. Moreover, Your grants from a particular
|
||||
Contributor are reinstated on an ongoing basis if such Contributor
|
||||
notifies You of the non-compliance by some reasonable means, this is the
|
||||
first time You have received notice of non-compliance with this License
|
||||
from such Contributor, and You become compliant prior to 30 days after
|
||||
Your receipt of the notice.
|
||||
|
||||
5.2. If You initiate litigation against any entity by asserting a patent
|
||||
infringement claim (excluding declaratory judgment actions,
|
||||
counter-claims, and cross-claims) alleging that a Contributor Version
|
||||
directly or indirectly infringes any patent, then the rights granted to
|
||||
You by any and all Contributors for the Covered Software under Section
|
||||
2.1 of this License shall terminate.
|
||||
|
||||
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
||||
end user license agreements (excluding distributors and resellers) which
|
||||
have been validly granted by You or Your distributors under this License
|
||||
prior to termination shall survive termination.
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 6. Disclaimer of Warranty *
|
||||
* ------------------------- *
|
||||
* *
|
||||
* Covered Software is provided under this License on an "as is" *
|
||||
* basis, without warranty of any kind, either expressed, implied, or *
|
||||
* statutory, including, without limitation, warranties that the *
|
||||
* Covered Software is free of defects, merchantable, fit for a *
|
||||
* particular purpose or non-infringing. The entire risk as to the *
|
||||
* quality and performance of the Covered Software is with You. *
|
||||
* Should any Covered Software prove defective in any respect, You *
|
||||
* (not any Contributor) assume the cost of any necessary servicing, *
|
||||
* repair, or correction. This disclaimer of warranty constitutes an *
|
||||
* essential part of this License. No use of any Covered Software is *
|
||||
* authorized under this License except under this disclaimer. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 7. Limitation of Liability *
|
||||
* -------------------------- *
|
||||
* *
|
||||
* Under no circumstances and under no legal theory, whether tort *
|
||||
* (including negligence), contract, or otherwise, shall any *
|
||||
* Contributor, or anyone who distributes Covered Software as *
|
||||
* permitted above, be liable to You for any direct, indirect, *
|
||||
* special, incidental, or consequential damages of any character *
|
||||
* including, without limitation, damages for lost profits, loss of *
|
||||
* goodwill, work stoppage, computer failure or malfunction, or any *
|
||||
* and all other commercial damages or losses, even if such party *
|
||||
* shall have been informed of the possibility of such damages. This *
|
||||
* limitation of liability shall not apply to liability for death or *
|
||||
* personal injury resulting from such party's negligence to the *
|
||||
* extent applicable law prohibits such limitation. Some *
|
||||
* jurisdictions do not allow the exclusion or limitation of *
|
||||
* incidental or consequential damages, so this exclusion and *
|
||||
* limitation may not apply to You. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
8. Litigation
|
||||
-------------
|
||||
|
||||
Any litigation relating to this License may be brought only in the
|
||||
courts of a jurisdiction where the defendant maintains its principal
|
||||
place of business and such litigation shall be governed by laws of that
|
||||
jurisdiction, without reference to its conflict-of-law provisions.
|
||||
Nothing in this Section shall prevent a party's ability to bring
|
||||
cross-claims or counter-claims.
|
||||
|
||||
9. Miscellaneous
|
||||
----------------
|
||||
|
||||
This License represents the complete agreement concerning the subject
|
||||
matter hereof. If any provision of this License is held to be
|
||||
unenforceable, such provision shall be reformed only to the extent
|
||||
necessary to make it enforceable. Any law or regulation which provides
|
||||
that the language of a contract shall be construed against the drafter
|
||||
shall not be used to construe this License against a Contributor.
|
||||
|
||||
10. Versions of the License
|
||||
---------------------------
|
||||
|
||||
10.1. New Versions
|
||||
|
||||
Mozilla Foundation is the license steward. Except as provided in Section
|
||||
10.3, no one other than the license steward has the right to modify or
|
||||
publish new versions of this License. Each version will be given a
|
||||
distinguishing version number.
|
||||
|
||||
10.2. Effect of New Versions
|
||||
|
||||
You may distribute the Covered Software under the terms of the version
|
||||
of the License under which You originally received the Covered Software,
|
||||
or under the terms of any subsequent version published by the license
|
||||
steward.
|
||||
|
||||
10.3. Modified Versions
|
||||
|
||||
If you create software not governed by this License, and you want to
|
||||
create a new license for such software, you may create and use a
|
||||
modified version of this License if you rename the license and remove
|
||||
any references to the name of the license steward (except to note that
|
||||
such modified license differs from this License).
|
||||
|
||||
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
||||
Licenses
|
||||
|
||||
If You choose to distribute Source Code Form that is Incompatible With
|
||||
Secondary Licenses under the terms of this version of the License, the
|
||||
notice described in Exhibit B of this License must be attached.
|
||||
|
||||
Exhibit A - Source Code Form License Notice
|
||||
-------------------------------------------
|
||||
|
||||
This Source Code Form is subject to the terms of the Mozilla Public
|
||||
License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
If it is not possible or desirable to put the notice in a particular
|
||||
file, then You may include the notice in a location (such as a LICENSE
|
||||
file in a relevant directory) where a recipient would be likely to look
|
||||
for such a notice.
|
||||
|
||||
You may add additional accurate notices of copyright ownership.
|
||||
|
||||
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
||||
---------------------------------------------------------
|
||||
|
||||
This Source Code Form is "Incompatible With Secondary Licenses", as
|
||||
defined by the Mozilla Public License, v. 2.0.
|
||||
|
|
@ -0,0 +1,520 @@
|
|||
# Go-MySQL-Driver
|
||||
|
||||
A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package
|
||||
|
||||

|
||||
|
||||
---------------------------------------
|
||||
* [Features](#features)
|
||||
* [Requirements](#requirements)
|
||||
* [Installation](#installation)
|
||||
* [Usage](#usage)
|
||||
* [DSN (Data Source Name)](#dsn-data-source-name)
|
||||
* [Password](#password)
|
||||
* [Protocol](#protocol)
|
||||
* [Address](#address)
|
||||
* [Parameters](#parameters)
|
||||
* [Examples](#examples)
|
||||
* [Connection pool and timeouts](#connection-pool-and-timeouts)
|
||||
* [context.Context Support](#contextcontext-support)
|
||||
* [ColumnType Support](#columntype-support)
|
||||
* [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support)
|
||||
* [time.Time support](#timetime-support)
|
||||
* [Unicode support](#unicode-support)
|
||||
* [Testing / Development](#testing--development)
|
||||
* [License](#license)
|
||||
|
||||
---------------------------------------
|
||||
|
||||
## Features
|
||||
* Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance")
|
||||
* Native Go implementation. No C-bindings, just pure Go
|
||||
* Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc)
|
||||
* Automatic handling of broken connections
|
||||
* Automatic Connection Pooling *(by database/sql package)*
|
||||
* Supports queries larger than 16MB
|
||||
* Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support.
|
||||
* Intelligent `LONG DATA` handling in prepared statements
|
||||
* Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support
|
||||
* Optional `time.Time` parsing
|
||||
* Optional placeholder interpolation
|
||||
|
||||
## Requirements
|
||||
* Go 1.10 or higher. We aim to support the 3 latest versions of Go.
|
||||
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
|
||||
|
||||
---------------------------------------
|
||||
|
||||
## Installation
|
||||
Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell:
|
||||
```bash
|
||||
$ go get -u github.com/go-sql-driver/mysql
|
||||
```
|
||||
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
|
||||
|
||||
## Usage
|
||||
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then.
|
||||
|
||||
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
|
||||
|
||||
```go
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
// ...
|
||||
|
||||
db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// See "Important settings" section.
|
||||
db.SetConnMaxLifetime(time.Minute * 3)
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(10)
|
||||
```
|
||||
|
||||
[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples").
|
||||
|
||||
### Important settings
|
||||
|
||||
`db.SetConnMaxLifetime()` is required to ensure connections are closed by the driver safely before connection is closed by MySQL server, OS, or other middlewares. Since some middlewares close idle connections by 5 minutes, we recommend timeout shorter than 5 minutes. This setting helps load balancing and changing system variables too.
|
||||
|
||||
`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server.
|
||||
|
||||
`db.SetMaxIdleConns()` is recommended to be set same to (or greater than) `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed very frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15.
|
||||
|
||||
|
||||
### DSN (Data Source Name)
|
||||
|
||||
The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets):
|
||||
```
|
||||
[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]
|
||||
```
|
||||
|
||||
A DSN in its fullest form:
|
||||
```
|
||||
username:password@protocol(address)/dbname?param=value
|
||||
```
|
||||
|
||||
Except for the databasename, all values are optional. So the minimal DSN is:
|
||||
```
|
||||
/dbname
|
||||
```
|
||||
|
||||
If you do not want to preselect a database, leave `dbname` empty:
|
||||
```
|
||||
/
|
||||
```
|
||||
This has the same effect as an empty DSN string:
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct.
|
||||
|
||||
#### Password
|
||||
Passwords can consist of any character. Escaping is **not** necessary.
|
||||
|
||||
#### Protocol
|
||||
See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available.
|
||||
In general you should use an Unix domain socket if available and TCP otherwise for best performance.
|
||||
|
||||
#### Address
|
||||
For TCP and UDP networks, addresses have the form `host[:port]`.
|
||||
If `port` is omitted, the default port will be used.
|
||||
If `host` is a literal IPv6 address, it must be enclosed in square brackets.
|
||||
The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](https://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form.
|
||||
|
||||
For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`.
|
||||
|
||||
#### Parameters
|
||||
*Parameters are case-sensitive!*
|
||||
|
||||
Notice that any of `true`, `TRUE`, `True` or `1` is accepted to stand for a true boolean value. Not surprisingly, false can be specified as any of: `false`, `FALSE`, `False` or `0`.
|
||||
|
||||
##### `allowAllFiles`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files.
|
||||
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
|
||||
|
||||
##### `allowCleartextPasswords`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
|
||||
|
||||
##### `allowNativePasswords`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: true
|
||||
```
|
||||
`allowNativePasswords=false` disallows the usage of MySQL native password method.
|
||||
|
||||
##### `allowOldPasswords`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
`allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
|
||||
|
||||
##### `charset`
|
||||
|
||||
```
|
||||
Type: string
|
||||
Valid Values: <name>
|
||||
Default: none
|
||||
```
|
||||
|
||||
Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
|
||||
|
||||
Usage of the `charset` parameter is discouraged because it issues additional queries to the server.
|
||||
Unless you need the fallback behavior, please use `collation` instead.
|
||||
|
||||
##### `checkConnLiveness`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: true
|
||||
```
|
||||
|
||||
On supported platforms connections retrieved from the connection pool are checked for liveness before using them. If the check fails, the respective connection is marked as bad and the query retried with another connection.
|
||||
`checkConnLiveness=false` disables this liveness check of connections.
|
||||
|
||||
##### `collation`
|
||||
|
||||
```
|
||||
Type: string
|
||||
Valid Values: <name>
|
||||
Default: utf8mb4_general_ci
|
||||
```
|
||||
|
||||
Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail.
|
||||
|
||||
A list of valid charsets for a server is retrievable with `SHOW COLLATION`.
|
||||
|
||||
The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL.
|
||||
|
||||
Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)).
|
||||
|
||||
|
||||
##### `clientFoundRows`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
`clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
|
||||
|
||||
##### `columnsWithAlias`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
When `columnsWithAlias` is true, calls to `sql.Rows.Columns()` will return the table alias and the column name separated by a dot. For example:
|
||||
|
||||
```
|
||||
SELECT u.id FROM users as u
|
||||
```
|
||||
|
||||
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
|
||||
|
||||
##### `interpolateParams`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`.
|
||||
|
||||
*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!*
|
||||
|
||||
##### `loc`
|
||||
|
||||
```
|
||||
Type: string
|
||||
Valid Values: <escaped name>
|
||||
Default: UTC
|
||||
```
|
||||
|
||||
Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details.
|
||||
|
||||
Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter.
|
||||
|
||||
Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
|
||||
|
||||
##### `maxAllowedPacket`
|
||||
```
|
||||
Type: decimal number
|
||||
Default: 4194304
|
||||
```
|
||||
|
||||
Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*.
|
||||
|
||||
##### `multiStatements`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded.
|
||||
|
||||
When `multiStatements` is used, `?` parameters must only be used in the first statement.
|
||||
|
||||
##### `parseTime`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
`parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
|
||||
The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`.
|
||||
|
||||
|
||||
##### `readTimeout`
|
||||
|
||||
```
|
||||
Type: duration
|
||||
Default: 0
|
||||
```
|
||||
|
||||
I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
|
||||
##### `rejectReadOnly`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
```
|
||||
|
||||
|
||||
`rejectReadOnly=true` causes the driver to reject read-only connections. This
|
||||
is for a possible race condition during an automatic failover, where the mysql
|
||||
client gets connected to a read-only replica after the failover.
|
||||
|
||||
Note that this should be a fairly rare case, as an automatic failover normally
|
||||
happens when the primary is down, and the race condition shouldn't happen
|
||||
unless it comes back up online as soon as the failover is kicked off. On the
|
||||
other hand, when this happens, a MySQL application can get stuck on a
|
||||
read-only connection until restarted. It is however fairly easy to reproduce,
|
||||
for example, using a manual failover on AWS Aurora's MySQL-compatible cluster.
|
||||
|
||||
If you are not relying on read-only transactions to reject writes that aren't
|
||||
supposed to happen, setting this on some MySQL providers (such as AWS Aurora)
|
||||
is safer for failovers.
|
||||
|
||||
Note that ERROR 1290 can be returned for a `read-only` server and this option will
|
||||
cause a retry for that error. However the same error number is used for some
|
||||
other cases. You should ensure your application will never cause an ERROR 1290
|
||||
except for `read-only` mode when enabling this option.
|
||||
|
||||
|
||||
##### `serverPubKey`
|
||||
|
||||
```
|
||||
Type: string
|
||||
Valid Values: <name>
|
||||
Default: none
|
||||
```
|
||||
|
||||
Server public keys can be registered with [`mysql.RegisterServerPubKey`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterServerPubKey), which can then be used by the assigned name in the DSN.
|
||||
Public keys are used to transmit encrypted data, e.g. for authentication.
|
||||
If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required.
|
||||
|
||||
|
||||
##### `timeout`
|
||||
|
||||
```
|
||||
Type: duration
|
||||
Default: OS default
|
||||
```
|
||||
|
||||
Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
|
||||
|
||||
##### `tls`
|
||||
|
||||
```
|
||||
Type: bool / string
|
||||
Valid Values: true, false, skip-verify, preferred, <name>
|
||||
Default: false
|
||||
```
|
||||
|
||||
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
|
||||
|
||||
|
||||
##### `writeTimeout`
|
||||
|
||||
```
|
||||
Type: duration
|
||||
Default: 0
|
||||
```
|
||||
|
||||
I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
|
||||
|
||||
##### System Variables
|
||||
|
||||
Any other parameters are interpreted as system variables:
|
||||
* `<boolean_var>=<value>`: `SET <boolean_var>=<value>`
|
||||
* `<enum_var>=<value>`: `SET <enum_var>=<value>`
|
||||
* `<string_var>=%27<value>%27`: `SET <string_var>='<value>'`
|
||||
|
||||
Rules:
|
||||
* The values for string variables must be quoted with `'`.
|
||||
* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!
|
||||
(which implies values of string variables must be wrapped with `%27`).
|
||||
|
||||
Examples:
|
||||
* `autocommit=1`: `SET autocommit=1`
|
||||
* [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'`
|
||||
* [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'`
|
||||
|
||||
|
||||
#### Examples
|
||||
```
|
||||
user@unix(/path/to/socket)/dbname
|
||||
```
|
||||
|
||||
```
|
||||
root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local
|
||||
```
|
||||
|
||||
```
|
||||
user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true
|
||||
```
|
||||
|
||||
Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html):
|
||||
```
|
||||
user:password@/dbname?sql_mode=TRADITIONAL
|
||||
```
|
||||
|
||||
TCP via IPv6:
|
||||
```
|
||||
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci
|
||||
```
|
||||
|
||||
TCP on a remote host, e.g. Amazon RDS:
|
||||
```
|
||||
id:password@tcp(your-amazonaws-uri.com:3306)/dbname
|
||||
```
|
||||
|
||||
Google Cloud SQL on App Engine:
|
||||
```
|
||||
user:password@unix(/cloudsql/project-id:region-name:instance-name)/dbname
|
||||
```
|
||||
|
||||
TCP using default port (3306) on localhost:
|
||||
```
|
||||
user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped
|
||||
```
|
||||
|
||||
Use the default protocol (tcp) and host (localhost:3306):
|
||||
```
|
||||
user:password@/dbname
|
||||
```
|
||||
|
||||
No Database preselected:
|
||||
```
|
||||
user:password@/
|
||||
```
|
||||
|
||||
|
||||
### Connection pool and timeouts
|
||||
The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively.
|
||||
|
||||
## `ColumnType` Support
|
||||
This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported.
|
||||
|
||||
## `context.Context` Support
|
||||
Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.
|
||||
See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details.
|
||||
|
||||
|
||||
### `LOAD DATA LOCAL INFILE` support
|
||||
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
|
||||
```go
|
||||
import "github.com/go-sql-driver/mysql"
|
||||
```
|
||||
|
||||
Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
|
||||
|
||||
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
|
||||
|
||||
See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details.
|
||||
|
||||
|
||||
### `time.Time` support
|
||||
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program.
|
||||
|
||||
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
|
||||
|
||||
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
|
||||
|
||||
|
||||
### Unicode support
|
||||
Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default.
|
||||
|
||||
Other collations / charsets can be set using the [`collation`](#collation) DSN parameter.
|
||||
|
||||
Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default.
|
||||
|
||||
See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support.
|
||||
|
||||
## Testing / Development
|
||||
To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.
|
||||
|
||||
Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated.
|
||||
If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls).
|
||||
|
||||
See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
---------------------------------------
|
||||
|
||||
## License
|
||||
Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE)
|
||||
|
||||
Mozilla summarizes the license scope as follows:
|
||||
> MPL: The copyleft applies to any files containing MPLed code.
|
||||
|
||||
|
||||
That means:
|
||||
* You can **use** the **unchanged** source code both in private and commercially.
|
||||
* When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0).
|
||||
* You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**.
|
||||
|
||||
Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license.
|
||||
|
||||
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE).
|
||||
|
||||

|
||||
|
|
@ -0,0 +1,425 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// server pub keys registry
|
||||
var (
|
||||
serverPubKeyLock sync.RWMutex
|
||||
serverPubKeyRegistry map[string]*rsa.PublicKey
|
||||
)
|
||||
|
||||
// RegisterServerPubKey registers a server RSA public key which can be used to
|
||||
// send data in a secure manner to the server without receiving the public key
|
||||
// in a potentially insecure way from the server first.
|
||||
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
|
||||
//
|
||||
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
|
||||
// after registering it and may not be modified.
|
||||
//
|
||||
// data, err := ioutil.ReadFile("mykey.pem")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// block, _ := pem.Decode(data)
|
||||
// if block == nil || block.Type != "PUBLIC KEY" {
|
||||
// log.Fatal("failed to decode PEM block containing public key")
|
||||
// }
|
||||
//
|
||||
// pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
|
||||
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
|
||||
// } else {
|
||||
// log.Fatal("not a RSA public key")
|
||||
// }
|
||||
//
|
||||
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
|
||||
serverPubKeyLock.Lock()
|
||||
if serverPubKeyRegistry == nil {
|
||||
serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
|
||||
}
|
||||
|
||||
serverPubKeyRegistry[name] = pubKey
|
||||
serverPubKeyLock.Unlock()
|
||||
}
|
||||
|
||||
// DeregisterServerPubKey removes the public key registered with the given name.
|
||||
func DeregisterServerPubKey(name string) {
|
||||
serverPubKeyLock.Lock()
|
||||
if serverPubKeyRegistry != nil {
|
||||
delete(serverPubKeyRegistry, name)
|
||||
}
|
||||
serverPubKeyLock.Unlock()
|
||||
}
|
||||
|
||||
func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
|
||||
serverPubKeyLock.RLock()
|
||||
if v, ok := serverPubKeyRegistry[name]; ok {
|
||||
pubKey = v
|
||||
}
|
||||
serverPubKeyLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Hash password using pre 4.1 (old password) method
|
||||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
|
||||
type myRnd struct {
|
||||
seed1, seed2 uint32
|
||||
}
|
||||
|
||||
const myRndMaxVal = 0x3FFFFFFF
|
||||
|
||||
// Pseudo random number generator
|
||||
func newMyRnd(seed1, seed2 uint32) *myRnd {
|
||||
return &myRnd{
|
||||
seed1: seed1 % myRndMaxVal,
|
||||
seed2: seed2 % myRndMaxVal,
|
||||
}
|
||||
}
|
||||
|
||||
// Tested to be equivalent to MariaDB's floating point variant
|
||||
// http://play.golang.org/p/QHvhd4qved
|
||||
// http://play.golang.org/p/RG0q4ElWDx
|
||||
func (r *myRnd) NextByte() byte {
|
||||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
|
||||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
|
||||
|
||||
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
|
||||
}
|
||||
|
||||
// Generate binary hash from byte string using insecure pre 4.1 method
|
||||
func pwHash(password []byte) (result [2]uint32) {
|
||||
var add uint32 = 7
|
||||
var tmp uint32
|
||||
|
||||
result[0] = 1345345333
|
||||
result[1] = 0x12345671
|
||||
|
||||
for _, c := range password {
|
||||
// skip spaces and tabs in password
|
||||
if c == ' ' || c == '\t' {
|
||||
continue
|
||||
}
|
||||
|
||||
tmp = uint32(c)
|
||||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
|
||||
result[1] += (result[1] << 8) ^ result[0]
|
||||
add += tmp
|
||||
}
|
||||
|
||||
// Remove sign bit (1<<31)-1)
|
||||
result[0] &= 0x7FFFFFFF
|
||||
result[1] &= 0x7FFFFFFF
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Hash password using insecure pre 4.1 method
|
||||
func scrambleOldPassword(scramble []byte, password string) []byte {
|
||||
scramble = scramble[:8]
|
||||
|
||||
hashPw := pwHash([]byte(password))
|
||||
hashSc := pwHash(scramble)
|
||||
|
||||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
|
||||
|
||||
var out [8]byte
|
||||
for i := range out {
|
||||
out[i] = r.NextByte() + 64
|
||||
}
|
||||
|
||||
mask := r.NextByte()
|
||||
for i := range out {
|
||||
out[i] ^= mask
|
||||
}
|
||||
|
||||
return out[:]
|
||||
}
|
||||
|
||||
// Hash password using 4.1+ method (SHA1)
|
||||
func scramblePassword(scramble []byte, password string) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stage1Hash = SHA1(password)
|
||||
crypt := sha1.New()
|
||||
crypt.Write([]byte(password))
|
||||
stage1 := crypt.Sum(nil)
|
||||
|
||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
||||
// inner Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(stage1)
|
||||
hash := crypt.Sum(nil)
|
||||
|
||||
// outer Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(scramble)
|
||||
crypt.Write(hash)
|
||||
scramble = crypt.Sum(nil)
|
||||
|
||||
// token = scrambleHash XOR stage1Hash
|
||||
for i := range scramble {
|
||||
scramble[i] ^= stage1[i]
|
||||
}
|
||||
return scramble
|
||||
}
|
||||
|
||||
// Hash password using MySQL 8+ method (SHA256)
|
||||
func scrambleSHA256Password(scramble []byte, password string) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
|
||||
|
||||
crypt := sha256.New()
|
||||
crypt.Write([]byte(password))
|
||||
message1 := crypt.Sum(nil)
|
||||
|
||||
crypt.Reset()
|
||||
crypt.Write(message1)
|
||||
message1Hash := crypt.Sum(nil)
|
||||
|
||||
crypt.Reset()
|
||||
crypt.Write(message1Hash)
|
||||
crypt.Write(scramble)
|
||||
message2 := crypt.Sum(nil)
|
||||
|
||||
for i := range message1 {
|
||||
message1[i] ^= message2[i]
|
||||
}
|
||||
|
||||
return message1
|
||||
}
|
||||
|
||||
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
|
||||
plain := make([]byte, len(password)+1)
|
||||
copy(plain, password)
|
||||
for i := range plain {
|
||||
j := i % len(seed)
|
||||
plain[i] ^= seed[j]
|
||||
}
|
||||
sha1 := sha1.New()
|
||||
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
|
||||
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.writeAuthSwitchPacket(enc)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
|
||||
switch plugin {
|
||||
case "caching_sha2_password":
|
||||
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
|
||||
return authResp, nil
|
||||
|
||||
case "mysql_old_password":
|
||||
if !mc.cfg.AllowOldPasswords {
|
||||
return nil, ErrOldPassword
|
||||
}
|
||||
if len(mc.cfg.Passwd) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Note: there are edge cases where this should work but doesn't;
|
||||
// this is currently "wontfix":
|
||||
// https://github.com/go-sql-driver/mysql/issues/184
|
||||
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
|
||||
return authResp, nil
|
||||
|
||||
case "mysql_clear_password":
|
||||
if !mc.cfg.AllowCleartextPasswords {
|
||||
return nil, ErrCleartextPassword
|
||||
}
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
|
||||
return append([]byte(mc.cfg.Passwd), 0), nil
|
||||
|
||||
case "mysql_native_password":
|
||||
if !mc.cfg.AllowNativePasswords {
|
||||
return nil, ErrNativePassword
|
||||
}
|
||||
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
||||
// Native password authentication only need and will need 20-byte challenge.
|
||||
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
|
||||
return authResp, nil
|
||||
|
||||
case "sha256_password":
|
||||
if len(mc.cfg.Passwd) == 0 {
|
||||
return []byte{0}, nil
|
||||
}
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
return append([]byte(mc.cfg.Passwd), 0), nil
|
||||
}
|
||||
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
return []byte{1}, nil
|
||||
}
|
||||
|
||||
// encrypted password
|
||||
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
|
||||
return enc, err
|
||||
|
||||
default:
|
||||
errLog.Print("unknown auth plugin:", plugin)
|
||||
return nil, ErrUnknownPlugin
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
// Read Result Packet
|
||||
authData, newPlugin, err := mc.readAuthResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// handle auth plugin switch, if requested
|
||||
if newPlugin != "" {
|
||||
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
|
||||
// sent and we have to keep using the cipher sent in the init packet.
|
||||
if authData == nil {
|
||||
authData = oldAuthData
|
||||
} else {
|
||||
// copy data from read buffer to owned slice
|
||||
copy(oldAuthData, authData)
|
||||
}
|
||||
|
||||
plugin = newPlugin
|
||||
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read Result Packet
|
||||
authData, newPlugin, err = mc.readAuthResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not allow to change the auth plugin more than once
|
||||
if newPlugin != "" {
|
||||
return ErrMalformPkt
|
||||
}
|
||||
}
|
||||
|
||||
switch plugin {
|
||||
|
||||
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
|
||||
case "caching_sha2_password":
|
||||
switch len(authData) {
|
||||
case 0:
|
||||
return nil // auth successful
|
||||
case 1:
|
||||
switch authData[0] {
|
||||
case cachingSha2PasswordFastAuthSuccess:
|
||||
if err = mc.readResultOK(); err == nil {
|
||||
return nil // auth successful
|
||||
}
|
||||
|
||||
case cachingSha2PasswordPerformFullAuthentication:
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
data, err := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[4] = cachingSha2PasswordRequestPublicKey
|
||||
mc.writePacket(data)
|
||||
|
||||
// parse public key
|
||||
if data, err = mc.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
block, rest := pem.Decode(data[1:])
|
||||
if block == nil {
|
||||
return fmt.Errorf("No Pem data found, data: %s", rest)
|
||||
}
|
||||
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pubKey = pkix.(*rsa.PublicKey)
|
||||
}
|
||||
|
||||
// send encrypted password
|
||||
err = mc.sendEncryptedPassword(oldAuthData, pubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return mc.readResultOK()
|
||||
|
||||
default:
|
||||
return ErrMalformPkt
|
||||
}
|
||||
default:
|
||||
return ErrMalformPkt
|
||||
}
|
||||
|
||||
case "sha256_password":
|
||||
switch len(authData) {
|
||||
case 0:
|
||||
return nil // auth successful
|
||||
default:
|
||||
block, _ := pem.Decode(authData)
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send encrypted password
|
||||
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
default:
|
||||
return nil // auth successful
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultBufSize = 4096
|
||||
const maxCachedBufSize = 256 * 1024
|
||||
|
||||
// A buffer which is used for both reading and writing.
|
||||
// This is possible since communication on each connection is synchronous.
|
||||
// In other words, we can't write and read simultaneously on the same connection.
|
||||
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
|
||||
// Also highly optimized for this particular use case.
|
||||
// This buffer is backed by two byte slices in a double-buffering scheme
|
||||
type buffer struct {
|
||||
buf []byte // buf is a byte buffer who's length and capacity are equal.
|
||||
nc net.Conn
|
||||
idx int
|
||||
length int
|
||||
timeout time.Duration
|
||||
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
|
||||
flipcnt uint // flipccnt is the current buffer counter for double-buffering
|
||||
}
|
||||
|
||||
// newBuffer allocates and returns a new buffer.
|
||||
func newBuffer(nc net.Conn) buffer {
|
||||
fg := make([]byte, defaultBufSize)
|
||||
return buffer{
|
||||
buf: fg,
|
||||
nc: nc,
|
||||
dbuf: [2][]byte{fg, nil},
|
||||
}
|
||||
}
|
||||
|
||||
// flip replaces the active buffer with the background buffer
|
||||
// this is a delayed flip that simply increases the buffer counter;
|
||||
// the actual flip will be performed the next time we call `buffer.fill`
|
||||
func (b *buffer) flip() {
|
||||
b.flipcnt += 1
|
||||
}
|
||||
|
||||
// fill reads into the buffer until at least _need_ bytes are in it
|
||||
func (b *buffer) fill(need int) error {
|
||||
n := b.length
|
||||
// fill data into its double-buffering target: if we've called
|
||||
// flip on this buffer, we'll be copying to the background buffer,
|
||||
// and then filling it with network data; otherwise we'll just move
|
||||
// the contents of the current buffer to the front before filling it
|
||||
dest := b.dbuf[b.flipcnt&1]
|
||||
|
||||
// grow buffer if necessary to fit the whole packet.
|
||||
if need > len(dest) {
|
||||
// Round up to the next multiple of the default size
|
||||
dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
|
||||
|
||||
// if the allocated buffer is not too large, move it to backing storage
|
||||
// to prevent extra allocations on applications that perform large reads
|
||||
if len(dest) <= maxCachedBufSize {
|
||||
b.dbuf[b.flipcnt&1] = dest
|
||||
}
|
||||
}
|
||||
|
||||
// if we're filling the fg buffer, move the existing data to the start of it.
|
||||
// if we're filling the bg buffer, copy over the data
|
||||
if n > 0 {
|
||||
copy(dest[:n], b.buf[b.idx:])
|
||||
}
|
||||
|
||||
b.buf = dest
|
||||
b.idx = 0
|
||||
|
||||
for {
|
||||
if b.timeout > 0 {
|
||||
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
nn, err := b.nc.Read(b.buf[n:])
|
||||
n += nn
|
||||
|
||||
switch err {
|
||||
case nil:
|
||||
if n < need {
|
||||
continue
|
||||
}
|
||||
b.length = n
|
||||
return nil
|
||||
|
||||
case io.EOF:
|
||||
if n >= need {
|
||||
b.length = n
|
||||
return nil
|
||||
}
|
||||
return io.ErrUnexpectedEOF
|
||||
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// returns next N bytes from buffer.
|
||||
// The returned slice is only guaranteed to be valid until the next read
|
||||
func (b *buffer) readNext(need int) ([]byte, error) {
|
||||
if b.length < need {
|
||||
// refill
|
||||
if err := b.fill(need); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
offset := b.idx
|
||||
b.idx += need
|
||||
b.length -= need
|
||||
return b.buf[offset:b.idx], nil
|
||||
}
|
||||
|
||||
// takeBuffer returns a buffer with the requested size.
|
||||
// If possible, a slice from the existing buffer is returned.
|
||||
// Otherwise a bigger buffer is made.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeBuffer(length int) ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
|
||||
// test (cheap) general case first
|
||||
if length <= cap(b.buf) {
|
||||
return b.buf[:length], nil
|
||||
}
|
||||
|
||||
if length < maxPacketSize {
|
||||
b.buf = make([]byte, length)
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
// buffer is larger than we want to store.
|
||||
return make([]byte, length), nil
|
||||
}
|
||||
|
||||
// takeSmallBuffer is shortcut which can be used if length is
|
||||
// known to be smaller than defaultBufSize.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
return b.buf[:length], nil
|
||||
}
|
||||
|
||||
// takeCompleteBuffer returns the complete existing buffer.
|
||||
// This can be used if the necessary buffer size is unknown.
|
||||
// cap and len of the returned buffer will be equal.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
// store stores buf, an updated buffer, if its suitable to do so.
|
||||
func (b *buffer) store(buf []byte) error {
|
||||
if b.length > 0 {
|
||||
return ErrBusyBuffer
|
||||
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
|
||||
b.buf = buf[:cap(buf)]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,265 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
const defaultCollation = "utf8mb4_general_ci"
|
||||
const binaryCollation = "binary"
|
||||
|
||||
// A list of available collations mapped to the internal ID.
|
||||
// To update this map use the following MySQL query:
|
||||
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID
|
||||
//
|
||||
// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255.
|
||||
//
|
||||
// ucs2, utf16, and utf32 can't be used for connection charset.
|
||||
// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset
|
||||
// They are commented out to reduce this map.
|
||||
var collations = map[string]byte{
|
||||
"big5_chinese_ci": 1,
|
||||
"latin2_czech_cs": 2,
|
||||
"dec8_swedish_ci": 3,
|
||||
"cp850_general_ci": 4,
|
||||
"latin1_german1_ci": 5,
|
||||
"hp8_english_ci": 6,
|
||||
"koi8r_general_ci": 7,
|
||||
"latin1_swedish_ci": 8,
|
||||
"latin2_general_ci": 9,
|
||||
"swe7_swedish_ci": 10,
|
||||
"ascii_general_ci": 11,
|
||||
"ujis_japanese_ci": 12,
|
||||
"sjis_japanese_ci": 13,
|
||||
"cp1251_bulgarian_ci": 14,
|
||||
"latin1_danish_ci": 15,
|
||||
"hebrew_general_ci": 16,
|
||||
"tis620_thai_ci": 18,
|
||||
"euckr_korean_ci": 19,
|
||||
"latin7_estonian_cs": 20,
|
||||
"latin2_hungarian_ci": 21,
|
||||
"koi8u_general_ci": 22,
|
||||
"cp1251_ukrainian_ci": 23,
|
||||
"gb2312_chinese_ci": 24,
|
||||
"greek_general_ci": 25,
|
||||
"cp1250_general_ci": 26,
|
||||
"latin2_croatian_ci": 27,
|
||||
"gbk_chinese_ci": 28,
|
||||
"cp1257_lithuanian_ci": 29,
|
||||
"latin5_turkish_ci": 30,
|
||||
"latin1_german2_ci": 31,
|
||||
"armscii8_general_ci": 32,
|
||||
"utf8_general_ci": 33,
|
||||
"cp1250_czech_cs": 34,
|
||||
//"ucs2_general_ci": 35,
|
||||
"cp866_general_ci": 36,
|
||||
"keybcs2_general_ci": 37,
|
||||
"macce_general_ci": 38,
|
||||
"macroman_general_ci": 39,
|
||||
"cp852_general_ci": 40,
|
||||
"latin7_general_ci": 41,
|
||||
"latin7_general_cs": 42,
|
||||
"macce_bin": 43,
|
||||
"cp1250_croatian_ci": 44,
|
||||
"utf8mb4_general_ci": 45,
|
||||
"utf8mb4_bin": 46,
|
||||
"latin1_bin": 47,
|
||||
"latin1_general_ci": 48,
|
||||
"latin1_general_cs": 49,
|
||||
"cp1251_bin": 50,
|
||||
"cp1251_general_ci": 51,
|
||||
"cp1251_general_cs": 52,
|
||||
"macroman_bin": 53,
|
||||
//"utf16_general_ci": 54,
|
||||
//"utf16_bin": 55,
|
||||
//"utf16le_general_ci": 56,
|
||||
"cp1256_general_ci": 57,
|
||||
"cp1257_bin": 58,
|
||||
"cp1257_general_ci": 59,
|
||||
//"utf32_general_ci": 60,
|
||||
//"utf32_bin": 61,
|
||||
//"utf16le_bin": 62,
|
||||
"binary": 63,
|
||||
"armscii8_bin": 64,
|
||||
"ascii_bin": 65,
|
||||
"cp1250_bin": 66,
|
||||
"cp1256_bin": 67,
|
||||
"cp866_bin": 68,
|
||||
"dec8_bin": 69,
|
||||
"greek_bin": 70,
|
||||
"hebrew_bin": 71,
|
||||
"hp8_bin": 72,
|
||||
"keybcs2_bin": 73,
|
||||
"koi8r_bin": 74,
|
||||
"koi8u_bin": 75,
|
||||
"utf8_tolower_ci": 76,
|
||||
"latin2_bin": 77,
|
||||
"latin5_bin": 78,
|
||||
"latin7_bin": 79,
|
||||
"cp850_bin": 80,
|
||||
"cp852_bin": 81,
|
||||
"swe7_bin": 82,
|
||||
"utf8_bin": 83,
|
||||
"big5_bin": 84,
|
||||
"euckr_bin": 85,
|
||||
"gb2312_bin": 86,
|
||||
"gbk_bin": 87,
|
||||
"sjis_bin": 88,
|
||||
"tis620_bin": 89,
|
||||
//"ucs2_bin": 90,
|
||||
"ujis_bin": 91,
|
||||
"geostd8_general_ci": 92,
|
||||
"geostd8_bin": 93,
|
||||
"latin1_spanish_ci": 94,
|
||||
"cp932_japanese_ci": 95,
|
||||
"cp932_bin": 96,
|
||||
"eucjpms_japanese_ci": 97,
|
||||
"eucjpms_bin": 98,
|
||||
"cp1250_polish_ci": 99,
|
||||
//"utf16_unicode_ci": 101,
|
||||
//"utf16_icelandic_ci": 102,
|
||||
//"utf16_latvian_ci": 103,
|
||||
//"utf16_romanian_ci": 104,
|
||||
//"utf16_slovenian_ci": 105,
|
||||
//"utf16_polish_ci": 106,
|
||||
//"utf16_estonian_ci": 107,
|
||||
//"utf16_spanish_ci": 108,
|
||||
//"utf16_swedish_ci": 109,
|
||||
//"utf16_turkish_ci": 110,
|
||||
//"utf16_czech_ci": 111,
|
||||
//"utf16_danish_ci": 112,
|
||||
//"utf16_lithuanian_ci": 113,
|
||||
//"utf16_slovak_ci": 114,
|
||||
//"utf16_spanish2_ci": 115,
|
||||
//"utf16_roman_ci": 116,
|
||||
//"utf16_persian_ci": 117,
|
||||
//"utf16_esperanto_ci": 118,
|
||||
//"utf16_hungarian_ci": 119,
|
||||
//"utf16_sinhala_ci": 120,
|
||||
//"utf16_german2_ci": 121,
|
||||
//"utf16_croatian_ci": 122,
|
||||
//"utf16_unicode_520_ci": 123,
|
||||
//"utf16_vietnamese_ci": 124,
|
||||
//"ucs2_unicode_ci": 128,
|
||||
//"ucs2_icelandic_ci": 129,
|
||||
//"ucs2_latvian_ci": 130,
|
||||
//"ucs2_romanian_ci": 131,
|
||||
//"ucs2_slovenian_ci": 132,
|
||||
//"ucs2_polish_ci": 133,
|
||||
//"ucs2_estonian_ci": 134,
|
||||
//"ucs2_spanish_ci": 135,
|
||||
//"ucs2_swedish_ci": 136,
|
||||
//"ucs2_turkish_ci": 137,
|
||||
//"ucs2_czech_ci": 138,
|
||||
//"ucs2_danish_ci": 139,
|
||||
//"ucs2_lithuanian_ci": 140,
|
||||
//"ucs2_slovak_ci": 141,
|
||||
//"ucs2_spanish2_ci": 142,
|
||||
//"ucs2_roman_ci": 143,
|
||||
//"ucs2_persian_ci": 144,
|
||||
//"ucs2_esperanto_ci": 145,
|
||||
//"ucs2_hungarian_ci": 146,
|
||||
//"ucs2_sinhala_ci": 147,
|
||||
//"ucs2_german2_ci": 148,
|
||||
//"ucs2_croatian_ci": 149,
|
||||
//"ucs2_unicode_520_ci": 150,
|
||||
//"ucs2_vietnamese_ci": 151,
|
||||
//"ucs2_general_mysql500_ci": 159,
|
||||
//"utf32_unicode_ci": 160,
|
||||
//"utf32_icelandic_ci": 161,
|
||||
//"utf32_latvian_ci": 162,
|
||||
//"utf32_romanian_ci": 163,
|
||||
//"utf32_slovenian_ci": 164,
|
||||
//"utf32_polish_ci": 165,
|
||||
//"utf32_estonian_ci": 166,
|
||||
//"utf32_spanish_ci": 167,
|
||||
//"utf32_swedish_ci": 168,
|
||||
//"utf32_turkish_ci": 169,
|
||||
//"utf32_czech_ci": 170,
|
||||
//"utf32_danish_ci": 171,
|
||||
//"utf32_lithuanian_ci": 172,
|
||||
//"utf32_slovak_ci": 173,
|
||||
//"utf32_spanish2_ci": 174,
|
||||
//"utf32_roman_ci": 175,
|
||||
//"utf32_persian_ci": 176,
|
||||
//"utf32_esperanto_ci": 177,
|
||||
//"utf32_hungarian_ci": 178,
|
||||
//"utf32_sinhala_ci": 179,
|
||||
//"utf32_german2_ci": 180,
|
||||
//"utf32_croatian_ci": 181,
|
||||
//"utf32_unicode_520_ci": 182,
|
||||
//"utf32_vietnamese_ci": 183,
|
||||
"utf8_unicode_ci": 192,
|
||||
"utf8_icelandic_ci": 193,
|
||||
"utf8_latvian_ci": 194,
|
||||
"utf8_romanian_ci": 195,
|
||||
"utf8_slovenian_ci": 196,
|
||||
"utf8_polish_ci": 197,
|
||||
"utf8_estonian_ci": 198,
|
||||
"utf8_spanish_ci": 199,
|
||||
"utf8_swedish_ci": 200,
|
||||
"utf8_turkish_ci": 201,
|
||||
"utf8_czech_ci": 202,
|
||||
"utf8_danish_ci": 203,
|
||||
"utf8_lithuanian_ci": 204,
|
||||
"utf8_slovak_ci": 205,
|
||||
"utf8_spanish2_ci": 206,
|
||||
"utf8_roman_ci": 207,
|
||||
"utf8_persian_ci": 208,
|
||||
"utf8_esperanto_ci": 209,
|
||||
"utf8_hungarian_ci": 210,
|
||||
"utf8_sinhala_ci": 211,
|
||||
"utf8_german2_ci": 212,
|
||||
"utf8_croatian_ci": 213,
|
||||
"utf8_unicode_520_ci": 214,
|
||||
"utf8_vietnamese_ci": 215,
|
||||
"utf8_general_mysql500_ci": 223,
|
||||
"utf8mb4_unicode_ci": 224,
|
||||
"utf8mb4_icelandic_ci": 225,
|
||||
"utf8mb4_latvian_ci": 226,
|
||||
"utf8mb4_romanian_ci": 227,
|
||||
"utf8mb4_slovenian_ci": 228,
|
||||
"utf8mb4_polish_ci": 229,
|
||||
"utf8mb4_estonian_ci": 230,
|
||||
"utf8mb4_spanish_ci": 231,
|
||||
"utf8mb4_swedish_ci": 232,
|
||||
"utf8mb4_turkish_ci": 233,
|
||||
"utf8mb4_czech_ci": 234,
|
||||
"utf8mb4_danish_ci": 235,
|
||||
"utf8mb4_lithuanian_ci": 236,
|
||||
"utf8mb4_slovak_ci": 237,
|
||||
"utf8mb4_spanish2_ci": 238,
|
||||
"utf8mb4_roman_ci": 239,
|
||||
"utf8mb4_persian_ci": 240,
|
||||
"utf8mb4_esperanto_ci": 241,
|
||||
"utf8mb4_hungarian_ci": 242,
|
||||
"utf8mb4_sinhala_ci": 243,
|
||||
"utf8mb4_german2_ci": 244,
|
||||
"utf8mb4_croatian_ci": 245,
|
||||
"utf8mb4_unicode_520_ci": 246,
|
||||
"utf8mb4_vietnamese_ci": 247,
|
||||
"gb18030_chinese_ci": 248,
|
||||
"gb18030_bin": 249,
|
||||
"gb18030_unicode_520_ci": 250,
|
||||
"utf8mb4_0900_ai_ci": 255,
|
||||
}
|
||||
|
||||
// A denylist of collations which is unsafe to interpolate parameters.
|
||||
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
|
||||
var unsafeCollations = map[string]bool{
|
||||
"big5_chinese_ci": true,
|
||||
"sjis_japanese_ci": true,
|
||||
"gbk_chinese_ci": true,
|
||||
"big5_bin": true,
|
||||
"gb2312_bin": true,
|
||||
"gbk_bin": true,
|
||||
"sjis_bin": true,
|
||||
"cp932_japanese_ci": true,
|
||||
"cp932_bin": true,
|
||||
"gb18030_chinese_ci": true,
|
||||
"gb18030_bin": true,
|
||||
"gb18030_unicode_520_ci": true,
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var errUnexpectedRead = errors.New("unexpected read from socket")
|
||||
|
||||
func connCheck(conn net.Conn) error {
|
||||
var sysErr error
|
||||
|
||||
sysConn, ok := conn.(syscall.Conn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rawConn, err := sysConn.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = rawConn.Read(func(fd uintptr) bool {
|
||||
var buf [1]byte
|
||||
n, err := syscall.Read(int(fd), buf[:])
|
||||
switch {
|
||||
case n == 0 && err == nil:
|
||||
sysErr = io.EOF
|
||||
case n > 0:
|
||||
sysErr = errUnexpectedRead
|
||||
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
|
||||
sysErr = nil
|
||||
default:
|
||||
sysErr = err
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sysErr
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
|
||||
|
||||
package mysql
|
||||
|
||||
import "net"
|
||||
|
||||
func connCheck(conn net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,650 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mysqlConn struct {
|
||||
buf buffer
|
||||
netConn net.Conn
|
||||
rawConn net.Conn // underlying connection when netConn is TLS connection.
|
||||
affectedRows uint64
|
||||
insertId uint64
|
||||
cfg *Config
|
||||
maxAllowedPacket int
|
||||
maxWriteSize int
|
||||
writeTimeout time.Duration
|
||||
flags clientFlag
|
||||
status statusFlag
|
||||
sequence uint8
|
||||
parseTime bool
|
||||
reset bool // set when the Go SQL package calls ResetSession
|
||||
|
||||
// for context support (Go 1.8+)
|
||||
watching bool
|
||||
watcher chan<- context.Context
|
||||
closech chan struct{}
|
||||
finished chan<- struct{}
|
||||
canceled atomicError // set non-nil if conn is canceled
|
||||
closed atomicBool // set when conn is closed, before closech is closed
|
||||
}
|
||||
|
||||
// Handles parameters set in DSN after the connection is established
|
||||
func (mc *mysqlConn) handleParams() (err error) {
|
||||
var cmdSet strings.Builder
|
||||
for param, val := range mc.cfg.Params {
|
||||
switch param {
|
||||
// Charset: character_set_connection, character_set_client, character_set_results
|
||||
case "charset":
|
||||
charsets := strings.Split(val, ",")
|
||||
for i := range charsets {
|
||||
// ignore errors here - a charset may not exist
|
||||
err = mc.exec("SET NAMES " + charsets[i])
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Other system vars accumulated in a single SET command
|
||||
default:
|
||||
if cmdSet.Len() == 0 {
|
||||
// Heuristic: 29 chars for each other key=value to reduce reallocations
|
||||
cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1))
|
||||
cmdSet.WriteString("SET ")
|
||||
} else {
|
||||
cmdSet.WriteByte(',')
|
||||
}
|
||||
cmdSet.WriteString(param)
|
||||
cmdSet.WriteByte('=')
|
||||
cmdSet.WriteString(val)
|
||||
}
|
||||
}
|
||||
|
||||
if cmdSet.Len() > 0 {
|
||||
err = mc.exec(cmdSet.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) markBadConn(err error) error {
|
||||
if mc == nil {
|
||||
return err
|
||||
}
|
||||
if err != errBadConnNoWrite {
|
||||
return err
|
||||
}
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
||||
return mc.begin(false)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
var q string
|
||||
if readOnly {
|
||||
q = "START TRANSACTION READ ONLY"
|
||||
} else {
|
||||
q = "START TRANSACTION"
|
||||
}
|
||||
err := mc.exec(q)
|
||||
if err == nil {
|
||||
return &mysqlTx{mc}, err
|
||||
}
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Close() (err error) {
|
||||
// Makes Close idempotent
|
||||
if !mc.closed.IsSet() {
|
||||
err = mc.writeCommandPacket(comQuit)
|
||||
}
|
||||
|
||||
mc.cleanup()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Closes the network connection and unsets internal variables. Do not call this
|
||||
// function after successfully authentication, call Close instead. This function
|
||||
// is called before auth or on auth failure because MySQL will have already
|
||||
// closed the network connection.
|
||||
func (mc *mysqlConn) cleanup() {
|
||||
if !mc.closed.TrySet(true) {
|
||||
return
|
||||
}
|
||||
|
||||
// Makes cleanup idempotent
|
||||
close(mc.closech)
|
||||
if mc.netConn == nil {
|
||||
return
|
||||
}
|
||||
if err := mc.netConn.Close(); err != nil {
|
||||
errLog.Print(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) error() error {
|
||||
if mc.closed.IsSet() {
|
||||
if err := mc.canceled.Value(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrInvalidConn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
||||
if err != nil {
|
||||
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
|
||||
errLog.Print(err)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
stmt := &mysqlStmt{
|
||||
mc: mc,
|
||||
}
|
||||
|
||||
// Read Result
|
||||
columnCount, err := stmt.readPrepareResultPacket()
|
||||
if err == nil {
|
||||
if stmt.paramCount > 0 {
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if columnCount > 0 {
|
||||
err = mc.readUntilEOF()
|
||||
}
|
||||
}
|
||||
|
||||
return stmt, err
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
|
||||
// Number of ? should be same to len(args)
|
||||
if strings.Count(query, "?") != len(args) {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
|
||||
buf, err := mc.buf.takeCompleteBuffer()
|
||||
if err != nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(err)
|
||||
return "", ErrInvalidConn
|
||||
}
|
||||
buf = buf[:0]
|
||||
argPos := 0
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
q := strings.IndexByte(query[i:], '?')
|
||||
if q == -1 {
|
||||
buf = append(buf, query[i:]...)
|
||||
break
|
||||
}
|
||||
buf = append(buf, query[i:i+q]...)
|
||||
i += q
|
||||
|
||||
arg := args[argPos]
|
||||
argPos++
|
||||
|
||||
if arg == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
continue
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case int64:
|
||||
buf = strconv.AppendInt(buf, v, 10)
|
||||
case uint64:
|
||||
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
|
||||
buf = strconv.AppendUint(buf, v, 10)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
||||
case bool:
|
||||
if v {
|
||||
buf = append(buf, '1')
|
||||
} else {
|
||||
buf = append(buf, '0')
|
||||
}
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
buf = append(buf, "'0000-00-00'"...)
|
||||
} else {
|
||||
buf = append(buf, '\'')
|
||||
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case json.RawMessage:
|
||||
buf = append(buf, '\'')
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeBytesQuotes(buf, v)
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
case []byte:
|
||||
if v == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
} else {
|
||||
buf = append(buf, "_binary'"...)
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeBytesQuotes(buf, v)
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case string:
|
||||
buf = append(buf, '\'')
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeStringBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeStringQuotes(buf, v)
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
default:
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
|
||||
if len(buf)+4 > mc.maxAllowedPacket {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
}
|
||||
if argPos != len(args) {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
if len(args) != 0 {
|
||||
if !mc.cfg.InterpolateParams {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
|
||||
prepared, err := mc.interpolateParams(query, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = prepared
|
||||
}
|
||||
mc.affectedRows = 0
|
||||
mc.insertId = 0
|
||||
|
||||
err := mc.exec(query)
|
||||
if err == nil {
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(mc.affectedRows),
|
||||
insertId: int64(mc.insertId),
|
||||
}, err
|
||||
}
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Internal function to execute commands
|
||||
func (mc *mysqlConn) exec(query string) error {
|
||||
// Send command
|
||||
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
|
||||
return mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return mc.discardResults()
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
return mc.query(query, args)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
if len(args) != 0 {
|
||||
if !mc.cfg.InterpolateParams {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
// try client-side prepare to reduce roundtrip
|
||||
prepared, err := mc.interpolateParams(query, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = prepared
|
||||
}
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comQuery, query)
|
||||
if err == nil {
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := new(textRows)
|
||||
rows.mc = mc
|
||||
|
||||
if resLen == 0 {
|
||||
rows.rs.done = true
|
||||
|
||||
switch err := rows.NextResultSet(); err {
|
||||
case nil, io.EOF:
|
||||
return rows, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Columns
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Gets the value of the given MySQL System Variable
|
||||
// The returned byte slice is only valid until the next read
|
||||
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||||
// Send command
|
||||
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := new(textRows)
|
||||
rows.mc = mc
|
||||
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dest := make([]driver.Value, resLen)
|
||||
if err = rows.readRow(dest); err == nil {
|
||||
return dest[0].([]byte), mc.readUntilEOF()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// finish is called when the query has canceled.
|
||||
func (mc *mysqlConn) cancel(err error) {
|
||||
mc.canceled.Set(err)
|
||||
mc.cleanup()
|
||||
}
|
||||
|
||||
// finish is called when the query has succeeded.
|
||||
func (mc *mysqlConn) finish() {
|
||||
if !mc.watching || mc.finished == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case mc.finished <- struct{}{}:
|
||||
mc.watching = false
|
||||
case <-mc.closech:
|
||||
}
|
||||
}
|
||||
|
||||
// Ping implements driver.Pinger interface
|
||||
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
if err = mc.watchCancel(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
if err = mc.writeCommandPacket(comPing); err != nil {
|
||||
return mc.markBadConn(err)
|
||||
}
|
||||
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
// BeginTx implements driver.ConnBeginTx interface
|
||||
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
if mc.closed.IsSet() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
|
||||
level, err := mapIsolationLevel(opts.Isolation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return mc.begin(opts.ReadOnly)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := mc.query(query, dargs)
|
||||
if err != nil {
|
||||
mc.finish()
|
||||
return nil, err
|
||||
}
|
||||
rows.finish = mc.finish
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
return mc.Exec(query, dargs)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmt, err := mc.Prepare(query)
|
||||
mc.finish()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
stmt.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := stmt.query(dargs)
|
||||
if err != nil {
|
||||
stmt.mc.finish()
|
||||
return nil, err
|
||||
}
|
||||
rows.finish = stmt.mc.finish
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.mc.finish()
|
||||
|
||||
return stmt.Exec(dargs)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
||||
if mc.watching {
|
||||
// Reach here if canceled,
|
||||
// so the connection is already invalid
|
||||
mc.cleanup()
|
||||
return nil
|
||||
}
|
||||
// When ctx is already cancelled, don't watch it.
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
// When ctx is not cancellable, don't watch it.
|
||||
if ctx.Done() == nil {
|
||||
return nil
|
||||
}
|
||||
// When watcher is not alive, can't watch it.
|
||||
if mc.watcher == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
mc.watcher <- ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) startWatcher() {
|
||||
watcher := make(chan context.Context, 1)
|
||||
mc.watcher = watcher
|
||||
finished := make(chan struct{})
|
||||
mc.finished = finished
|
||||
go func() {
|
||||
for {
|
||||
var ctx context.Context
|
||||
select {
|
||||
case ctx = <-watcher:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mc.cancel(ctx.Err())
|
||||
case <-finished:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||
nv.Value, err = converter{}.ConvertValue(nv.Value)
|
||||
return
|
||||
}
|
||||
|
||||
// ResetSession implements driver.SessionResetter.
|
||||
// (From Go 1.10)
|
||||
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
|
||||
if mc.closed.IsSet() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
mc.reset = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValid implements driver.Validator interface
|
||||
// (From Go 1.15)
|
||||
func (mc *mysqlConn) IsValid() bool {
|
||||
return !mc.closed.IsSet()
|
||||
}
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"net"
|
||||
)
|
||||
|
||||
type connector struct {
|
||||
cfg *Config // immutable private copy.
|
||||
}
|
||||
|
||||
// Connect implements driver.Connector interface.
|
||||
// Connect returns a connection to the database.
|
||||
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
var err error
|
||||
|
||||
// New mysqlConn
|
||||
mc := &mysqlConn{
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
maxWriteSize: maxPacketSize - 1,
|
||||
closech: make(chan struct{}),
|
||||
cfg: c.cfg,
|
||||
}
|
||||
mc.parseTime = mc.cfg.ParseTime
|
||||
|
||||
// Connect to Server
|
||||
dialsLock.RLock()
|
||||
dial, ok := dials[mc.cfg.Net]
|
||||
dialsLock.RUnlock()
|
||||
if ok {
|
||||
dctx := ctx
|
||||
if mc.cfg.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
mc.netConn, err = dial(dctx, mc.cfg.Addr)
|
||||
} else {
|
||||
nd := net.Dialer{Timeout: mc.cfg.Timeout}
|
||||
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enable TCP Keepalives on TCP connections
|
||||
if tc, ok := mc.netConn.(*net.TCPConn); ok {
|
||||
if err := tc.SetKeepAlive(true); err != nil {
|
||||
// Don't send COM_QUIT before handshake.
|
||||
mc.netConn.Close()
|
||||
mc.netConn = nil
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Call startWatcher for context support (From Go 1.8)
|
||||
mc.startWatcher()
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
mc.buf = newBuffer(mc.netConn)
|
||||
|
||||
// Set I/O timeouts
|
||||
mc.buf.timeout = mc.cfg.ReadTimeout
|
||||
mc.writeTimeout = mc.cfg.WriteTimeout
|
||||
|
||||
// Reading Handshake Initialization Packet
|
||||
authData, plugin, err := mc.readHandshakePacket()
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if plugin == "" {
|
||||
plugin = defaultAuthPlugin
|
||||
}
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
// try the default auth plugin, if using the requested plugin failed
|
||||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
|
||||
plugin = defaultAuthPlugin
|
||||
authResp, err = mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle response to auth packet, switch methods if possible
|
||||
if err = mc.handleAuthResult(authData, plugin); err != nil {
|
||||
// Authentication failed and MySQL has already closed the connection
|
||||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
|
||||
// Do not send COM_QUIT, just cleanup and return the error.
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mc.cfg.MaxAllowedPacket > 0 {
|
||||
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
|
||||
} else {
|
||||
// Get max allowed packet size
|
||||
maxap, err := mc.getSystemVar("max_allowed_packet")
|
||||
if err != nil {
|
||||
mc.Close()
|
||||
return nil, err
|
||||
}
|
||||
mc.maxAllowedPacket = stringToInt(maxap) - 1
|
||||
}
|
||||
if mc.maxAllowedPacket < maxPacketSize {
|
||||
mc.maxWriteSize = mc.maxAllowedPacket
|
||||
}
|
||||
|
||||
// Handle DSN Params
|
||||
err = mc.handleParams()
|
||||
if err != nil {
|
||||
mc.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
// Driver implements driver.Connector interface.
|
||||
// Driver returns &MySQLDriver{}.
|
||||
func (c *connector) Driver() driver.Driver {
|
||||
return &MySQLDriver{}
|
||||
}
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
const (
|
||||
defaultAuthPlugin = "mysql_native_password"
|
||||
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
|
||||
minProtocolVersion = 10
|
||||
maxPacketSize = 1<<24 - 1
|
||||
timeFormat = "2006-01-02 15:04:05.999999"
|
||||
)
|
||||
|
||||
// MySQL constants documentation:
|
||||
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
||||
|
||||
const (
|
||||
iOK byte = 0x00
|
||||
iAuthMoreData byte = 0x01
|
||||
iLocalInFile byte = 0xfb
|
||||
iEOF byte = 0xfe
|
||||
iERR byte = 0xff
|
||||
)
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
|
||||
type clientFlag uint32
|
||||
|
||||
const (
|
||||
clientLongPassword clientFlag = 1 << iota
|
||||
clientFoundRows
|
||||
clientLongFlag
|
||||
clientConnectWithDB
|
||||
clientNoSchema
|
||||
clientCompress
|
||||
clientODBC
|
||||
clientLocalFiles
|
||||
clientIgnoreSpace
|
||||
clientProtocol41
|
||||
clientInteractive
|
||||
clientSSL
|
||||
clientIgnoreSIGPIPE
|
||||
clientTransactions
|
||||
clientReserved
|
||||
clientSecureConn
|
||||
clientMultiStatements
|
||||
clientMultiResults
|
||||
clientPSMultiResults
|
||||
clientPluginAuth
|
||||
clientConnectAttrs
|
||||
clientPluginAuthLenEncClientData
|
||||
clientCanHandleExpiredPasswords
|
||||
clientSessionTrack
|
||||
clientDeprecateEOF
|
||||
)
|
||||
|
||||
const (
|
||||
comQuit byte = iota + 1
|
||||
comInitDB
|
||||
comQuery
|
||||
comFieldList
|
||||
comCreateDB
|
||||
comDropDB
|
||||
comRefresh
|
||||
comShutdown
|
||||
comStatistics
|
||||
comProcessInfo
|
||||
comConnect
|
||||
comProcessKill
|
||||
comDebug
|
||||
comPing
|
||||
comTime
|
||||
comDelayedInsert
|
||||
comChangeUser
|
||||
comBinlogDump
|
||||
comTableDump
|
||||
comConnectOut
|
||||
comRegisterSlave
|
||||
comStmtPrepare
|
||||
comStmtExecute
|
||||
comStmtSendLongData
|
||||
comStmtClose
|
||||
comStmtReset
|
||||
comSetOption
|
||||
comStmtFetch
|
||||
)
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
|
||||
type fieldType byte
|
||||
|
||||
const (
|
||||
fieldTypeDecimal fieldType = iota
|
||||
fieldTypeTiny
|
||||
fieldTypeShort
|
||||
fieldTypeLong
|
||||
fieldTypeFloat
|
||||
fieldTypeDouble
|
||||
fieldTypeNULL
|
||||
fieldTypeTimestamp
|
||||
fieldTypeLongLong
|
||||
fieldTypeInt24
|
||||
fieldTypeDate
|
||||
fieldTypeTime
|
||||
fieldTypeDateTime
|
||||
fieldTypeYear
|
||||
fieldTypeNewDate
|
||||
fieldTypeVarChar
|
||||
fieldTypeBit
|
||||
)
|
||||
const (
|
||||
fieldTypeJSON fieldType = iota + 0xf5
|
||||
fieldTypeNewDecimal
|
||||
fieldTypeEnum
|
||||
fieldTypeSet
|
||||
fieldTypeTinyBLOB
|
||||
fieldTypeMediumBLOB
|
||||
fieldTypeLongBLOB
|
||||
fieldTypeBLOB
|
||||
fieldTypeVarString
|
||||
fieldTypeString
|
||||
fieldTypeGeometry
|
||||
)
|
||||
|
||||
type fieldFlag uint16
|
||||
|
||||
const (
|
||||
flagNotNULL fieldFlag = 1 << iota
|
||||
flagPriKey
|
||||
flagUniqueKey
|
||||
flagMultipleKey
|
||||
flagBLOB
|
||||
flagUnsigned
|
||||
flagZeroFill
|
||||
flagBinary
|
||||
flagEnum
|
||||
flagAutoIncrement
|
||||
flagTimestamp
|
||||
flagSet
|
||||
flagUnknown1
|
||||
flagUnknown2
|
||||
flagUnknown3
|
||||
flagUnknown4
|
||||
)
|
||||
|
||||
// http://dev.mysql.com/doc/internals/en/status-flags.html
|
||||
type statusFlag uint16
|
||||
|
||||
const (
|
||||
statusInTrans statusFlag = 1 << iota
|
||||
statusInAutocommit
|
||||
statusReserved // Not in documentation
|
||||
statusMoreResultsExists
|
||||
statusNoGoodIndexUsed
|
||||
statusNoIndexUsed
|
||||
statusCursorExists
|
||||
statusLastRowSent
|
||||
statusDbDropped
|
||||
statusNoBackslashEscapes
|
||||
statusMetadataChanged
|
||||
statusQueryWasSlow
|
||||
statusPsOutParams
|
||||
statusInTransReadonly
|
||||
statusSessionStateChanged
|
||||
)
|
||||
|
||||
const (
|
||||
cachingSha2PasswordRequestPublicKey = 2
|
||||
cachingSha2PasswordFastAuthSuccess = 3
|
||||
cachingSha2PasswordPerformFullAuthentication = 4
|
||||
)
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Package mysql provides a MySQL driver for Go's database/sql package.
|
||||
//
|
||||
// The driver should be used via the database/sql package:
|
||||
//
|
||||
// import "database/sql"
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
//
|
||||
// db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
//
|
||||
// See https://github.com/go-sql-driver/mysql#usage for details
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MySQLDriver is exported to make the driver directly accessible.
|
||||
// In general the driver is used via the database/sql package.
|
||||
type MySQLDriver struct{}
|
||||
|
||||
// DialFunc is a function which can be used to establish the network connection.
|
||||
// Custom dial functions must be registered with RegisterDial
|
||||
//
|
||||
// Deprecated: users should register a DialContextFunc instead
|
||||
type DialFunc func(addr string) (net.Conn, error)
|
||||
|
||||
// DialContextFunc is a function which can be used to establish the network connection.
|
||||
// Custom dial functions must be registered with RegisterDialContext
|
||||
type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
|
||||
|
||||
var (
|
||||
dialsLock sync.RWMutex
|
||||
dials map[string]DialContextFunc
|
||||
)
|
||||
|
||||
// RegisterDialContext registers a custom dial function. It can then be used by the
|
||||
// network address mynet(addr), where mynet is the registered new network.
|
||||
// The current context for the connection and its address is passed to the dial function.
|
||||
func RegisterDialContext(net string, dial DialContextFunc) {
|
||||
dialsLock.Lock()
|
||||
defer dialsLock.Unlock()
|
||||
if dials == nil {
|
||||
dials = make(map[string]DialContextFunc)
|
||||
}
|
||||
dials[net] = dial
|
||||
}
|
||||
|
||||
// RegisterDial registers a custom dial function. It can then be used by the
|
||||
// network address mynet(addr), where mynet is the registered new network.
|
||||
// addr is passed as a parameter to the dial function.
|
||||
//
|
||||
// Deprecated: users should call RegisterDialContext instead
|
||||
func RegisterDial(network string, dial DialFunc) {
|
||||
RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) {
|
||||
return dial(addr)
|
||||
})
|
||||
}
|
||||
|
||||
// Open new Connection.
|
||||
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
|
||||
// the DSN string is formatted
|
||||
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
cfg, err := ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &connector{
|
||||
cfg: cfg,
|
||||
}
|
||||
return c.Connect(context.Background())
|
||||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("mysql", &MySQLDriver{})
|
||||
}
|
||||
|
||||
// NewConnector returns new driver.Connector.
|
||||
func NewConnector(cfg *Config) (driver.Connector, error) {
|
||||
cfg = cfg.Clone()
|
||||
// normalize the contents of cfg so calls to NewConnector have the same
|
||||
// behavior as MySQLDriver.OpenConnector
|
||||
if err := cfg.normalize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &connector{cfg: cfg}, nil
|
||||
}
|
||||
|
||||
// OpenConnector implements driver.DriverContext.
|
||||
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
|
||||
cfg, err := ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &connector{
|
||||
cfg: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,560 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
|
||||
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
|
||||
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
|
||||
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
|
||||
)
|
||||
|
||||
// Config is a configuration parsed from a DSN string.
|
||||
// If a new Config is created instead of being parsed from a DSN string,
|
||||
// the NewConfig function should be used, which sets default values.
|
||||
type Config struct {
|
||||
User string // Username
|
||||
Passwd string // Password (requires User)
|
||||
Net string // Network type
|
||||
Addr string // Network address (requires Net)
|
||||
DBName string // Database name
|
||||
Params map[string]string // Connection parameters
|
||||
Collation string // Connection collation
|
||||
Loc *time.Location // Location for time.Time values
|
||||
MaxAllowedPacket int // Max packet size allowed
|
||||
ServerPubKey string // Server public key name
|
||||
pubKey *rsa.PublicKey // Server public key
|
||||
TLSConfig string // TLS configuration name
|
||||
tls *tls.Config // TLS configuration
|
||||
Timeout time.Duration // Dial timeout
|
||||
ReadTimeout time.Duration // I/O read timeout
|
||||
WriteTimeout time.Duration // I/O write timeout
|
||||
|
||||
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
|
||||
AllowCleartextPasswords bool // Allows the cleartext client side plugin
|
||||
AllowNativePasswords bool // Allows the native password authentication method
|
||||
AllowOldPasswords bool // Allows the old insecure password method
|
||||
CheckConnLiveness bool // Check connections for liveness before using them
|
||||
ClientFoundRows bool // Return number of matching rows instead of rows changed
|
||||
ColumnsWithAlias bool // Prepend table alias to column names
|
||||
InterpolateParams bool // Interpolate placeholders into query string
|
||||
MultiStatements bool // Allow multiple statements in one query
|
||||
ParseTime bool // Parse time values to time.Time
|
||||
RejectReadOnly bool // Reject read-only connections
|
||||
}
|
||||
|
||||
// NewConfig creates a new Config and sets default values.
|
||||
func NewConfig() *Config {
|
||||
return &Config{
|
||||
Collation: defaultCollation,
|
||||
Loc: time.UTC,
|
||||
MaxAllowedPacket: defaultMaxAllowedPacket,
|
||||
AllowNativePasswords: true,
|
||||
CheckConnLiveness: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) Clone() *Config {
|
||||
cp := *cfg
|
||||
if cp.tls != nil {
|
||||
cp.tls = cfg.tls.Clone()
|
||||
}
|
||||
if len(cp.Params) > 0 {
|
||||
cp.Params = make(map[string]string, len(cfg.Params))
|
||||
for k, v := range cfg.Params {
|
||||
cp.Params[k] = v
|
||||
}
|
||||
}
|
||||
if cfg.pubKey != nil {
|
||||
cp.pubKey = &rsa.PublicKey{
|
||||
N: new(big.Int).Set(cfg.pubKey.N),
|
||||
E: cfg.pubKey.E,
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (cfg *Config) normalize() error {
|
||||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
|
||||
return errInvalidDSNUnsafeCollation
|
||||
}
|
||||
|
||||
// Set default network if empty
|
||||
if cfg.Net == "" {
|
||||
cfg.Net = "tcp"
|
||||
}
|
||||
|
||||
// Set default address if empty
|
||||
if cfg.Addr == "" {
|
||||
switch cfg.Net {
|
||||
case "tcp":
|
||||
cfg.Addr = "127.0.0.1:3306"
|
||||
case "unix":
|
||||
cfg.Addr = "/tmp/mysql.sock"
|
||||
default:
|
||||
return errors.New("default addr for network '" + cfg.Net + "' unknown")
|
||||
}
|
||||
} else if cfg.Net == "tcp" {
|
||||
cfg.Addr = ensureHavePort(cfg.Addr)
|
||||
}
|
||||
|
||||
switch cfg.TLSConfig {
|
||||
case "false", "":
|
||||
// don't set anything
|
||||
case "true":
|
||||
cfg.tls = &tls.Config{}
|
||||
case "skip-verify", "preferred":
|
||||
cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
||||
default:
|
||||
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
|
||||
if cfg.tls == nil {
|
||||
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
|
||||
host, _, err := net.SplitHostPort(cfg.Addr)
|
||||
if err == nil {
|
||||
cfg.tls.ServerName = host
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.ServerPubKey != "" {
|
||||
cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
|
||||
if cfg.pubKey == nil {
|
||||
return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
|
||||
buf.Grow(1 + len(name) + 1 + len(value))
|
||||
if !*hasParam {
|
||||
*hasParam = true
|
||||
buf.WriteByte('?')
|
||||
} else {
|
||||
buf.WriteByte('&')
|
||||
}
|
||||
buf.WriteString(name)
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(value)
|
||||
}
|
||||
|
||||
// FormatDSN formats the given Config into a DSN string which can be passed to
|
||||
// the driver.
|
||||
func (cfg *Config) FormatDSN() string {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// [username[:password]@]
|
||||
if len(cfg.User) > 0 {
|
||||
buf.WriteString(cfg.User)
|
||||
if len(cfg.Passwd) > 0 {
|
||||
buf.WriteByte(':')
|
||||
buf.WriteString(cfg.Passwd)
|
||||
}
|
||||
buf.WriteByte('@')
|
||||
}
|
||||
|
||||
// [protocol[(address)]]
|
||||
if len(cfg.Net) > 0 {
|
||||
buf.WriteString(cfg.Net)
|
||||
if len(cfg.Addr) > 0 {
|
||||
buf.WriteByte('(')
|
||||
buf.WriteString(cfg.Addr)
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
// /dbname
|
||||
buf.WriteByte('/')
|
||||
buf.WriteString(cfg.DBName)
|
||||
|
||||
// [?param1=value1&...¶mN=valueN]
|
||||
hasParam := false
|
||||
|
||||
if cfg.AllowAllFiles {
|
||||
hasParam = true
|
||||
buf.WriteString("?allowAllFiles=true")
|
||||
}
|
||||
|
||||
if cfg.AllowCleartextPasswords {
|
||||
writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
|
||||
}
|
||||
|
||||
if !cfg.AllowNativePasswords {
|
||||
writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
|
||||
}
|
||||
|
||||
if cfg.AllowOldPasswords {
|
||||
writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true")
|
||||
}
|
||||
|
||||
if !cfg.CheckConnLiveness {
|
||||
writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false")
|
||||
}
|
||||
|
||||
if cfg.ClientFoundRows {
|
||||
writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
|
||||
}
|
||||
|
||||
if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
|
||||
writeDSNParam(&buf, &hasParam, "collation", col)
|
||||
}
|
||||
|
||||
if cfg.ColumnsWithAlias {
|
||||
writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
|
||||
}
|
||||
|
||||
if cfg.InterpolateParams {
|
||||
writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
|
||||
}
|
||||
|
||||
if cfg.Loc != time.UTC && cfg.Loc != nil {
|
||||
writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String()))
|
||||
}
|
||||
|
||||
if cfg.MultiStatements {
|
||||
writeDSNParam(&buf, &hasParam, "multiStatements", "true")
|
||||
}
|
||||
|
||||
if cfg.ParseTime {
|
||||
writeDSNParam(&buf, &hasParam, "parseTime", "true")
|
||||
}
|
||||
|
||||
if cfg.ReadTimeout > 0 {
|
||||
writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String())
|
||||
}
|
||||
|
||||
if cfg.RejectReadOnly {
|
||||
writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true")
|
||||
}
|
||||
|
||||
if len(cfg.ServerPubKey) > 0 {
|
||||
writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey))
|
||||
}
|
||||
|
||||
if cfg.Timeout > 0 {
|
||||
writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String())
|
||||
}
|
||||
|
||||
if len(cfg.TLSConfig) > 0 {
|
||||
writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig))
|
||||
}
|
||||
|
||||
if cfg.WriteTimeout > 0 {
|
||||
writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String())
|
||||
}
|
||||
|
||||
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
|
||||
writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
|
||||
}
|
||||
|
||||
// other params
|
||||
if cfg.Params != nil {
|
||||
var params []string
|
||||
for param := range cfg.Params {
|
||||
params = append(params, param)
|
||||
}
|
||||
sort.Strings(params)
|
||||
for _, param := range params {
|
||||
writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param]))
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// ParseDSN parses the DSN string to a Config
|
||||
func ParseDSN(dsn string) (cfg *Config, err error) {
|
||||
// New config with some default values
|
||||
cfg = NewConfig()
|
||||
|
||||
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
||||
// Find the last '/' (since the password or the net addr might contain a '/')
|
||||
foundSlash := false
|
||||
for i := len(dsn) - 1; i >= 0; i-- {
|
||||
if dsn[i] == '/' {
|
||||
foundSlash = true
|
||||
var j, k int
|
||||
|
||||
// left part is empty if i <= 0
|
||||
if i > 0 {
|
||||
// [username[:password]@][protocol[(address)]]
|
||||
// Find the last '@' in dsn[:i]
|
||||
for j = i; j >= 0; j-- {
|
||||
if dsn[j] == '@' {
|
||||
// username[:password]
|
||||
// Find the first ':' in dsn[:j]
|
||||
for k = 0; k < j; k++ {
|
||||
if dsn[k] == ':' {
|
||||
cfg.Passwd = dsn[k+1 : j]
|
||||
break
|
||||
}
|
||||
}
|
||||
cfg.User = dsn[:k]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// [protocol[(address)]]
|
||||
// Find the first '(' in dsn[j+1:i]
|
||||
for k = j + 1; k < i; k++ {
|
||||
if dsn[k] == '(' {
|
||||
// dsn[i-1] must be == ')' if an address is specified
|
||||
if dsn[i-1] != ')' {
|
||||
if strings.ContainsRune(dsn[k+1:i], ')') {
|
||||
return nil, errInvalidDSNUnescaped
|
||||
}
|
||||
return nil, errInvalidDSNAddr
|
||||
}
|
||||
cfg.Addr = dsn[k+1 : i-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
cfg.Net = dsn[j+1 : k]
|
||||
}
|
||||
|
||||
// dbname[?param1=value1&...¶mN=valueN]
|
||||
// Find the first '?' in dsn[i+1:]
|
||||
for j = i + 1; j < len(dsn); j++ {
|
||||
if dsn[j] == '?' {
|
||||
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
cfg.DBName = dsn[i+1 : j]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundSlash && len(dsn) > 0 {
|
||||
return nil, errInvalidDSNNoSlash
|
||||
}
|
||||
|
||||
if err = cfg.normalize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// parseDSNParams parses the DSN "query string"
|
||||
// Values must be url.QueryEscape'ed
|
||||
func parseDSNParams(cfg *Config, params string) (err error) {
|
||||
for _, v := range strings.Split(params, "&") {
|
||||
param := strings.SplitN(v, "=", 2)
|
||||
if len(param) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// cfg params
|
||||
switch value := param[1]; param[0] {
|
||||
// Disable INFILE allowlist / enable all files
|
||||
case "allowAllFiles":
|
||||
var isBool bool
|
||||
cfg.AllowAllFiles, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Use cleartext authentication mode (MySQL 5.5.10+)
|
||||
case "allowCleartextPasswords":
|
||||
var isBool bool
|
||||
cfg.AllowCleartextPasswords, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Use native password authentication
|
||||
case "allowNativePasswords":
|
||||
var isBool bool
|
||||
cfg.AllowNativePasswords, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Use old authentication mode (pre MySQL 4.1)
|
||||
case "allowOldPasswords":
|
||||
var isBool bool
|
||||
cfg.AllowOldPasswords, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Check connections for Liveness before using them
|
||||
case "checkConnLiveness":
|
||||
var isBool bool
|
||||
cfg.CheckConnLiveness, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Switch "rowsAffected" mode
|
||||
case "clientFoundRows":
|
||||
var isBool bool
|
||||
cfg.ClientFoundRows, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Collation
|
||||
case "collation":
|
||||
cfg.Collation = value
|
||||
break
|
||||
|
||||
case "columnsWithAlias":
|
||||
var isBool bool
|
||||
cfg.ColumnsWithAlias, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Compression
|
||||
case "compress":
|
||||
return errors.New("compression not implemented yet")
|
||||
|
||||
// Enable client side placeholder substitution
|
||||
case "interpolateParams":
|
||||
var isBool bool
|
||||
cfg.InterpolateParams, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Time Location
|
||||
case "loc":
|
||||
if value, err = url.QueryUnescape(value); err != nil {
|
||||
return
|
||||
}
|
||||
cfg.Loc, err = time.LoadLocation(value)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// multiple statements in one query
|
||||
case "multiStatements":
|
||||
var isBool bool
|
||||
cfg.MultiStatements, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// time.Time parsing
|
||||
case "parseTime":
|
||||
var isBool bool
|
||||
cfg.ParseTime, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// I/O read Timeout
|
||||
case "readTimeout":
|
||||
cfg.ReadTimeout, err = time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Reject read-only connections
|
||||
case "rejectReadOnly":
|
||||
var isBool bool
|
||||
cfg.RejectReadOnly, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Server public key
|
||||
case "serverPubKey":
|
||||
name, err := url.QueryUnescape(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for server pub key name: %v", err)
|
||||
}
|
||||
cfg.ServerPubKey = name
|
||||
|
||||
// Strict mode
|
||||
case "strict":
|
||||
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
|
||||
|
||||
// Dial Timeout
|
||||
case "timeout":
|
||||
cfg.Timeout, err = time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TLS-Encryption
|
||||
case "tls":
|
||||
boolValue, isBool := readBool(value)
|
||||
if isBool {
|
||||
if boolValue {
|
||||
cfg.TLSConfig = "true"
|
||||
} else {
|
||||
cfg.TLSConfig = "false"
|
||||
}
|
||||
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
|
||||
cfg.TLSConfig = vl
|
||||
} else {
|
||||
name, err := url.QueryUnescape(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for TLS config name: %v", err)
|
||||
}
|
||||
cfg.TLSConfig = name
|
||||
}
|
||||
|
||||
// I/O write Timeout
|
||||
case "writeTimeout":
|
||||
cfg.WriteTimeout, err = time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "maxAllowedPacket":
|
||||
cfg.MaxAllowedPacket, err = strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
default:
|
||||
// lazy init
|
||||
if cfg.Params == nil {
|
||||
cfg.Params = make(map[string]string)
|
||||
}
|
||||
|
||||
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ensureHavePort(addr string) string {
|
||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||
return net.JoinHostPort(addr, "3306")
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Various errors the driver might return. Can change between driver versions.
|
||||
var (
|
||||
ErrInvalidConn = errors.New("invalid connection")
|
||||
ErrMalformPkt = errors.New("malformed packet")
|
||||
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
|
||||
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
|
||||
ErrNativePassword = errors.New("this user requires mysql native password authentication.")
|
||||
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
|
||||
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
|
||||
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
|
||||
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
|
||||
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
|
||||
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
|
||||
ErrBusyBuffer = errors.New("busy buffer")
|
||||
|
||||
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
|
||||
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
|
||||
// to trigger a resend.
|
||||
// See https://github.com/go-sql-driver/mysql/pull/302
|
||||
errBadConnNoWrite = errors.New("bad connection")
|
||||
)
|
||||
|
||||
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
|
||||
|
||||
// Logger is used to log critical error messages.
|
||||
type Logger interface {
|
||||
Print(v ...interface{})
|
||||
}
|
||||
|
||||
// SetLogger is used to set the logger for critical errors.
|
||||
// The initial logger is os.Stderr.
|
||||
func SetLogger(logger Logger) error {
|
||||
if logger == nil {
|
||||
return errors.New("logger is nil")
|
||||
}
|
||||
errLog = logger
|
||||
return nil
|
||||
}
|
||||
|
||||
// MySQLError is an error type which represents a single MySQL error
|
||||
type MySQLError struct {
|
||||
Number uint16
|
||||
Message string
|
||||
}
|
||||
|
||||
func (me *MySQLError) Error() string {
|
||||
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
|
||||
}
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func (mf *mysqlField) typeDatabaseName() string {
|
||||
switch mf.fieldType {
|
||||
case fieldTypeBit:
|
||||
return "BIT"
|
||||
case fieldTypeBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "TEXT"
|
||||
}
|
||||
return "BLOB"
|
||||
case fieldTypeDate:
|
||||
return "DATE"
|
||||
case fieldTypeDateTime:
|
||||
return "DATETIME"
|
||||
case fieldTypeDecimal:
|
||||
return "DECIMAL"
|
||||
case fieldTypeDouble:
|
||||
return "DOUBLE"
|
||||
case fieldTypeEnum:
|
||||
return "ENUM"
|
||||
case fieldTypeFloat:
|
||||
return "FLOAT"
|
||||
case fieldTypeGeometry:
|
||||
return "GEOMETRY"
|
||||
case fieldTypeInt24:
|
||||
return "MEDIUMINT"
|
||||
case fieldTypeJSON:
|
||||
return "JSON"
|
||||
case fieldTypeLong:
|
||||
return "INT"
|
||||
case fieldTypeLongBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "LONGTEXT"
|
||||
}
|
||||
return "LONGBLOB"
|
||||
case fieldTypeLongLong:
|
||||
return "BIGINT"
|
||||
case fieldTypeMediumBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "MEDIUMTEXT"
|
||||
}
|
||||
return "MEDIUMBLOB"
|
||||
case fieldTypeNewDate:
|
||||
return "DATE"
|
||||
case fieldTypeNewDecimal:
|
||||
return "DECIMAL"
|
||||
case fieldTypeNULL:
|
||||
return "NULL"
|
||||
case fieldTypeSet:
|
||||
return "SET"
|
||||
case fieldTypeShort:
|
||||
return "SMALLINT"
|
||||
case fieldTypeString:
|
||||
if mf.charSet == collations[binaryCollation] {
|
||||
return "BINARY"
|
||||
}
|
||||
return "CHAR"
|
||||
case fieldTypeTime:
|
||||
return "TIME"
|
||||
case fieldTypeTimestamp:
|
||||
return "TIMESTAMP"
|
||||
case fieldTypeTiny:
|
||||
return "TINYINT"
|
||||
case fieldTypeTinyBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "TINYTEXT"
|
||||
}
|
||||
return "TINYBLOB"
|
||||
case fieldTypeVarChar:
|
||||
if mf.charSet == collations[binaryCollation] {
|
||||
return "VARBINARY"
|
||||
}
|
||||
return "VARCHAR"
|
||||
case fieldTypeVarString:
|
||||
if mf.charSet == collations[binaryCollation] {
|
||||
return "VARBINARY"
|
||||
}
|
||||
return "VARCHAR"
|
||||
case fieldTypeYear:
|
||||
return "YEAR"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
scanTypeFloat32 = reflect.TypeOf(float32(0))
|
||||
scanTypeFloat64 = reflect.TypeOf(float64(0))
|
||||
scanTypeInt8 = reflect.TypeOf(int8(0))
|
||||
scanTypeInt16 = reflect.TypeOf(int16(0))
|
||||
scanTypeInt32 = reflect.TypeOf(int32(0))
|
||||
scanTypeInt64 = reflect.TypeOf(int64(0))
|
||||
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
|
||||
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
|
||||
scanTypeNullTime = reflect.TypeOf(nullTime{})
|
||||
scanTypeUint8 = reflect.TypeOf(uint8(0))
|
||||
scanTypeUint16 = reflect.TypeOf(uint16(0))
|
||||
scanTypeUint32 = reflect.TypeOf(uint32(0))
|
||||
scanTypeUint64 = reflect.TypeOf(uint64(0))
|
||||
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
|
||||
scanTypeUnknown = reflect.TypeOf(new(interface{}))
|
||||
)
|
||||
|
||||
type mysqlField struct {
|
||||
tableName string
|
||||
name string
|
||||
length uint32
|
||||
flags fieldFlag
|
||||
fieldType fieldType
|
||||
decimals byte
|
||||
charSet uint8
|
||||
}
|
||||
|
||||
func (mf *mysqlField) scanType() reflect.Type {
|
||||
switch mf.fieldType {
|
||||
case fieldTypeTiny:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint8
|
||||
}
|
||||
return scanTypeInt8
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeShort, fieldTypeYear:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint16
|
||||
}
|
||||
return scanTypeInt16
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeInt24, fieldTypeLong:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint32
|
||||
}
|
||||
return scanTypeInt32
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeLongLong:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint64
|
||||
}
|
||||
return scanTypeInt64
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeFloat:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
return scanTypeFloat32
|
||||
}
|
||||
return scanTypeNullFloat
|
||||
|
||||
case fieldTypeDouble:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
return scanTypeFloat64
|
||||
}
|
||||
return scanTypeNullFloat
|
||||
|
||||
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
|
||||
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
|
||||
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
|
||||
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
|
||||
fieldTypeTime:
|
||||
return scanTypeRawBytes
|
||||
|
||||
case fieldTypeDate, fieldTypeNewDate,
|
||||
fieldTypeTimestamp, fieldTypeDateTime:
|
||||
// NullTime is always returned for more consistent behavior as it can
|
||||
// handle both cases of parseTime regardless if the field is nullable.
|
||||
return scanTypeNullTime
|
||||
|
||||
default:
|
||||
return scanTypeUnknown
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
|
||||
//
|
||||
// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build gofuzz
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
db, err := sql.Open("mysql", string(data))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
db.Close()
|
||||
return 1
|
||||
}
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
fileRegister map[string]bool
|
||||
fileRegisterLock sync.RWMutex
|
||||
readerRegister map[string]func() io.Reader
|
||||
readerRegisterLock sync.RWMutex
|
||||
)
|
||||
|
||||
// RegisterLocalFile adds the given file to the file allowlist,
|
||||
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
|
||||
// Alternatively you can allow the use of all local files with
|
||||
// the DSN parameter 'allowAllFiles=true'
|
||||
//
|
||||
// filePath := "/home/gopher/data.csv"
|
||||
// mysql.RegisterLocalFile(filePath)
|
||||
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
|
||||
// if err != nil {
|
||||
// ...
|
||||
//
|
||||
func RegisterLocalFile(filePath string) {
|
||||
fileRegisterLock.Lock()
|
||||
// lazy map init
|
||||
if fileRegister == nil {
|
||||
fileRegister = make(map[string]bool)
|
||||
}
|
||||
|
||||
fileRegister[strings.Trim(filePath, `"`)] = true
|
||||
fileRegisterLock.Unlock()
|
||||
}
|
||||
|
||||
// DeregisterLocalFile removes the given filepath from the allowlist.
|
||||
func DeregisterLocalFile(filePath string) {
|
||||
fileRegisterLock.Lock()
|
||||
delete(fileRegister, strings.Trim(filePath, `"`))
|
||||
fileRegisterLock.Unlock()
|
||||
}
|
||||
|
||||
// RegisterReaderHandler registers a handler function which is used
|
||||
// to receive a io.Reader.
|
||||
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
|
||||
// If the handler returns a io.ReadCloser Close() is called when the
|
||||
// request is finished.
|
||||
//
|
||||
// mysql.RegisterReaderHandler("data", func() io.Reader {
|
||||
// var csvReader io.Reader // Some Reader that returns CSV data
|
||||
// ... // Open Reader here
|
||||
// return csvReader
|
||||
// })
|
||||
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
|
||||
// if err != nil {
|
||||
// ...
|
||||
//
|
||||
func RegisterReaderHandler(name string, handler func() io.Reader) {
|
||||
readerRegisterLock.Lock()
|
||||
// lazy map init
|
||||
if readerRegister == nil {
|
||||
readerRegister = make(map[string]func() io.Reader)
|
||||
}
|
||||
|
||||
readerRegister[name] = handler
|
||||
readerRegisterLock.Unlock()
|
||||
}
|
||||
|
||||
// DeregisterReaderHandler removes the ReaderHandler function with
|
||||
// the given name from the registry.
|
||||
func DeregisterReaderHandler(name string) {
|
||||
readerRegisterLock.Lock()
|
||||
delete(readerRegister, name)
|
||||
readerRegisterLock.Unlock()
|
||||
}
|
||||
|
||||
func deferredClose(err *error, closer io.Closer) {
|
||||
closeErr := closer.Close()
|
||||
if *err == nil {
|
||||
*err = closeErr
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
||||
var rdr io.Reader
|
||||
var data []byte
|
||||
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
|
||||
if mc.maxWriteSize < packetSize {
|
||||
packetSize = mc.maxWriteSize
|
||||
}
|
||||
|
||||
if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
|
||||
// The server might return an an absolute path. See issue #355.
|
||||
name = name[idx+8:]
|
||||
|
||||
readerRegisterLock.RLock()
|
||||
handler, inMap := readerRegister[name]
|
||||
readerRegisterLock.RUnlock()
|
||||
|
||||
if inMap {
|
||||
rdr = handler()
|
||||
if rdr != nil {
|
||||
if cl, ok := rdr.(io.Closer); ok {
|
||||
defer deferredClose(&err, cl)
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Reader '%s' is <nil>", name)
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Reader '%s' is not registered", name)
|
||||
}
|
||||
} else { // File
|
||||
name = strings.Trim(name, `"`)
|
||||
fileRegisterLock.RLock()
|
||||
fr := fileRegister[name]
|
||||
fileRegisterLock.RUnlock()
|
||||
if mc.cfg.AllowAllFiles || fr {
|
||||
var file *os.File
|
||||
var fi os.FileInfo
|
||||
|
||||
if file, err = os.Open(name); err == nil {
|
||||
defer deferredClose(&err, file)
|
||||
|
||||
// get file size
|
||||
if fi, err = file.Stat(); err == nil {
|
||||
rdr = file
|
||||
if fileSize := int(fi.Size()); fileSize < packetSize {
|
||||
packetSize = fileSize
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("local file '%s' is not registered", name)
|
||||
}
|
||||
}
|
||||
|
||||
// send content packets
|
||||
// if packetSize == 0, the Reader contains no data
|
||||
if err == nil && packetSize > 0 {
|
||||
data := make([]byte, 4+packetSize)
|
||||
var n int
|
||||
for err == nil {
|
||||
n, err = rdr.Read(data[4:])
|
||||
if n > 0 {
|
||||
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
|
||||
return ioErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
// send empty packet (termination)
|
||||
if data == nil {
|
||||
data = make([]byte, 4)
|
||||
}
|
||||
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
|
||||
return ioErr
|
||||
}
|
||||
|
||||
// read OK packet
|
||||
if err == nil {
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
mc.readPacket()
|
||||
return err
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
// The value type must be time.Time or string / []byte (formatted time-string),
|
||||
// otherwise Scan fails.
|
||||
func (nt *NullTime) Scan(value interface{}) (err error) {
|
||||
if value == nil {
|
||||
nt.Time, nt.Valid = time.Time{}, false
|
||||
return
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
nt.Time, nt.Valid = v, true
|
||||
return
|
||||
case []byte:
|
||||
nt.Time, err = parseDateTime(v, time.UTC)
|
||||
nt.Valid = (err == nil)
|
||||
return
|
||||
case string:
|
||||
nt.Time, err = parseDateTime([]byte(v), time.UTC)
|
||||
nt.Valid = (err == nil)
|
||||
return
|
||||
}
|
||||
|
||||
nt.Valid = false
|
||||
return fmt.Errorf("Can't convert %T to time.Time", value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.13
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// NullTime represents a time.Time that may be NULL.
|
||||
// NullTime implements the Scanner interface so
|
||||
// it can be used as a scan destination:
|
||||
//
|
||||
// var nt NullTime
|
||||
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
|
||||
// ...
|
||||
// if nt.Valid {
|
||||
// // use nt.Time
|
||||
// } else {
|
||||
// // NULL value
|
||||
// }
|
||||
//
|
||||
// This NullTime implementation is not driver-specific
|
||||
//
|
||||
// Deprecated: NullTime doesn't honor the loc DSN parameter.
|
||||
// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
|
||||
// Use sql.NullTime instead.
|
||||
type NullTime sql.NullTime
|
||||
|
||||
// for internal use.
|
||||
// the mysql package uses sql.NullTime if it is available.
|
||||
// if not, the package uses mysql.NullTime.
|
||||
type nullTime = sql.NullTime // sql.NullTime is available
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build !go1.13
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// NullTime represents a time.Time that may be NULL.
|
||||
// NullTime implements the Scanner interface so
|
||||
// it can be used as a scan destination:
|
||||
//
|
||||
// var nt NullTime
|
||||
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
|
||||
// ...
|
||||
// if nt.Valid {
|
||||
// // use nt.Time
|
||||
// } else {
|
||||
// // NULL value
|
||||
// }
|
||||
//
|
||||
// This NullTime implementation is not driver-specific
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// for internal use.
|
||||
// the mysql package uses sql.NullTime if it is available.
|
||||
// if not, the package uses mysql.NullTime.
|
||||
type nullTime = NullTime // sql.NullTime is not available
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,22 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
type mysqlResult struct {
|
||||
affectedRows int64
|
||||
insertId int64
|
||||
}
|
||||
|
||||
func (res *mysqlResult) LastInsertId() (int64, error) {
|
||||
return res.insertId, nil
|
||||
}
|
||||
|
||||
func (res *mysqlResult) RowsAffected() (int64, error) {
|
||||
return res.affectedRows, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type resultSet struct {
|
||||
columns []mysqlField
|
||||
columnNames []string
|
||||
done bool
|
||||
}
|
||||
|
||||
type mysqlRows struct {
|
||||
mc *mysqlConn
|
||||
rs resultSet
|
||||
finish func()
|
||||
}
|
||||
|
||||
type binaryRows struct {
|
||||
mysqlRows
|
||||
}
|
||||
|
||||
type textRows struct {
|
||||
mysqlRows
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) Columns() []string {
|
||||
if rows.rs.columnNames != nil {
|
||||
return rows.rs.columnNames
|
||||
}
|
||||
|
||||
columns := make([]string, len(rows.rs.columns))
|
||||
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
|
||||
for i := range columns {
|
||||
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
|
||||
columns[i] = tableName + "." + rows.rs.columns[i].name
|
||||
} else {
|
||||
columns[i] = rows.rs.columns[i].name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := range columns {
|
||||
columns[i] = rows.rs.columns[i].name
|
||||
}
|
||||
}
|
||||
|
||||
rows.rs.columnNames = columns
|
||||
return columns
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
|
||||
return rows.rs.columns[i].typeDatabaseName()
|
||||
}
|
||||
|
||||
// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
|
||||
// return int64(rows.rs.columns[i].length), true
|
||||
// }
|
||||
|
||||
func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
|
||||
return rows.rs.columns[i].flags&flagNotNULL == 0, true
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) {
|
||||
column := rows.rs.columns[i]
|
||||
decimals := int64(column.decimals)
|
||||
|
||||
switch column.fieldType {
|
||||
case fieldTypeDecimal, fieldTypeNewDecimal:
|
||||
if decimals > 0 {
|
||||
return int64(column.length) - 2, decimals, true
|
||||
}
|
||||
return int64(column.length) - 1, decimals, true
|
||||
case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime:
|
||||
return decimals, decimals, true
|
||||
case fieldTypeFloat, fieldTypeDouble:
|
||||
if decimals == 0x1f {
|
||||
return math.MaxInt64, math.MaxInt64, true
|
||||
}
|
||||
return math.MaxInt64, decimals, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
|
||||
return rows.rs.columns[i].scanType()
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) Close() (err error) {
|
||||
if f := rows.finish; f != nil {
|
||||
f()
|
||||
rows.finish = nil
|
||||
}
|
||||
|
||||
mc := rows.mc
|
||||
if mc == nil {
|
||||
return nil
|
||||
}
|
||||
if err := mc.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// flip the buffer for this connection if we need to drain it.
|
||||
// note that for a successful query (i.e. one where rows.next()
|
||||
// has been called until it returns false), `rows.mc` will be nil
|
||||
// by the time the user calls `(*Rows).Close`, so we won't reach this
|
||||
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
|
||||
mc.buf.flip()
|
||||
|
||||
// Remove unread packets from stream
|
||||
if !rows.rs.done {
|
||||
err = mc.readUntilEOF()
|
||||
}
|
||||
if err == nil {
|
||||
if err = mc.discardResults(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
rows.mc = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) HasNextResultSet() (b bool) {
|
||||
if rows.mc == nil {
|
||||
return false
|
||||
}
|
||||
return rows.mc.status&statusMoreResultsExists != 0
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) nextResultSet() (int, error) {
|
||||
if rows.mc == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if err := rows.mc.error(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Remove unread packets from stream
|
||||
if !rows.rs.done {
|
||||
if err := rows.mc.readUntilEOF(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows.rs.done = true
|
||||
}
|
||||
|
||||
if !rows.HasNextResultSet() {
|
||||
rows.mc = nil
|
||||
return 0, io.EOF
|
||||
}
|
||||
rows.rs = resultSet{}
|
||||
return rows.mc.readResultSetHeaderPacket()
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
|
||||
for {
|
||||
resLen, err := rows.nextResultSet()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
return resLen, nil
|
||||
}
|
||||
|
||||
rows.rs.done = true
|
||||
}
|
||||
}
|
||||
|
||||
func (rows *binaryRows) NextResultSet() error {
|
||||
resLen, err := rows.nextNotEmptyResultSet()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||
return err
|
||||
}
|
||||
|
||||
func (rows *binaryRows) Next(dest []driver.Value) error {
|
||||
if mc := rows.mc; mc != nil {
|
||||
if err := mc.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Fetch next row from stream
|
||||
return rows.readRow(dest)
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (rows *textRows) NextResultSet() (err error) {
|
||||
resLen, err := rows.nextNotEmptyResultSet()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||
return err
|
||||
}
|
||||
|
||||
func (rows *textRows) Next(dest []driver.Value) error {
|
||||
if mc := rows.mc; mc != nil {
|
||||
if err := mc.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Fetch next row from stream
|
||||
return rows.readRow(dest)
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
|
@ -0,0 +1,220 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type mysqlStmt struct {
|
||||
mc *mysqlConn
|
||||
id uint32
|
||||
paramCount int
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Close() error {
|
||||
if stmt.mc == nil || stmt.mc.closed.IsSet() {
|
||||
// driver.Stmt.Close can be called more than once, thus this function
|
||||
// has to be idempotent.
|
||||
// See also Issue #450 and golang/go#16019.
|
||||
//errLog.Print(ErrInvalidConn)
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
|
||||
stmt.mc = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) NumInput() int {
|
||||
return stmt.paramCount
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
|
||||
return converter{}
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||
nv.Value, err = converter{}.ConvertValue(nv.Value)
|
||||
return
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
if stmt.mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, stmt.mc.markBadConn(err)
|
||||
}
|
||||
|
||||
mc := stmt.mc
|
||||
|
||||
mc.affectedRows = 0
|
||||
mc.insertId = 0
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := mc.discardResults(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(mc.affectedRows),
|
||||
insertId: int64(mc.insertId),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return stmt.query(args)
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
||||
if stmt.mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, stmt.mc.markBadConn(err)
|
||||
}
|
||||
|
||||
mc := stmt.mc
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows := new(binaryRows)
|
||||
|
||||
if resLen > 0 {
|
||||
rows.mc = mc
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
} else {
|
||||
rows.rs.done = true
|
||||
|
||||
switch err := rows.NextResultSet(); err {
|
||||
case nil, io.EOF:
|
||||
return rows, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
var jsonType = reflect.TypeOf(json.RawMessage{})
|
||||
|
||||
type converter struct{}
|
||||
|
||||
// ConvertValue mirrors the reference/default converter in database/sql/driver
|
||||
// with _one_ exception. We support uint64 with their high bit and the default
|
||||
// implementation does not. This function should be kept in sync with
|
||||
// database/sql/driver defaultConverter.ConvertValue() except for that
|
||||
// deliberate difference.
|
||||
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
||||
if driver.IsValue(v) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
if vr, ok := v.(driver.Valuer); ok {
|
||||
sv, err := callValuerValue(vr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if driver.IsValue(sv) {
|
||||
return sv, nil
|
||||
}
|
||||
// A value returend from the Valuer interface can be "a type handled by
|
||||
// a database driver's NamedValueChecker interface" so we should accept
|
||||
// uint64 here as well.
|
||||
if u, ok := sv.(uint64); ok {
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
} else {
|
||||
return c.ConvertValue(rv.Elem().Interface())
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return rv.Int(), nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return rv.Uint(), nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return rv.Float(), nil
|
||||
case reflect.Bool:
|
||||
return rv.Bool(), nil
|
||||
case reflect.Slice:
|
||||
switch t := rv.Type(); {
|
||||
case t == jsonType:
|
||||
return v, nil
|
||||
case t.Elem().Kind() == reflect.Uint8:
|
||||
return rv.Bytes(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
|
||||
}
|
||||
case reflect.String:
|
||||
return rv.String(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
|
||||
}
|
||||
|
||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
|
||||
// callValuerValue returns vr.Value(), with one exception:
|
||||
// If vr.Value is an auto-generated method on a pointer type and the
|
||||
// pointer is nil, it would panic at runtime in the panicwrap
|
||||
// method. Treat it like nil instead.
|
||||
//
|
||||
// This is so people can implement driver.Value on value types and
|
||||
// still use nil pointers to those types to mean nil/NULL, just like
|
||||
// string/*string.
|
||||
//
|
||||
// This is an exact copy of the same-named unexported function from the
|
||||
// database/sql package.
|
||||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
|
||||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
|
||||
rv.IsNil() &&
|
||||
rv.Type().Elem().Implements(valuerReflectType) {
|
||||
return nil, nil
|
||||
}
|
||||
return vr.Value()
|
||||
}
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
type mysqlTx struct {
|
||||
mc *mysqlConn
|
||||
}
|
||||
|
||||
func (tx *mysqlTx) Commit() (err error) {
|
||||
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||
return ErrInvalidConn
|
||||
}
|
||||
err = tx.mc.exec("COMMIT")
|
||||
tx.mc = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (tx *mysqlTx) Rollback() (err error) {
|
||||
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||
return ErrInvalidConn
|
||||
}
|
||||
err = tx.mc.exec("ROLLBACK")
|
||||
tx.mc = nil
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,868 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Registry for custom tls.Configs
|
||||
var (
|
||||
tlsConfigLock sync.RWMutex
|
||||
tlsConfigRegistry map[string]*tls.Config
|
||||
)
|
||||
|
||||
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
|
||||
// Use the key as a value in the DSN where tls=value.
|
||||
//
|
||||
// Note: The provided tls.Config is exclusively owned by the driver after
|
||||
// registering it.
|
||||
//
|
||||
// rootCertPool := x509.NewCertPool()
|
||||
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||
// log.Fatal("Failed to append PEM.")
|
||||
// }
|
||||
// clientCert := make([]tls.Certificate, 0, 1)
|
||||
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// clientCert = append(clientCert, certs)
|
||||
// mysql.RegisterTLSConfig("custom", &tls.Config{
|
||||
// RootCAs: rootCertPool,
|
||||
// Certificates: clientCert,
|
||||
// })
|
||||
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
|
||||
//
|
||||
func RegisterTLSConfig(key string, config *tls.Config) error {
|
||||
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
|
||||
return fmt.Errorf("key '%s' is reserved", key)
|
||||
}
|
||||
|
||||
tlsConfigLock.Lock()
|
||||
if tlsConfigRegistry == nil {
|
||||
tlsConfigRegistry = make(map[string]*tls.Config)
|
||||
}
|
||||
|
||||
tlsConfigRegistry[key] = config
|
||||
tlsConfigLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeregisterTLSConfig removes the tls.Config associated with key.
|
||||
func DeregisterTLSConfig(key string) {
|
||||
tlsConfigLock.Lock()
|
||||
if tlsConfigRegistry != nil {
|
||||
delete(tlsConfigRegistry, key)
|
||||
}
|
||||
tlsConfigLock.Unlock()
|
||||
}
|
||||
|
||||
func getTLSConfigClone(key string) (config *tls.Config) {
|
||||
tlsConfigLock.RLock()
|
||||
if v, ok := tlsConfigRegistry[key]; ok {
|
||||
config = v.Clone()
|
||||
}
|
||||
tlsConfigLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Returns the bool value of the input.
|
||||
// The 2nd return value indicates if the input was a valid bool value
|
||||
func readBool(input string) (value bool, valid bool) {
|
||||
switch input {
|
||||
case "1", "true", "TRUE", "True":
|
||||
return true, true
|
||||
case "0", "false", "FALSE", "False":
|
||||
return false, true
|
||||
}
|
||||
|
||||
// Not a valid bool value
|
||||
return
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Time related utils *
|
||||
******************************************************************************/
|
||||
|
||||
func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
|
||||
const base = "0000-00-00 00:00:00.000000"
|
||||
switch len(b) {
|
||||
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
|
||||
if string(b) == base[:len(b)] {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
year, err := parseByteYear(b)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if year <= 0 {
|
||||
year = 1
|
||||
}
|
||||
|
||||
if b[4] != '-' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
|
||||
}
|
||||
|
||||
m, err := parseByte2Digits(b[5], b[6])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if m <= 0 {
|
||||
m = 1
|
||||
}
|
||||
month := time.Month(m)
|
||||
|
||||
if b[7] != '-' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
|
||||
}
|
||||
|
||||
day, err := parseByte2Digits(b[8], b[9])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if day <= 0 {
|
||||
day = 1
|
||||
}
|
||||
if len(b) == 10 {
|
||||
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
|
||||
}
|
||||
|
||||
if b[10] != ' ' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
|
||||
}
|
||||
|
||||
hour, err := parseByte2Digits(b[11], b[12])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if b[13] != ':' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
|
||||
}
|
||||
|
||||
min, err := parseByte2Digits(b[14], b[15])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if b[16] != ':' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
|
||||
}
|
||||
|
||||
sec, err := parseByte2Digits(b[17], b[18])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if len(b) == 19 {
|
||||
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
|
||||
}
|
||||
|
||||
if b[19] != '.' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
|
||||
}
|
||||
nsec, err := parseByteNanoSec(b[20:])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
|
||||
default:
|
||||
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
|
||||
}
|
||||
}
|
||||
|
||||
func parseByteYear(b []byte) (int, error) {
|
||||
year, n := 0, 1000
|
||||
for i := 0; i < 4; i++ {
|
||||
v, err := bToi(b[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
year += v * n
|
||||
n = n / 10
|
||||
}
|
||||
return year, nil
|
||||
}
|
||||
|
||||
func parseByte2Digits(b1, b2 byte) (int, error) {
|
||||
d1, err := bToi(b1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
d2, err := bToi(b2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return d1*10 + d2, nil
|
||||
}
|
||||
|
||||
func parseByteNanoSec(b []byte) (int, error) {
|
||||
ns, digit := 0, 100000 // max is 6-digits
|
||||
for i := 0; i < len(b); i++ {
|
||||
v, err := bToi(b[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ns += v * digit
|
||||
digit /= 10
|
||||
}
|
||||
// nanoseconds has 10-digits. (needs to scale digits)
|
||||
// 10 - 6 = 4, so we have to multiple 1000.
|
||||
return ns * 1000, nil
|
||||
}
|
||||
|
||||
func bToi(b byte) (int, error) {
|
||||
if b < '0' || b > '9' {
|
||||
return 0, errors.New("not [0-9]")
|
||||
}
|
||||
return int(b - '0'), nil
|
||||
}
|
||||
|
||||
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
|
||||
switch num {
|
||||
case 0:
|
||||
return time.Time{}, nil
|
||||
case 4:
|
||||
return time.Date(
|
||||
int(binary.LittleEndian.Uint16(data[:2])), // year
|
||||
time.Month(data[2]), // month
|
||||
int(data[3]), // day
|
||||
0, 0, 0, 0,
|
||||
loc,
|
||||
), nil
|
||||
case 7:
|
||||
return time.Date(
|
||||
int(binary.LittleEndian.Uint16(data[:2])), // year
|
||||
time.Month(data[2]), // month
|
||||
int(data[3]), // day
|
||||
int(data[4]), // hour
|
||||
int(data[5]), // minutes
|
||||
int(data[6]), // seconds
|
||||
0,
|
||||
loc,
|
||||
), nil
|
||||
case 11:
|
||||
return time.Date(
|
||||
int(binary.LittleEndian.Uint16(data[:2])), // year
|
||||
time.Month(data[2]), // month
|
||||
int(data[3]), // day
|
||||
int(data[4]), // hour
|
||||
int(data[5]), // minutes
|
||||
int(data[6]), // seconds
|
||||
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
|
||||
loc,
|
||||
), nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
|
||||
}
|
||||
|
||||
func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
|
||||
year, month, day := t.Date()
|
||||
hour, min, sec := t.Clock()
|
||||
nsec := t.Nanosecond()
|
||||
|
||||
if year < 1 || year > 9999 {
|
||||
return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
|
||||
}
|
||||
year100 := year / 100
|
||||
year1 := year % 100
|
||||
|
||||
var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
|
||||
localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
|
||||
localBuf[4] = '-'
|
||||
localBuf[5], localBuf[6] = digits10[month], digits01[month]
|
||||
localBuf[7] = '-'
|
||||
localBuf[8], localBuf[9] = digits10[day], digits01[day]
|
||||
|
||||
if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
|
||||
return append(buf, localBuf[:10]...), nil
|
||||
}
|
||||
|
||||
localBuf[10] = ' '
|
||||
localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
|
||||
localBuf[13] = ':'
|
||||
localBuf[14], localBuf[15] = digits10[min], digits01[min]
|
||||
localBuf[16] = ':'
|
||||
localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
|
||||
|
||||
if nsec == 0 {
|
||||
return append(buf, localBuf[:19]...), nil
|
||||
}
|
||||
nsec100000000 := nsec / 100000000
|
||||
nsec1000000 := (nsec / 1000000) % 100
|
||||
nsec10000 := (nsec / 10000) % 100
|
||||
nsec100 := (nsec / 100) % 100
|
||||
nsec1 := nsec % 100
|
||||
localBuf[19] = '.'
|
||||
|
||||
// milli second
|
||||
localBuf[20], localBuf[21], localBuf[22] =
|
||||
digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
|
||||
// micro second
|
||||
localBuf[23], localBuf[24], localBuf[25] =
|
||||
digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
|
||||
// nano second
|
||||
localBuf[26], localBuf[27], localBuf[28] =
|
||||
digits01[nsec100], digits10[nsec1], digits01[nsec1]
|
||||
|
||||
// trim trailing zeros
|
||||
n := len(localBuf)
|
||||
for n > 0 && localBuf[n-1] == '0' {
|
||||
n--
|
||||
}
|
||||
|
||||
return append(buf, localBuf[:n]...), nil
|
||||
}
|
||||
|
||||
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
|
||||
// if the DATE or DATETIME has the zero value.
|
||||
// It must never be changed.
|
||||
// The current behavior depends on database/sql copying the result.
|
||||
var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
|
||||
|
||||
const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
|
||||
const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
|
||||
|
||||
func appendMicrosecs(dst, src []byte, decimals int) []byte {
|
||||
if decimals <= 0 {
|
||||
return dst
|
||||
}
|
||||
if len(src) == 0 {
|
||||
return append(dst, ".000000"[:decimals+1]...)
|
||||
}
|
||||
|
||||
microsecs := binary.LittleEndian.Uint32(src[:4])
|
||||
p1 := byte(microsecs / 10000)
|
||||
microsecs -= 10000 * uint32(p1)
|
||||
p2 := byte(microsecs / 100)
|
||||
microsecs -= 100 * uint32(p2)
|
||||
p3 := byte(microsecs)
|
||||
|
||||
switch decimals {
|
||||
default:
|
||||
return append(dst, '.',
|
||||
digits10[p1], digits01[p1],
|
||||
digits10[p2], digits01[p2],
|
||||
digits10[p3], digits01[p3],
|
||||
)
|
||||
case 1:
|
||||
return append(dst, '.',
|
||||
digits10[p1],
|
||||
)
|
||||
case 2:
|
||||
return append(dst, '.',
|
||||
digits10[p1], digits01[p1],
|
||||
)
|
||||
case 3:
|
||||
return append(dst, '.',
|
||||
digits10[p1], digits01[p1],
|
||||
digits10[p2],
|
||||
)
|
||||
case 4:
|
||||
return append(dst, '.',
|
||||
digits10[p1], digits01[p1],
|
||||
digits10[p2], digits01[p2],
|
||||
)
|
||||
case 5:
|
||||
return append(dst, '.',
|
||||
digits10[p1], digits01[p1],
|
||||
digits10[p2], digits01[p2],
|
||||
digits10[p3],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
|
||||
// length expects the deterministic length of the zero value,
|
||||
// negative time and 100+ hours are automatically added if needed
|
||||
if len(src) == 0 {
|
||||
return zeroDateTime[:length], nil
|
||||
}
|
||||
var dst []byte // return value
|
||||
var p1, p2, p3 byte // current digit pair
|
||||
|
||||
switch length {
|
||||
case 10, 19, 21, 22, 23, 24, 25, 26:
|
||||
default:
|
||||
t := "DATE"
|
||||
if length > 10 {
|
||||
t += "TIME"
|
||||
}
|
||||
return nil, fmt.Errorf("illegal %s length %d", t, length)
|
||||
}
|
||||
switch len(src) {
|
||||
case 4, 7, 11:
|
||||
default:
|
||||
t := "DATE"
|
||||
if length > 10 {
|
||||
t += "TIME"
|
||||
}
|
||||
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
|
||||
}
|
||||
dst = make([]byte, 0, length)
|
||||
// start with the date
|
||||
year := binary.LittleEndian.Uint16(src[:2])
|
||||
pt := year / 100
|
||||
p1 = byte(year - 100*uint16(pt))
|
||||
p2, p3 = src[2], src[3]
|
||||
dst = append(dst,
|
||||
digits10[pt], digits01[pt],
|
||||
digits10[p1], digits01[p1], '-',
|
||||
digits10[p2], digits01[p2], '-',
|
||||
digits10[p3], digits01[p3],
|
||||
)
|
||||
if length == 10 {
|
||||
return dst, nil
|
||||
}
|
||||
if len(src) == 4 {
|
||||
return append(dst, zeroDateTime[10:length]...), nil
|
||||
}
|
||||
dst = append(dst, ' ')
|
||||
p1 = src[4] // hour
|
||||
src = src[5:]
|
||||
|
||||
// p1 is 2-digit hour, src is after hour
|
||||
p2, p3 = src[0], src[1]
|
||||
dst = append(dst,
|
||||
digits10[p1], digits01[p1], ':',
|
||||
digits10[p2], digits01[p2], ':',
|
||||
digits10[p3], digits01[p3],
|
||||
)
|
||||
return appendMicrosecs(dst, src[2:], int(length)-20), nil
|
||||
}
|
||||
|
||||
func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
|
||||
// length expects the deterministic length of the zero value,
|
||||
// negative time and 100+ hours are automatically added if needed
|
||||
if len(src) == 0 {
|
||||
return zeroDateTime[11 : 11+length], nil
|
||||
}
|
||||
var dst []byte // return value
|
||||
|
||||
switch length {
|
||||
case
|
||||
8, // time (can be up to 10 when negative and 100+ hours)
|
||||
10, 11, 12, 13, 14, 15: // time with fractional seconds
|
||||
default:
|
||||
return nil, fmt.Errorf("illegal TIME length %d", length)
|
||||
}
|
||||
switch len(src) {
|
||||
case 8, 12:
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
|
||||
}
|
||||
// +2 to enable negative time and 100+ hours
|
||||
dst = make([]byte, 0, length+2)
|
||||
if src[0] == 1 {
|
||||
dst = append(dst, '-')
|
||||
}
|
||||
days := binary.LittleEndian.Uint32(src[1:5])
|
||||
hours := int64(days)*24 + int64(src[5])
|
||||
|
||||
if hours >= 100 {
|
||||
dst = strconv.AppendInt(dst, hours, 10)
|
||||
} else {
|
||||
dst = append(dst, digits10[hours], digits01[hours])
|
||||
}
|
||||
|
||||
min, sec := src[6], src[7]
|
||||
dst = append(dst, ':',
|
||||
digits10[min], digits01[min], ':',
|
||||
digits10[sec], digits01[sec],
|
||||
)
|
||||
return appendMicrosecs(dst, src[8:], int(length)-9), nil
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Convert from and to bytes *
|
||||
******************************************************************************/
|
||||
|
||||
func uint64ToBytes(n uint64) []byte {
|
||||
return []byte{
|
||||
byte(n),
|
||||
byte(n >> 8),
|
||||
byte(n >> 16),
|
||||
byte(n >> 24),
|
||||
byte(n >> 32),
|
||||
byte(n >> 40),
|
||||
byte(n >> 48),
|
||||
byte(n >> 56),
|
||||
}
|
||||
}
|
||||
|
||||
func uint64ToString(n uint64) []byte {
|
||||
var a [20]byte
|
||||
i := 20
|
||||
|
||||
// U+0030 = 0
|
||||
// ...
|
||||
// U+0039 = 9
|
||||
|
||||
var q uint64
|
||||
for n >= 10 {
|
||||
i--
|
||||
q = n / 10
|
||||
a[i] = uint8(n-q*10) + 0x30
|
||||
n = q
|
||||
}
|
||||
|
||||
i--
|
||||
a[i] = uint8(n) + 0x30
|
||||
|
||||
return a[i:]
|
||||
}
|
||||
|
||||
// treats string value as unsigned integer representation
|
||||
func stringToInt(b []byte) int {
|
||||
val := 0
|
||||
for i := range b {
|
||||
val *= 10
|
||||
val += int(b[i] - 0x30)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// returns the string read as a bytes slice, wheter the value is NULL,
|
||||
// the number of bytes read and an error, in case the string is longer than
|
||||
// the input slice
|
||||
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
|
||||
// Get length
|
||||
num, isNull, n := readLengthEncodedInteger(b)
|
||||
if num < 1 {
|
||||
return b[n:n], isNull, n, nil
|
||||
}
|
||||
|
||||
n += int(num)
|
||||
|
||||
// Check data length
|
||||
if len(b) >= n {
|
||||
return b[n-int(num) : n : n], false, n, nil
|
||||
}
|
||||
return nil, false, n, io.EOF
|
||||
}
|
||||
|
||||
// returns the number of bytes skipped and an error, in case the string is
|
||||
// longer than the input slice
|
||||
func skipLengthEncodedString(b []byte) (int, error) {
|
||||
// Get length
|
||||
num, _, n := readLengthEncodedInteger(b)
|
||||
if num < 1 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
n += int(num)
|
||||
|
||||
// Check data length
|
||||
if len(b) >= n {
|
||||
return n, nil
|
||||
}
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
// returns the number read, whether the value is NULL and the number of bytes read
|
||||
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
|
||||
// See issue #349
|
||||
if len(b) == 0 {
|
||||
return 0, true, 1
|
||||
}
|
||||
|
||||
switch b[0] {
|
||||
// 251: NULL
|
||||
case 0xfb:
|
||||
return 0, true, 1
|
||||
|
||||
// 252: value of following 2
|
||||
case 0xfc:
|
||||
return uint64(b[1]) | uint64(b[2])<<8, false, 3
|
||||
|
||||
// 253: value of following 3
|
||||
case 0xfd:
|
||||
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
|
||||
|
||||
// 254: value of following 8
|
||||
case 0xfe:
|
||||
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
|
||||
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
|
||||
uint64(b[7])<<48 | uint64(b[8])<<56,
|
||||
false, 9
|
||||
}
|
||||
|
||||
// 0-250: value of first byte
|
||||
return uint64(b[0]), false, 1
|
||||
}
|
||||
|
||||
// encodes a uint64 value and appends it to the given bytes slice
|
||||
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
|
||||
switch {
|
||||
case n <= 250:
|
||||
return append(b, byte(n))
|
||||
|
||||
case n <= 0xffff:
|
||||
return append(b, 0xfc, byte(n), byte(n>>8))
|
||||
|
||||
case n <= 0xffffff:
|
||||
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
|
||||
}
|
||||
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
|
||||
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
|
||||
}
|
||||
|
||||
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
|
||||
// If cap(buf) is not enough, reallocate new buffer.
|
||||
func reserveBuffer(buf []byte, appendSize int) []byte {
|
||||
newSize := len(buf) + appendSize
|
||||
if cap(buf) < newSize {
|
||||
// Grow buffer exponentially
|
||||
newBuf := make([]byte, len(buf)*2+appendSize)
|
||||
copy(newBuf, buf)
|
||||
buf = newBuf
|
||||
}
|
||||
return buf[:newSize]
|
||||
}
|
||||
|
||||
// escapeBytesBackslash escapes []byte with backslashes (\)
|
||||
// This escapes the contents of a string (provided as []byte) by adding backslashes before special
|
||||
// characters, and turning others into specific escape sequences, such as
|
||||
// turning newlines into \n and null bytes into \0.
|
||||
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
|
||||
func escapeBytesBackslash(buf, v []byte) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
|
||||
for _, c := range v {
|
||||
switch c {
|
||||
case '\x00':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '0'
|
||||
pos += 2
|
||||
case '\n':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'n'
|
||||
pos += 2
|
||||
case '\r':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'r'
|
||||
pos += 2
|
||||
case '\x1a':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'Z'
|
||||
pos += 2
|
||||
case '\'':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '\''
|
||||
pos += 2
|
||||
case '"':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '"'
|
||||
pos += 2
|
||||
case '\\':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '\\'
|
||||
pos += 2
|
||||
default:
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
|
||||
func escapeStringBackslash(buf []byte, v string) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
|
||||
for i := 0; i < len(v); i++ {
|
||||
c := v[i]
|
||||
switch c {
|
||||
case '\x00':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '0'
|
||||
pos += 2
|
||||
case '\n':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'n'
|
||||
pos += 2
|
||||
case '\r':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'r'
|
||||
pos += 2
|
||||
case '\x1a':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'Z'
|
||||
pos += 2
|
||||
case '\'':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '\''
|
||||
pos += 2
|
||||
case '"':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '"'
|
||||
pos += 2
|
||||
case '\\':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '\\'
|
||||
pos += 2
|
||||
default:
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
|
||||
// This escapes the contents of a string by doubling up any apostrophes that
|
||||
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
|
||||
// effect on the server.
|
||||
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
|
||||
func escapeBytesQuotes(buf, v []byte) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
|
||||
for _, c := range v {
|
||||
if c == '\'' {
|
||||
buf[pos] = '\''
|
||||
buf[pos+1] = '\''
|
||||
pos += 2
|
||||
} else {
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
|
||||
func escapeStringQuotes(buf []byte, v string) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
|
||||
for i := 0; i < len(v); i++ {
|
||||
c := v[i]
|
||||
if c == '\'' {
|
||||
buf[pos] = '\''
|
||||
buf[pos+1] = '\''
|
||||
pos += 2
|
||||
} else {
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Sync utils *
|
||||
******************************************************************************/
|
||||
|
||||
// noCopy may be embedded into structs which must not be copied
|
||||
// after the first use.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
|
||||
// for details.
|
||||
type noCopy struct{}
|
||||
|
||||
// Lock is a no-op used by -copylocks checker from `go vet`.
|
||||
func (*noCopy) Lock() {}
|
||||
|
||||
// atomicBool is a wrapper around uint32 for usage as a boolean value with
|
||||
// atomic access.
|
||||
type atomicBool struct {
|
||||
_noCopy noCopy
|
||||
value uint32
|
||||
}
|
||||
|
||||
// IsSet returns whether the current boolean value is true
|
||||
func (ab *atomicBool) IsSet() bool {
|
||||
return atomic.LoadUint32(&ab.value) > 0
|
||||
}
|
||||
|
||||
// Set sets the value of the bool regardless of the previous value
|
||||
func (ab *atomicBool) Set(value bool) {
|
||||
if value {
|
||||
atomic.StoreUint32(&ab.value, 1)
|
||||
} else {
|
||||
atomic.StoreUint32(&ab.value, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// TrySet sets the value of the bool and returns whether the value changed
|
||||
func (ab *atomicBool) TrySet(value bool) bool {
|
||||
if value {
|
||||
return atomic.SwapUint32(&ab.value, 1) == 0
|
||||
}
|
||||
return atomic.SwapUint32(&ab.value, 0) > 0
|
||||
}
|
||||
|
||||
// atomicError is a wrapper for atomically accessed error values
|
||||
type atomicError struct {
|
||||
_noCopy noCopy
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
// Set sets the error value regardless of the previous value.
|
||||
// The value must not be nil
|
||||
func (ae *atomicError) Set(value error) {
|
||||
ae.value.Store(value)
|
||||
}
|
||||
|
||||
// Value returns the current error value
|
||||
func (ae *atomicError) Value() error {
|
||||
if v := ae.value.Load(); v != nil {
|
||||
// this will panic if the value doesn't implement the error interface
|
||||
return v.(error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
|
||||
dargs := make([]driver.Value, len(named))
|
||||
for n, param := range named {
|
||||
if len(param.Name) > 0 {
|
||||
// TODO: support the use of Named Parameters #561
|
||||
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
|
||||
}
|
||||
dargs[n] = param.Value
|
||||
}
|
||||
return dargs, nil
|
||||
}
|
||||
|
||||
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
|
||||
switch sql.IsolationLevel(level) {
|
||||
case sql.LevelRepeatableRead:
|
||||
return "REPEATABLE READ", nil
|
||||
case sql.LevelReadCommitted:
|
||||
return "READ COMMITTED", nil
|
||||
case sql.LevelReadUncommitted:
|
||||
return "READ UNCOMMITTED", nil
|
||||
case sql.LevelSerializable:
|
||||
return "SERIALIZABLE", nil
|
||||
default:
|
||||
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
_*
|
||||
cover*.out
|
||||
cover*.txt
|
||||
|
|
@ -0,0 +1,373 @@
|
|||
Mozilla Public License Version 2.0
|
||||
==================================
|
||||
|
||||
1. Definitions
|
||||
--------------
|
||||
|
||||
1.1. "Contributor"
|
||||
means each individual or legal entity that creates, contributes to
|
||||
the creation of, or owns Covered Software.
|
||||
|
||||
1.2. "Contributor Version"
|
||||
means the combination of the Contributions of others (if any) used
|
||||
by a Contributor and that particular Contributor's Contribution.
|
||||
|
||||
1.3. "Contribution"
|
||||
means Covered Software of a particular Contributor.
|
||||
|
||||
1.4. "Covered Software"
|
||||
means Source Code Form to which the initial Contributor has attached
|
||||
the notice in Exhibit A, the Executable Form of such Source Code
|
||||
Form, and Modifications of such Source Code Form, in each case
|
||||
including portions thereof.
|
||||
|
||||
1.5. "Incompatible With Secondary Licenses"
|
||||
means
|
||||
|
||||
(a) that the initial Contributor has attached the notice described
|
||||
in Exhibit B to the Covered Software; or
|
||||
|
||||
(b) that the Covered Software was made available under the terms of
|
||||
version 1.1 or earlier of the License, but not also under the
|
||||
terms of a Secondary License.
|
||||
|
||||
1.6. "Executable Form"
|
||||
means any form of the work other than Source Code Form.
|
||||
|
||||
1.7. "Larger Work"
|
||||
means a work that combines Covered Software with other material, in
|
||||
a separate file or files, that is not Covered Software.
|
||||
|
||||
1.8. "License"
|
||||
means this document.
|
||||
|
||||
1.9. "Licensable"
|
||||
means having the right to grant, to the maximum extent possible,
|
||||
whether at the time of the initial grant or subsequently, any and
|
||||
all of the rights conveyed by this License.
|
||||
|
||||
1.10. "Modifications"
|
||||
means any of the following:
|
||||
|
||||
(a) any file in Source Code Form that results from an addition to,
|
||||
deletion from, or modification of the contents of Covered
|
||||
Software; or
|
||||
|
||||
(b) any new file in Source Code Form that contains any Covered
|
||||
Software.
|
||||
|
||||
1.11. "Patent Claims" of a Contributor
|
||||
means any patent claim(s), including without limitation, method,
|
||||
process, and apparatus claims, in any patent Licensable by such
|
||||
Contributor that would be infringed, but for the grant of the
|
||||
License, by the making, using, selling, offering for sale, having
|
||||
made, import, or transfer of either its Contributions or its
|
||||
Contributor Version.
|
||||
|
||||
1.12. "Secondary License"
|
||||
means either the GNU General Public License, Version 2.0, the GNU
|
||||
Lesser General Public License, Version 2.1, the GNU Affero General
|
||||
Public License, Version 3.0, or any later versions of those
|
||||
licenses.
|
||||
|
||||
1.13. "Source Code Form"
|
||||
means the form of the work preferred for making modifications.
|
||||
|
||||
1.14. "You" (or "Your")
|
||||
means an individual or a legal entity exercising rights under this
|
||||
License. For legal entities, "You" includes any entity that
|
||||
controls, is controlled by, or is under common control with You. For
|
||||
purposes of this definition, "control" means (a) the power, direct
|
||||
or indirect, to cause the direction or management of such entity,
|
||||
whether by contract or otherwise, or (b) ownership of more than
|
||||
fifty percent (50%) of the outstanding shares or beneficial
|
||||
ownership of such entity.
|
||||
|
||||
2. License Grants and Conditions
|
||||
--------------------------------
|
||||
|
||||
2.1. Grants
|
||||
|
||||
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||
non-exclusive license:
|
||||
|
||||
(a) under intellectual property rights (other than patent or trademark)
|
||||
Licensable by such Contributor to use, reproduce, make available,
|
||||
modify, display, perform, distribute, and otherwise exploit its
|
||||
Contributions, either on an unmodified basis, with Modifications, or
|
||||
as part of a Larger Work; and
|
||||
|
||||
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
||||
for sale, have made, import, and otherwise transfer either its
|
||||
Contributions or its Contributor Version.
|
||||
|
||||
2.2. Effective Date
|
||||
|
||||
The licenses granted in Section 2.1 with respect to any Contribution
|
||||
become effective for each Contribution on the date the Contributor first
|
||||
distributes such Contribution.
|
||||
|
||||
2.3. Limitations on Grant Scope
|
||||
|
||||
The licenses granted in this Section 2 are the only rights granted under
|
||||
this License. No additional rights or licenses will be implied from the
|
||||
distribution or licensing of Covered Software under this License.
|
||||
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
||||
Contributor:
|
||||
|
||||
(a) for any code that a Contributor has removed from Covered Software;
|
||||
or
|
||||
|
||||
(b) for infringements caused by: (i) Your and any other third party's
|
||||
modifications of Covered Software, or (ii) the combination of its
|
||||
Contributions with other software (except as part of its Contributor
|
||||
Version); or
|
||||
|
||||
(c) under Patent Claims infringed by Covered Software in the absence of
|
||||
its Contributions.
|
||||
|
||||
This License does not grant any rights in the trademarks, service marks,
|
||||
or logos of any Contributor (except as may be necessary to comply with
|
||||
the notice requirements in Section 3.4).
|
||||
|
||||
2.4. Subsequent Licenses
|
||||
|
||||
No Contributor makes additional grants as a result of Your choice to
|
||||
distribute the Covered Software under a subsequent version of this
|
||||
License (see Section 10.2) or under the terms of a Secondary License (if
|
||||
permitted under the terms of Section 3.3).
|
||||
|
||||
2.5. Representation
|
||||
|
||||
Each Contributor represents that the Contributor believes its
|
||||
Contributions are its original creation(s) or it has sufficient rights
|
||||
to grant the rights to its Contributions conveyed by this License.
|
||||
|
||||
2.6. Fair Use
|
||||
|
||||
This License is not intended to limit any rights You have under
|
||||
applicable copyright doctrines of fair use, fair dealing, or other
|
||||
equivalents.
|
||||
|
||||
2.7. Conditions
|
||||
|
||||
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
||||
in Section 2.1.
|
||||
|
||||
3. Responsibilities
|
||||
-------------------
|
||||
|
||||
3.1. Distribution of Source Form
|
||||
|
||||
All distribution of Covered Software in Source Code Form, including any
|
||||
Modifications that You create or to which You contribute, must be under
|
||||
the terms of this License. You must inform recipients that the Source
|
||||
Code Form of the Covered Software is governed by the terms of this
|
||||
License, and how they can obtain a copy of this License. You may not
|
||||
attempt to alter or restrict the recipients' rights in the Source Code
|
||||
Form.
|
||||
|
||||
3.2. Distribution of Executable Form
|
||||
|
||||
If You distribute Covered Software in Executable Form then:
|
||||
|
||||
(a) such Covered Software must also be made available in Source Code
|
||||
Form, as described in Section 3.1, and You must inform recipients of
|
||||
the Executable Form how they can obtain a copy of such Source Code
|
||||
Form by reasonable means in a timely manner, at a charge no more
|
||||
than the cost of distribution to the recipient; and
|
||||
|
||||
(b) You may distribute such Executable Form under the terms of this
|
||||
License, or sublicense it under different terms, provided that the
|
||||
license for the Executable Form does not attempt to limit or alter
|
||||
the recipients' rights in the Source Code Form under this License.
|
||||
|
||||
3.3. Distribution of a Larger Work
|
||||
|
||||
You may create and distribute a Larger Work under terms of Your choice,
|
||||
provided that You also comply with the requirements of this License for
|
||||
the Covered Software. If the Larger Work is a combination of Covered
|
||||
Software with a work governed by one or more Secondary Licenses, and the
|
||||
Covered Software is not Incompatible With Secondary Licenses, this
|
||||
License permits You to additionally distribute such Covered Software
|
||||
under the terms of such Secondary License(s), so that the recipient of
|
||||
the Larger Work may, at their option, further distribute the Covered
|
||||
Software under the terms of either this License or such Secondary
|
||||
License(s).
|
||||
|
||||
3.4. Notices
|
||||
|
||||
You may not remove or alter the substance of any license notices
|
||||
(including copyright notices, patent notices, disclaimers of warranty,
|
||||
or limitations of liability) contained within the Source Code Form of
|
||||
the Covered Software, except that You may alter any license notices to
|
||||
the extent required to remedy known factual inaccuracies.
|
||||
|
||||
3.5. Application of Additional Terms
|
||||
|
||||
You may choose to offer, and to charge a fee for, warranty, support,
|
||||
indemnity or liability obligations to one or more recipients of Covered
|
||||
Software. However, You may do so only on Your own behalf, and not on
|
||||
behalf of any Contributor. You must make it absolutely clear that any
|
||||
such warranty, support, indemnity, or liability obligation is offered by
|
||||
You alone, and You hereby agree to indemnify every Contributor for any
|
||||
liability incurred by such Contributor as a result of warranty, support,
|
||||
indemnity or liability terms You offer. You may include additional
|
||||
disclaimers of warranty and limitations of liability specific to any
|
||||
jurisdiction.
|
||||
|
||||
4. Inability to Comply Due to Statute or Regulation
|
||||
---------------------------------------------------
|
||||
|
||||
If it is impossible for You to comply with any of the terms of this
|
||||
License with respect to some or all of the Covered Software due to
|
||||
statute, judicial order, or regulation then You must: (a) comply with
|
||||
the terms of this License to the maximum extent possible; and (b)
|
||||
describe the limitations and the code they affect. Such description must
|
||||
be placed in a text file included with all distributions of the Covered
|
||||
Software under this License. Except to the extent prohibited by statute
|
||||
or regulation, such description must be sufficiently detailed for a
|
||||
recipient of ordinary skill to be able to understand it.
|
||||
|
||||
5. Termination
|
||||
--------------
|
||||
|
||||
5.1. The rights granted under this License will terminate automatically
|
||||
if You fail to comply with any of its terms. However, if You become
|
||||
compliant, then the rights granted under this License from a particular
|
||||
Contributor are reinstated (a) provisionally, unless and until such
|
||||
Contributor explicitly and finally terminates Your grants, and (b) on an
|
||||
ongoing basis, if such Contributor fails to notify You of the
|
||||
non-compliance by some reasonable means prior to 60 days after You have
|
||||
come back into compliance. Moreover, Your grants from a particular
|
||||
Contributor are reinstated on an ongoing basis if such Contributor
|
||||
notifies You of the non-compliance by some reasonable means, this is the
|
||||
first time You have received notice of non-compliance with this License
|
||||
from such Contributor, and You become compliant prior to 30 days after
|
||||
Your receipt of the notice.
|
||||
|
||||
5.2. If You initiate litigation against any entity by asserting a patent
|
||||
infringement claim (excluding declaratory judgment actions,
|
||||
counter-claims, and cross-claims) alleging that a Contributor Version
|
||||
directly or indirectly infringes any patent, then the rights granted to
|
||||
You by any and all Contributors for the Covered Software under Section
|
||||
2.1 of this License shall terminate.
|
||||
|
||||
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
||||
end user license agreements (excluding distributors and resellers) which
|
||||
have been validly granted by You or Your distributors under this License
|
||||
prior to termination shall survive termination.
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 6. Disclaimer of Warranty *
|
||||
* ------------------------- *
|
||||
* *
|
||||
* Covered Software is provided under this License on an "as is" *
|
||||
* basis, without warranty of any kind, either expressed, implied, or *
|
||||
* statutory, including, without limitation, warranties that the *
|
||||
* Covered Software is free of defects, merchantable, fit for a *
|
||||
* particular purpose or non-infringing. The entire risk as to the *
|
||||
* quality and performance of the Covered Software is with You. *
|
||||
* Should any Covered Software prove defective in any respect, You *
|
||||
* (not any Contributor) assume the cost of any necessary servicing, *
|
||||
* repair, or correction. This disclaimer of warranty constitutes an *
|
||||
* essential part of this License. No use of any Covered Software is *
|
||||
* authorized under this License except under this disclaimer. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 7. Limitation of Liability *
|
||||
* -------------------------- *
|
||||
* *
|
||||
* Under no circumstances and under no legal theory, whether tort *
|
||||
* (including negligence), contract, or otherwise, shall any *
|
||||
* Contributor, or anyone who distributes Covered Software as *
|
||||
* permitted above, be liable to You for any direct, indirect, *
|
||||
* special, incidental, or consequential damages of any character *
|
||||
* including, without limitation, damages for lost profits, loss of *
|
||||
* goodwill, work stoppage, computer failure or malfunction, or any *
|
||||
* and all other commercial damages or losses, even if such party *
|
||||
* shall have been informed of the possibility of such damages. This *
|
||||
* limitation of liability shall not apply to liability for death or *
|
||||
* personal injury resulting from such party's negligence to the *
|
||||
* extent applicable law prohibits such limitation. Some *
|
||||
* jurisdictions do not allow the exclusion or limitation of *
|
||||
* incidental or consequential damages, so this exclusion and *
|
||||
* limitation may not apply to You. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
8. Litigation
|
||||
-------------
|
||||
|
||||
Any litigation relating to this License may be brought only in the
|
||||
courts of a jurisdiction where the defendant maintains its principal
|
||||
place of business and such litigation shall be governed by laws of that
|
||||
jurisdiction, without reference to its conflict-of-law provisions.
|
||||
Nothing in this Section shall prevent a party's ability to bring
|
||||
cross-claims or counter-claims.
|
||||
|
||||
9. Miscellaneous
|
||||
----------------
|
||||
|
||||
This License represents the complete agreement concerning the subject
|
||||
matter hereof. If any provision of this License is held to be
|
||||
unenforceable, such provision shall be reformed only to the extent
|
||||
necessary to make it enforceable. Any law or regulation which provides
|
||||
that the language of a contract shall be construed against the drafter
|
||||
shall not be used to construe this License against a Contributor.
|
||||
|
||||
10. Versions of the License
|
||||
---------------------------
|
||||
|
||||
10.1. New Versions
|
||||
|
||||
Mozilla Foundation is the license steward. Except as provided in Section
|
||||
10.3, no one other than the license steward has the right to modify or
|
||||
publish new versions of this License. Each version will be given a
|
||||
distinguishing version number.
|
||||
|
||||
10.2. Effect of New Versions
|
||||
|
||||
You may distribute the Covered Software under the terms of the version
|
||||
of the License under which You originally received the Covered Software,
|
||||
or under the terms of any subsequent version published by the license
|
||||
steward.
|
||||
|
||||
10.3. Modified Versions
|
||||
|
||||
If you create software not governed by this License, and you want to
|
||||
create a new license for such software, you may create and use a
|
||||
modified version of this License if you rename the license and remove
|
||||
any references to the name of the license steward (except to note that
|
||||
such modified license differs from this License).
|
||||
|
||||
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
||||
Licenses
|
||||
|
||||
If You choose to distribute Source Code Form that is Incompatible With
|
||||
Secondary Licenses under the terms of this version of the License, the
|
||||
notice described in Exhibit B of this License must be attached.
|
||||
|
||||
Exhibit A - Source Code Form License Notice
|
||||
-------------------------------------------
|
||||
|
||||
This Source Code Form is subject to the terms of the Mozilla Public
|
||||
License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
If it is not possible or desirable to put the notice in a particular
|
||||
file, then You may include the notice in a location (such as a LICENSE
|
||||
file in a relevant directory) where a recipient would be likely to look
|
||||
for such a notice.
|
||||
|
||||
You may add additional accurate notices of copyright ownership.
|
||||
|
||||
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
||||
---------------------------------------------------------
|
||||
|
||||
This Source Code Form is "Incompatible With Secondary Licenses", as
|
||||
defined by the Mozilla Public License, v. 2.0.
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# slug
|
||||
|
||||
Package `slug` generate slug from Unicode string, URL-friendly slugify with
|
||||
multiple languages support.
|
||||
|
||||
[](https://pkg.go.dev/github.com/gosimple/slug)
|
||||
[](https://github.com/gosimple/slug/actions/workflows/tests.yml)
|
||||
[](https://codecov.io/gh/gosimple/slug)
|
||||
[](https://github.com/gosimple/slug/releases)
|
||||
|
||||
## Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gosimple/slug"
|
||||
)
|
||||
|
||||
func main() {
|
||||
text := slug.Make("Hellö Wörld хелло ворлд")
|
||||
fmt.Println(text) // Will print: "hello-world-khello-vorld"
|
||||
|
||||
someText := slug.Make("影師")
|
||||
fmt.Println(someText) // Will print: "ying-shi"
|
||||
|
||||
enText := slug.MakeLang("This & that", "en")
|
||||
fmt.Println(enText) // Will print: "this-and-that"
|
||||
|
||||
deText := slug.MakeLang("Diese & Dass", "de")
|
||||
fmt.Println(deText) // Will print: "diese-und-dass"
|
||||
|
||||
slug.Lowercase = false // Keep uppercase characters
|
||||
deUppercaseText := slug.MakeLang("Diese & Dass", "de")
|
||||
fmt.Println(deUppercaseText) // Will print: "Diese-und-Dass"
|
||||
|
||||
slug.CustomSub = map[string]string{
|
||||
"water": "sand",
|
||||
}
|
||||
textSub := slug.Make("water is hot")
|
||||
fmt.Println(textSub) // Will print: "sand-is-hot"
|
||||
}
|
||||
```
|
||||
|
||||
## Design
|
||||
|
||||
This library will always returns clean output from any Unicode string
|
||||
containing only the following ASCII characters:
|
||||
|
||||
* numbers: `0-9`
|
||||
* small letters: `a-z`
|
||||
* big letters: `A-Z` (only if you set `Lowercase` to `false`)
|
||||
* minus sign: `-`
|
||||
* underscore: `_`
|
||||
|
||||
Minus sign and underscore characters will never appear at the beginning or
|
||||
the end of the returned string.
|
||||
|
||||
Thanks to context-insensitive transliteration of Unicode characters to ASCII
|
||||
output returned string is safe for URL slugs and filenames.
|
||||
|
||||
## Requests or bugs?
|
||||
|
||||
<https://github.com/gosimple/slug/issues>
|
||||
|
||||
If your language is missing you could add it in `languages_substitution.go`
|
||||
file.
|
||||
|
||||
In case of missing proper Unicode characters transliteration to ASCII you could
|
||||
add them to underlying library:
|
||||
<https://github.com/gosimple/unidecode>.
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
go get -u github.com/gosimple/slug
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
```shell
|
||||
go test -run=NONE -bench=. -benchmem -count=6 ./... > old.txt
|
||||
# make changes
|
||||
go test -run=NONE -bench=. -benchmem -count=6 ./... > new.txt
|
||||
|
||||
go install golang.org/x/perf/cmd/benchstat@latest
|
||||
|
||||
benchstat old.txt new.txt
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The source files are distributed under the
|
||||
[Mozilla Public License, version 2.0](http://mozilla.org/MPL/2.0/),
|
||||
unless otherwise noted.
|
||||
Please read the [FAQ](http://www.mozilla.org/MPL/2.0/FAQ.html)
|
||||
if you have further questions regarding the license.
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
comment:
|
||||
layout: "diff, files"
|
||||
behavior: default
|
||||
require_changes: false # if true: only post the comment if coverage changes
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
/*
|
||||
Package slug generate slug from unicode string, URL-friendly slugify with
|
||||
multiple languages support.
|
||||
|
||||
Example:
|
||||
|
||||
package main
|
||||
|
||||
import(
|
||||
"github.com/gosimple/slug"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main () {
|
||||
text := slug.Make("Hellö Wörld хелло ворлд")
|
||||
fmt.Println(text) // Will print: "hello-world-khello-vorld"
|
||||
|
||||
someText := slug.Make("影師")
|
||||
fmt.Println(someText) // Will print: "ying-shi"
|
||||
|
||||
enText := slug.MakeLang("This & that", "en")
|
||||
fmt.Println(enText) // Will print: "this-and-that"
|
||||
|
||||
deText := slug.MakeLang("Diese & Dass", "de")
|
||||
fmt.Println(deText) // Will print: "diese-und-dass"
|
||||
|
||||
slug.Lowercase = false // Keep uppercase characters
|
||||
deUppercaseText := slug.MakeLang("Diese & Dass", "de")
|
||||
fmt.Println(deUppercaseText) // Will print: "Diese-und-Dass"
|
||||
|
||||
slug.CustomSub = map[string]string{
|
||||
"water": "sand",
|
||||
}
|
||||
textSub := slug.Make("water is hot")
|
||||
fmt.Println(textSub) // Will print: "sand-is-hot"
|
||||
}
|
||||
|
||||
Requests or bugs?
|
||||
|
||||
https://github.com/gosimple/slug/issues
|
||||
*/
|
||||
package slug
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package slug
|
||||
|
||||
func init() {
|
||||
// Merge language subs with the default one.
|
||||
// TODO: Find better way so all langs are merged automatically and better
|
||||
// tested.
|
||||
for _, sub := range []*map[rune]string{
|
||||
&csSub,
|
||||
&deSub,
|
||||
&enSub,
|
||||
&esSub,
|
||||
&fiSub,
|
||||
&frSub,
|
||||
&grSub,
|
||||
&huSub,
|
||||
&idSub,
|
||||
&kkSub,
|
||||
&nbSub,
|
||||
&nlSub,
|
||||
&nnSub,
|
||||
&plSub,
|
||||
&slSub,
|
||||
&svSub,
|
||||
&trSub,
|
||||
} {
|
||||
for key, value := range defaultSub {
|
||||
(*sub)[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var defaultSub = map[rune]string{
|
||||
'"': "",
|
||||
'\'': "",
|
||||
'’': "",
|
||||
'‒': "-", // figure dash
|
||||
'–': "-", // en dash
|
||||
'—': "-", // em dash
|
||||
'―': "-", // horizontal bar
|
||||
}
|
||||
|
||||
var csSub = map[rune]string{
|
||||
'&': "a",
|
||||
'@': "zavinac",
|
||||
}
|
||||
|
||||
var deSub = map[rune]string{
|
||||
'&': "und",
|
||||
'@': "an",
|
||||
'ä': "ae",
|
||||
'Ä': "Ae",
|
||||
'ö': "oe",
|
||||
'Ö': "Oe",
|
||||
'ü': "ue",
|
||||
'Ü': "Ue",
|
||||
}
|
||||
|
||||
var enSub = map[rune]string{
|
||||
'&': "and",
|
||||
'@': "at",
|
||||
}
|
||||
|
||||
var esSub = map[rune]string{
|
||||
'&': "y",
|
||||
'@': "en",
|
||||
}
|
||||
|
||||
var fiSub = map[rune]string{
|
||||
'&': "ja",
|
||||
'@': "at",
|
||||
}
|
||||
|
||||
var frSub = map[rune]string{
|
||||
'&': "et",
|
||||
'@': "arobase",
|
||||
}
|
||||
|
||||
var grSub = map[rune]string{
|
||||
'&': "kai",
|
||||
'η': "i",
|
||||
'ή': "i",
|
||||
'Η': "i",
|
||||
'ι': "i",
|
||||
'ί': "i",
|
||||
'ϊ': "i",
|
||||
'Ι': "i",
|
||||
'χ': "x",
|
||||
'Χ': "x",
|
||||
'ω': "w",
|
||||
'ώ': "w",
|
||||
'Ω': "w",
|
||||
'ϋ': "u",
|
||||
}
|
||||
|
||||
var huSub = map[rune]string{
|
||||
'á': "a",
|
||||
'Á': "A",
|
||||
'é': "e",
|
||||
'É': "E",
|
||||
'í': "i",
|
||||
'Í': "I",
|
||||
'ó': "o",
|
||||
'Ó': "O",
|
||||
'ö': "o",
|
||||
'Ö': "O",
|
||||
'ő': "o",
|
||||
'Ő': "O",
|
||||
'ú': "u",
|
||||
'Ú': "U",
|
||||
'ü': "u",
|
||||
'Ü': "U",
|
||||
'ű': "u",
|
||||
'Ű': "U",
|
||||
}
|
||||
|
||||
var idSub = map[rune]string{
|
||||
'&': "dan",
|
||||
}
|
||||
|
||||
var kkSub = map[rune]string{
|
||||
'&': "jane",
|
||||
'ә': "a",
|
||||
'ғ': "g",
|
||||
'қ': "q",
|
||||
'ң': "n",
|
||||
'ө': "o",
|
||||
'ұ': "u",
|
||||
'Ә': "A",
|
||||
'Ғ': "G",
|
||||
'Қ': "Q",
|
||||
'Ң': "N",
|
||||
'Ө': "O",
|
||||
'Ұ': "U",
|
||||
}
|
||||
|
||||
var nbSub = map[rune]string{
|
||||
'&': "og",
|
||||
'@': "at",
|
||||
'æ': "ae",
|
||||
'ø': "oe",
|
||||
'å': "aa",
|
||||
'Æ': "Ae",
|
||||
'Ø': "Oe",
|
||||
'Å': "Aa",
|
||||
}
|
||||
|
||||
// Norwegian Nynorsk has the same rules
|
||||
var nnSub = nbSub
|
||||
|
||||
var nlSub = map[rune]string{
|
||||
'&': "en",
|
||||
'@': "at",
|
||||
}
|
||||
|
||||
var plSub = map[rune]string{
|
||||
'&': "i",
|
||||
'@': "na",
|
||||
}
|
||||
|
||||
var slSub = map[rune]string{
|
||||
'&': "in",
|
||||
'Đ': "DZ",
|
||||
'đ': "dz",
|
||||
}
|
||||
|
||||
var svSub = map[rune]string{
|
||||
'&': "och",
|
||||
'@': "snabel a",
|
||||
}
|
||||
|
||||
var trSub = map[rune]string{
|
||||
'&': "ve",
|
||||
'@': "et",
|
||||
'ş': "s",
|
||||
'Ş': "S",
|
||||
'ü': "u",
|
||||
'Ü': "U",
|
||||
'ö': "o",
|
||||
'Ö': "O",
|
||||
'İ': "I",
|
||||
'ı': "i",
|
||||
'ğ': "g",
|
||||
'Ğ': "G",
|
||||
'ç': "c",
|
||||
'Ç': "C",
|
||||
}
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package slug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gosimple/unidecode"
|
||||
)
|
||||
|
||||
var (
|
||||
// CustomSub stores custom substitution map
|
||||
CustomSub map[string]string
|
||||
// CustomRuneSub stores custom rune substitution map
|
||||
CustomRuneSub map[rune]string
|
||||
|
||||
// MaxLength stores maximum slug length.
|
||||
// It's smart so it will cat slug after full word.
|
||||
// By default slugs aren't shortened.
|
||||
// If MaxLength is smaller than length of the first word, then returned
|
||||
// slug will contain only substring from the first word truncated
|
||||
// after MaxLength.
|
||||
MaxLength int
|
||||
|
||||
// Lowercase defines if the resulting slug is transformed to lowercase.
|
||||
// Default is true.
|
||||
Lowercase = true
|
||||
|
||||
regexpNonAuthorizedChars = regexp.MustCompile("[^a-zA-Z0-9-_]")
|
||||
regexpMultipleDashes = regexp.MustCompile("-+")
|
||||
)
|
||||
|
||||
//=============================================================================
|
||||
|
||||
// Make returns slug generated from provided string. Will use "en" as language
|
||||
// substitution.
|
||||
func Make(s string) (slug string) {
|
||||
return MakeLang(s, "en")
|
||||
}
|
||||
|
||||
// MakeLang returns slug generated from provided string and will use provided
|
||||
// language for chars substitution.
|
||||
func MakeLang(s string, lang string) (slug string) {
|
||||
slug = strings.TrimSpace(s)
|
||||
|
||||
// Custom substitutions
|
||||
// Always substitute runes first
|
||||
slug = SubstituteRune(slug, CustomRuneSub)
|
||||
slug = Substitute(slug, CustomSub)
|
||||
|
||||
// Process string with selected substitution language.
|
||||
// Catch ISO 3166-1, ISO 639-1:2002 and ISO 639-3:2007.
|
||||
switch strings.ToLower(lang) {
|
||||
case "cs", "ces":
|
||||
slug = SubstituteRune(slug, csSub)
|
||||
case "de", "deu":
|
||||
slug = SubstituteRune(slug, deSub)
|
||||
case "en", "eng":
|
||||
slug = SubstituteRune(slug, enSub)
|
||||
case "es", "spa":
|
||||
slug = SubstituteRune(slug, esSub)
|
||||
case "fi", "fin":
|
||||
slug = SubstituteRune(slug, fiSub)
|
||||
case "fr", "fra":
|
||||
slug = SubstituteRune(slug, frSub)
|
||||
case "gr", "el", "ell":
|
||||
slug = SubstituteRune(slug, grSub)
|
||||
case "hu", "hun":
|
||||
slug = SubstituteRune(slug, huSub)
|
||||
case "id", "idn", "ind":
|
||||
slug = SubstituteRune(slug, idSub)
|
||||
case "kz", "kk", "kaz":
|
||||
slug = SubstituteRune(slug, kkSub)
|
||||
case "nb", "nob":
|
||||
slug = SubstituteRune(slug, nbSub)
|
||||
case "nl", "nld":
|
||||
slug = SubstituteRune(slug, nlSub)
|
||||
case "nn", "nno":
|
||||
slug = SubstituteRune(slug, nnSub)
|
||||
case "pl", "pol":
|
||||
slug = SubstituteRune(slug, plSub)
|
||||
case "sl", "slv":
|
||||
slug = SubstituteRune(slug, slSub)
|
||||
case "sv", "swe":
|
||||
slug = SubstituteRune(slug, svSub)
|
||||
case "tr", "tur":
|
||||
slug = SubstituteRune(slug, trSub)
|
||||
default: // fallback to "en" if lang not found
|
||||
slug = SubstituteRune(slug, enSub)
|
||||
}
|
||||
|
||||
// Process all non ASCII symbols
|
||||
slug = unidecode.Unidecode(slug)
|
||||
|
||||
if Lowercase {
|
||||
slug = strings.ToLower(slug)
|
||||
}
|
||||
|
||||
// Process all remaining symbols
|
||||
slug = regexpNonAuthorizedChars.ReplaceAllString(slug, "-")
|
||||
slug = regexpMultipleDashes.ReplaceAllString(slug, "-")
|
||||
slug = strings.Trim(slug, "-_")
|
||||
|
||||
if MaxLength > 0 {
|
||||
slug = smartTruncate(slug)
|
||||
}
|
||||
|
||||
return slug
|
||||
}
|
||||
|
||||
// Substitute returns string with superseded all substrings from
|
||||
// provided substitution map. Substitution map will be applied in alphabetic
|
||||
// order. Many passes, on one substitution another one could apply.
|
||||
func Substitute(s string, sub map[string]string) (buf string) {
|
||||
buf = s
|
||||
var keys []string
|
||||
for k := range sub {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
buf = strings.Replace(buf, key, sub[key], -1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SubstituteRune substitutes string chars with provided rune
|
||||
// substitution map. One pass.
|
||||
func SubstituteRune(s string, sub map[rune]string) string {
|
||||
var buf bytes.Buffer
|
||||
for _, c := range s {
|
||||
if d, ok := sub[c]; ok {
|
||||
buf.WriteString(d)
|
||||
} else {
|
||||
buf.WriteRune(c)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func smartTruncate(text string) string {
|
||||
if len(text) < MaxLength {
|
||||
return text
|
||||
}
|
||||
|
||||
var truncated string
|
||||
words := strings.SplitAfter(text, "-")
|
||||
// If MaxLength is smaller than length of the first word return word
|
||||
// truncated after MaxLength.
|
||||
if len(words[0]) > MaxLength {
|
||||
return words[0][:MaxLength]
|
||||
}
|
||||
for _, word := range words {
|
||||
if len(truncated)+len(word)-1 <= MaxLength {
|
||||
truncated = truncated + word
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return strings.Trim(truncated, "-")
|
||||
}
|
||||
|
||||
// IsSlug returns True if provided text does not contain white characters,
|
||||
// punctuation, all letters are lower case and only from ASCII range.
|
||||
// It could contain `-` and `_` but not at the beginning or end of the text.
|
||||
// It should be in range of the MaxLength var if specified.
|
||||
// All output from slug.Make(text) should pass this test.
|
||||
func IsSlug(text string) bool {
|
||||
if text == "" ||
|
||||
(MaxLength > 0 && len(text) > MaxLength) ||
|
||||
text[0] == '-' || text[0] == '_' ||
|
||||
text[len(text)-1] == '-' || text[len(text)-1] == '_' {
|
||||
return false
|
||||
}
|
||||
for _, c := range text {
|
||||
if (c < 'a' || c > 'z') && c != '-' && c != '_' && (c < '0' || c > '9') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
*.test
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
Copyright 2014 Rainy Cape S.L. <hello@rainycape.com>
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# unidecode
|
||||
|
||||
[](https://pkg.go.dev/github.com/gosimple/unidecode)
|
||||
[](https://github.com/gosimple/unidecode/actions/workflows/tests.yml)
|
||||
|
||||
Unicode transliterator in Golang - Replaces non-ASCII characters with their
|
||||
ASCII approximations.
|
||||
|
||||
Fork of https://github.com/rainycape/unidecode
|
||||
|
||||
## Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gosimple/unidecode"
|
||||
)
|
||||
|
||||
func main() {
|
||||
decoded := unidecode.Unidecode("Łódź")
|
||||
fmt.Println(decoded)
|
||||
// Output: Lodz
|
||||
}
|
||||
```
|
||||
|
||||
### Requests or bugs?
|
||||
|
||||
<https://github.com/gosimple/unidecode/issues>
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
go get -u github.com/gosimple/unidecode
|
||||
```
|
||||
|
||||
## Benchmark
|
||||
|
||||
```shell
|
||||
go test -run=NONE -bench=. -benchmem -count=6 ./... > old.txt
|
||||
# make changes
|
||||
go test -run=NONE -bench=. -benchmem -count=6 ./... > new.txt
|
||||
|
||||
go install golang.org/x/perf/cmd/benchstat@latest
|
||||
|
||||
benchstat old.txt new.txt
|
||||
```
|
||||
|
||||
## Add new characters
|
||||
|
||||
1. Edit `table.txt` file.
|
||||
2. Rebuild `table.go` file:
|
||||
|
||||
```go
|
||||
go run ./make_table.go
|
||||
```
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
package unidecode
|
||||
|
||||
import (
|
||||
"compress/zlib"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
dummyLenght = byte(0xff)
|
||||
)
|
||||
|
||||
var (
|
||||
transliterations [65536][]rune
|
||||
transCount = rune(len(transliterations))
|
||||
)
|
||||
|
||||
func decodeTransliterations() {
|
||||
r, err := zlib.NewReader(strings.NewReader(tableData))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer r.Close()
|
||||
b := make([]byte, 0, 13) // 13 = longest transliteration, adjust if needed
|
||||
lenB := b[:1]
|
||||
chr := uint16(0xffff) // char counter, rely on overflow on first pass
|
||||
for {
|
||||
chr++
|
||||
if _, err := io.ReadFull(r, lenB); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
if lenB[0] == dummyLenght {
|
||||
continue
|
||||
}
|
||||
b = b[:lenB[0]] // resize, preserving allocation
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
transliterations[int(chr)] = []rune(string(b))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
//go:build none
|
||||
// +build none
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
data, err := ioutil.ReadFile("table.txt")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
previousVal := int64(-1)
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "/*") || line == "" {
|
||||
continue
|
||||
}
|
||||
sep := strings.IndexByte(line, ':')
|
||||
if sep == -1 {
|
||||
panic(line)
|
||||
}
|
||||
val, err := strconv.ParseInt(line[:sep], 0, 32)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if previousVal+1 != val {
|
||||
rangechars := 0
|
||||
for i := previousVal + 1; i <= val-1; i++ {
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint8(0xff)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rangechars++
|
||||
}
|
||||
fmt.Printf("Filled dummy range: 0x%04x - 0x%04x (%4d chars)\n", previousVal+1, val-1, rangechars)
|
||||
}
|
||||
|
||||
s, err := strconv.Unquote(line[sep+2:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint8(len(s))); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
previousVal = val
|
||||
buf.WriteString(s)
|
||||
}
|
||||
var cbuf bytes.Buffer
|
||||
w, err := zlib.NewWriterLevel(&cbuf, zlib.BestCompression)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
buf.Reset()
|
||||
buf.WriteString("package unidecode\n")
|
||||
buf.WriteString("// AUTOGENERATED - DO NOT EDIT!\n\n")
|
||||
fmt.Fprintf(&buf, "const tableData = %q;\n", cbuf.String())
|
||||
dst, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := ioutil.WriteFile("table.go", dst, 0644); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,58 @@
|
|||
// Package unidecode implements a unicode transliterator
|
||||
// which replaces non-ASCII characters with their ASCII
|
||||
// approximations.
|
||||
package unidecode
|
||||
|
||||
//go:generate go run make_table.go
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const pooledCapacity = 64
|
||||
|
||||
var (
|
||||
slicePool sync.Pool
|
||||
decodingOnce sync.Once
|
||||
)
|
||||
|
||||
// Unidecode implements a unicode transliterator, which
|
||||
// replaces non-ASCII characters with their ASCII
|
||||
// counterparts.
|
||||
// Given an unicode encoded string, returns
|
||||
// another string with non-ASCII characters replaced
|
||||
// with their closest ASCII counterparts.
|
||||
// e.g. Unicode("áéíóú") => "aeiou"
|
||||
func Unidecode(s string) string {
|
||||
decodingOnce.Do(decodeTransliterations)
|
||||
l := len(s)
|
||||
var r []rune
|
||||
if l > pooledCapacity {
|
||||
r = make([]rune, 0, len(s))
|
||||
} else {
|
||||
if x := slicePool.Get(); x != nil {
|
||||
r = x.([]rune)[:0]
|
||||
} else {
|
||||
r = make([]rune, 0, pooledCapacity)
|
||||
}
|
||||
}
|
||||
for _, c := range s {
|
||||
if c <= unicode.MaxASCII {
|
||||
r = append(r, c)
|
||||
continue
|
||||
}
|
||||
if c > unicode.MaxRune || c >= transCount {
|
||||
/* Ignore reserved chars */
|
||||
continue
|
||||
}
|
||||
if d := transliterations[c]; d != nil {
|
||||
r = append(r, d...)
|
||||
}
|
||||
}
|
||||
res := string(r)
|
||||
if l <= pooledCapacity {
|
||||
slicePool.Put(r)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
documents
|
||||
coverage.txt
|
||||
_book
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# GORM
|
||||
|
||||
GORM V2 moved to https://github.com/go-gorm/gorm
|
||||
|
||||
GORM V1 Doc https://v1.gorm.io/
|
||||
|
|
@ -0,0 +1,377 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Association Mode contains some helper methods to handle relationship things easily.
|
||||
type Association struct {
|
||||
Error error
|
||||
scope *Scope
|
||||
column string
|
||||
field *Field
|
||||
}
|
||||
|
||||
// Find find out all related associations
|
||||
func (association *Association) Find(value interface{}) *Association {
|
||||
association.scope.related(value, association.column)
|
||||
return association.setErr(association.scope.db.Error)
|
||||
}
|
||||
|
||||
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
|
||||
func (association *Association) Append(values ...interface{}) *Association {
|
||||
if association.Error != nil {
|
||||
return association
|
||||
}
|
||||
|
||||
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
|
||||
return association.Replace(values...)
|
||||
}
|
||||
return association.saveAssociations(values...)
|
||||
}
|
||||
|
||||
// Replace replace current associations with new one
|
||||
func (association *Association) Replace(values ...interface{}) *Association {
|
||||
if association.Error != nil {
|
||||
return association
|
||||
}
|
||||
|
||||
var (
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
// Append new values
|
||||
association.field.Set(reflect.Zero(association.field.Field.Type()))
|
||||
association.saveAssociations(values...)
|
||||
|
||||
// Belongs To
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// Set foreign key to be null when clearing value (length equals 0)
|
||||
if len(values) == 0 {
|
||||
// Set foreign key to be nil
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
} else {
|
||||
// Polymorphic Relations
|
||||
if relationship.PolymorphicDBName != "" {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
|
||||
}
|
||||
|
||||
// Delete Relations except new created
|
||||
if len(values) > 0 {
|
||||
var associationForeignFieldNames, associationForeignDBNames []string
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// if many to many relations, get association fields name from association foreign keys
|
||||
associationScope := scope.New(reflect.New(field.Type()).Interface())
|
||||
for idx, dbName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := associationScope.FieldByName(dbName); ok {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If has one/many relations, use primary keys
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
associationForeignDBNames = append(associationForeignDBNames, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
|
||||
|
||||
if len(newPrimaryKeys) > 0 {
|
||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// if many to many relations, delete related relations from join table
|
||||
var sourceForeignFieldNames []string
|
||||
|
||||
for _, dbName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
if association.Error != nil {
|
||||
return association
|
||||
}
|
||||
|
||||
var (
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
if len(values) == 0 {
|
||||
return association
|
||||
}
|
||||
|
||||
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
|
||||
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
|
||||
}
|
||||
|
||||
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// source value's foreign keys
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// get association's foreign fields name
|
||||
var associationScope = scope.New(reflect.New(field.Type()).Interface())
|
||||
var associationForeignFieldNames []string
|
||||
for _, associationDBName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := associationScope.FieldByName(associationDBName); ok {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// association value's foreign keys
|
||||
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
|
||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||
} else {
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// find with deleting relation's foreign keys
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// set foreign key to be null if there are some records affected
|
||||
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||
if results.RowsAffected > 0 {
|
||||
scope.updatedAttrsWithValues(foreignKeyMap)
|
||||
}
|
||||
} else {
|
||||
association.setErr(results.Error)
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// find all relations
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// only include those deleting relations
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
|
||||
toQueryValues(deletingPrimaryKeys)...,
|
||||
)
|
||||
|
||||
// set matched relation's foreign key to be null
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove deleted records from source's field
|
||||
if association.Error == nil {
|
||||
if field.Kind() == reflect.Slice {
|
||||
leftValues := reflect.Zero(field.Type())
|
||||
|
||||
for i := 0; i < field.Len(); i++ {
|
||||
reflectValue := field.Index(i)
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||
var isDeleted = false
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
isDeleted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isDeleted {
|
||||
leftValues = reflect.Append(leftValues, reflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
association.field.Set(leftValues)
|
||||
} else if field.Kind() == reflect.Struct {
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
association.field.Set(reflect.Zero(field.Type()))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return association
|
||||
}
|
||||
|
||||
// Clear remove relationship between source & current associations, won't delete those associations
|
||||
func (association *Association) Clear() *Association {
|
||||
return association.Replace()
|
||||
}
|
||||
|
||||
// Count return the count of current associations
|
||||
func (association *Association) Count() int {
|
||||
var (
|
||||
count = 0
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
fieldValue = association.field.Field.Interface()
|
||||
query = scope.DB()
|
||||
)
|
||||
|
||||
switch relationship.Kind {
|
||||
case "many_to_many":
|
||||
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||
case "has_many", "has_one":
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
case "belongs_to":
|
||||
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
|
||||
relationship.PolymorphicValue,
|
||||
)
|
||||
}
|
||||
|
||||
if err := query.Model(fieldValue).Count(&count).Error; err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// saveAssociations save passed values as associations
|
||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||
var (
|
||||
scope = association.scope
|
||||
field = association.field
|
||||
relationship = field.Relationship
|
||||
)
|
||||
|
||||
saveAssociation := func(reflectValue reflect.Value) {
|
||||
// value has to been pointer
|
||||
if reflectValue.Kind() != reflect.Ptr {
|
||||
reflectPtr := reflect.New(reflectValue.Type())
|
||||
reflectPtr.Elem().Set(reflectValue)
|
||||
reflectValue = reflectPtr
|
||||
}
|
||||
|
||||
// value has to been saved for many2many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign Fields
|
||||
var fieldType = field.Field.Type()
|
||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||
if reflectValue.Type().AssignableTo(fieldType) {
|
||||
field.Set(reflectValue)
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||
// if field's type is struct, then need to set value back to argument after save
|
||||
setFieldBackToValue = true
|
||||
field.Set(reflectValue.Elem())
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||
field.Set(reflect.Append(field.Field, reflectValue))
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||
// if field's type is slice of struct, then need to set value back to argument after save
|
||||
setSliceFieldBackToValue = true
|
||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||
} else {
|
||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||
|
||||
if setFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field)
|
||||
} else if setSliceFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||
if indirectReflectValue.Kind() == reflect.Struct {
|
||||
saveAssociation(reflectValue)
|
||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||
saveAssociation(indirectReflectValue.Index(i))
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("invalid value type"))
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
// setErr set error when the error is not nil. And return Association.
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
|
@ -0,0 +1,250 @@
|
|||
package gorm
|
||||
|
||||
import "fmt"
|
||||
|
||||
// DefaultCallback default callbacks defined by gorm
|
||||
var DefaultCallback = &Callback{logger: nopLogger{}}
|
||||
|
||||
// Callback is a struct that contains all CRUD callbacks
|
||||
// Field `creates` contains callbacks will be call when creating object
|
||||
// Field `updates` contains callbacks will be call when updating object
|
||||
// Field `deletes` contains callbacks will be call when deleting object
|
||||
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
|
||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||
type Callback struct {
|
||||
logger logger
|
||||
creates []*func(scope *Scope)
|
||||
updates []*func(scope *Scope)
|
||||
deletes []*func(scope *Scope)
|
||||
queries []*func(scope *Scope)
|
||||
rowQueries []*func(scope *Scope)
|
||||
processors []*CallbackProcessor
|
||||
}
|
||||
|
||||
// CallbackProcessor contains callback informations
|
||||
type CallbackProcessor struct {
|
||||
logger logger
|
||||
name string // current callback's name
|
||||
before string // register current callback before a callback
|
||||
after string // register current callback after a callback
|
||||
replace bool // replace callbacks with same name
|
||||
remove bool // delete callbacks with same name
|
||||
kind string // callback type: create, update, delete, query, row_query
|
||||
processor *func(scope *Scope) // callback handler
|
||||
parent *Callback
|
||||
}
|
||||
|
||||
func (c *Callback) clone(logger logger) *Callback {
|
||||
return &Callback{
|
||||
logger: logger,
|
||||
creates: c.creates,
|
||||
updates: c.updates,
|
||||
deletes: c.deletes,
|
||||
queries: c.queries,
|
||||
rowQueries: c.rowQueries,
|
||||
processors: c.processors,
|
||||
}
|
||||
}
|
||||
|
||||
// Create could be used to register callbacks for creating object
|
||||
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
||||
// // business logic
|
||||
// ...
|
||||
//
|
||||
// // set error if some thing wrong happened, will rollback the creating
|
||||
// scope.Err(errors.New("error"))
|
||||
// })
|
||||
func (c *Callback) Create() *CallbackProcessor {
|
||||
return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
|
||||
}
|
||||
|
||||
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||
func (c *Callback) Update() *CallbackProcessor {
|
||||
return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
|
||||
}
|
||||
|
||||
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||
func (c *Callback) Delete() *CallbackProcessor {
|
||||
return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
|
||||
}
|
||||
|
||||
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||
// Refer `Create` for usage
|
||||
func (c *Callback) Query() *CallbackProcessor {
|
||||
return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
|
||||
}
|
||||
|
||||
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||
func (c *Callback) RowQuery() *CallbackProcessor {
|
||||
return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
|
||||
}
|
||||
|
||||
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
|
||||
cp.after = callbackName
|
||||
return cp
|
||||
}
|
||||
|
||||
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||
cp.before = callbackName
|
||||
return cp
|
||||
}
|
||||
|
||||
// Register a new callback, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||
if cp.kind == "row_query" {
|
||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
||||
cp.before = "gorm:row_query"
|
||||
}
|
||||
}
|
||||
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
// Remove a registered callback
|
||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
cp.name = callbackName
|
||||
cp.remove = true
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
// Replace a registered callback with new callback
|
||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||
// scope.SetColumn("CreatedAt", now)
|
||||
// scope.SetColumn("UpdatedAt", now)
|
||||
// })
|
||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.replace = true
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
// Get registered callback
|
||||
// db.Callback().Create().Get("gorm:create")
|
||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||
for _, p := range cp.parent.processors {
|
||||
if p.name == callbackName && p.kind == cp.kind {
|
||||
if p.remove {
|
||||
callback = nil
|
||||
} else {
|
||||
callback = *p.processor
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
func getRIndex(strs []string, str string) int {
|
||||
for i := len(strs) - 1; i >= 0; i-- {
|
||||
if strs[i] == str {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||
var (
|
||||
allNames, sortedNames []string
|
||||
sortCallbackProcessor func(c *CallbackProcessor)
|
||||
)
|
||||
|
||||
for _, cp := range cps {
|
||||
// show warning message the callback name already exists
|
||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||
cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum()))
|
||||
}
|
||||
allNames = append(allNames, cp.name)
|
||||
}
|
||||
|
||||
sortCallbackProcessor = func(c *CallbackProcessor) {
|
||||
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
|
||||
if c.before != "" { // if defined before callback
|
||||
if index := getRIndex(sortedNames, c.before); index != -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
|
||||
} else if index := getRIndex(allNames, c.before); index != -1 {
|
||||
// if before callback exists but haven't sorted, append current callback to last
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
sortCallbackProcessor(cps[index])
|
||||
}
|
||||
}
|
||||
|
||||
if c.after != "" { // if defined after callback
|
||||
if index := getRIndex(sortedNames, c.after); index != -1 {
|
||||
// if after callback already sorted, append current callback just before it
|
||||
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
|
||||
} else if index := getRIndex(allNames, c.after); index != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
cp := cps[index]
|
||||
// set after callback's before callback to current callback
|
||||
if cp.before == "" {
|
||||
cp.before = c.name
|
||||
}
|
||||
sortCallbackProcessor(cp)
|
||||
}
|
||||
}
|
||||
|
||||
// if current callback haven't been sorted, append it to last
|
||||
if getRIndex(sortedNames, c.name) == -1 {
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, cp := range cps {
|
||||
sortCallbackProcessor(cp)
|
||||
}
|
||||
|
||||
var sortedFuncs []*func(scope *Scope)
|
||||
for _, name := range sortedNames {
|
||||
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||
}
|
||||
}
|
||||
|
||||
return sortedFuncs
|
||||
}
|
||||
|
||||
// reorder all registered processors, and reset CRUD callbacks
|
||||
func (c *Callback) reorder() {
|
||||
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
||||
|
||||
for _, processor := range c.processors {
|
||||
if processor.name != "" {
|
||||
switch processor.kind {
|
||||
case "create":
|
||||
creates = append(creates, processor)
|
||||
case "update":
|
||||
updates = append(updates, processor)
|
||||
case "delete":
|
||||
deletes = append(deletes, processor)
|
||||
case "query":
|
||||
queries = append(queries, processor)
|
||||
case "row_query":
|
||||
rowQueries = append(rowQueries, processor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.creates = sortProcessors(creates)
|
||||
c.updates = sortProcessors(updates)
|
||||
c.deletes = sortProcessors(deletes)
|
||||
c.queries = sortProcessors(queries)
|
||||
c.rowQueries = sortProcessors(rowQueries)
|
||||
}
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Define callbacks for creating
|
||||
func init() {
|
||||
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:create", createCallback)
|
||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||
func beforeCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeCreate")
|
||||
}
|
||||
}
|
||||
|
||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
now := scope.db.nowFunc()
|
||||
|
||||
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
||||
if createdAtField.IsBlank {
|
||||
createdAtField.Set(now)
|
||||
}
|
||||
}
|
||||
|
||||
if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
|
||||
if updatedAtField.IsBlank {
|
||||
updatedAtField.Set(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createCallback the callback used to insert data into database
|
||||
func createCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
columns, placeholders []string
|
||||
blankColumnsWithDefaultValue []string
|
||||
)
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if field.IsNormal && !field.IsIgnored {
|
||||
if field.IsBlank && field.HasDefaultValue {
|
||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||
} else if !field.IsPrimaryKey || !field.IsBlank {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||
}
|
||||
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
returningColumn = "*"
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryField = scope.PrimaryField()
|
||||
extraOption string
|
||||
insertModifier string
|
||||
)
|
||||
|
||||
if str, ok := scope.Get("gorm:insert_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
if str, ok := scope.Get("gorm:insert_modifier"); ok {
|
||||
insertModifier = strings.ToUpper(fmt.Sprint(str))
|
||||
if insertModifier == "INTO" {
|
||||
insertModifier = ""
|
||||
}
|
||||
}
|
||||
|
||||
if primaryField != nil {
|
||||
returningColumn = scope.Quote(primaryField.DBName)
|
||||
}
|
||||
|
||||
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
|
||||
var lastInsertIDReturningSuffix string
|
||||
if lastInsertIDOutputInterstitial == "" {
|
||||
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||
}
|
||||
|
||||
if len(columns) == 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT%v INTO %v %v%v%v",
|
||||
addExtraSpaceIfExist(insertModifier),
|
||||
quotedTableName,
|
||||
scope.Dialect().DefaultValueStr(),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
|
||||
addExtraSpaceIfExist(insertModifier),
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(columns, ","),
|
||||
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
|
||||
strings.Join(placeholders, ","),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
// execute create sql: no primaryField
|
||||
if primaryField == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
// set primary value to primary field
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||
scope.Err(primaryField.Set(primaryValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// execute create sql: lastInsertID implemention for majority of dialects
|
||||
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
// set primary value to primary field
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||
scope.Err(primaryField.Set(primaryValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
|
||||
if primaryField.Field.CanAddr() {
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
primaryField.IsBlank = false
|
||||
scope.db.RowsAffected = 1
|
||||
}
|
||||
} else {
|
||||
scope.Err(ErrUnaddressable)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsPrimaryKey && !field.IsBlank {
|
||||
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
db.Scan(scope.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
||||
func afterCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterCreate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Define callbacks for deleting
|
||||
func init() {
|
||||
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||
func beforeDeleteCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("missing WHERE clause while deleting"))
|
||||
return
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeDelete")
|
||||
}
|
||||
}
|
||||
|
||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||
func deleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
|
||||
|
||||
if !scope.Search.Unscoped && hasDeletedAtField {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v=%v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
scope.Quote(deletedAtField.DBName),
|
||||
scope.AddToVars(scope.db.nowFunc()),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"DELETE FROM %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
||||
func afterDeleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterDelete")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Define callbacks for querying
|
||||
func init() {
|
||||
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
||||
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
||||
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
||||
}
|
||||
|
||||
// queryCallback used to query data from database
|
||||
func queryCallback(scope *Scope) {
|
||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||
return
|
||||
}
|
||||
|
||||
//we are only preloading relations, dont touch base model
|
||||
if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
|
||||
return
|
||||
}
|
||||
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
isSlice, isPtr bool
|
||||
resultType reflect.Type
|
||||
results = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||
if primaryField := scope.PrimaryField(); primaryField != nil {
|
||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
||||
results = indirect(reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
if kind := results.Kind(); kind == reflect.Slice {
|
||||
isSlice = true
|
||||
resultType = results.Type().Elem()
|
||||
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
|
||||
|
||||
if resultType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
resultType = resultType.Elem()
|
||||
}
|
||||
} else if kind != reflect.Struct {
|
||||
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
||||
return
|
||||
}
|
||||
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if !scope.HasError() {
|
||||
scope.db.RowsAffected = 0
|
||||
|
||||
if str, ok := scope.Get("gorm:query_hint"); ok {
|
||||
scope.SQL = fmt.Sprint(str) + scope.SQL
|
||||
}
|
||||
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
|
||||
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
scope.db.RowsAffected++
|
||||
|
||||
elem := results
|
||||
if isSlice {
|
||||
elem = reflect.New(resultType).Elem()
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
|
||||
|
||||
if isSlice {
|
||||
if isPtr {
|
||||
results.Set(reflect.Append(results, elem.Addr()))
|
||||
} else {
|
||||
results.Set(reflect.Append(results, elem))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
scope.Err(err)
|
||||
} else if scope.db.RowsAffected == 0 && !isSlice {
|
||||
scope.Err(ErrRecordNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterQueryCallback will invoke `AfterFind` method after querying
|
||||
func afterQueryCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterFind")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,410 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// preloadCallback used to preload associations
|
||||
func preloadCallback(scope *Scope) {
|
||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||
return
|
||||
}
|
||||
|
||||
if ap, ok := scope.Get("gorm:auto_preload"); ok {
|
||||
// If gorm:auto_preload IS NOT a bool then auto preload.
|
||||
// Else if it IS a bool, use the value
|
||||
if apb, ok := ap.(bool); !ok {
|
||||
autoPreload(scope)
|
||||
} else if apb {
|
||||
autoPreload(scope)
|
||||
}
|
||||
}
|
||||
|
||||
if scope.Search.preload == nil || scope.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
preloadedMap = map[string]bool{}
|
||||
fields = scope.Fields()
|
||||
)
|
||||
|
||||
for _, preload := range scope.Search.preload {
|
||||
var (
|
||||
preloadFields = strings.Split(preload.schema, ".")
|
||||
currentScope = scope
|
||||
currentFields = fields
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
var currentPreloadConditions []interface{}
|
||||
|
||||
if currentScope == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if not preloaded
|
||||
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||
|
||||
// assign search conditions to last preload
|
||||
if idx == len(preloadFields)-1 {
|
||||
currentPreloadConditions = preload.conditions
|
||||
}
|
||||
|
||||
for _, field := range currentFields {
|
||||
if field.Name != preloadField || field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Relationship.Kind {
|
||||
case "has_one":
|
||||
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
||||
case "has_many":
|
||||
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
||||
case "belongs_to":
|
||||
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||
case "many_to_many":
|
||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||
default:
|
||||
scope.Err(errors.New("unsupported relation"))
|
||||
}
|
||||
|
||||
preloadedMap[preloadKey] = true
|
||||
break
|
||||
}
|
||||
|
||||
if !preloadedMap[preloadKey] {
|
||||
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// preload next level
|
||||
if idx < len(preloadFields)-1 {
|
||||
currentScope = currentScope.getColumnAsScope(preloadField)
|
||||
if currentScope != nil {
|
||||
currentFields = currentScope.Fields()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func autoPreload(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettingsGet("PRELOAD"); ok {
|
||||
if preload, err := strconv.ParseBool(val); err != nil {
|
||||
scope.Err(errors.New("invalid preload option"))
|
||||
return
|
||||
} else if !preload {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
scope.Search.Preload(field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
||||
var (
|
||||
preloadDB = scope.NewDB()
|
||||
preloadConditions []interface{}
|
||||
)
|
||||
|
||||
for _, condition := range conditions {
|
||||
if scopes, ok := condition.(func(*DB) *DB); ok {
|
||||
preloadDB = scopes(preloadDB)
|
||||
} else {
|
||||
preloadConditions = append(preloadConditions, condition)
|
||||
}
|
||||
}
|
||||
|
||||
return preloadDB, preloadConditions
|
||||
}
|
||||
|
||||
// handleHasOnePreload used to preload has one associations
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||
values := toQueryValues(primaryKeys)
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
foreignValuesToResults := make(map[string]reflect.Value)
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
|
||||
foreignValuesToResults[foreignValues] = result
|
||||
}
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
indirectValue := indirect(indirectScopeValue.Index(j))
|
||||
valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
|
||||
if result, found := foreignValuesToResults[valueString]; found {
|
||||
indirectValue.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleHasManyPreload used to preload has many associations
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||
values := toQueryValues(primaryKeys)
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
preloadMap := make(map[string][]reflect.Value)
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
|
||||
}
|
||||
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
|
||||
f := object.FieldByName(field.Name)
|
||||
if results, ok := preloadMap[toString(objectRealValue)]; ok {
|
||||
f.Set(reflect.Append(f, results...))
|
||||
} else {
|
||||
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(resultsValue))
|
||||
}
|
||||
}
|
||||
|
||||
// handleBelongsToPreload used to preload belongs to associations
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// find relations
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
foreignFieldToObjects := make(map[string][]*reflect.Value)
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
|
||||
foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
|
||||
if objects, found := foreignFieldToObjects[valueString]; found {
|
||||
for _, object := range objects {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleManyToManyPreload used to preload many to many associations
|
||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||
var (
|
||||
relation = field.Relationship
|
||||
joinTableHandler = relation.JoinTableHandler
|
||||
fieldType = field.Struct.Type.Elem()
|
||||
foreignKeyValue interface{}
|
||||
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
||||
linkHash = map[string][]reflect.Value{}
|
||||
isPtr bool
|
||||
)
|
||||
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
var sourceKeys = []string{}
|
||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||
sourceKeys = append(sourceKeys, key.DBName)
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// generate query with join table
|
||||
newScope := scope.New(reflect.New(fieldType).Interface())
|
||||
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
|
||||
|
||||
if len(preloadDB.search.selects) == 0 {
|
||||
preloadDB = preloadDB.Select("*")
|
||||
}
|
||||
|
||||
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
||||
|
||||
// preload inline conditions
|
||||
if len(preloadConditions) > 0 {
|
||||
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
||||
}
|
||||
|
||||
rows, err := preloadDB.Rows()
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
var (
|
||||
elem = reflect.New(fieldType).Elem()
|
||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||
)
|
||||
|
||||
// register foreign keys in join tables
|
||||
var joinTableFields []*Field
|
||||
for _, sourceKey := range sourceKeys {
|
||||
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, append(fields, joinTableFields...))
|
||||
|
||||
scope.New(elem.Addr().Interface()).
|
||||
InstanceSet("gorm:skip_query_callback", true).
|
||||
callCallbacks(scope.db.parent.callbacks.queries)
|
||||
|
||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||
// generate hashed forkey keys in join table
|
||||
for idx, joinTableField := range joinTableFields {
|
||||
if !joinTableField.Field.IsNil() {
|
||||
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
|
||||
}
|
||||
}
|
||||
hashedSourceKeys := toString(foreignKeys)
|
||||
|
||||
if isPtr {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
||||
} else {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
scope.Err(err)
|
||||
}
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
fieldsSourceMap = map[string][]reflect.Value{}
|
||||
foreignFieldNames = []string{}
|
||||
)
|
||||
|
||||
for _, dbName := range relation.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
key := toString(getValueFromFields(object, foreignFieldNames))
|
||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
|
||||
}
|
||||
} else if indirectScopeValue.IsValid() {
|
||||
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
||||
}
|
||||
|
||||
for source, fields := range fieldsSourceMap {
|
||||
for _, f := range fields {
|
||||
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
||||
if f.Len() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
v := reflect.MakeSlice(f.Type(), 0, 0)
|
||||
if len(linkHash[source]) > 0 {
|
||||
v = reflect.Append(f, linkHash[source]...)
|
||||
}
|
||||
|
||||
f.Set(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Define callbacks for row query
|
||||
func init() {
|
||||
DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
|
||||
}
|
||||
|
||||
type RowQueryResult struct {
|
||||
Row *sql.Row
|
||||
}
|
||||
|
||||
type RowsQueryResult struct {
|
||||
Rows *sql.Rows
|
||||
Error error
|
||||
}
|
||||
|
||||
// queryCallback used to query data from database
|
||||
func rowQueryCallback(scope *Scope) {
|
||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if str, ok := scope.Get("gorm:query_hint"); ok {
|
||||
scope.SQL = fmt.Sprint(str) + scope.SQL
|
||||
}
|
||||
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
|
||||
if rowResult, ok := result.(*RowQueryResult); ok {
|
||||
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
|
||||
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func beginTransactionCallback(scope *Scope) {
|
||||
scope.Begin()
|
||||
}
|
||||
|
||||
func commitOrRollbackTransactionCallback(scope *Scope) {
|
||||
scope.CommitOrRollback()
|
||||
}
|
||||
|
||||
func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
|
||||
checkTruth := func(value interface{}) bool {
|
||||
if v, ok := value.(bool); ok && !v {
|
||||
return false
|
||||
}
|
||||
|
||||
if v, ok := value.(string); ok {
|
||||
v = strings.ToLower(v)
|
||||
return v == "true"
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||
if r = field.Relationship; r != nil {
|
||||
autoUpdate, autoCreate, saveReference = true, true, true
|
||||
|
||||
if value, ok := scope.Get("gorm:save_associations"); ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
autoCreate = autoUpdate
|
||||
saveReference = autoUpdate
|
||||
} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
autoCreate = autoUpdate
|
||||
saveReference = autoUpdate
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
||||
autoCreate = checkTruth(value)
|
||||
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
|
||||
autoCreate = checkTruth(value)
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
||||
saveReference = checkTruth(value)
|
||||
} else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
|
||||
saveReference = checkTruth(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func saveBeforeAssociationsCallback(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||
|
||||
if relationship != nil && relationship.Kind == "belongs_to" {
|
||||
fieldValue := field.Field.Addr().Interface()
|
||||
newScope := scope.New(fieldValue)
|
||||
|
||||
if newScope.PrimaryKeyZero() {
|
||||
if autoCreate {
|
||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||
}
|
||||
} else if autoUpdate {
|
||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||
}
|
||||
|
||||
if saveReference {
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
// set value's foreign key
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
|
||||
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func saveAfterAssociationsCallback(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||
|
||||
if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
||||
value := field.Field
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
newDB := scope.NewDB()
|
||||
elem := value.Index(i).Addr().Interface()
|
||||
newScope := newDB.NewScope(elem)
|
||||
|
||||
if saveReference {
|
||||
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||
}
|
||||
}
|
||||
|
||||
if newScope.PrimaryKeyZero() {
|
||||
if autoCreate {
|
||||
scope.Err(newDB.Save(elem).Error)
|
||||
}
|
||||
} else if autoUpdate {
|
||||
scope.Err(newDB.Save(elem).Error)
|
||||
}
|
||||
|
||||
if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
|
||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
elem := value.Addr().Interface()
|
||||
newScope := scope.New(elem)
|
||||
|
||||
if saveReference {
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||
}
|
||||
}
|
||||
|
||||
if newScope.PrimaryKeyZero() {
|
||||
if autoCreate {
|
||||
scope.Err(scope.NewDB().Save(elem).Error)
|
||||
}
|
||||
} else if autoUpdate {
|
||||
scope.Err(scope.NewDB().Save(elem).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Define callbacks for updating
|
||||
func init() {
|
||||
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
||||
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||
} else {
|
||||
scope.SkipLeft()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||
func beforeUpdateCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("missing WHERE clause while updating"))
|
||||
return
|
||||
}
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeUpdate")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
|
||||
}
|
||||
}
|
||||
|
||||
// updateCallback the callback used to update data to database
|
||||
func updateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var sqls []string
|
||||
|
||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
// Sort the column names so that the generated SQL is the same every time.
|
||||
updateMap := updateAttrs.(map[string]interface{})
|
||||
var columns []string
|
||||
for c := range updateMap {
|
||||
columns = append(columns, c)
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
for _, column := range columns {
|
||||
value := updateMap[column]
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||
}
|
||||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) {
|
||||
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
sqls = append(sqls,
|
||||
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:update_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if len(sqls) > 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(sqls, ", "),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
||||
func afterUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterUpdate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Dialect interface contains behaviors that differ across SQL database
|
||||
type Dialect interface {
|
||||
// GetName get dialect's name
|
||||
GetName() string
|
||||
|
||||
// SetDB set db for dialect
|
||||
SetDB(db SQLCommon)
|
||||
|
||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||
BindVar(i int) string
|
||||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||
Quote(key string) string
|
||||
// DataTypeOf return data's sql type
|
||||
DataTypeOf(field *StructField) string
|
||||
|
||||
// HasIndex check has index or not
|
||||
HasIndex(tableName string, indexName string) bool
|
||||
// HasForeignKey check has foreign key or not
|
||||
HasForeignKey(tableName string, foreignKeyName string) bool
|
||||
// RemoveIndex remove index
|
||||
RemoveIndex(tableName string, indexName string) error
|
||||
// HasTable check has table or not
|
||||
HasTable(tableName string) bool
|
||||
// HasColumn check has column or not
|
||||
HasColumn(tableName string, columnName string) bool
|
||||
// ModifyColumn modify column's type
|
||||
ModifyColumn(tableName string, columnName string, typ string) error
|
||||
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||
SelectFromDummyTable() string
|
||||
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||
// DefaultValueStr
|
||||
DefaultValueStr() string
|
||||
|
||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||
BuildKeyName(kind, tableName string, fields ...string) string
|
||||
|
||||
// NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
|
||||
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
|
||||
|
||||
// CurrentDatabase return current database name
|
||||
CurrentDatabase() string
|
||||
}
|
||||
|
||||
var dialectsMap = map[string]Dialect{}
|
||||
|
||||
func newDialect(name string, db SQLCommon) Dialect {
|
||||
if value, ok := dialectsMap[name]; ok {
|
||||
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||
dialect.SetDB(db)
|
||||
return dialect
|
||||
}
|
||||
|
||||
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
|
||||
commontDialect := &commonDialect{}
|
||||
commontDialect.SetDB(db)
|
||||
return commontDialect
|
||||
}
|
||||
|
||||
// RegisterDialect register new dialect
|
||||
func RegisterDialect(name string, dialect Dialect) {
|
||||
dialectsMap[name] = dialect
|
||||
}
|
||||
|
||||
// GetDialect gets the dialect for the specified dialect name
|
||||
func GetDialect(name string) (dialect Dialect, ok bool) {
|
||||
dialect, ok = dialectsMap[name]
|
||||
return
|
||||
}
|
||||
|
||||
// ParseFieldStructForDialect get field's sql data type
|
||||
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||
// Get redirected field type
|
||||
var (
|
||||
reflectType = field.Struct.Type
|
||||
dataType, _ = field.TagSettingsGet("TYPE")
|
||||
)
|
||||
|
||||
for reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
|
||||
// Get redirected field value
|
||||
fieldValue = reflect.Indirect(reflect.New(reflectType))
|
||||
|
||||
if gormDataType, ok := fieldValue.Interface().(interface {
|
||||
GormDataType(Dialect) string
|
||||
}); ok {
|
||||
dataType = gormDataType.GormDataType(dialect)
|
||||
}
|
||||
|
||||
// Get scanner's real value
|
||||
if dataType == "" {
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
fieldValue = value
|
||||
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
|
||||
getScannerValue(fieldValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(fieldValue)
|
||||
}
|
||||
|
||||
// Default Size
|
||||
if num, ok := field.TagSettingsGet("SIZE"); ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
} else {
|
||||
size = 255
|
||||
}
|
||||
|
||||
// Default type from tag setting
|
||||
notNull, _ := field.TagSettingsGet("NOT NULL")
|
||||
unique, _ := field.TagSettingsGet("UNIQUE")
|
||||
additionalType = notNull + " " + unique
|
||||
if value, ok := field.TagSettingsGet("DEFAULT"); ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
if value, ok := field.TagSettingsGet("COMMENT"); ok {
|
||||
additionalType = additionalType + " COMMENT " + value
|
||||
}
|
||||
|
||||
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
||||
}
|
||||
|
||||
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
|
||||
if strings.Contains(tableName, ".") {
|
||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||
return splitStrings[0], splitStrings[1]
|
||||
}
|
||||
return dialect.CurrentDatabase(), tableName
|
||||
}
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
|
||||
|
||||
// DefaultForeignKeyNamer contains the default foreign key name generator method
|
||||
type DefaultForeignKeyNamer struct {
|
||||
}
|
||||
|
||||
type commonDialect struct {
|
||||
db SQLCommon
|
||||
DefaultForeignKeyNamer
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("common", &commonDialect{})
|
||||
}
|
||||
|
||||
func (commonDialect) GetName() string {
|
||||
return "common"
|
||||
}
|
||||
|
||||
func (s *commonDialect) SetDB(db SQLCommon) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
func (commonDialect) BindVar(i int) string {
|
||||
return "$$$" // ?
|
||||
}
|
||||
|
||||
func (commonDialect) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
|
||||
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||
return strings.ToLower(value) != "false"
|
||||
}
|
||||
return field.IsPrimaryKey
|
||||
}
|
||||
|
||||
func (s *commonDialect) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
sqlType = "INTEGER AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "INTEGER"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
sqlType = "BIGINT AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "BIGINT"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "FLOAT"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
|
||||
} else {
|
||||
sqlType = "VARCHAR(65532)"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("BINARY(%d)", size)
|
||||
} else {
|
||||
sqlType = "BINARY(65532)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s commonDialect) HasTable(tableName string) bool {
|
||||
var count int
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s commonDialect) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset
|
||||
func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
if limit != nil {
|
||||
if parsedLimit, err := s.parseInt(limit); err != nil {
|
||||
return "", err
|
||||
} else if parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
}
|
||||
}
|
||||
if offset != nil {
|
||||
if parsedOffset, err := s.parseInt(offset); err != nil {
|
||||
return "", err
|
||||
} else if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) DefaultValueStr() string {
|
||||
return "DEFAULT VALUES"
|
||||
}
|
||||
|
||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
|
||||
keyName = keyNameRegex.ReplaceAllString(keyName, "_")
|
||||
return keyName
|
||||
}
|
||||
|
||||
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
|
||||
func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||
return indexName, columnName
|
||||
}
|
||||
|
||||
func (commonDialect) parseInt(value interface{}) (int64, error) {
|
||||
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
|
||||
}
|
||||
|
||||
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
||||
func IsByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||
}
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`)
|
||||
|
||||
type mysql struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("mysql", &mysql{})
|
||||
}
|
||||
|
||||
func (mysql) GetName() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (mysql) Quote(key string) string {
|
||||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
||||
// Get Data Type for MySQL Dialect
|
||||
func (s *mysql) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
// MySQL allows only one auto increment column per table, and it must
|
||||
// be a KEY column.
|
||||
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||
if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
|
||||
field.TagSettingsDelete("AUTO_INCREMENT")
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "boolean"
|
||||
case reflect.Int8:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "tinyint AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "tinyint"
|
||||
}
|
||||
case reflect.Int, reflect.Int16, reflect.Int32:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "int AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "int"
|
||||
}
|
||||
case reflect.Uint8:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "tinyint unsigned"
|
||||
}
|
||||
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "int unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "int unsigned"
|
||||
}
|
||||
case reflect.Int64:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "bigint AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Uint64:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "bigint unsigned"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "double"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "longtext"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
precision := ""
|
||||
if p, ok := field.TagSettingsGet("PRECISION"); ok {
|
||||
precision = fmt.Sprintf("(%s)", p)
|
||||
}
|
||||
|
||||
if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey {
|
||||
sqlType = fmt.Sprintf("DATETIME%v", precision)
|
||||
} else {
|
||||
sqlType = fmt.Sprintf("DATETIME%v NULL", precision)
|
||||
}
|
||||
}
|
||||
default:
|
||||
if IsByteArrayOrSlice(dataValue) {
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
sqlType = "longblob"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
if limit != nil {
|
||||
parsedLimit, err := s.parseInt(limit)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
|
||||
if offset != nil {
|
||||
parsedOffset, err := s.parseInt(offset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
var count int
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mysql) HasTable(tableName string) bool {
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
var name string
|
||||
// allow mysql database name with '-' character
|
||||
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false
|
||||
}
|
||||
panic(err)
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s mysql) HasIndex(tableName string, indexName string) bool {
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
return rows.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (s mysql) HasColumn(tableName string, columnName string) bool {
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
return rows.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (s mysql) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql) SelectFromDummyTable() string {
|
||||
return "FROM DUAL"
|
||||
}
|
||||
|
||||
func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||
keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
|
||||
if utf8.RuneCountInString(keyName) <= 64 {
|
||||
return keyName
|
||||
}
|
||||
h := sha1.New()
|
||||
h.Write([]byte(keyName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
// sha1 is 40 characters, keep first 24 characters of destination
|
||||
destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_"))
|
||||
if len(destRunes) > 24 {
|
||||
destRunes = destRunes[:24]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%x", string(destRunes), bs)
|
||||
}
|
||||
|
||||
// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed
|
||||
func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||
submatch := mysqlIndexRegex.FindStringSubmatch(indexName)
|
||||
if len(submatch) != 3 {
|
||||
return indexName, columnName
|
||||
}
|
||||
indexName = submatch[1]
|
||||
columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2])
|
||||
return indexName, columnName
|
||||
}
|
||||
|
||||
func (mysql) DefaultValueStr() string {
|
||||
return "VALUES()"
|
||||
}
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("postgres", &postgres{})
|
||||
RegisterDialect("cloudsqlpostgres", &postgres{})
|
||||
}
|
||||
|
||||
func (postgres) GetName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (postgres) BindVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (s *postgres) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "serial"
|
||||
} else {
|
||||
sqlType = "integer"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "bigserial"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "numeric"
|
||||
case reflect.String:
|
||||
if _, ok := field.TagSettingsGet("SIZE"); !ok {
|
||||
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
||||
}
|
||||
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "timestamp with time zone"
|
||||
}
|
||||
case reflect.Map:
|
||||
if dataValue.Type().Name() == "Hstore" {
|
||||
sqlType = "hstore"
|
||||
}
|
||||
default:
|
||||
if IsByteArrayOrSlice(dataValue) {
|
||||
sqlType = "bytea"
|
||||
|
||||
if isUUID(dataValue) {
|
||||
sqlType = "uuid"
|
||||
}
|
||||
|
||||
if isJSON(dataValue) {
|
||||
sqlType = "jsonb"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
||||
func (postgres) SupportLastInsertID() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isUUID(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||
return false
|
||||
}
|
||||
typename := value.Type().Name()
|
||||
lower := strings.ToLower(typename)
|
||||
return "uuid" == lower || "guid" == lower
|
||||
}
|
||||
|
||||
func isJSON(value reflect.Value) bool {
|
||||
_, ok := value.Interface().(json.RawMessage)
|
||||
return ok
|
||||
}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type sqlite3 struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("sqlite3", &sqlite3{})
|
||||
}
|
||||
|
||||
func (sqlite3) GetName() string {
|
||||
return "sqlite3"
|
||||
}
|
||||
|
||||
// Get Data Type for Sqlite Dialect
|
||||
func (s *sqlite3) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "bool"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "integer"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if s.fieldCanAutoIncrement(field) {
|
||||
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "real"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "datetime"
|
||||
}
|
||||
default:
|
||||
if IsByteArrayOrSlice(dataValue) {
|
||||
sqlType = "blob"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) CurrentDatabase() (name string) {
|
||||
var (
|
||||
ifaces = make([]interface{}, 3)
|
||||
pointers = make([]*string, 3)
|
||||
i int
|
||||
)
|
||||
for i = 0; i < 3; i++ {
|
||||
ifaces[i] = &pointers[i]
|
||||
}
|
||||
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
|
||||
return
|
||||
}
|
||||
if pointers[1] != nil {
|
||||
name = *pointers[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
package mysql
|
||||
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
version: '3'
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: 'mysql:latest'
|
||||
ports:
|
||||
- 9910:3306
|
||||
environment:
|
||||
- MYSQL_DATABASE=gorm
|
||||
- MYSQL_USER=gorm
|
||||
- MYSQL_PASSWORD=gorm
|
||||
- MYSQL_RANDOM_ROOT_PASSWORD="yes"
|
||||
postgres:
|
||||
image: 'postgres:latest'
|
||||
ports:
|
||||
- 9920:5432
|
||||
environment:
|
||||
- POSTGRES_USER=gorm
|
||||
- POSTGRES_DB=gorm
|
||||
- POSTGRES_PASSWORD=gorm
|
||||
mssql:
|
||||
image: 'mcmoe/mssqldocker:latest'
|
||||
ports:
|
||||
- 9930:1433
|
||||
environment:
|
||||
- ACCEPT_EULA=Y
|
||||
- SA_PASSWORD=LoremIpsum86
|
||||
- MSSQL_DB=gorm
|
||||
- MSSQL_USER=gorm
|
||||
- MSSQL_PASSWORD=LoremIpsum86
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
// ErrInvalidSQL occurs when you attempt a query with invalid SQL
|
||||
ErrInvalidSQL = errors.New("invalid SQL")
|
||||
// ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||
// ErrUnaddressable unaddressable value
|
||||
ErrUnaddressable = errors.New("using unaddressable value")
|
||||
)
|
||||
|
||||
// Errors contains all happened errors
|
||||
type Errors []error
|
||||
|
||||
// IsRecordNotFoundError returns true if error contains a RecordNotFound error
|
||||
func IsRecordNotFoundError(err error) bool {
|
||||
if errs, ok := err.(Errors); ok {
|
||||
for _, err := range errs {
|
||||
if err == ErrRecordNotFound {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return err == ErrRecordNotFound
|
||||
}
|
||||
|
||||
// GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
|
||||
func (errs Errors) GetErrors() []error {
|
||||
return errs
|
||||
}
|
||||
|
||||
// Add adds an error to a given slice of errors
|
||||
func (errs Errors) Add(newErrors ...error) Errors {
|
||||
for _, err := range newErrors {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if errors, ok := err.(Errors); ok {
|
||||
errs = errs.Add(errors...)
|
||||
} else {
|
||||
ok = true
|
||||
for _, e := range errs {
|
||||
if err == e {
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
if ok {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// Error takes a slice of all errors that have occurred and returns it as a formatted string
|
||||
func (errs Errors) Error() string {
|
||||
var errors = []string{}
|
||||
for _, e := range errs {
|
||||
errors = append(errors, e.Error())
|
||||
}
|
||||
return strings.Join(errors, "; ")
|
||||
}
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Field model field definition
|
||||
type Field struct {
|
||||
*StructField
|
||||
IsBlank bool
|
||||
Field reflect.Value
|
||||
}
|
||||
|
||||
// Set set a value to the field
|
||||
func (field *Field) Set(value interface{}) (err error) {
|
||||
if !field.Field.IsValid() {
|
||||
return errors.New("field value not valid")
|
||||
}
|
||||
|
||||
if !field.Field.CanAddr() {
|
||||
return ErrUnaddressable
|
||||
}
|
||||
|
||||
reflectValue, ok := value.(reflect.Value)
|
||||
if !ok {
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
fieldValue := field.Field
|
||||
if reflectValue.IsValid() {
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else {
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||
v := reflectValue.Interface()
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
err = scanner.Scan(v)
|
||||
}
|
||||
} else {
|
||||
err = scanner.Scan(v)
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.Field.Set(reflect.Zero(field.Field.Type()))
|
||||
}
|
||||
|
||||
field.IsBlank = isBlank(field.Field)
|
||||
return err
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||
type SQLCommon interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type sqlDb interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type sqlTx interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
|
@ -0,0 +1,211 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JoinTableHandlerInterface is an interface for how to handle many2many relations
|
||||
type JoinTableHandlerInterface interface {
|
||||
// initialize join table handler
|
||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||
// Table return join table's table name
|
||||
Table(db *DB) string
|
||||
// Add create relationship in join table for source and destination
|
||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||
// Delete delete relationship in join table for sources
|
||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||
// JoinWith query with `Join` conditions
|
||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||
// SourceForeignKeys return source foreign keys
|
||||
SourceForeignKeys() []JoinTableForeignKey
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
DestinationForeignKeys() []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableForeignKey join table foreign key struct
|
||||
type JoinTableForeignKey struct {
|
||||
DBName string
|
||||
AssociationDBName string
|
||||
}
|
||||
|
||||
// JoinTableSource is a struct that contains model type and foreign keys
|
||||
type JoinTableSource struct {
|
||||
ModelType reflect.Type
|
||||
ForeignKeys []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableHandler default join table handler
|
||||
type JoinTableHandler struct {
|
||||
TableName string `sql:"-"`
|
||||
Source JoinTableSource `sql:"-"`
|
||||
Destination JoinTableSource `sql:"-"`
|
||||
}
|
||||
|
||||
// SourceForeignKeys return source foreign keys
|
||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||
return s.Source.ForeignKeys
|
||||
}
|
||||
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||
return s.Destination.ForeignKeys
|
||||
}
|
||||
|
||||
// Setup initialize a default join table handler
|
||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||
s.TableName = tableName
|
||||
|
||||
s.Source = JoinTableSource{ModelType: source}
|
||||
s.Source.ForeignKeys = []JoinTableForeignKey{}
|
||||
for idx, dbName := range relationship.ForeignFieldNames {
|
||||
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.ForeignDBNames[idx],
|
||||
AssociationDBName: dbName,
|
||||
})
|
||||
}
|
||||
|
||||
s.Destination = JoinTableSource{ModelType: destination}
|
||||
s.Destination.ForeignKeys = []JoinTableForeignKey{}
|
||||
for idx, dbName := range relationship.AssociationForeignFieldNames {
|
||||
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.AssociationForeignDBNames[idx],
|
||||
AssociationDBName: dbName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Table return join table's table name
|
||||
func (s JoinTableHandler) Table(db *DB) string {
|
||||
return DefaultTableNameHandler(db, s.TableName)
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
|
||||
for _, source := range sources {
|
||||
scope := db.NewScope(source)
|
||||
modelType := scope.GetModelStruct().ModelType
|
||||
|
||||
for _, joinTableSource := range joinTableSources {
|
||||
if joinTableSource.ModelType == modelType {
|
||||
for _, foreignKey := range joinTableSource.ForeignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
conditionMap[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add create relationship in join table for source and destination
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||
var (
|
||||
scope = db.NewScope("")
|
||||
conditionMap = map[string]interface{}{}
|
||||
)
|
||||
|
||||
// Update condition map for source
|
||||
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
|
||||
|
||||
// Update condition map for destination
|
||||
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
|
||||
|
||||
var assignColumns, binVars, conditions []string
|
||||
var values []interface{}
|
||||
for key, value := range conditionMap {
|
||||
assignColumns = append(assignColumns, scope.Quote(key))
|
||||
binVars = append(binVars, `?`)
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
quotedTable := scope.Quote(handler.Table(db))
|
||||
sql := fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
|
||||
quotedTable,
|
||||
strings.Join(assignColumns, ","),
|
||||
strings.Join(binVars, ","),
|
||||
scope.Dialect().SelectFromDummyTable(),
|
||||
quotedTable,
|
||||
strings.Join(conditions, " AND "),
|
||||
)
|
||||
|
||||
return db.Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
// Delete delete relationship in join table for sources
|
||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||
var (
|
||||
scope = db.NewScope(nil)
|
||||
conditions []string
|
||||
values []interface{}
|
||||
conditionMap = map[string]interface{}{}
|
||||
)
|
||||
|
||||
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
|
||||
|
||||
for key, value := range conditionMap {
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
}
|
||||
|
||||
// JoinWith query with `Join` conditions
|
||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||
var (
|
||||
scope = db.NewScope(source)
|
||||
tableName = handler.Table(db)
|
||||
quotedTableName = scope.Quote(tableName)
|
||||
joinConditions []string
|
||||
values []interface{}
|
||||
)
|
||||
|
||||
if s.Source.ModelType == scope.GetModelStruct().ModelType {
|
||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||
}
|
||||
|
||||
var foreignDBNames []string
|
||||
var foreignFieldNames []string
|
||||
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||
|
||||
var condString string
|
||||
if len(foreignFieldValues) > 0 {
|
||||
var quotedForeignDBNames []string
|
||||
for _, dbName := range foreignDBNames {
|
||||
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
|
||||
}
|
||||
|
||||
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
|
||||
|
||||
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||
values = append(values, toQueryValues(keys))
|
||||
} else {
|
||||
condString = fmt.Sprintf("1 <> 1")
|
||||
}
|
||||
|
||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
|
||||
Where(condString, toQueryValues(foreignFieldValues)...)
|
||||
}
|
||||
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
}
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||
sqlRegexp = regexp.MustCompile(`\?`)
|
||||
numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`)
|
||||
)
|
||||
|
||||
func isPrintable(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsPrint(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
||||
if len(values) > 1 {
|
||||
var (
|
||||
sql string
|
||||
formattedValues []string
|
||||
level = values[0]
|
||||
currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
|
||||
source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
|
||||
)
|
||||
|
||||
messages = []interface{}{source, currentTime}
|
||||
|
||||
if len(values) == 2 {
|
||||
//remove the line break
|
||||
currentTime = currentTime[1:]
|
||||
//remove the brackets
|
||||
source = fmt.Sprintf("\033[35m%v\033[0m", values[1])
|
||||
|
||||
messages = []interface{}{currentTime, source}
|
||||
}
|
||||
|
||||
if level == "sql" {
|
||||
// duration
|
||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||
// sql
|
||||
|
||||
for _, value := range values[4].([]interface{}) {
|
||||
indirectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if indirectValue.IsValid() {
|
||||
value = indirectValue.Interface()
|
||||
if t, ok := value.(time.Time); ok {
|
||||
if t.IsZero() {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
|
||||
} else {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
} else if b, ok := value.([]byte); ok {
|
||||
if str := string(b); isPrintable(str) {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
||||
} else {
|
||||
formattedValues = append(formattedValues, "'<binary>'")
|
||||
}
|
||||
} else if r, ok := value.(driver.Valuer); ok {
|
||||
if value, err := r.Value(); err == nil && value != nil {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||
} else {
|
||||
formattedValues = append(formattedValues, "NULL")
|
||||
}
|
||||
} else {
|
||||
switch value.(type) {
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
|
||||
default:
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
formattedValues = append(formattedValues, "NULL")
|
||||
}
|
||||
}
|
||||
|
||||
// differentiate between $n placeholders or else treat like ?
|
||||
if numericPlaceHolderRegexp.MatchString(values[3].(string)) {
|
||||
sql = values[3].(string)
|
||||
for index, value := range formattedValues {
|
||||
placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
|
||||
sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
|
||||
}
|
||||
} else {
|
||||
formattedValuesLength := len(formattedValues)
|
||||
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
|
||||
sql += value
|
||||
if index < formattedValuesLength {
|
||||
sql += formattedValues[index]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, sql)
|
||||
messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned "))
|
||||
} else {
|
||||
messages = append(messages, "\033[31;1m")
|
||||
messages = append(messages, values[2:]...)
|
||||
messages = append(messages, "\033[0m")
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type logger interface {
|
||||
Print(v ...interface{})
|
||||
}
|
||||
|
||||
// LogWriter log writer interface
|
||||
type LogWriter interface {
|
||||
Println(v ...interface{})
|
||||
}
|
||||
|
||||
// Logger default logger
|
||||
type Logger struct {
|
||||
LogWriter
|
||||
}
|
||||
|
||||
// Print format & print log
|
||||
func (logger Logger) Print(values ...interface{}) {
|
||||
logger.Println(LogFormatter(values...)...)
|
||||
}
|
||||
|
||||
type nopLogger struct{}
|
||||
|
||||
func (nopLogger) Print(values ...interface{}) {}
|
||||
|
|
@ -0,0 +1,886 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DB contains information for current db connection
|
||||
type DB struct {
|
||||
sync.RWMutex
|
||||
Value interface{}
|
||||
Error error
|
||||
RowsAffected int64
|
||||
|
||||
// single db
|
||||
db SQLCommon
|
||||
blockGlobalUpdate bool
|
||||
logMode logModeValue
|
||||
logger logger
|
||||
search *search
|
||||
values sync.Map
|
||||
|
||||
// global db
|
||||
parent *DB
|
||||
callbacks *Callback
|
||||
dialect Dialect
|
||||
singularTable bool
|
||||
|
||||
// function to be used to override the creating of a new timestamp
|
||||
nowFuncOverride func() time.Time
|
||||
}
|
||||
|
||||
type logModeValue int
|
||||
|
||||
const (
|
||||
defaultLogMode logModeValue = iota
|
||||
noLogMode
|
||||
detailedLogMode
|
||||
)
|
||||
|
||||
// Open initialize a new db connection, need to import driver first, e.g:
|
||||
//
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
// func main() {
|
||||
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
||||
// }
|
||||
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
|
||||
// import _ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
|
||||
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
|
||||
func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
||||
if len(args) == 0 {
|
||||
err = errors.New("invalid database source")
|
||||
return nil, err
|
||||
}
|
||||
var source string
|
||||
var dbSQL SQLCommon
|
||||
var ownDbSQL bool
|
||||
|
||||
switch value := args[0].(type) {
|
||||
case string:
|
||||
var driver = dialect
|
||||
if len(args) == 1 {
|
||||
source = value
|
||||
} else if len(args) >= 2 {
|
||||
driver = value
|
||||
source = args[1].(string)
|
||||
}
|
||||
dbSQL, err = sql.Open(driver, source)
|
||||
ownDbSQL = true
|
||||
case SQLCommon:
|
||||
dbSQL = value
|
||||
ownDbSQL = false
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
|
||||
}
|
||||
|
||||
db = &DB{
|
||||
db: dbSQL,
|
||||
logger: defaultLogger,
|
||||
callbacks: DefaultCallback,
|
||||
dialect: newDialect(dialect, dbSQL),
|
||||
}
|
||||
db.parent = db
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Send a ping to make sure the database connection is alive.
|
||||
if d, ok := dbSQL.(*sql.DB); ok {
|
||||
if err = d.Ping(); err != nil && ownDbSQL {
|
||||
d.Close()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// New clone a new db connection without search conditions
|
||||
func (s *DB) New() *DB {
|
||||
clone := s.clone()
|
||||
clone.search = nil
|
||||
clone.Value = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
type closer interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Close close current db connection. If database connection is not an io.Closer, returns an error.
|
||||
func (s *DB) Close() error {
|
||||
if db, ok := s.parent.db.(closer); ok {
|
||||
return db.Close()
|
||||
}
|
||||
return errors.New("can't close current db")
|
||||
}
|
||||
|
||||
// DB get `*sql.DB` from current connection
|
||||
// If the underlying database connection is not a *sql.DB, returns nil
|
||||
func (s *DB) DB() *sql.DB {
|
||||
db, ok := s.db.(*sql.DB)
|
||||
if !ok {
|
||||
panic("can't support full GORM on currently status, maybe this is a TX instance.")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
|
||||
func (s *DB) CommonDB() SQLCommon {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// Dialect get dialect
|
||||
func (s *DB) Dialect() Dialect {
|
||||
return s.dialect
|
||||
}
|
||||
|
||||
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
||||
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
||||
func (s *DB) Callback() *Callback {
|
||||
s.parent.callbacks = s.parent.callbacks.clone(s.logger)
|
||||
return s.parent.callbacks
|
||||
}
|
||||
|
||||
// SetLogger replace default logger
|
||||
func (s *DB) SetLogger(log logger) {
|
||||
s.logger = log
|
||||
}
|
||||
|
||||
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
||||
func (s *DB) LogMode(enable bool) *DB {
|
||||
if enable {
|
||||
s.logMode = detailedLogMode
|
||||
} else {
|
||||
s.logMode = noLogMode
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SetNowFuncOverride set the function to be used when creating a new timestamp
|
||||
func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB {
|
||||
s.nowFuncOverride = nowFuncOverride
|
||||
return s
|
||||
}
|
||||
|
||||
// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
|
||||
// otherwise defaults to the global NowFunc()
|
||||
func (s *DB) nowFunc() time.Time {
|
||||
if s.nowFuncOverride != nil {
|
||||
return s.nowFuncOverride()
|
||||
}
|
||||
|
||||
return NowFunc()
|
||||
}
|
||||
|
||||
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
|
||||
// This is to prevent eventual error with empty objects updates/deletions
|
||||
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
|
||||
s.blockGlobalUpdate = enable
|
||||
return s
|
||||
}
|
||||
|
||||
// HasBlockGlobalUpdate return state of block
|
||||
func (s *DB) HasBlockGlobalUpdate() bool {
|
||||
return s.blockGlobalUpdate
|
||||
}
|
||||
|
||||
// SingularTable use singular table by default
|
||||
func (s *DB) SingularTable(enable bool) {
|
||||
s.parent.Lock()
|
||||
defer s.parent.Unlock()
|
||||
s.parent.singularTable = enable
|
||||
}
|
||||
|
||||
// NewScope create a scope for current operation
|
||||
func (s *DB) NewScope(value interface{}) *Scope {
|
||||
dbClone := s.clone()
|
||||
dbClone.Value = value
|
||||
scope := &Scope{db: dbClone, Value: value}
|
||||
if s.search != nil {
|
||||
scope.Search = s.search.clone()
|
||||
} else {
|
||||
scope.Search = &search{}
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
// QueryExpr returns the query as SqlExpr object
|
||||
func (s *DB) QueryExpr() *SqlExpr {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.InstanceSet("skip_bindvar", true)
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
return Expr(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
// SubQuery returns the query as sub query
|
||||
func (s *DB) SubQuery() *SqlExpr {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.InstanceSet("skip_bindvar", true)
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...)
|
||||
}
|
||||
|
||||
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
|
||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Where(query, args...).db
|
||||
}
|
||||
|
||||
// Or filter records that match before conditions or this one, similar to `Where`
|
||||
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Or(query, args...).db
|
||||
}
|
||||
|
||||
// Not filter records that don't match current conditions, similar to `Where`
|
||||
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Not(query, args...).db
|
||||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
func (s *DB) Limit(limit interface{}) *DB {
|
||||
return s.clone().search.Limit(limit).db
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
func (s *DB) Offset(offset interface{}) *DB {
|
||||
return s.clone().search.Offset(offset).db
|
||||
}
|
||||
|
||||
// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
|
||||
// db.Order("name DESC")
|
||||
// db.Order("name DESC", true) // reorder
|
||||
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
|
||||
func (s *DB) Order(value interface{}, reorder ...bool) *DB {
|
||||
return s.clone().search.Order(value, reorder...).db
|
||||
}
|
||||
|
||||
// Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
|
||||
// When creating/updating, specify fields that you want to save to database
|
||||
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Select(query, args...).db
|
||||
}
|
||||
|
||||
// Omit specify fields that you want to ignore when saving to database for creating, updating
|
||||
func (s *DB) Omit(columns ...string) *DB {
|
||||
return s.clone().search.Omit(columns...).db
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
func (s *DB) Group(query string) *DB {
|
||||
return s.clone().search.Group(query).db
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
func (s *DB) Having(query interface{}, values ...interface{}) *DB {
|
||||
return s.clone().search.Having(query, values...).db
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
func (s *DB) Joins(query string, args ...interface{}) *DB {
|
||||
return s.clone().search.Joins(query, args...).db
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
// Refer https://jinzhu.github.io/gorm/crud.html#scopes
|
||||
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||
for _, f := range funcs {
|
||||
s = f(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete
|
||||
func (s *DB) Unscoped() *DB {
|
||||
return s.clone().search.unscoped().db
|
||||
}
|
||||
|
||||
// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
|
||||
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
||||
return s.clone().search.Attrs(attrs...).db
|
||||
}
|
||||
|
||||
// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
|
||||
func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||
return s.clone().search.Assign(attrs...).db
|
||||
}
|
||||
|
||||
// First find first record that match given conditions, order by primary key
|
||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||
newScope := s.NewScope(out)
|
||||
newScope.Search.Limit(1)
|
||||
|
||||
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||
func (s *DB) Take(out interface{}, where ...interface{}) *DB {
|
||||
newScope := s.NewScope(out)
|
||||
newScope.Search.Limit(1)
|
||||
return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Last find last record that match given conditions, order by primary key
|
||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||
newScope := s.NewScope(out)
|
||||
newScope.Search.Limit(1)
|
||||
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Find find records that match given conditions
|
||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
//Preloads preloads relations, don`t touch out
|
||||
func (s *DB) Preloads(out interface{}) *DB {
|
||||
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (s *DB) Scan(dest interface{}) *DB {
|
||||
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Row return `*sql.Row` with given conditions
|
||||
func (s *DB) Row() *sql.Row {
|
||||
return s.NewScope(s.Value).row()
|
||||
}
|
||||
|
||||
// Rows return `*sql.Rows` with given conditions
|
||||
func (s *DB) Rows() (*sql.Rows, error) {
|
||||
return s.NewScope(s.Value).rows()
|
||||
}
|
||||
|
||||
// ScanRows scan `*sql.Rows` to give struct
|
||||
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
||||
var (
|
||||
scope = s.NewScope(result)
|
||||
clone = scope.db
|
||||
columns, err = rows.Columns()
|
||||
)
|
||||
|
||||
if clone.AddError(err) == nil {
|
||||
scope.scan(rows, columns, scope.Fields())
|
||||
}
|
||||
|
||||
return clone.Error
|
||||
}
|
||||
|
||||
// Pluck used to query single column from a model as a map
|
||||
// var ages []int64
|
||||
// db.Find(&users).Pluck("age", &ages)
|
||||
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||
return s.NewScope(s.Value).pluck(column, value).db
|
||||
}
|
||||
|
||||
// Count get how many records for a model
|
||||
func (s *DB) Count(value interface{}) *DB {
|
||||
return s.NewScope(s.Value).count(value).db
|
||||
}
|
||||
|
||||
// Related get related associations
|
||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||
return s.NewScope(s.Value).related(value, foreignKeys...).db
|
||||
}
|
||||
|
||||
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
||||
// https://jinzhu.github.io/gorm/crud.html#firstorinit
|
||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
if result := c.First(out, where...); result.Error != nil {
|
||||
if !result.RecordNotFound() {
|
||||
return result
|
||||
}
|
||||
c.NewScope(out).inlineCondition(where...).initialize()
|
||||
} else {
|
||||
c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||
// https://jinzhu.github.io/gorm/crud.html#firstorcreate
|
||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
if result := s.First(out, where...); result.Error != nil {
|
||||
if !result.RecordNotFound() {
|
||||
return result
|
||||
}
|
||||
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db
|
||||
} else if len(c.search.assignAttrs) > 0 {
|
||||
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||
// WARNING when update with struct, GORM will not update fields that with zero value
|
||||
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||
return s.Updates(toSearchableMap(attrs...), true)
|
||||
}
|
||||
|
||||
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||
return s.NewScope(s.Value).
|
||||
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||
InstanceSet("gorm:update_interface", values).
|
||||
callCallbacks(s.parent.callbacks.updates).db
|
||||
}
|
||||
|
||||
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||
return s.UpdateColumns(toSearchableMap(attrs...))
|
||||
}
|
||||
|
||||
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||
return s.NewScope(s.Value).
|
||||
Set("gorm:update_column", true).
|
||||
Set("gorm:save_associations", false).
|
||||
InstanceSet("gorm:update_interface", values).
|
||||
callCallbacks(s.parent.callbacks.updates).db
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
func (s *DB) Save(value interface{}) *DB {
|
||||
scope := s.NewScope(value)
|
||||
if !scope.PrimaryKeyZero() {
|
||||
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
|
||||
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
||||
return s.New().Table(scope.TableName()).FirstOrCreate(value)
|
||||
}
|
||||
return newDB
|
||||
}
|
||||
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||
}
|
||||
|
||||
// Create insert the value into database
|
||||
func (s *DB) Create(value interface{}) *DB {
|
||||
scope := s.NewScope(value)
|
||||
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||
}
|
||||
|
||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||
// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
|
||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
||||
}
|
||||
|
||||
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
||||
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
|
||||
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
||||
return s.clone().search.Raw(true).Where(sql, values...).db
|
||||
}
|
||||
|
||||
// Exec execute raw sql
|
||||
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
||||
scope := s.NewScope(nil)
|
||||
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
|
||||
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
||||
scope.Raw(generatedSQL)
|
||||
return scope.Exec().db
|
||||
}
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (s *DB) Model(value interface{}) *DB {
|
||||
c := s.clone()
|
||||
c.Value = value
|
||||
return c
|
||||
}
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
func (s *DB) Table(name string) *DB {
|
||||
clone := s.clone()
|
||||
clone.search.Table(name)
|
||||
clone.Value = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (s *DB) Debug() *DB {
|
||||
return s.clone().LogMode(true)
|
||||
}
|
||||
|
||||
// Transaction start a transaction as a block,
|
||||
// return error will rollback, otherwise to commit.
|
||||
func (s *DB) Transaction(fc func(tx *DB) error) (err error) {
|
||||
|
||||
if _, ok := s.db.(*sql.Tx); ok {
|
||||
return fc(s)
|
||||
}
|
||||
|
||||
panicked := true
|
||||
tx := s.Begin()
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
err = fc(tx)
|
||||
|
||||
if err == nil {
|
||||
err = tx.Commit().Error
|
||||
}
|
||||
|
||||
panicked = false
|
||||
return
|
||||
}
|
||||
|
||||
// Begin begins a transaction
|
||||
func (s *DB) Begin() *DB {
|
||||
return s.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
}
|
||||
|
||||
// BeginTx begins a transaction with options
|
||||
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
||||
c := s.clone()
|
||||
if db, ok := c.db.(sqlDb); ok && db != nil {
|
||||
tx, err := db.BeginTx(ctx, opts)
|
||||
c.db = interface{}(tx).(SQLCommon)
|
||||
|
||||
c.dialect.SetDB(c.db)
|
||||
c.AddError(err)
|
||||
} else {
|
||||
c.AddError(ErrCantStartTransaction)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Commit commit a transaction
|
||||
func (s *DB) Commit() *DB {
|
||||
var emptySQLTx *sql.Tx
|
||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||
s.AddError(db.Commit())
|
||||
} else {
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Rollback rollback a transaction
|
||||
func (s *DB) Rollback() *DB {
|
||||
var emptySQLTx *sql.Tx
|
||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||
if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
s.AddError(err)
|
||||
}
|
||||
} else {
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// RollbackUnlessCommitted rollback a transaction if it has not yet been
|
||||
// committed.
|
||||
func (s *DB) RollbackUnlessCommitted() *DB {
|
||||
var emptySQLTx *sql.Tx
|
||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||
err := db.Rollback()
|
||||
// Ignore the error indicating that the transaction has already
|
||||
// been committed.
|
||||
if err != sql.ErrTxDone {
|
||||
s.AddError(err)
|
||||
}
|
||||
} else {
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// NewRecord check if value's primary key is blank
|
||||
func (s *DB) NewRecord(value interface{}) bool {
|
||||
return s.NewScope(value).PrimaryKeyZero()
|
||||
}
|
||||
|
||||
// RecordNotFound check if returning ErrRecordNotFound error
|
||||
func (s *DB) RecordNotFound() bool {
|
||||
for _, err := range s.GetErrors() {
|
||||
if err == ErrRecordNotFound {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateTable create table for models
|
||||
func (s *DB) CreateTable(models ...interface{}) *DB {
|
||||
db := s.Unscoped()
|
||||
for _, model := range models {
|
||||
db = db.NewScope(model).createTable().db
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// DropTable drop table for models
|
||||
func (s *DB) DropTable(values ...interface{}) *DB {
|
||||
db := s.clone()
|
||||
for _, value := range values {
|
||||
if tableName, ok := value.(string); ok {
|
||||
db = db.Table(tableName)
|
||||
}
|
||||
|
||||
db = db.NewScope(value).dropTable().db
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// DropTableIfExists drop table if it is exist
|
||||
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
||||
db := s.clone()
|
||||
for _, value := range values {
|
||||
if s.HasTable(value) {
|
||||
db.AddError(s.DropTable(value).Error)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// HasTable check has table or not
|
||||
func (s *DB) HasTable(value interface{}) bool {
|
||||
var (
|
||||
scope = s.NewScope(value)
|
||||
tableName string
|
||||
)
|
||||
|
||||
if name, ok := value.(string); ok {
|
||||
tableName = name
|
||||
} else {
|
||||
tableName = scope.TableName()
|
||||
}
|
||||
|
||||
has := scope.Dialect().HasTable(tableName)
|
||||
s.AddError(scope.db.Error)
|
||||
return has
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
|
||||
func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
||||
db := s.Unscoped()
|
||||
for _, value := range values {
|
||||
db = db.NewScope(value).autoMigrate().db
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// ModifyColumn modify column to type
|
||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.modifyColumn(column, typ)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// DropColumn drop a column
|
||||
func (s *DB) DropColumn(column string) *DB {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.dropColumn(column)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// AddIndex add index for columns with given name
|
||||
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
|
||||
scope := s.Unscoped().NewScope(s.Value)
|
||||
scope.addIndex(false, indexName, columns...)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// AddUniqueIndex add unique index for columns with given name
|
||||
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
|
||||
scope := s.Unscoped().NewScope(s.Value)
|
||||
scope.addIndex(true, indexName, columns...)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// RemoveIndex remove index with name
|
||||
func (s *DB) RemoveIndex(indexName string) *DB {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.removeIndex(indexName)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// AddForeignKey Add foreign key to the given scope, e.g:
|
||||
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// RemoveForeignKey Remove foreign key from the given scope, e.g:
|
||||
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
|
||||
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.removeForeignKey(field, dest)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
|
||||
func (s *DB) Association(column string) *Association {
|
||||
var err error
|
||||
var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value)
|
||||
|
||||
if primaryField := scope.PrimaryField(); primaryField.IsBlank {
|
||||
err = errors.New("primary key can't be nil")
|
||||
} else {
|
||||
if field, ok := scope.FieldByName(column); ok {
|
||||
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
||||
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||
} else {
|
||||
return &Association{scope: scope, column: column, field: field}
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||
}
|
||||
}
|
||||
|
||||
return &Association{Error: err}
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
||||
return s.clone().search.Preload(column, conditions...).db
|
||||
}
|
||||
|
||||
// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
|
||||
func (s *DB) Set(name string, value interface{}) *DB {
|
||||
return s.clone().InstantSet(name, value)
|
||||
}
|
||||
|
||||
// InstantSet instant set setting, will affect current db
|
||||
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
||||
s.values.Store(name, value)
|
||||
return s
|
||||
}
|
||||
|
||||
// Get get setting by name
|
||||
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||
value, ok = s.values.Load(name)
|
||||
return
|
||||
}
|
||||
|
||||
// SetJoinTableHandler set a model's join table handler for a relation
|
||||
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||
scope := s.NewScope(source)
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if field.Name == column || field.DBName == column {
|
||||
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
|
||||
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||
handler.Setup(field.Relationship, many2many, source, destination)
|
||||
field.Relationship.JoinTableHandler = handler
|
||||
if table := handler.Table(s); scope.Dialect().HasTable(table) {
|
||||
s.Table(table).AutoMigrate(handler)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddError add error to the db
|
||||
func (s *DB) AddError(err error) error {
|
||||
if err != nil {
|
||||
if err != ErrRecordNotFound {
|
||||
if s.logMode == defaultLogMode {
|
||||
go s.print("error", fileWithLineNum(), err)
|
||||
} else {
|
||||
s.log(err)
|
||||
}
|
||||
|
||||
errors := Errors(s.GetErrors())
|
||||
errors = errors.Add(err)
|
||||
if len(errors) > 1 {
|
||||
err = errors
|
||||
}
|
||||
}
|
||||
|
||||
s.Error = err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetErrors get happened errors from the db
|
||||
func (s *DB) GetErrors() []error {
|
||||
if errs, ok := s.Error.(Errors); ok {
|
||||
return errs
|
||||
} else if s.Error != nil {
|
||||
return []error{s.Error}
|
||||
}
|
||||
return []error{}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Private Methods For DB
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (s *DB) clone() *DB {
|
||||
db := &DB{
|
||||
db: s.db,
|
||||
parent: s.parent,
|
||||
logger: s.logger,
|
||||
logMode: s.logMode,
|
||||
Value: s.Value,
|
||||
Error: s.Error,
|
||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||
dialect: newDialect(s.dialect.GetName(), s.db),
|
||||
nowFuncOverride: s.nowFuncOverride,
|
||||
}
|
||||
|
||||
s.values.Range(func(k, v interface{}) bool {
|
||||
db.values.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if s.search == nil {
|
||||
db.search = &search{limit: -1, offset: -1}
|
||||
} else {
|
||||
db.search = s.search.clone()
|
||||
}
|
||||
|
||||
db.search.db = db
|
||||
return db
|
||||
}
|
||||
|
||||
func (s *DB) print(v ...interface{}) {
|
||||
s.logger.Print(v...)
|
||||
}
|
||||
|
||||
func (s *DB) log(v ...interface{}) {
|
||||
if s != nil && s.logMode == detailedLogMode {
|
||||
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
||||
if s.logMode == detailedLogMode {
|
||||
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `sql:"index"`
|
||||
}
|
||||
|
|
@ -0,0 +1,677 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
)
|
||||
|
||||
// DefaultTableNameHandler default table name handler
|
||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||
return defaultTableName
|
||||
}
|
||||
|
||||
// lock for mutating global cached model metadata
|
||||
var structsLock sync.Mutex
|
||||
|
||||
// global cache of model metadata
|
||||
var modelStructsMap sync.Map
|
||||
|
||||
// ModelStruct model definition
|
||||
type ModelStruct struct {
|
||||
PrimaryFields []*StructField
|
||||
StructFields []*StructField
|
||||
ModelType reflect.Type
|
||||
|
||||
defaultTableName string
|
||||
l sync.Mutex
|
||||
}
|
||||
|
||||
// TableName returns model's table name
|
||||
func (s *ModelStruct) TableName(db *DB) string {
|
||||
s.l.Lock()
|
||||
defer s.l.Unlock()
|
||||
|
||||
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
|
||||
// Set default table name
|
||||
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
||||
s.defaultTableName = tabler.TableName()
|
||||
} else {
|
||||
tableName := ToTableName(s.ModelType.Name())
|
||||
db.parent.RLock()
|
||||
if db == nil || (db.parent != nil && !db.parent.singularTable) {
|
||||
tableName = inflection.Plural(tableName)
|
||||
}
|
||||
db.parent.RUnlock()
|
||||
s.defaultTableName = tableName
|
||||
}
|
||||
}
|
||||
|
||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||
}
|
||||
|
||||
// StructField model field's struct definition
|
||||
type StructField struct {
|
||||
DBName string
|
||||
Name string
|
||||
Names []string
|
||||
IsPrimaryKey bool
|
||||
IsNormal bool
|
||||
IsIgnored bool
|
||||
IsScanner bool
|
||||
HasDefaultValue bool
|
||||
Tag reflect.StructTag
|
||||
TagSettings map[string]string
|
||||
Struct reflect.StructField
|
||||
IsForeignKey bool
|
||||
Relationship *Relationship
|
||||
|
||||
tagSettingsLock sync.RWMutex
|
||||
}
|
||||
|
||||
// TagSettingsSet Sets a tag in the tag settings map
|
||||
func (sf *StructField) TagSettingsSet(key, val string) {
|
||||
sf.tagSettingsLock.Lock()
|
||||
defer sf.tagSettingsLock.Unlock()
|
||||
sf.TagSettings[key] = val
|
||||
}
|
||||
|
||||
// TagSettingsGet returns a tag from the tag settings
|
||||
func (sf *StructField) TagSettingsGet(key string) (string, bool) {
|
||||
sf.tagSettingsLock.RLock()
|
||||
defer sf.tagSettingsLock.RUnlock()
|
||||
val, ok := sf.TagSettings[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// TagSettingsDelete deletes a tag
|
||||
func (sf *StructField) TagSettingsDelete(key string) {
|
||||
sf.tagSettingsLock.Lock()
|
||||
defer sf.tagSettingsLock.Unlock()
|
||||
delete(sf.TagSettings, key)
|
||||
}
|
||||
|
||||
func (sf *StructField) clone() *StructField {
|
||||
clone := &StructField{
|
||||
DBName: sf.DBName,
|
||||
Name: sf.Name,
|
||||
Names: sf.Names,
|
||||
IsPrimaryKey: sf.IsPrimaryKey,
|
||||
IsNormal: sf.IsNormal,
|
||||
IsIgnored: sf.IsIgnored,
|
||||
IsScanner: sf.IsScanner,
|
||||
HasDefaultValue: sf.HasDefaultValue,
|
||||
Tag: sf.Tag,
|
||||
TagSettings: map[string]string{},
|
||||
Struct: sf.Struct,
|
||||
IsForeignKey: sf.IsForeignKey,
|
||||
}
|
||||
|
||||
if sf.Relationship != nil {
|
||||
relationship := *sf.Relationship
|
||||
clone.Relationship = &relationship
|
||||
}
|
||||
|
||||
// copy the struct field tagSettings, they should be read-locked while they are copied
|
||||
sf.tagSettingsLock.Lock()
|
||||
defer sf.tagSettingsLock.Unlock()
|
||||
for key, value := range sf.TagSettings {
|
||||
clone.TagSettings[key] = value
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// Relationship described the relationship between models
|
||||
type Relationship struct {
|
||||
Kind string
|
||||
PolymorphicType string
|
||||
PolymorphicDBName string
|
||||
PolymorphicValue string
|
||||
ForeignFieldNames []string
|
||||
ForeignDBNames []string
|
||||
AssociationForeignFieldNames []string
|
||||
AssociationForeignDBNames []string
|
||||
JoinTableHandler JoinTableHandlerInterface
|
||||
}
|
||||
|
||||
func getForeignField(column string, fields []*StructField) *StructField {
|
||||
for _, field := range fields {
|
||||
if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelStruct get value's model struct, relationships based on struct and tag definition
|
||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
return scope.getModelStruct(scope, make([]*StructField, 0))
|
||||
}
|
||||
|
||||
func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct {
|
||||
var modelStruct ModelStruct
|
||||
// Scope value can't be nil
|
||||
if scope.Value == nil {
|
||||
return &modelStruct
|
||||
}
|
||||
|
||||
reflectType := reflect.ValueOf(scope.Value).Type()
|
||||
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
|
||||
// Scope value need to be a struct
|
||||
if reflectType.Kind() != reflect.Struct {
|
||||
return &modelStruct
|
||||
}
|
||||
|
||||
// Get Cached model struct
|
||||
isSingularTable := false
|
||||
if scope.db != nil && scope.db.parent != nil {
|
||||
scope.db.parent.RLock()
|
||||
isSingularTable = scope.db.parent.singularTable
|
||||
scope.db.parent.RUnlock()
|
||||
}
|
||||
|
||||
hashKey := struct {
|
||||
singularTable bool
|
||||
reflectType reflect.Type
|
||||
}{isSingularTable, reflectType}
|
||||
if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
|
||||
return value.(*ModelStruct)
|
||||
}
|
||||
|
||||
modelStruct.ModelType = reflectType
|
||||
|
||||
// Get all fields
|
||||
for i := 0; i < reflectType.NumField(); i++ {
|
||||
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
field := &StructField{
|
||||
Struct: fieldStruct,
|
||||
Name: fieldStruct.Name,
|
||||
Names: []string{fieldStruct.Name},
|
||||
Tag: fieldStruct.Tag,
|
||||
TagSettings: parseTagSetting(fieldStruct.Tag),
|
||||
}
|
||||
|
||||
// is ignored field
|
||||
if _, ok := field.TagSettingsGet("-"); ok {
|
||||
field.IsIgnored = true
|
||||
} else {
|
||||
if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
|
||||
field.IsPrimaryKey = true
|
||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||
}
|
||||
|
||||
if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
|
||||
field.HasDefaultValue = true
|
||||
}
|
||||
|
||||
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
|
||||
field.HasDefaultValue = true
|
||||
}
|
||||
|
||||
indirectType := fieldStruct.Type
|
||||
for indirectType.Kind() == reflect.Ptr {
|
||||
indirectType = indirectType.Elem()
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(indirectType).Interface()
|
||||
if _, isScanner := fieldValue.(sql.Scanner); isScanner {
|
||||
// is scanner
|
||||
field.IsScanner, field.IsNormal = true, true
|
||||
if indirectType.Kind() == reflect.Struct {
|
||||
for i := 0; i < indirectType.NumField(); i++ {
|
||||
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
|
||||
if _, ok := field.TagSettingsGet(key); !ok {
|
||||
field.TagSettingsSet(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if _, isTime := fieldValue.(*time.Time); isTime {
|
||||
// is time
|
||||
field.IsNormal = true
|
||||
} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
|
||||
// is embedded struct
|
||||
for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields {
|
||||
subField = subField.clone()
|
||||
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
||||
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
|
||||
subField.DBName = prefix + subField.DBName
|
||||
}
|
||||
|
||||
if subField.IsPrimaryKey {
|
||||
if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
|
||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
|
||||
} else {
|
||||
subField.IsPrimaryKey = false
|
||||
}
|
||||
}
|
||||
|
||||
if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
|
||||
if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
|
||||
newJoinTableHandler := &JoinTableHandler{}
|
||||
newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
|
||||
subField.Relationship.JoinTableHandler = newJoinTableHandler
|
||||
}
|
||||
}
|
||||
|
||||
modelStruct.StructFields = append(modelStruct.StructFields, subField)
|
||||
allFields = append(allFields, subField)
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
// build relationships
|
||||
switch indirectType.Kind() {
|
||||
case reflect.Slice:
|
||||
defer func(field *StructField) {
|
||||
var (
|
||||
relationship = &Relationship{}
|
||||
toScope = scope.New(reflect.New(field.Struct.Type).Interface())
|
||||
foreignKeys []string
|
||||
associationForeignKeys []string
|
||||
elemType = field.Struct.Type
|
||||
)
|
||||
|
||||
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
|
||||
foreignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
|
||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
|
||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
|
||||
relationship.Kind = "many_to_many"
|
||||
|
||||
{ // Foreign Keys for Source
|
||||
joinTableDBNames := []string{}
|
||||
|
||||
if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
|
||||
joinTableDBNames = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
// if no foreign keys defined with tag
|
||||
if len(foreignKeys) == 0 {
|
||||
for _, field := range modelStruct.PrimaryFields {
|
||||
foreignKeys = append(foreignKeys, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||
// source foreign keys (db names)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
|
||||
|
||||
// setup join table foreign keys for source
|
||||
if len(joinTableDBNames) > idx {
|
||||
// if defined join table's foreign key
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
||||
} else {
|
||||
defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{ // Foreign Keys for Association (Destination)
|
||||
associationJoinTableDBNames := []string{}
|
||||
|
||||
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
|
||||
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
// if no association foreign keys defined with tag
|
||||
if len(associationForeignKeys) == 0 {
|
||||
for _, field := range toScope.PrimaryFields() {
|
||||
associationForeignKeys = append(associationForeignKeys, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
for idx, name := range associationForeignKeys {
|
||||
if field, ok := toScope.FieldByName(name); ok {
|
||||
// association foreign keys (db names)
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
||||
|
||||
// setup join table foreign keys for association
|
||||
if len(associationJoinTableDBNames) > idx {
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
||||
} else {
|
||||
// join table foreign keys for association
|
||||
joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
joinTableHandler := JoinTableHandler{}
|
||||
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
||||
relationship.JoinTableHandler = &joinTableHandler
|
||||
field.Relationship = relationship
|
||||
} else {
|
||||
// User has many comments, associationType is User, comment use UserID as foreign key
|
||||
var associationType = reflectType.Name()
|
||||
var toFields = toScope.GetStructFields()
|
||||
relationship.Kind = "has_many"
|
||||
|
||||
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
|
||||
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
|
||||
// Toy use OwnerID, OwnerType ('dogs') as foreign key
|
||||
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||
associationType = polymorphic
|
||||
relationship.PolymorphicType = polymorphicType.Name
|
||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
|
||||
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
|
||||
relationship.PolymorphicValue = value
|
||||
} else {
|
||||
relationship.PolymorphicValue = scope.TableName()
|
||||
}
|
||||
polymorphicType.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
|
||||
// if no foreign keys defined with tag
|
||||
if len(foreignKeys) == 0 {
|
||||
// if no association foreign keys defined with tag
|
||||
if len(associationForeignKeys) == 0 {
|
||||
for _, field := range modelStruct.PrimaryFields {
|
||||
foreignKeys = append(foreignKeys, associationType+field.Name)
|
||||
associationForeignKeys = append(associationForeignKeys, field.Name)
|
||||
}
|
||||
} else {
|
||||
// generate foreign keys from defined association foreign keys
|
||||
for _, scopeFieldName := range associationForeignKeys {
|
||||
if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil {
|
||||
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
||||
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// generate association foreign keys from foreign keys
|
||||
if len(associationForeignKeys) == 0 {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if strings.HasPrefix(foreignKey, associationType) {
|
||||
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||
if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
|
||||
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||
associationForeignKeys = []string{rootScope.PrimaryKey()}
|
||||
}
|
||||
} else if len(foreignKeys) != len(associationForeignKeys) {
|
||||
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||
if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil {
|
||||
// mark field as foreignkey, use global lock to avoid race
|
||||
structsLock.Lock()
|
||||
foreignField.IsForeignKey = true
|
||||
structsLock.Unlock()
|
||||
|
||||
// association foreign keys
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
||||
|
||||
// association foreign keys
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
field.Relationship = relationship
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.IsNormal = true
|
||||
}
|
||||
}(field)
|
||||
case reflect.Struct:
|
||||
defer func(field *StructField) {
|
||||
var (
|
||||
// user has one profile, associationType is User, profile use UserID as foreign key
|
||||
// user belongs to profile, associationType is Profile, user use ProfileID as foreign key
|
||||
associationType = reflectType.Name()
|
||||
relationship = &Relationship{}
|
||||
toScope = scope.New(reflect.New(field.Struct.Type).Interface())
|
||||
toFields = toScope.GetStructFields()
|
||||
tagForeignKeys []string
|
||||
tagAssociationForeignKeys []string
|
||||
)
|
||||
|
||||
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
|
||||
tagForeignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
|
||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
|
||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
|
||||
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
|
||||
// Toy use OwnerID, OwnerType ('cats') as foreign key
|
||||
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||
associationType = polymorphic
|
||||
relationship.PolymorphicType = polymorphicType.Name
|
||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||
// if Cat has several different types of toys set name for each (instead of default 'cats')
|
||||
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
|
||||
relationship.PolymorphicValue = value
|
||||
} else {
|
||||
relationship.PolymorphicValue = scope.TableName()
|
||||
}
|
||||
polymorphicType.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
|
||||
// Has One
|
||||
{
|
||||
var foreignKeys = tagForeignKeys
|
||||
var associationForeignKeys = tagAssociationForeignKeys
|
||||
// if no foreign keys defined with tag
|
||||
if len(foreignKeys) == 0 {
|
||||
// if no association foreign keys defined with tag
|
||||
if len(associationForeignKeys) == 0 {
|
||||
for _, primaryField := range modelStruct.PrimaryFields {
|
||||
foreignKeys = append(foreignKeys, associationType+primaryField.Name)
|
||||
associationForeignKeys = append(associationForeignKeys, primaryField.Name)
|
||||
}
|
||||
} else {
|
||||
// generate foreign keys form association foreign keys
|
||||
for _, associationForeignKey := range tagAssociationForeignKeys {
|
||||
if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
|
||||
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
||||
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// generate association foreign keys from foreign keys
|
||||
if len(associationForeignKeys) == 0 {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if strings.HasPrefix(foreignKey, associationType) {
|
||||
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||
if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
|
||||
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||
associationForeignKeys = []string{rootScope.PrimaryKey()}
|
||||
}
|
||||
} else if len(foreignKeys) != len(associationForeignKeys) {
|
||||
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||
if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil {
|
||||
// mark field as foreignkey, use global lock to avoid race
|
||||
structsLock.Lock()
|
||||
foreignField.IsForeignKey = true
|
||||
structsLock.Unlock()
|
||||
|
||||
// association foreign keys
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
|
||||
|
||||
// association foreign keys
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
relationship.Kind = "has_one"
|
||||
field.Relationship = relationship
|
||||
} else {
|
||||
var foreignKeys = tagForeignKeys
|
||||
var associationForeignKeys = tagAssociationForeignKeys
|
||||
|
||||
if len(foreignKeys) == 0 {
|
||||
// generate foreign keys & association foreign keys
|
||||
if len(associationForeignKeys) == 0 {
|
||||
for _, primaryField := range toScope.PrimaryFields() {
|
||||
foreignKeys = append(foreignKeys, field.Name+primaryField.Name)
|
||||
associationForeignKeys = append(associationForeignKeys, primaryField.Name)
|
||||
}
|
||||
} else {
|
||||
// generate foreign keys with association foreign keys
|
||||
for _, associationForeignKey := range associationForeignKeys {
|
||||
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
|
||||
foreignKeys = append(foreignKeys, field.Name+foreignField.Name)
|
||||
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// generate foreign keys & association foreign keys
|
||||
if len(associationForeignKeys) == 0 {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if strings.HasPrefix(foreignKey, field.Name) {
|
||||
associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
|
||||
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
|
||||
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||
associationForeignKeys = []string{toScope.PrimaryKey()}
|
||||
}
|
||||
} else if len(foreignKeys) != len(associationForeignKeys) {
|
||||
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
|
||||
// mark field as foreignkey, use global lock to avoid race
|
||||
structsLock.Lock()
|
||||
foreignField.IsForeignKey = true
|
||||
structsLock.Unlock()
|
||||
|
||||
// association foreign keys
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
||||
|
||||
// source foreign keys
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
relationship.Kind = "belongs_to"
|
||||
field.Relationship = relationship
|
||||
}
|
||||
}
|
||||
}(field)
|
||||
default:
|
||||
field.IsNormal = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Even it is ignored, also possible to decode db value into the field
|
||||
if value, ok := field.TagSettingsGet("COLUMN"); ok {
|
||||
field.DBName = value
|
||||
} else {
|
||||
field.DBName = ToColumnName(fieldStruct.Name)
|
||||
}
|
||||
|
||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||
allFields = append(allFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelStruct.PrimaryFields) == 0 {
|
||||
if field := getForeignField("id", modelStruct.StructFields); field != nil {
|
||||
field.IsPrimaryKey = true
|
||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
modelStructsMap.Store(hashKey, &modelStruct)
|
||||
|
||||
return &modelStruct
|
||||
}
|
||||
|
||||
// GetStructFields get model's field structs
|
||||
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||
return scope.GetModelStruct().StructFields
|
||||
}
|
||||
|
||||
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
||||
setting := map[string]string{}
|
||||
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
||||
if str == "" {
|
||||
continue
|
||||
}
|
||||
tags := strings.Split(str, ";")
|
||||
for _, value := range tags {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||
if len(v) >= 2 {
|
||||
setting[k] = strings.Join(v[1:], ":")
|
||||
} else {
|
||||
setting[k] = k
|
||||
}
|
||||
}
|
||||
}
|
||||
return setting
|
||||
}
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Namer is a function type which is given a string and return a string
|
||||
type Namer func(string) string
|
||||
|
||||
// NamingStrategy represents naming strategies
|
||||
type NamingStrategy struct {
|
||||
DB Namer
|
||||
Table Namer
|
||||
Column Namer
|
||||
}
|
||||
|
||||
// TheNamingStrategy is being initialized with defaultNamingStrategy
|
||||
var TheNamingStrategy = &NamingStrategy{
|
||||
DB: defaultNamer,
|
||||
Table: defaultNamer,
|
||||
Column: defaultNamer,
|
||||
}
|
||||
|
||||
// AddNamingStrategy sets the naming strategy
|
||||
func AddNamingStrategy(ns *NamingStrategy) {
|
||||
if ns.DB == nil {
|
||||
ns.DB = defaultNamer
|
||||
}
|
||||
if ns.Table == nil {
|
||||
ns.Table = defaultNamer
|
||||
}
|
||||
if ns.Column == nil {
|
||||
ns.Column = defaultNamer
|
||||
}
|
||||
TheNamingStrategy = ns
|
||||
}
|
||||
|
||||
// DBName alters the given name by DB
|
||||
func (ns *NamingStrategy) DBName(name string) string {
|
||||
return ns.DB(name)
|
||||
}
|
||||
|
||||
// TableName alters the given name by Table
|
||||
func (ns *NamingStrategy) TableName(name string) string {
|
||||
return ns.Table(name)
|
||||
}
|
||||
|
||||
// ColumnName alters the given name by Column
|
||||
func (ns *NamingStrategy) ColumnName(name string) string {
|
||||
return ns.Column(name)
|
||||
}
|
||||
|
||||
// ToDBName convert string to db name
|
||||
func ToDBName(name string) string {
|
||||
return TheNamingStrategy.DBName(name)
|
||||
}
|
||||
|
||||
// ToTableName convert string to table name
|
||||
func ToTableName(name string) string {
|
||||
return TheNamingStrategy.TableName(name)
|
||||
}
|
||||
|
||||
// ToColumnName convert string to db name
|
||||
func ToColumnName(name string) string {
|
||||
return TheNamingStrategy.ColumnName(name)
|
||||
}
|
||||
|
||||
var smap = newSafeMap()
|
||||
|
||||
func defaultNamer(name string) string {
|
||||
const (
|
||||
lower = false
|
||||
upper = true
|
||||
)
|
||||
|
||||
if v := smap.Get(name); v != "" {
|
||||
return v
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var (
|
||||
value = commonInitialismsReplacer.Replace(name)
|
||||
buf = bytes.NewBufferString("")
|
||||
lastCase, currCase, nextCase, nextNumber bool
|
||||
)
|
||||
|
||||
for i, v := range value[:len(value)-1] {
|
||||
nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
||||
nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
|
||||
|
||||
if i > 0 {
|
||||
if currCase == upper {
|
||||
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
|
||||
buf.WriteRune(v)
|
||||
} else {
|
||||
if value[i-1] != '_' && value[i+1] != '_' {
|
||||
buf.WriteRune('_')
|
||||
}
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
} else {
|
||||
buf.WriteRune(v)
|
||||
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
|
||||
buf.WriteRune('_')
|
||||
}
|
||||
}
|
||||
} else {
|
||||
currCase = upper
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
lastCase = currCase
|
||||
currCase = nextCase
|
||||
}
|
||||
|
||||
buf.WriteByte(value[len(value)-1])
|
||||
|
||||
s := strings.ToLower(buf.String())
|
||||
smap.Set(name, s)
|
||||
return s
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,203 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type search struct {
|
||||
db *DB
|
||||
whereConditions []map[string]interface{}
|
||||
orConditions []map[string]interface{}
|
||||
notConditions []map[string]interface{}
|
||||
havingConditions []map[string]interface{}
|
||||
joinConditions []map[string]interface{}
|
||||
initAttrs []interface{}
|
||||
assignAttrs []interface{}
|
||||
selects map[string]interface{}
|
||||
omits []string
|
||||
orders []interface{}
|
||||
preload []searchPreload
|
||||
offset interface{}
|
||||
limit interface{}
|
||||
group string
|
||||
tableName string
|
||||
raw bool
|
||||
Unscoped bool
|
||||
ignoreOrderQuery bool
|
||||
}
|
||||
|
||||
type searchPreload struct {
|
||||
schema string
|
||||
conditions []interface{}
|
||||
}
|
||||
|
||||
func (s *search) clone() *search {
|
||||
clone := search{
|
||||
db: s.db,
|
||||
whereConditions: make([]map[string]interface{}, len(s.whereConditions)),
|
||||
orConditions: make([]map[string]interface{}, len(s.orConditions)),
|
||||
notConditions: make([]map[string]interface{}, len(s.notConditions)),
|
||||
havingConditions: make([]map[string]interface{}, len(s.havingConditions)),
|
||||
joinConditions: make([]map[string]interface{}, len(s.joinConditions)),
|
||||
initAttrs: make([]interface{}, len(s.initAttrs)),
|
||||
assignAttrs: make([]interface{}, len(s.assignAttrs)),
|
||||
selects: s.selects,
|
||||
omits: make([]string, len(s.omits)),
|
||||
orders: make([]interface{}, len(s.orders)),
|
||||
preload: make([]searchPreload, len(s.preload)),
|
||||
offset: s.offset,
|
||||
limit: s.limit,
|
||||
group: s.group,
|
||||
tableName: s.tableName,
|
||||
raw: s.raw,
|
||||
Unscoped: s.Unscoped,
|
||||
ignoreOrderQuery: s.ignoreOrderQuery,
|
||||
}
|
||||
for i, value := range s.whereConditions {
|
||||
clone.whereConditions[i] = value
|
||||
}
|
||||
for i, value := range s.orConditions {
|
||||
clone.orConditions[i] = value
|
||||
}
|
||||
for i, value := range s.notConditions {
|
||||
clone.notConditions[i] = value
|
||||
}
|
||||
for i, value := range s.havingConditions {
|
||||
clone.havingConditions[i] = value
|
||||
}
|
||||
for i, value := range s.joinConditions {
|
||||
clone.joinConditions[i] = value
|
||||
}
|
||||
for i, value := range s.initAttrs {
|
||||
clone.initAttrs[i] = value
|
||||
}
|
||||
for i, value := range s.assignAttrs {
|
||||
clone.assignAttrs[i] = value
|
||||
}
|
||||
for i, value := range s.omits {
|
||||
clone.omits[i] = value
|
||||
}
|
||||
for i, value := range s.orders {
|
||||
clone.orders[i] = value
|
||||
}
|
||||
for i, value := range s.preload {
|
||||
clone.preload[i] = value
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
func (s *search) Where(query interface{}, values ...interface{}) *search {
|
||||
s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Not(query interface{}, values ...interface{}) *search {
|
||||
s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Or(query interface{}, values ...interface{}) *search {
|
||||
s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Attrs(attrs ...interface{}) *search {
|
||||
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Assign(attrs ...interface{}) *search {
|
||||
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Order(value interface{}, reorder ...bool) *search {
|
||||
if len(reorder) > 0 && reorder[0] {
|
||||
s.orders = []interface{}{}
|
||||
}
|
||||
|
||||
if value != nil && value != "" {
|
||||
s.orders = append(s.orders, value)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Select(query interface{}, args ...interface{}) *search {
|
||||
s.selects = map[string]interface{}{"query": query, "args": args}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Omit(columns ...string) *search {
|
||||
s.omits = columns
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Limit(limit interface{}) *search {
|
||||
s.limit = limit
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Offset(offset interface{}) *search {
|
||||
s.offset = offset
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Group(query string) *search {
|
||||
s.group = s.getInterfaceAsSQL(query)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Having(query interface{}, values ...interface{}) *search {
|
||||
if val, ok := query.(*SqlExpr); ok {
|
||||
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
|
||||
} else {
|
||||
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Joins(query string, values ...interface{}) *search {
|
||||
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Preload(schema string, values ...interface{}) *search {
|
||||
var preloads []searchPreload
|
||||
for _, preload := range s.preload {
|
||||
if preload.schema != schema {
|
||||
preloads = append(preloads, preload)
|
||||
}
|
||||
}
|
||||
preloads = append(preloads, searchPreload{schema, values})
|
||||
s.preload = preloads
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Raw(b bool) *search {
|
||||
s.raw = b
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) unscoped() *search {
|
||||
s.Unscoped = true
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Table(name string) *search {
|
||||
s.tableName = name
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
|
||||
switch value.(type) {
|
||||
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
str = fmt.Sprintf("%v", value)
|
||||
default:
|
||||
s.db.AddError(ErrInvalidSQL)
|
||||
}
|
||||
|
||||
if str == "-1" {
|
||||
return ""
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
dialects=("postgres" "mysql" "mssql" "sqlite")
|
||||
|
||||
for dialect in "${dialects[@]}" ; do
|
||||
DEBUG=false GORM_DIALECT=${dialect} go test
|
||||
done
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NowFunc returns current time, this function is exported in order to be able
|
||||
// to give the flexibility to the developer to customize it according to their
|
||||
// needs, e.g:
|
||||
// gorm.NowFunc = func() time.Time {
|
||||
// return time.Now().UTC()
|
||||
// }
|
||||
var NowFunc = func() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// Copied from golint
|
||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
var commonInitialismsReplacer *strings.Replacer
|
||||
|
||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
|
||||
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
|
||||
|
||||
func init() {
|
||||
var commonInitialismsForReplacer []string
|
||||
for _, initialism := range commonInitialisms {
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||
}
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
}
|
||||
|
||||
type safeMap struct {
|
||||
m map[string]string
|
||||
l *sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *safeMap) Set(key string, value string) {
|
||||
s.l.Lock()
|
||||
defer s.l.Unlock()
|
||||
s.m[key] = value
|
||||
}
|
||||
|
||||
func (s *safeMap) Get(key string) string {
|
||||
s.l.RLock()
|
||||
defer s.l.RUnlock()
|
||||
return s.m[key]
|
||||
}
|
||||
|
||||
func newSafeMap() *safeMap {
|
||||
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
||||
}
|
||||
|
||||
// SQL expression
|
||||
type SqlExpr struct {
|
||||
expr string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// Expr generate raw SQL expression, for example:
|
||||
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
|
||||
func Expr(expression string, args ...interface{}) *SqlExpr {
|
||||
return &SqlExpr{expr: expression, args: args}
|
||||
}
|
||||
|
||||
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||
for reflectValue.Kind() == reflect.Ptr {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
return reflectValue
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
}
|
||||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
|
||||
func toQueryValues(values [][]interface{}) (results []interface{}) {
|
||||
for _, value := range values {
|
||||
for _, v := range value {
|
||||
results = append(results, v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
|
||||
return fmt.Sprintf("%v:%v", file, line)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isBlank(value reflect.Value) bool {
|
||||
switch value.Kind() {
|
||||
case reflect.String:
|
||||
return value.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !value.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return value.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return value.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return value.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return value.IsNil()
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||
}
|
||||
|
||||
func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||
if len(attrs) > 1 {
|
||||
if str, ok := attrs[0].(string); ok {
|
||||
result = map[string]interface{}{str: attrs[1]}
|
||||
}
|
||||
} else if len(attrs) == 1 {
|
||||
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
|
||||
if attr, ok := attrs[0].(interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func equalAsString(a interface{}, b interface{}) bool {
|
||||
return toString(a) == toString(b)
|
||||
}
|
||||
|
||||
func toString(str interface{}) string {
|
||||
if values, ok := str.([]interface{}); ok {
|
||||
var results []string
|
||||
for _, value := range values {
|
||||
results = append(results, toString(value))
|
||||
}
|
||||
return strings.Join(results, "_")
|
||||
} else if bytes, ok := str.([]byte); ok {
|
||||
return string(bytes)
|
||||
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
|
||||
return fmt.Sprintf("%v", reflectValue.Interface())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func makeSlice(elemType reflect.Type) interface{} {
|
||||
if elemType.Kind() == reflect.Slice {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
sliceType := reflect.SliceOf(elemType)
|
||||
slice := reflect.New(sliceType)
|
||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||
return slice.Interface()
|
||||
}
|
||||
|
||||
func strInSlice(a string, list []string) bool {
|
||||
for _, b := range list {
|
||||
if b == a {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getValueFromFields return given fields's value
|
||||
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
|
||||
// If value is a nil pointer, Indirect returns a zero Value!
|
||||
// Therefor we need to check for a zero value,
|
||||
// as FieldByName could panic
|
||||
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
||||
for _, fieldName := range fieldNames {
|
||||
if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() {
|
||||
result := fieldValue.Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addExtraSpaceIfExist(str string) string {
|
||||
if str != "" {
|
||||
return " " + str
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
# use the default golang container from Docker Hub
|
||||
box: golang
|
||||
|
||||
services:
|
||||
- name: mariadb
|
||||
id: mariadb:latest
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
- name: mysql
|
||||
id: mysql:latest
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
- name: mysql57
|
||||
id: mysql:5.7
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
- name: mysql56
|
||||
id: mysql:5.6
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
- name: postgres
|
||||
id: postgres:latest
|
||||
env:
|
||||
POSTGRES_USER: gorm
|
||||
POSTGRES_PASSWORD: gorm
|
||||
POSTGRES_DB: gorm
|
||||
- name: postgres96
|
||||
id: postgres:9.6
|
||||
env:
|
||||
POSTGRES_USER: gorm
|
||||
POSTGRES_PASSWORD: gorm
|
||||
POSTGRES_DB: gorm
|
||||
- name: postgres95
|
||||
id: postgres:9.5
|
||||
env:
|
||||
POSTGRES_USER: gorm
|
||||
POSTGRES_PASSWORD: gorm
|
||||
POSTGRES_DB: gorm
|
||||
- name: postgres94
|
||||
id: postgres:9.4
|
||||
env:
|
||||
POSTGRES_USER: gorm
|
||||
POSTGRES_PASSWORD: gorm
|
||||
POSTGRES_DB: gorm
|
||||
- name: postgres93
|
||||
id: postgres:9.3
|
||||
env:
|
||||
POSTGRES_USER: gorm
|
||||
POSTGRES_PASSWORD: gorm
|
||||
POSTGRES_DB: gorm
|
||||
- name: mssql
|
||||
id: mcmoe/mssqldocker:latest
|
||||
env:
|
||||
ACCEPT_EULA: Y
|
||||
SA_PASSWORD: LoremIpsum86
|
||||
MSSQL_DB: gorm
|
||||
MSSQL_USER: gorm
|
||||
MSSQL_PASSWORD: LoremIpsum86
|
||||
|
||||
# The steps that will be executed in the build pipeline
|
||||
build:
|
||||
# The steps that will be executed on build
|
||||
steps:
|
||||
# Sets the go workspace and places you package
|
||||
# at the right place in the workspace tree
|
||||
- setup-go-workspace
|
||||
|
||||
# Gets the dependencies
|
||||
- script:
|
||||
name: go get
|
||||
code: |
|
||||
cd $WERCKER_SOURCE_DIR
|
||||
go version
|
||||
go get -t -v ./...
|
||||
|
||||
# Build the project
|
||||
- script:
|
||||
name: go build
|
||||
code: |
|
||||
go build ./...
|
||||
|
||||
# Test the project
|
||||
- script:
|
||||
name: test sqlite
|
||||
code: |
|
||||
go test -race -v ./...
|
||||
|
||||
- script:
|
||||
name: test mariadb
|
||||
code: |
|
||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test mysql
|
||||
code: |
|
||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test mysql5.7
|
||||
code: |
|
||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test mysql5.6
|
||||
code: |
|
||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test postgres
|
||||
code: |
|
||||
GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test postgres96
|
||||
code: |
|
||||
GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test postgres95
|
||||
code: |
|
||||
GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test postgres94
|
||||
code: |
|
||||
GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: test postgres93
|
||||
code: |
|
||||
GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||
|
||||
- script:
|
||||
name: codecov
|
||||
code: |
|
||||
go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
bash <(curl -s https://codecov.io/bash)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 - Jinzhu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Inflection
|
||||
|
||||
Inflection pluralizes and singularizes English nouns
|
||||
|
||||
[](https://app.wercker.com/project/byKey/f8c7432b097d1f4ce636879670be0930)
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```go
|
||||
inflection.Plural("person") => "people"
|
||||
inflection.Plural("Person") => "People"
|
||||
inflection.Plural("PERSON") => "PEOPLE"
|
||||
inflection.Plural("bus") => "buses"
|
||||
inflection.Plural("BUS") => "BUSES"
|
||||
inflection.Plural("Bus") => "Buses"
|
||||
|
||||
inflection.Singular("people") => "person"
|
||||
inflection.Singular("People") => "Person"
|
||||
inflection.Singular("PEOPLE") => "PERSON"
|
||||
inflection.Singular("buses") => "bus"
|
||||
inflection.Singular("BUSES") => "BUS"
|
||||
inflection.Singular("Buses") => "Bus"
|
||||
|
||||
inflection.Plural("FancyPerson") => "FancyPeople"
|
||||
inflection.Singular("FancyPeople") => "FancyPerson"
|
||||
```
|
||||
|
||||
## Register Rules
|
||||
|
||||
Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb)
|
||||
|
||||
If you want to register more rules, follow:
|
||||
|
||||
```
|
||||
inflection.AddUncountable("fish")
|
||||
inflection.AddIrregular("person", "people")
|
||||
inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses"
|
||||
inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS"
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
You can help to make the project better, check out [http://gorm.io/contribute.html](http://gorm.io/contribute.html) for things you can do.
|
||||
|
||||
## Author
|
||||
|
||||
**jinzhu**
|
||||
|
||||
* <http://github.com/jinzhu>
|
||||
* <wosmvp@gmail.com>
|
||||
* <http://twitter.com/zhangjinzhu>
|
||||
|
||||
## License
|
||||
|
||||
Released under the [MIT License](http://www.opensource.org/licenses/MIT).
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
/*
|
||||
Package inflection pluralizes and singularizes English nouns.
|
||||
|
||||
inflection.Plural("person") => "people"
|
||||
inflection.Plural("Person") => "People"
|
||||
inflection.Plural("PERSON") => "PEOPLE"
|
||||
|
||||
inflection.Singular("people") => "person"
|
||||
inflection.Singular("People") => "Person"
|
||||
inflection.Singular("PEOPLE") => "PERSON"
|
||||
|
||||
inflection.Plural("FancyPerson") => "FancydPeople"
|
||||
inflection.Singular("FancyPeople") => "FancydPerson"
|
||||
|
||||
Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb)
|
||||
|
||||
If you want to register more rules, follow:
|
||||
|
||||
inflection.AddUncountable("fish")
|
||||
inflection.AddIrregular("person", "people")
|
||||
inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses"
|
||||
inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS"
|
||||
*/
|
||||
package inflection
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type inflection struct {
|
||||
regexp *regexp.Regexp
|
||||
replace string
|
||||
}
|
||||
|
||||
// Regular is a regexp find replace inflection
|
||||
type Regular struct {
|
||||
find string
|
||||
replace string
|
||||
}
|
||||
|
||||
// Irregular is a hard replace inflection,
|
||||
// containing both singular and plural forms
|
||||
type Irregular struct {
|
||||
singular string
|
||||
plural string
|
||||
}
|
||||
|
||||
// RegularSlice is a slice of Regular inflections
|
||||
type RegularSlice []Regular
|
||||
|
||||
// IrregularSlice is a slice of Irregular inflections
|
||||
type IrregularSlice []Irregular
|
||||
|
||||
var pluralInflections = RegularSlice{
|
||||
{"([a-z])$", "${1}s"},
|
||||
{"s$", "s"},
|
||||
{"^(ax|test)is$", "${1}es"},
|
||||
{"(octop|vir)us$", "${1}i"},
|
||||
{"(octop|vir)i$", "${1}i"},
|
||||
{"(alias|status)$", "${1}es"},
|
||||
{"(bu)s$", "${1}ses"},
|
||||
{"(buffal|tomat)o$", "${1}oes"},
|
||||
{"([ti])um$", "${1}a"},
|
||||
{"([ti])a$", "${1}a"},
|
||||
{"sis$", "ses"},
|
||||
{"(?:([^f])fe|([lr])f)$", "${1}${2}ves"},
|
||||
{"(hive)$", "${1}s"},
|
||||
{"([^aeiouy]|qu)y$", "${1}ies"},
|
||||
{"(x|ch|ss|sh)$", "${1}es"},
|
||||
{"(matr|vert|ind)(?:ix|ex)$", "${1}ices"},
|
||||
{"^(m|l)ouse$", "${1}ice"},
|
||||
{"^(m|l)ice$", "${1}ice"},
|
||||
{"^(ox)$", "${1}en"},
|
||||
{"^(oxen)$", "${1}"},
|
||||
{"(quiz)$", "${1}zes"},
|
||||
}
|
||||
|
||||
var singularInflections = RegularSlice{
|
||||
{"s$", ""},
|
||||
{"(ss)$", "${1}"},
|
||||
{"(n)ews$", "${1}ews"},
|
||||
{"([ti])a$", "${1}um"},
|
||||
{"((a)naly|(b)a|(d)iagno|(p)arenthe|(p)rogno|(s)ynop|(t)he)(sis|ses)$", "${1}sis"},
|
||||
{"(^analy)(sis|ses)$", "${1}sis"},
|
||||
{"([^f])ves$", "${1}fe"},
|
||||
{"(hive)s$", "${1}"},
|
||||
{"(tive)s$", "${1}"},
|
||||
{"([lr])ves$", "${1}f"},
|
||||
{"([^aeiouy]|qu)ies$", "${1}y"},
|
||||
{"(s)eries$", "${1}eries"},
|
||||
{"(m)ovies$", "${1}ovie"},
|
||||
{"(c)ookies$", "${1}ookie"},
|
||||
{"(x|ch|ss|sh)es$", "${1}"},
|
||||
{"^(m|l)ice$", "${1}ouse"},
|
||||
{"(bus)(es)?$", "${1}"},
|
||||
{"(o)es$", "${1}"},
|
||||
{"(shoe)s$", "${1}"},
|
||||
{"(cris|test)(is|es)$", "${1}is"},
|
||||
{"^(a)x[ie]s$", "${1}xis"},
|
||||
{"(octop|vir)(us|i)$", "${1}us"},
|
||||
{"(alias|status)(es)?$", "${1}"},
|
||||
{"^(ox)en", "${1}"},
|
||||
{"(vert|ind)ices$", "${1}ex"},
|
||||
{"(matr)ices$", "${1}ix"},
|
||||
{"(quiz)zes$", "${1}"},
|
||||
{"(database)s$", "${1}"},
|
||||
}
|
||||
|
||||
var irregularInflections = IrregularSlice{
|
||||
{"person", "people"},
|
||||
{"man", "men"},
|
||||
{"child", "children"},
|
||||
{"sex", "sexes"},
|
||||
{"move", "moves"},
|
||||
{"mombie", "mombies"},
|
||||
}
|
||||
|
||||
var uncountableInflections = []string{"equipment", "information", "rice", "money", "species", "series", "fish", "sheep", "jeans", "police"}
|
||||
|
||||
var compiledPluralMaps []inflection
|
||||
var compiledSingularMaps []inflection
|
||||
|
||||
func compile() {
|
||||
compiledPluralMaps = []inflection{}
|
||||
compiledSingularMaps = []inflection{}
|
||||
for _, uncountable := range uncountableInflections {
|
||||
inf := inflection{
|
||||
regexp: regexp.MustCompile("^(?i)(" + uncountable + ")$"),
|
||||
replace: "${1}",
|
||||
}
|
||||
compiledPluralMaps = append(compiledPluralMaps, inf)
|
||||
compiledSingularMaps = append(compiledSingularMaps, inf)
|
||||
}
|
||||
|
||||
for _, value := range irregularInflections {
|
||||
infs := []inflection{
|
||||
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.singular) + "$"), replace: strings.ToUpper(value.plural)},
|
||||
inflection{regexp: regexp.MustCompile(strings.Title(value.singular) + "$"), replace: strings.Title(value.plural)},
|
||||
inflection{regexp: regexp.MustCompile(value.singular + "$"), replace: value.plural},
|
||||
}
|
||||
compiledPluralMaps = append(compiledPluralMaps, infs...)
|
||||
}
|
||||
|
||||
for _, value := range irregularInflections {
|
||||
infs := []inflection{
|
||||
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.plural) + "$"), replace: strings.ToUpper(value.singular)},
|
||||
inflection{regexp: regexp.MustCompile(strings.Title(value.plural) + "$"), replace: strings.Title(value.singular)},
|
||||
inflection{regexp: regexp.MustCompile(value.plural + "$"), replace: value.singular},
|
||||
}
|
||||
compiledSingularMaps = append(compiledSingularMaps, infs...)
|
||||
}
|
||||
|
||||
for i := len(pluralInflections) - 1; i >= 0; i-- {
|
||||
value := pluralInflections[i]
|
||||
infs := []inflection{
|
||||
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
|
||||
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
|
||||
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
|
||||
}
|
||||
compiledPluralMaps = append(compiledPluralMaps, infs...)
|
||||
}
|
||||
|
||||
for i := len(singularInflections) - 1; i >= 0; i-- {
|
||||
value := singularInflections[i]
|
||||
infs := []inflection{
|
||||
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
|
||||
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
|
||||
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
|
||||
}
|
||||
compiledSingularMaps = append(compiledSingularMaps, infs...)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
compile()
|
||||
}
|
||||
|
||||
// AddPlural adds a plural inflection
|
||||
func AddPlural(find, replace string) {
|
||||
pluralInflections = append(pluralInflections, Regular{find, replace})
|
||||
compile()
|
||||
}
|
||||
|
||||
// AddSingular adds a singular inflection
|
||||
func AddSingular(find, replace string) {
|
||||
singularInflections = append(singularInflections, Regular{find, replace})
|
||||
compile()
|
||||
}
|
||||
|
||||
// AddIrregular adds an irregular inflection
|
||||
func AddIrregular(singular, plural string) {
|
||||
irregularInflections = append(irregularInflections, Irregular{singular, plural})
|
||||
compile()
|
||||
}
|
||||
|
||||
// AddUncountable adds an uncountable inflection
|
||||
func AddUncountable(values ...string) {
|
||||
uncountableInflections = append(uncountableInflections, values...)
|
||||
compile()
|
||||
}
|
||||
|
||||
// GetPlural retrieves the plural inflection values
|
||||
func GetPlural() RegularSlice {
|
||||
plurals := make(RegularSlice, len(pluralInflections))
|
||||
copy(plurals, pluralInflections)
|
||||
return plurals
|
||||
}
|
||||
|
||||
// GetSingular retrieves the singular inflection values
|
||||
func GetSingular() RegularSlice {
|
||||
singulars := make(RegularSlice, len(singularInflections))
|
||||
copy(singulars, singularInflections)
|
||||
return singulars
|
||||
}
|
||||
|
||||
// GetIrregular retrieves the irregular inflection values
|
||||
func GetIrregular() IrregularSlice {
|
||||
irregular := make(IrregularSlice, len(irregularInflections))
|
||||
copy(irregular, irregularInflections)
|
||||
return irregular
|
||||
}
|
||||
|
||||
// GetUncountable retrieves the uncountable inflection values
|
||||
func GetUncountable() []string {
|
||||
uncountables := make([]string, len(uncountableInflections))
|
||||
copy(uncountables, uncountableInflections)
|
||||
return uncountables
|
||||
}
|
||||
|
||||
// SetPlural sets the plural inflections slice
|
||||
func SetPlural(inflections RegularSlice) {
|
||||
pluralInflections = inflections
|
||||
compile()
|
||||
}
|
||||
|
||||
// SetSingular sets the singular inflections slice
|
||||
func SetSingular(inflections RegularSlice) {
|
||||
singularInflections = inflections
|
||||
compile()
|
||||
}
|
||||
|
||||
// SetIrregular sets the irregular inflections slice
|
||||
func SetIrregular(inflections IrregularSlice) {
|
||||
irregularInflections = inflections
|
||||
compile()
|
||||
}
|
||||
|
||||
// SetUncountable sets the uncountable inflections slice
|
||||
func SetUncountable(inflections []string) {
|
||||
uncountableInflections = inflections
|
||||
compile()
|
||||
}
|
||||
|
||||
// Plural converts a word to its plural form
|
||||
func Plural(str string) string {
|
||||
for _, inflection := range compiledPluralMaps {
|
||||
if inflection.regexp.MatchString(str) {
|
||||
return inflection.regexp.ReplaceAllString(str, inflection.replace)
|
||||
}
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// Singular converts a word to its singular form
|
||||
func Singular(str string) string {
|
||||
for _, inflection := range compiledSingularMaps {
|
||||
if inflection.regexp.MatchString(str) {
|
||||
return inflection.regexp.ReplaceAllString(str, inflection.replace)
|
||||
}
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
box: golang
|
||||
|
||||
build:
|
||||
steps:
|
||||
- setup-go-workspace
|
||||
|
||||
# Gets the dependencies
|
||||
- script:
|
||||
name: go get
|
||||
code: |
|
||||
go get
|
||||
|
||||
# Build the project
|
||||
- script:
|
||||
name: go build
|
||||
code: |
|
||||
go build ./...
|
||||
|
||||
# Test the project
|
||||
- script:
|
||||
name: go test
|
||||
code: |
|
||||
go test ./...
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
guard 'gotest' do
|
||||
watch(%r{\.go$})
|
||||
end
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
## Now
|
||||
|
||||
Now is a time toolkit for golang
|
||||
|
||||
[](https://goreportcard.com/report/github.com/jinzhu/now)
|
||||
[](https://github.com/jinzhu/now/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
## Install
|
||||
|
||||
```
|
||||
go get -u github.com/jinzhu/now
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Calculating time based on current time
|
||||
|
||||
```go
|
||||
import "github.com/jinzhu/now"
|
||||
|
||||
time.Now() // 2013-11-18 17:51:49.123456789 Mon
|
||||
|
||||
now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon
|
||||
now.BeginningOfHour() // 2013-11-18 17:00:00 Mon
|
||||
now.BeginningOfDay() // 2013-11-18 00:00:00 Mon
|
||||
now.BeginningOfWeek() // 2013-11-17 00:00:00 Sun
|
||||
now.BeginningOfMonth() // 2013-11-01 00:00:00 Fri
|
||||
now.BeginningOfQuarter() // 2013-10-01 00:00:00 Tue
|
||||
now.BeginningOfYear() // 2013-01-01 00:00:00 Tue
|
||||
|
||||
now.EndOfMinute() // 2013-11-18 17:51:59.999999999 Mon
|
||||
now.EndOfHour() // 2013-11-18 17:59:59.999999999 Mon
|
||||
now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon
|
||||
now.EndOfWeek() // 2013-11-23 23:59:59.999999999 Sat
|
||||
now.EndOfMonth() // 2013-11-30 23:59:59.999999999 Sat
|
||||
now.EndOfQuarter() // 2013-12-31 23:59:59.999999999 Tue
|
||||
now.EndOfYear() // 2013-12-31 23:59:59.999999999 Tue
|
||||
|
||||
now.WeekStartDay = time.Monday // Set Monday as first day, default is Sunday
|
||||
now.EndOfWeek() // 2013-11-24 23:59:59.999999999 Sun
|
||||
```
|
||||
|
||||
Calculating time based on another time
|
||||
|
||||
```go
|
||||
t := time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.Now().Location())
|
||||
now.With(t).EndOfMonth() // 2013-02-28 23:59:59.999999999 Thu
|
||||
```
|
||||
|
||||
Calculating time based on configuration
|
||||
|
||||
```go
|
||||
location, err := time.LoadLocation("Asia/Shanghai")
|
||||
|
||||
myConfig := &now.Config{
|
||||
WeekStartDay: time.Monday,
|
||||
TimeLocation: location,
|
||||
TimeFormats: []string{"2006-01-02 15:04:05"},
|
||||
}
|
||||
|
||||
t := time.Date(2013, 11, 18, 17, 51, 49, 123456789, time.Now().Location()) // // 2013-11-18 17:51:49.123456789 Mon
|
||||
myConfig.With(t).BeginningOfWeek() // 2013-11-18 00:00:00 Mon
|
||||
|
||||
myConfig.Parse("2002-10-12 22:14:01") // 2002-10-12 22:14:01
|
||||
myConfig.Parse("2002-10-12 22:14") // returns error 'can't parse string as time: 2002-10-12 22:14'
|
||||
```
|
||||
|
||||
### Monday/Sunday
|
||||
|
||||
Don't be bothered with the `WeekStartDay` setting, you can use `Monday`, `Sunday`
|
||||
|
||||
```go
|
||||
now.Monday() // 2013-11-18 00:00:00 Mon
|
||||
now.Monday("17:44") // 2013-11-18 17:44:00 Mon
|
||||
now.Sunday() // 2013-11-24 00:00:00 Sun (Next Sunday)
|
||||
now.Sunday("18:19:24") // 2013-11-24 18:19:24 Sun (Next Sunday)
|
||||
now.EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of next Sunday)
|
||||
|
||||
t := time.Date(2013, 11, 24, 17, 51, 49, 123456789, time.Now().Location()) // 2013-11-24 17:51:49.123456789 Sun
|
||||
now.With(t).Monday() // 2013-11-18 00:00:00 Mon (Last Monday if today is Sunday)
|
||||
now.With(t).Monday("17:44") // 2013-11-18 17:44:00 Mon (Last Monday if today is Sunday)
|
||||
now.With(t).Sunday() // 2013-11-24 00:00:00 Sun (Beginning Of Today if today is Sunday)
|
||||
now.With(t).Sunday("18:19:24") // 2013-11-24 18:19:24 Sun (Beginning Of Today if today is Sunday)
|
||||
now.With(t).EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of Today if today is Sunday)
|
||||
```
|
||||
|
||||
### Parse String to Time
|
||||
|
||||
```go
|
||||
time.Now() // 2013-11-18 17:51:49.123456789 Mon
|
||||
|
||||
// Parse(string) (time.Time, error)
|
||||
t, err := now.Parse("2017") // 2017-01-01 00:00:00, nil
|
||||
t, err := now.Parse("2017-10") // 2017-10-01 00:00:00, nil
|
||||
t, err := now.Parse("2017-10-13") // 2017-10-13 00:00:00, nil
|
||||
t, err := now.Parse("1999-12-12 12") // 1999-12-12 12:00:00, nil
|
||||
t, err := now.Parse("1999-12-12 12:20") // 1999-12-12 12:20:00, nil
|
||||
t, err := now.Parse("1999-12-12 12:20:21") // 1999-12-12 12:20:21, nil
|
||||
t, err := now.Parse("10-13") // 2013-10-13 00:00:00, nil
|
||||
t, err := now.Parse("12:20") // 2013-11-18 12:20:00, nil
|
||||
t, err := now.Parse("12:20:13") // 2013-11-18 12:20:13, nil
|
||||
t, err := now.Parse("14") // 2013-11-18 14:00:00, nil
|
||||
t, err := now.Parse("99:99") // 2013-11-18 12:20:00, Can't parse string as time: 99:99
|
||||
|
||||
// MustParse must parse string to time or it will panic
|
||||
now.MustParse("2013-01-13") // 2013-01-13 00:00:00
|
||||
now.MustParse("02-17") // 2013-02-17 00:00:00
|
||||
now.MustParse("2-17") // 2013-02-17 00:00:00
|
||||
now.MustParse("8") // 2013-11-18 08:00:00
|
||||
now.MustParse("2002-10-12 22:14") // 2002-10-12 22:14:00
|
||||
now.MustParse("99:99") // panic: Can't parse string as time: 99:99
|
||||
```
|
||||
|
||||
Extend `now` to support more formats is quite easy, just update `now.TimeFormats` with other time layouts, e.g:
|
||||
|
||||
```go
|
||||
now.TimeFormats = append(now.TimeFormats, "02 Jan 2006 15:04")
|
||||
```
|
||||
|
||||
Please send me pull requests if you want a format to be supported officially
|
||||
|
||||
## Contributing
|
||||
|
||||
You can help to make the project better, check out [http://gorm.io/contribute.html](http://gorm.io/contribute.html) for things you can do.
|
||||
|
||||
# Author
|
||||
|
||||
**jinzhu**
|
||||
|
||||
* <http://github.com/jinzhu>
|
||||
* <wosmvp@gmail.com>
|
||||
* <http://twitter.com/zhangjinzhu>
|
||||
|
||||
## License
|
||||
|
||||
Released under the [MIT License](http://www.opensource.org/licenses/MIT).
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
// Package now is a time toolkit for golang.
|
||||
//
|
||||
// More details README here: https://github.com/jinzhu/now
|
||||
//
|
||||
// import "github.com/jinzhu/now"
|
||||
//
|
||||
// now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon
|
||||
// now.BeginningOfDay() // 2013-11-18 00:00:00 Mon
|
||||
// now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon
|
||||
package now
|
||||
|
||||
import "time"
|
||||
|
||||
// WeekStartDay set week start day, default is sunday
|
||||
var WeekStartDay = time.Sunday
|
||||
|
||||
// TimeFormats default time formats will be parsed as
|
||||
var TimeFormats = []string{
|
||||
"2006", "2006-1", "2006-1-2", "2006-1-2 15", "2006-1-2 15:4", "2006-1-2 15:4:5", "1-2",
|
||||
"15:4:5", "15:4", "15",
|
||||
"15:4:5 Jan 2, 2006 MST", "2006-01-02 15:04:05.999999999 -0700 MST", "2006-01-02T15:04:05Z0700", "2006-01-02T15:04:05Z07",
|
||||
"2006.1.2", "2006.1.2 15:04:05", "2006.01.02", "2006.01.02 15:04:05", "2006.01.02 15:04:05.999999999",
|
||||
"1/2/2006", "1/2/2006 15:4:5", "2006/01/02", "20060102", "2006/01/02 15:04:05",
|
||||
time.ANSIC, time.UnixDate, time.RubyDate, time.RFC822, time.RFC822Z, time.RFC850,
|
||||
time.RFC1123, time.RFC1123Z, time.RFC3339, time.RFC3339Nano,
|
||||
time.Kitchen, time.Stamp, time.StampMilli, time.StampMicro, time.StampNano,
|
||||
}
|
||||
|
||||
// Config configuration for now package
|
||||
type Config struct {
|
||||
WeekStartDay time.Weekday
|
||||
TimeLocation *time.Location
|
||||
TimeFormats []string
|
||||
}
|
||||
|
||||
// DefaultConfig default config
|
||||
var DefaultConfig *Config
|
||||
|
||||
// New initialize Now based on configuration
|
||||
func (config *Config) With(t time.Time) *Now {
|
||||
return &Now{Time: t, Config: config}
|
||||
}
|
||||
|
||||
// Parse parse string to time based on configuration
|
||||
func (config *Config) Parse(strs ...string) (time.Time, error) {
|
||||
if config.TimeLocation == nil {
|
||||
return config.With(time.Now()).Parse(strs...)
|
||||
} else {
|
||||
return config.With(time.Now().In(config.TimeLocation)).Parse(strs...)
|
||||
}
|
||||
}
|
||||
|
||||
// MustParse must parse string to time or will panic
|
||||
func (config *Config) MustParse(strs ...string) time.Time {
|
||||
if config.TimeLocation == nil {
|
||||
return config.With(time.Now()).MustParse(strs...)
|
||||
} else {
|
||||
return config.With(time.Now().In(config.TimeLocation)).MustParse(strs...)
|
||||
}
|
||||
}
|
||||
|
||||
// Now now struct
|
||||
type Now struct {
|
||||
time.Time
|
||||
*Config
|
||||
}
|
||||
|
||||
// With initialize Now with time
|
||||
func With(t time.Time) *Now {
|
||||
config := DefaultConfig
|
||||
if config == nil {
|
||||
config = &Config{
|
||||
WeekStartDay: WeekStartDay,
|
||||
TimeFormats: TimeFormats,
|
||||
}
|
||||
}
|
||||
|
||||
return &Now{Time: t, Config: config}
|
||||
}
|
||||
|
||||
// New initialize Now with time
|
||||
func New(t time.Time) *Now {
|
||||
return With(t)
|
||||
}
|
||||
|
||||
// BeginningOfMinute beginning of minute
|
||||
func BeginningOfMinute() time.Time {
|
||||
return With(time.Now()).BeginningOfMinute()
|
||||
}
|
||||
|
||||
// BeginningOfHour beginning of hour
|
||||
func BeginningOfHour() time.Time {
|
||||
return With(time.Now()).BeginningOfHour()
|
||||
}
|
||||
|
||||
// BeginningOfDay beginning of day
|
||||
func BeginningOfDay() time.Time {
|
||||
return With(time.Now()).BeginningOfDay()
|
||||
}
|
||||
|
||||
// BeginningOfWeek beginning of week
|
||||
func BeginningOfWeek() time.Time {
|
||||
return With(time.Now()).BeginningOfWeek()
|
||||
}
|
||||
|
||||
// BeginningOfMonth beginning of month
|
||||
func BeginningOfMonth() time.Time {
|
||||
return With(time.Now()).BeginningOfMonth()
|
||||
}
|
||||
|
||||
// BeginningOfQuarter beginning of quarter
|
||||
func BeginningOfQuarter() time.Time {
|
||||
return With(time.Now()).BeginningOfQuarter()
|
||||
}
|
||||
|
||||
// BeginningOfYear beginning of year
|
||||
func BeginningOfYear() time.Time {
|
||||
return With(time.Now()).BeginningOfYear()
|
||||
}
|
||||
|
||||
// EndOfMinute end of minute
|
||||
func EndOfMinute() time.Time {
|
||||
return With(time.Now()).EndOfMinute()
|
||||
}
|
||||
|
||||
// EndOfHour end of hour
|
||||
func EndOfHour() time.Time {
|
||||
return With(time.Now()).EndOfHour()
|
||||
}
|
||||
|
||||
// EndOfDay end of day
|
||||
func EndOfDay() time.Time {
|
||||
return With(time.Now()).EndOfDay()
|
||||
}
|
||||
|
||||
// EndOfWeek end of week
|
||||
func EndOfWeek() time.Time {
|
||||
return With(time.Now()).EndOfWeek()
|
||||
}
|
||||
|
||||
// EndOfMonth end of month
|
||||
func EndOfMonth() time.Time {
|
||||
return With(time.Now()).EndOfMonth()
|
||||
}
|
||||
|
||||
// EndOfQuarter end of quarter
|
||||
func EndOfQuarter() time.Time {
|
||||
return With(time.Now()).EndOfQuarter()
|
||||
}
|
||||
|
||||
// EndOfYear end of year
|
||||
func EndOfYear() time.Time {
|
||||
return With(time.Now()).EndOfYear()
|
||||
}
|
||||
|
||||
// Monday monday
|
||||
|
||||
func Monday(strs ...string) time.Time {
|
||||
return With(time.Now()).Monday(strs...)
|
||||
}
|
||||
|
||||
// Sunday sunday
|
||||
func Sunday(strs ...string) time.Time {
|
||||
return With(time.Now()).Sunday(strs...)
|
||||
}
|
||||
|
||||
// EndOfSunday end of sunday
|
||||
func EndOfSunday() time.Time {
|
||||
return With(time.Now()).EndOfSunday()
|
||||
}
|
||||
|
||||
// Quarter returns the yearly quarter
|
||||
func Quarter() uint {
|
||||
return With(time.Now()).Quarter()
|
||||
}
|
||||
|
||||
// Parse parse string to time
|
||||
func Parse(strs ...string) (time.Time, error) {
|
||||
return With(time.Now()).Parse(strs...)
|
||||
}
|
||||
|
||||
// ParseInLocation parse string to time in location
|
||||
func ParseInLocation(loc *time.Location, strs ...string) (time.Time, error) {
|
||||
return With(time.Now().In(loc)).Parse(strs...)
|
||||
}
|
||||
|
||||
// MustParse must parse string to time or will panic
|
||||
func MustParse(strs ...string) time.Time {
|
||||
return With(time.Now()).MustParse(strs...)
|
||||
}
|
||||
|
||||
// MustParseInLocation must parse string to time in location or will panic
|
||||
func MustParseInLocation(loc *time.Location, strs ...string) time.Time {
|
||||
return With(time.Now().In(loc)).MustParse(strs...)
|
||||
}
|
||||
|
||||
// Between check now between the begin, end time or not
|
||||
func Between(time1, time2 string) bool {
|
||||
return With(time.Now()).Between(time1, time2)
|
||||
}
|
||||
|
|
@ -0,0 +1,245 @@
|
|||
package now
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BeginningOfMinute beginning of minute
|
||||
func (now *Now) BeginningOfMinute() time.Time {
|
||||
return now.Truncate(time.Minute)
|
||||
}
|
||||
|
||||
// BeginningOfHour beginning of hour
|
||||
func (now *Now) BeginningOfHour() time.Time {
|
||||
y, m, d := now.Date()
|
||||
return time.Date(y, m, d, now.Time.Hour(), 0, 0, 0, now.Time.Location())
|
||||
}
|
||||
|
||||
// BeginningOfDay beginning of day
|
||||
func (now *Now) BeginningOfDay() time.Time {
|
||||
y, m, d := now.Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, now.Time.Location())
|
||||
}
|
||||
|
||||
// BeginningOfWeek beginning of week
|
||||
func (now *Now) BeginningOfWeek() time.Time {
|
||||
t := now.BeginningOfDay()
|
||||
weekday := int(t.Weekday())
|
||||
|
||||
if now.WeekStartDay != time.Sunday {
|
||||
weekStartDayInt := int(now.WeekStartDay)
|
||||
|
||||
if weekday < weekStartDayInt {
|
||||
weekday = weekday + 7 - weekStartDayInt
|
||||
} else {
|
||||
weekday = weekday - weekStartDayInt
|
||||
}
|
||||
}
|
||||
return t.AddDate(0, 0, -weekday)
|
||||
}
|
||||
|
||||
// BeginningOfMonth beginning of month
|
||||
func (now *Now) BeginningOfMonth() time.Time {
|
||||
y, m, _ := now.Date()
|
||||
return time.Date(y, m, 1, 0, 0, 0, 0, now.Location())
|
||||
}
|
||||
|
||||
// BeginningOfQuarter beginning of quarter
|
||||
func (now *Now) BeginningOfQuarter() time.Time {
|
||||
month := now.BeginningOfMonth()
|
||||
offset := (int(month.Month()) - 1) % 3
|
||||
return month.AddDate(0, -offset, 0)
|
||||
}
|
||||
|
||||
// BeginningOfHalf beginning of half year
|
||||
func (now *Now) BeginningOfHalf() time.Time {
|
||||
month := now.BeginningOfMonth()
|
||||
offset := (int(month.Month()) - 1) % 6
|
||||
return month.AddDate(0, -offset, 0)
|
||||
}
|
||||
|
||||
// BeginningOfYear BeginningOfYear beginning of year
|
||||
func (now *Now) BeginningOfYear() time.Time {
|
||||
y, _, _ := now.Date()
|
||||
return time.Date(y, time.January, 1, 0, 0, 0, 0, now.Location())
|
||||
}
|
||||
|
||||
// EndOfMinute end of minute
|
||||
func (now *Now) EndOfMinute() time.Time {
|
||||
return now.BeginningOfMinute().Add(time.Minute - time.Nanosecond)
|
||||
}
|
||||
|
||||
// EndOfHour end of hour
|
||||
func (now *Now) EndOfHour() time.Time {
|
||||
return now.BeginningOfHour().Add(time.Hour - time.Nanosecond)
|
||||
}
|
||||
|
||||
// EndOfDay end of day
|
||||
func (now *Now) EndOfDay() time.Time {
|
||||
y, m, d := now.Date()
|
||||
return time.Date(y, m, d, 23, 59, 59, int(time.Second-time.Nanosecond), now.Location())
|
||||
}
|
||||
|
||||
// EndOfWeek end of week
|
||||
func (now *Now) EndOfWeek() time.Time {
|
||||
return now.BeginningOfWeek().AddDate(0, 0, 7).Add(-time.Nanosecond)
|
||||
}
|
||||
|
||||
// EndOfMonth end of month
|
||||
func (now *Now) EndOfMonth() time.Time {
|
||||
return now.BeginningOfMonth().AddDate(0, 1, 0).Add(-time.Nanosecond)
|
||||
}
|
||||
|
||||
// EndOfQuarter end of quarter
|
||||
func (now *Now) EndOfQuarter() time.Time {
|
||||
return now.BeginningOfQuarter().AddDate(0, 3, 0).Add(-time.Nanosecond)
|
||||
}
|
||||
|
||||
// EndOfHalf end of half year
|
||||
func (now *Now) EndOfHalf() time.Time {
|
||||
return now.BeginningOfHalf().AddDate(0, 6, 0).Add(-time.Nanosecond)
|
||||
}
|
||||
|
||||
// EndOfYear end of year
|
||||
func (now *Now) EndOfYear() time.Time {
|
||||
return now.BeginningOfYear().AddDate(1, 0, 0).Add(-time.Nanosecond)
|
||||
}
|
||||
|
||||
// Monday monday
|
||||
/*
|
||||
func (now *Now) Monday() time.Time {
|
||||
t := now.BeginningOfDay()
|
||||
weekday := int(t.Weekday())
|
||||
if weekday == 0 {
|
||||
weekday = 7
|
||||
}
|
||||
return t.AddDate(0, 0, -weekday+1)
|
||||
}
|
||||
*/
|
||||
|
||||
func (now *Now) Monday(strs ...string) time.Time {
|
||||
var parseTime time.Time
|
||||
var err error
|
||||
if len(strs) > 0 {
|
||||
parseTime, err = now.Parse(strs...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
parseTime = now.BeginningOfDay()
|
||||
}
|
||||
weekday := int(parseTime.Weekday())
|
||||
if weekday == 0 {
|
||||
weekday = 7
|
||||
}
|
||||
return parseTime.AddDate(0, 0, -weekday+1)
|
||||
}
|
||||
|
||||
func (now *Now) Sunday(strs ...string) time.Time {
|
||||
var parseTime time.Time
|
||||
var err error
|
||||
if len(strs) > 0 {
|
||||
parseTime, err = now.Parse(strs...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
parseTime = now.BeginningOfDay()
|
||||
}
|
||||
weekday := int(parseTime.Weekday())
|
||||
if weekday == 0 {
|
||||
weekday = 7
|
||||
}
|
||||
return parseTime.AddDate(0, 0, (7 - weekday))
|
||||
}
|
||||
|
||||
// EndOfSunday end of sunday
|
||||
func (now *Now) EndOfSunday() time.Time {
|
||||
return New(now.Sunday()).EndOfDay()
|
||||
}
|
||||
|
||||
// Quarter returns the yearly quarter
|
||||
func (now *Now) Quarter() uint {
|
||||
return (uint(now.Month())-1)/3 + 1
|
||||
}
|
||||
|
||||
func (now *Now) parseWithFormat(str string, location *time.Location) (t time.Time, err error) {
|
||||
for _, format := range now.TimeFormats {
|
||||
t, err = time.ParseInLocation(format, str, location)
|
||||
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = errors.New("Can't parse string as time: " + str)
|
||||
return
|
||||
}
|
||||
|
||||
var hasTimeRegexp = regexp.MustCompile(`(\s+|^\s*|T)\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))(\s*$|[Z+-])`) // match 15:04:05, 15:04:05.000, 15:04:05.000000 15, 2017-01-01 15:04, 2021-07-20T00:59:10Z, 2021-07-20T00:59:10+08:00, 2021-07-20T00:00:10-07:00 etc
|
||||
var onlyTimeRegexp = regexp.MustCompile(`^\s*\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))\s*$`) // match 15:04:05, 15, 15:04:05.000, 15:04:05.000000, etc
|
||||
|
||||
// Parse parse string to time
|
||||
func (now *Now) Parse(strs ...string) (t time.Time, err error) {
|
||||
var (
|
||||
setCurrentTime bool
|
||||
parseTime []int
|
||||
currentLocation = now.Location()
|
||||
onlyTimeInStr = true
|
||||
currentTime = formatTimeToList(now.Time)
|
||||
)
|
||||
|
||||
for _, str := range strs {
|
||||
hasTimeInStr := hasTimeRegexp.MatchString(str) // match 15:04:05, 15
|
||||
onlyTimeInStr = hasTimeInStr && onlyTimeInStr && onlyTimeRegexp.MatchString(str)
|
||||
if t, err = now.parseWithFormat(str, currentLocation); err == nil {
|
||||
location := t.Location()
|
||||
parseTime = formatTimeToList(t)
|
||||
|
||||
for i, v := range parseTime {
|
||||
// Don't reset hour, minute, second if current time str including time
|
||||
if hasTimeInStr && i <= 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If value is zero, replace it with current time
|
||||
if v == 0 {
|
||||
if setCurrentTime {
|
||||
parseTime[i] = currentTime[i]
|
||||
}
|
||||
} else {
|
||||
setCurrentTime = true
|
||||
}
|
||||
|
||||
// if current time only includes time, should change day, month to current time
|
||||
if onlyTimeInStr {
|
||||
if i == 4 || i == 5 {
|
||||
parseTime[i] = currentTime[i]
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t = time.Date(parseTime[6], time.Month(parseTime[5]), parseTime[4], parseTime[3], parseTime[2], parseTime[1], parseTime[0], location)
|
||||
currentTime = formatTimeToList(t)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MustParse must parse string to time or it will panic
|
||||
func (now *Now) MustParse(strs ...string) (t time.Time) {
|
||||
t, err := now.Parse(strs...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// Between check time between the begin, end time or not
|
||||
func (now *Now) Between(begin, end string) bool {
|
||||
beginTime := now.MustParse(begin)
|
||||
endTime := now.MustParse(end)
|
||||
return now.After(beginTime) && now.Before(endTime)
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
package now
|
||||
|
||||
import "time"
|
||||
|
||||
func formatTimeToList(t time.Time) []int {
|
||||
hour, min, sec := t.Clock()
|
||||
year, month, day := t.Date()
|
||||
return []int{t.Nanosecond(), sec, min, hour, day, int(month), year}
|
||||
}
|
||||
|
|
@ -0,0 +1 @@
|
|||
.DS_Store
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
Copyright (c) 2013 John Barton
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
# GoDotEnv  [](https://goreportcard.com/report/github.com/joho/godotenv)
|
||||
|
||||
A Go (golang) port of the Ruby dotenv project (which loads env vars from a .env file)
|
||||
|
||||
From the original Library:
|
||||
|
||||
> Storing configuration in the environment is one of the tenets of a twelve-factor app. Anything that is likely to change between deployment environments–such as resource handles for databases or credentials for external services–should be extracted from the code into environment variables.
|
||||
>
|
||||
> But it is not always practical to set environment variables on development machines or continuous integration servers where multiple projects are run. Dotenv load variables from a .env file into ENV when the environment is bootstrapped.
|
||||
|
||||
It can be used as a library (for loading in env for your own daemons etc) or as a bin command.
|
||||
|
||||
There is test coverage and CI for both linuxish and windows environments, but I make no guarantees about the bin version working on windows.
|
||||
|
||||
## Installation
|
||||
|
||||
As a library
|
||||
|
||||
```shell
|
||||
go get github.com/joho/godotenv
|
||||
```
|
||||
|
||||
or if you want to use it as a bin command
|
||||
```shell
|
||||
go get github.com/joho/godotenv/cmd/godotenv
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Add your application configuration to your `.env` file in the root of your project:
|
||||
|
||||
```shell
|
||||
S3_BUCKET=YOURS3BUCKET
|
||||
SECRET_KEY=YOURSECRETKEYGOESHERE
|
||||
```
|
||||
|
||||
Then in your Go app you can do something like
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/joho/godotenv"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
log.Fatal("Error loading .env file")
|
||||
}
|
||||
|
||||
s3Bucket := os.Getenv("S3_BUCKET")
|
||||
secretKey := os.Getenv("SECRET_KEY")
|
||||
|
||||
// now do something with s3 or whatever
|
||||
}
|
||||
```
|
||||
|
||||
If you're even lazier than that, you can just take advantage of the autoload package which will read in `.env` on import
|
||||
|
||||
```go
|
||||
import _ "github.com/joho/godotenv/autoload"
|
||||
```
|
||||
|
||||
While `.env` in the project root is the default, you don't have to be constrained, both examples below are 100% legit
|
||||
|
||||
```go
|
||||
_ = godotenv.Load("somerandomfile")
|
||||
_ = godotenv.Load("filenumberone.env", "filenumbertwo.env")
|
||||
```
|
||||
|
||||
If you want to be really fancy with your env file you can do comments and exports (below is a valid env file)
|
||||
|
||||
```shell
|
||||
# I am a comment and that is OK
|
||||
SOME_VAR=someval
|
||||
FOO=BAR # comments at line end are OK too
|
||||
export BAR=BAZ
|
||||
```
|
||||
|
||||
Or finally you can do YAML(ish) style
|
||||
|
||||
```yaml
|
||||
FOO: bar
|
||||
BAR: baz
|
||||
```
|
||||
|
||||
as a final aside, if you don't want godotenv munging your env you can just get a map back instead
|
||||
|
||||
```go
|
||||
var myEnv map[string]string
|
||||
myEnv, err := godotenv.Read()
|
||||
|
||||
s3Bucket := myEnv["S3_BUCKET"]
|
||||
```
|
||||
|
||||
... or from an `io.Reader` instead of a local file
|
||||
|
||||
```go
|
||||
reader := getRemoteFile()
|
||||
myEnv, err := godotenv.Parse(reader)
|
||||
```
|
||||
|
||||
... or from a `string` if you so desire
|
||||
|
||||
```go
|
||||
content := getRemoteFileContent()
|
||||
myEnv, err := godotenv.Unmarshal(content)
|
||||
```
|
||||
|
||||
### Precedence & Conventions
|
||||
|
||||
Existing envs take precedence of envs that are loaded later.
|
||||
|
||||
The [convention](https://github.com/bkeepers/dotenv#what-other-env-files-can-i-use)
|
||||
for managing multiple environments (i.e. development, test, production)
|
||||
is to create an env named `{YOURAPP}_ENV` and load envs in this order:
|
||||
|
||||
```go
|
||||
env := os.Getenv("FOO_ENV")
|
||||
if "" == env {
|
||||
env = "development"
|
||||
}
|
||||
|
||||
godotenv.Load(".env." + env + ".local")
|
||||
if "test" != env {
|
||||
godotenv.Load(".env.local")
|
||||
}
|
||||
godotenv.Load(".env." + env)
|
||||
godotenv.Load() // The Original .env
|
||||
```
|
||||
|
||||
If you need to, you can also use `godotenv.Overload()` to defy this convention
|
||||
and overwrite existing envs instead of only supplanting them. Use with caution.
|
||||
|
||||
### Command Mode
|
||||
|
||||
Assuming you've installed the command as above and you've got `$GOPATH/bin` in your `$PATH`
|
||||
|
||||
```
|
||||
godotenv -f /some/path/to/.env some_command with some args
|
||||
```
|
||||
|
||||
If you don't specify `-f` it will fall back on the default of loading `.env` in `PWD`
|
||||
|
||||
### Writing Env Files
|
||||
|
||||
Godotenv can also write a map representing the environment to a correctly-formatted and escaped file
|
||||
|
||||
```go
|
||||
env, err := godotenv.Unmarshal("KEY=value")
|
||||
err := godotenv.Write(env, "./.env")
|
||||
```
|
||||
|
||||
... or to a string
|
||||
|
||||
```go
|
||||
env, err := godotenv.Unmarshal("KEY=value")
|
||||
content, err := godotenv.Marshal(env)
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are most welcome! The parser itself is pretty stupidly naive and I wouldn't be surprised if it breaks with edge cases.
|
||||
|
||||
*code changes without tests will not be accepted*
|
||||
|
||||
1. Fork it
|
||||
2. Create your feature branch (`git checkout -b my-new-feature`)
|
||||
3. Commit your changes (`git commit -am 'Added some feature'`)
|
||||
4. Push to the branch (`git push origin my-new-feature`)
|
||||
5. Create new Pull Request
|
||||
|
||||
## Releases
|
||||
|
||||
Releases should follow [Semver](http://semver.org/) though the first couple of releases are `v1` and `v1.1`.
|
||||
|
||||
Use [annotated tags for all releases](https://github.com/joho/godotenv/issues/30). Example `git tag -a v1.2.1`
|
||||
|
||||
## CI
|
||||
|
||||
Linux: [](https://travis-ci.org/joho/godotenv) Windows: [](https://ci.appveyor.com/project/joho/godotenv)
|
||||
|
||||
## Who?
|
||||
|
||||
The original library [dotenv](https://github.com/bkeepers/dotenv) was written by [Brandon Keepers](http://opensoul.org/), and this port was done by [John Barton](https://johnbarton.co/) based off the tests/fixtures in the original library.
|
||||
|
|
@ -0,0 +1,363 @@
|
|||
// Package godotenv is a go port of the ruby dotenv library (https://github.com/bkeepers/dotenv)
|
||||
//
|
||||
// Examples/readme can be found on the github page at https://github.com/joho/godotenv
|
||||
//
|
||||
// The TL;DR is that you make a .env file that looks something like
|
||||
//
|
||||
// SOME_ENV_VAR=somevalue
|
||||
//
|
||||
// and then in your go code you can call
|
||||
//
|
||||
// godotenv.Load()
|
||||
//
|
||||
// and all the env vars declared in .env will be available through os.Getenv("SOME_ENV_VAR")
|
||||
package godotenv
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const doubleQuoteSpecialChars = "\\\n\r\"!$`"
|
||||
|
||||
// Load will read your env file(s) and load them into ENV for this process.
|
||||
//
|
||||
// Call this function as close as possible to the start of your program (ideally in main)
|
||||
//
|
||||
// If you call Load without any args it will default to loading .env in the current path
|
||||
//
|
||||
// You can otherwise tell it which files to load (there can be more than one) like
|
||||
//
|
||||
// godotenv.Load("fileone", "filetwo")
|
||||
//
|
||||
// It's important to note that it WILL NOT OVERRIDE an env variable that already exists - consider the .env file to set dev vars or sensible defaults
|
||||
func Load(filenames ...string) (err error) {
|
||||
filenames = filenamesOrDefault(filenames)
|
||||
|
||||
for _, filename := range filenames {
|
||||
err = loadFile(filename, false)
|
||||
if err != nil {
|
||||
return // return early on a spazout
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Overload will read your env file(s) and load them into ENV for this process.
|
||||
//
|
||||
// Call this function as close as possible to the start of your program (ideally in main)
|
||||
//
|
||||
// If you call Overload without any args it will default to loading .env in the current path
|
||||
//
|
||||
// You can otherwise tell it which files to load (there can be more than one) like
|
||||
//
|
||||
// godotenv.Overload("fileone", "filetwo")
|
||||
//
|
||||
// It's important to note this WILL OVERRIDE an env variable that already exists - consider the .env file to forcefilly set all vars.
|
||||
func Overload(filenames ...string) (err error) {
|
||||
filenames = filenamesOrDefault(filenames)
|
||||
|
||||
for _, filename := range filenames {
|
||||
err = loadFile(filename, true)
|
||||
if err != nil {
|
||||
return // return early on a spazout
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Read all env (with same file loading semantics as Load) but return values as
|
||||
// a map rather than automatically writing values into env
|
||||
func Read(filenames ...string) (envMap map[string]string, err error) {
|
||||
filenames = filenamesOrDefault(filenames)
|
||||
envMap = make(map[string]string)
|
||||
|
||||
for _, filename := range filenames {
|
||||
individualEnvMap, individualErr := readFile(filename)
|
||||
|
||||
if individualErr != nil {
|
||||
err = individualErr
|
||||
return // return early on a spazout
|
||||
}
|
||||
|
||||
for key, value := range individualEnvMap {
|
||||
envMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Parse reads an env file from io.Reader, returning a map of keys and values.
|
||||
func Parse(r io.Reader) (envMap map[string]string, err error) {
|
||||
envMap = make(map[string]string)
|
||||
|
||||
var lines []string
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, fullLine := range lines {
|
||||
if !isIgnoredLine(fullLine) {
|
||||
var key, value string
|
||||
key, value, err = parseLine(fullLine, envMap)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
envMap[key] = value
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//Unmarshal reads an env file from a string, returning a map of keys and values.
|
||||
func Unmarshal(str string) (envMap map[string]string, err error) {
|
||||
return Parse(strings.NewReader(str))
|
||||
}
|
||||
|
||||
// Exec loads env vars from the specified filenames (empty map falls back to default)
|
||||
// then executes the cmd specified.
|
||||
//
|
||||
// Simply hooks up os.Stdin/err/out to the command and calls Run()
|
||||
//
|
||||
// If you want more fine grained control over your command it's recommended
|
||||
// that you use `Load()` or `Read()` and the `os/exec` package yourself.
|
||||
func Exec(filenames []string, cmd string, cmdArgs []string) error {
|
||||
Load(filenames...)
|
||||
|
||||
command := exec.Command(cmd, cmdArgs...)
|
||||
command.Stdin = os.Stdin
|
||||
command.Stdout = os.Stdout
|
||||
command.Stderr = os.Stderr
|
||||
return command.Run()
|
||||
}
|
||||
|
||||
// Write serializes the given environment and writes it to a file
|
||||
func Write(envMap map[string]string, filename string) error {
|
||||
content, err := Marshal(envMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
_, err = file.WriteString(content + "\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.Sync()
|
||||
return err
|
||||
}
|
||||
|
||||
// Marshal outputs the given environment as a dotenv-formatted environment file.
|
||||
// Each line is in the format: KEY="VALUE" where VALUE is backslash-escaped.
|
||||
func Marshal(envMap map[string]string) (string, error) {
|
||||
lines := make([]string, 0, len(envMap))
|
||||
for k, v := range envMap {
|
||||
if d, err := strconv.Atoi(v); err == nil {
|
||||
lines = append(lines, fmt.Sprintf(`%s=%d`, k, d))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf(`%s="%s"`, k, doubleQuoteEscape(v)))
|
||||
}
|
||||
}
|
||||
sort.Strings(lines)
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
func filenamesOrDefault(filenames []string) []string {
|
||||
if len(filenames) == 0 {
|
||||
return []string{".env"}
|
||||
}
|
||||
return filenames
|
||||
}
|
||||
|
||||
func loadFile(filename string, overload bool) error {
|
||||
envMap, err := readFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
currentEnv := map[string]bool{}
|
||||
rawEnv := os.Environ()
|
||||
for _, rawEnvLine := range rawEnv {
|
||||
key := strings.Split(rawEnvLine, "=")[0]
|
||||
currentEnv[key] = true
|
||||
}
|
||||
|
||||
for key, value := range envMap {
|
||||
if !currentEnv[key] || overload {
|
||||
os.Setenv(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readFile(filename string) (envMap map[string]string, err error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return Parse(file)
|
||||
}
|
||||
|
||||
var exportRegex = regexp.MustCompile(`^\s*(?:export\s+)?(.*?)\s*$`)
|
||||
|
||||
func parseLine(line string, envMap map[string]string) (key string, value string, err error) {
|
||||
if len(line) == 0 {
|
||||
err = errors.New("zero length string")
|
||||
return
|
||||
}
|
||||
|
||||
// ditch the comments (but keep quoted hashes)
|
||||
if strings.Contains(line, "#") {
|
||||
segmentsBetweenHashes := strings.Split(line, "#")
|
||||
quotesAreOpen := false
|
||||
var segmentsToKeep []string
|
||||
for _, segment := range segmentsBetweenHashes {
|
||||
if strings.Count(segment, "\"") == 1 || strings.Count(segment, "'") == 1 {
|
||||
if quotesAreOpen {
|
||||
quotesAreOpen = false
|
||||
segmentsToKeep = append(segmentsToKeep, segment)
|
||||
} else {
|
||||
quotesAreOpen = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(segmentsToKeep) == 0 || quotesAreOpen {
|
||||
segmentsToKeep = append(segmentsToKeep, segment)
|
||||
}
|
||||
}
|
||||
|
||||
line = strings.Join(segmentsToKeep, "#")
|
||||
}
|
||||
|
||||
firstEquals := strings.Index(line, "=")
|
||||
firstColon := strings.Index(line, ":")
|
||||
splitString := strings.SplitN(line, "=", 2)
|
||||
if firstColon != -1 && (firstColon < firstEquals || firstEquals == -1) {
|
||||
//this is a yaml-style line
|
||||
splitString = strings.SplitN(line, ":", 2)
|
||||
}
|
||||
|
||||
if len(splitString) != 2 {
|
||||
err = errors.New("Can't separate key from value")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key = splitString[0]
|
||||
if strings.HasPrefix(key, "export") {
|
||||
key = strings.TrimPrefix(key, "export")
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
key = exportRegex.ReplaceAllString(splitString[0], "$1")
|
||||
|
||||
// Parse the value
|
||||
value = parseValue(splitString[1], envMap)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
singleQuotesRegex = regexp.MustCompile(`\A'(.*)'\z`)
|
||||
doubleQuotesRegex = regexp.MustCompile(`\A"(.*)"\z`)
|
||||
escapeRegex = regexp.MustCompile(`\\.`)
|
||||
unescapeCharsRegex = regexp.MustCompile(`\\([^$])`)
|
||||
)
|
||||
|
||||
func parseValue(value string, envMap map[string]string) string {
|
||||
|
||||
// trim
|
||||
value = strings.Trim(value, " ")
|
||||
|
||||
// check if we've got quoted values or possible escapes
|
||||
if len(value) > 1 {
|
||||
singleQuotes := singleQuotesRegex.FindStringSubmatch(value)
|
||||
|
||||
doubleQuotes := doubleQuotesRegex.FindStringSubmatch(value)
|
||||
|
||||
if singleQuotes != nil || doubleQuotes != nil {
|
||||
// pull the quotes off the edges
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
|
||||
if doubleQuotes != nil {
|
||||
// expand newlines
|
||||
value = escapeRegex.ReplaceAllStringFunc(value, func(match string) string {
|
||||
c := strings.TrimPrefix(match, `\`)
|
||||
switch c {
|
||||
case "n":
|
||||
return "\n"
|
||||
case "r":
|
||||
return "\r"
|
||||
default:
|
||||
return match
|
||||
}
|
||||
})
|
||||
// unescape characters
|
||||
value = unescapeCharsRegex.ReplaceAllString(value, "$1")
|
||||
}
|
||||
|
||||
if singleQuotes == nil {
|
||||
value = expandVariables(value, envMap)
|
||||
}
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
var expandVarRegex = regexp.MustCompile(`(\\)?(\$)(\()?\{?([A-Z0-9_]+)?\}?`)
|
||||
|
||||
func expandVariables(v string, m map[string]string) string {
|
||||
return expandVarRegex.ReplaceAllStringFunc(v, func(s string) string {
|
||||
submatch := expandVarRegex.FindStringSubmatch(s)
|
||||
|
||||
if submatch == nil {
|
||||
return s
|
||||
}
|
||||
if submatch[1] == "\\" || submatch[2] == "(" {
|
||||
return submatch[0][1:]
|
||||
} else if submatch[4] != "" {
|
||||
return m[submatch[4]]
|
||||
}
|
||||
return s
|
||||
})
|
||||
}
|
||||
|
||||
func isIgnoredLine(line string) bool {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
return len(trimmedLine) == 0 || strings.HasPrefix(trimmedLine, "#")
|
||||
}
|
||||
|
||||
func doubleQuoteEscape(line string) string {
|
||||
for _, c := range doubleQuoteSpecialChars {
|
||||
toReplace := "\\" + string(c)
|
||||
if c == '\n' {
|
||||
toReplace = `\n`
|
||||
}
|
||||
if c == '\r' {
|
||||
toReplace = `\r`
|
||||
}
|
||||
line = strings.Replace(line, string(c), toReplace, -1)
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"extends": [
|
||||
"config:base"
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
TODO*
|
||||
documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
vendor
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# GORM MySQL Driver
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import (
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// https://github.com/go-sql-driver/mysql
|
||||
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
import (
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var datetimePrecision = 2
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{
|
||||
DSN: "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local", // data source name, refer https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
DefaultStringSize: 256, // add default size for string fields, by default, will use db type `longtext` for fields without size, not a primary key, no index defined and don't have default values
|
||||
DisableDatetimePrecision: true, // disable datetime precision support, which not supported before MySQL 5.6
|
||||
DefaultDatetimePrecision: &datetimePrecision, // default datetime precision
|
||||
DontSupportRenameIndex: true, // drop & create index when rename index, rename index not supported before MySQL 5.7, MariaDB
|
||||
DontSupportRenameColumn: true, // use change when rename column, rename rename not supported before MySQL 8, MariaDB
|
||||
SkipInitializeWithVersion: false, // smart configure based on used version
|
||||
}), &gorm.Config{})
|
||||
```
|
||||
|
||||
## Customized Driver
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "example.com/my_mysql_driver"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{
|
||||
DriverName: "my_mysql_driver_name",
|
||||
DSN: "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local", // data source name, refer https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
})
|
||||
```
|
||||
|
||||
Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
|
|
@ -0,0 +1,322 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
const indexSql = `
|
||||
SELECT
|
||||
TABLE_NAME,
|
||||
COLUMN_NAME,
|
||||
INDEX_NAME,
|
||||
NON_UNIQUE
|
||||
FROM
|
||||
information_schema.STATISTICS
|
||||
WHERE
|
||||
TABLE_SCHEMA = ?
|
||||
AND TABLE_NAME = ?
|
||||
ORDER BY
|
||||
INDEX_NAME,
|
||||
SEQ_IN_INDEX`
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
Dialector
|
||||
}
|
||||
|
||||
func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
|
||||
expr := m.Migrator.FullDataTypeOf(field)
|
||||
|
||||
if value, ok := field.TagSettings["COMMENT"]; ok {
|
||||
expr.SQL += " COMMENT " + m.Dialector.Explain("?", value)
|
||||
}
|
||||
|
||||
return expr
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? MODIFY COLUMN ? ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if !m.Dialector.DontSupportRenameColumn {
|
||||
return m.Migrator.RenameColumn(value, oldName, newName)
|
||||
}
|
||||
|
||||
var field *schema.Field
|
||||
if f := stmt.Schema.LookUpField(oldName); f != nil {
|
||||
oldName = f.DBName
|
||||
field = f
|
||||
}
|
||||
|
||||
if f := stmt.Schema.LookUpField(newName); f != nil {
|
||||
newName = f.DBName
|
||||
field = f
|
||||
}
|
||||
|
||||
if field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? CHANGE ? ? ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName},
|
||||
clause.Column{Name: newName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to look up field with name: %s", newName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
if !m.Dialector.DontSupportRenameIndex {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? RENAME INDEX ? TO ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
err := m.DropIndex(value, oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if idx := stmt.Schema.LookIndex(newName); idx == nil {
|
||||
if idx = stmt.Schema.LookIndex(oldName); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ? ON ??"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
}
|
||||
|
||||
return m.CreateIndex(value, newName)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
values = m.ReorderModels(values, false)
|
||||
return m.DB.Connection(func(tx *gorm.DB) error {
|
||||
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Exec("SET FOREIGN_KEY_CHECKS = 1;").Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if chk != nil {
|
||||
return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}).Error
|
||||
}
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes column types return columnTypes,error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var (
|
||||
currentDatabase, table = m.CurrentSchema(stmt, stmt.Table)
|
||||
columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
|
||||
rows, err = m.DB.Session(&gorm.Session{}).Table(table).Limit(1).Rows()
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawColumnTypes, err := rows.ColumnTypes()
|
||||
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !m.DisableDatetimePrecision {
|
||||
columnTypeSQL += ", datetime_precision "
|
||||
}
|
||||
columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION"
|
||||
|
||||
columns, rowErr := m.DB.Raw(columnTypeSQL, currentDatabase, table).Rows()
|
||||
if rowErr != nil {
|
||||
return rowErr
|
||||
}
|
||||
|
||||
defer columns.Close()
|
||||
|
||||
for columns.Next() {
|
||||
var (
|
||||
column migrator.ColumnType
|
||||
datetimePrecision sql.NullInt64
|
||||
extraValue sql.NullString
|
||||
columnKey sql.NullString
|
||||
values = []interface{}{
|
||||
&column.NameValue, &column.DefaultValueValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.ColumnTypeValue, &columnKey, &extraValue, &column.CommentValue, &column.DecimalSizeValue, &column.ScaleValue,
|
||||
}
|
||||
)
|
||||
|
||||
if !m.DisableDatetimePrecision {
|
||||
values = append(values, &datetimePrecision)
|
||||
}
|
||||
|
||||
if scanErr := columns.Scan(values...); scanErr != nil {
|
||||
return scanErr
|
||||
}
|
||||
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: false, Valid: true}
|
||||
column.UniqueValue = sql.NullBool{Bool: false, Valid: true}
|
||||
switch columnKey.String {
|
||||
case "PRI":
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
case "UNI":
|
||||
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
|
||||
if strings.Contains(extraValue.String, "auto_increment") {
|
||||
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
|
||||
column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'")
|
||||
if m.Dialector.DontSupportNullAsDefaultValue {
|
||||
// rewrite mariadb default value like other version
|
||||
if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" {
|
||||
column.DefaultValueValue.Valid = false
|
||||
column.DefaultValueValue.String = ""
|
||||
}
|
||||
}
|
||||
|
||||
if datetimePrecision.Valid {
|
||||
column.DecimalSizeValue = datetimePrecision
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
if c.Name() == column.NameValue.String {
|
||||
column.SQLColumnType = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
columnTypes = append(columnTypes, column)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return columnTypes, err
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
baseName := m.Migrator.CurrentDatabase()
|
||||
m.DB.Raw(
|
||||
"SELECT SCHEMA_NAME from Information_schema.SCHEMATA where SCHEMA_NAME LIKE ? ORDER BY SCHEMA_NAME=? DESC,SCHEMA_NAME limit 1",
|
||||
baseName+"%", baseName).Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
|
||||
Scan(&tableList).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
|
||||
indexes := make([]gorm.Index, 0)
|
||||
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
|
||||
result := make([]*Index, 0)
|
||||
schema, table := m.CurrentSchema(stmt, stmt.Table)
|
||||
scanErr := m.DB.Raw(indexSql, schema, table).Scan(&result).Error
|
||||
if scanErr != nil {
|
||||
return scanErr
|
||||
}
|
||||
indexMap := groupByIndexName(result)
|
||||
|
||||
for _, idx := range indexMap {
|
||||
tempIdx := &migrator.Index{
|
||||
TableName: idx[0].TableName,
|
||||
NameValue: idx[0].IndexName,
|
||||
PrimaryKeyValue: sql.NullBool{
|
||||
Bool: idx[0].IndexName == "PRIMARY",
|
||||
Valid: true,
|
||||
},
|
||||
UniqueValue: sql.NullBool{
|
||||
Bool: idx[0].NonUnique == 0,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
for _, x := range idx {
|
||||
tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName)
|
||||
}
|
||||
indexes = append(indexes, tempIdx)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return indexes, err
|
||||
}
|
||||
|
||||
// Index table index info
|
||||
type Index struct {
|
||||
TableName string `gorm:"column:TABLE_NAME"`
|
||||
ColumnName string `gorm:"column:COLUMN_NAME"`
|
||||
IndexName string `gorm:"column:INDEX_NAME"`
|
||||
NonUnique int32 `gorm:"column:NON_UNIQUE"`
|
||||
}
|
||||
|
||||
func groupByIndexName(indexList []*Index) map[string][]*Index {
|
||||
columnIndexMap := make(map[string][]*Index, len(indexList))
|
||||
for _, idx := range indexList {
|
||||
columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx)
|
||||
}
|
||||
return columnIndexMap
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) {
|
||||
if strings.Contains(table, ".") {
|
||||
if tables := strings.Split(table, `.`); len(tables) == 2 {
|
||||
return tables[0], tables[1]
|
||||
}
|
||||
}
|
||||
|
||||
return m.CurrentDatabase(), table
|
||||
}
|
||||
|
|
@ -0,0 +1,407 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DriverName string
|
||||
ServerVersion string
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
SkipInitializeWithVersion bool
|
||||
DefaultStringSize uint
|
||||
DefaultDatetimePrecision *int
|
||||
DisableDatetimePrecision bool
|
||||
DontSupportRenameIndex bool
|
||||
DontSupportRenameColumn bool
|
||||
DontSupportForShareClause bool
|
||||
DontSupportNullAsDefaultValue bool
|
||||
}
|
||||
|
||||
type Dialector struct {
|
||||
*Config
|
||||
}
|
||||
|
||||
var (
|
||||
// CreateClauses create clauses
|
||||
CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||
// QueryClauses query clauses
|
||||
QueryClauses = []string{}
|
||||
// UpdateClauses update clauses
|
||||
UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"}
|
||||
// DeleteClauses delete clauses
|
||||
DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"}
|
||||
|
||||
defaultDatetimePrecision = 3
|
||||
)
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{Config: &Config{DSN: dsn}}
|
||||
}
|
||||
|
||||
func New(config Config) gorm.Dialector {
|
||||
return &Dialector{Config: &config}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
// NowFunc return now func
|
||||
func (dialector Dialector) NowFunc(n int) func() time.Time {
|
||||
return func() time.Time {
|
||||
round := time.Second / time.Duration(math.Pow10(n))
|
||||
return time.Now().Round(round)
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Apply(config *gorm.Config) error {
|
||||
if config.NowFunc == nil {
|
||||
if dialector.DefaultDatetimePrecision == nil {
|
||||
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
|
||||
}
|
||||
|
||||
// while maintaining the readability of the code, separate the business logic from
|
||||
// the general part and leave it to the function to do it here.
|
||||
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: CreateClauses,
|
||||
QueryClauses: QueryClauses,
|
||||
UpdateClauses: UpdateClauses,
|
||||
DeleteClauses: DeleteClauses,
|
||||
})
|
||||
|
||||
if dialector.DriverName == "" {
|
||||
dialector.DriverName = "mysql"
|
||||
}
|
||||
|
||||
if dialector.DefaultDatetimePrecision == nil {
|
||||
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
|
||||
}
|
||||
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !dialector.Config.SkipInitializeWithVersion {
|
||||
err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&dialector.ServerVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.Contains(dialector.ServerVersion, "MariaDB") {
|
||||
dialector.Config.DontSupportRenameIndex = true
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
dialector.Config.DontSupportNullAsDefaultValue = true
|
||||
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
|
||||
dialector.Config.DontSupportRenameIndex = true
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
} else if strings.HasPrefix(dialector.ServerVersion, "5.7.") {
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
} else if strings.HasPrefix(dialector.ServerVersion, "5.") {
|
||||
dialector.Config.DisableDatetimePrecision = true
|
||||
dialector.Config.DontSupportRenameIndex = true
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
// ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key
|
||||
ClauseOnConflict = "ON CONFLICT"
|
||||
// ClauseValues for clause.ClauseBuilder VALUES key
|
||||
ClauseValues = "VALUES"
|
||||
// ClauseValues for clause.ClauseBuilder FOR key
|
||||
ClauseFor = "FOR"
|
||||
)
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
clauseBuilders := map[string]clause.ClauseBuilder{
|
||||
ClauseOnConflict: func(c clause.Clause, builder clause.Builder) {
|
||||
onConflict, ok := c.Expression.(clause.OnConflict)
|
||||
if !ok {
|
||||
c.Build(builder)
|
||||
return
|
||||
}
|
||||
|
||||
builder.WriteString("ON DUPLICATE KEY UPDATE ")
|
||||
if len(onConflict.DoUpdates) == 0 {
|
||||
if s := builder.(*gorm.Statement).Schema; s != nil {
|
||||
var column clause.Column
|
||||
onConflict.DoNothing = false
|
||||
|
||||
if s.PrioritizedPrimaryField != nil {
|
||||
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
|
||||
} else if len(s.DBNames) > 0 {
|
||||
column = clause.Column{Name: s.DBNames[0]}
|
||||
}
|
||||
|
||||
if column.Name != "" {
|
||||
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, assignment := range onConflict.DoUpdates {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(assignment.Column)
|
||||
builder.WriteByte('=')
|
||||
if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" {
|
||||
column.Table = ""
|
||||
builder.WriteString("VALUES(")
|
||||
builder.WriteQuoted(column)
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
builder.AddVar(builder, assignment.Value)
|
||||
}
|
||||
}
|
||||
},
|
||||
ClauseValues: func(c clause.Clause, builder clause.Builder) {
|
||||
if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
|
||||
builder.WriteString("VALUES()")
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
|
||||
if dialector.Config.DontSupportForShareClause {
|
||||
clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) {
|
||||
if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") {
|
||||
builder.WriteString("LOCK IN SHARE MODE")
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
return clauseBuilders
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{
|
||||
Migrator: migrator.Migrator{
|
||||
Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
},
|
||||
},
|
||||
Dialector: dialector,
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
var (
|
||||
underQuoted, selfQuoted bool
|
||||
continuousBacktick int8
|
||||
shiftDelimiter int8
|
||||
)
|
||||
|
||||
for _, v := range []byte(str) {
|
||||
switch v {
|
||||
case '`':
|
||||
continuousBacktick++
|
||||
if continuousBacktick == 2 {
|
||||
writer.WriteString("``")
|
||||
continuousBacktick = 0
|
||||
}
|
||||
case '.':
|
||||
if continuousBacktick > 0 || !selfQuoted {
|
||||
shiftDelimiter = 0
|
||||
underQuoted = false
|
||||
continuousBacktick = 0
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
writer.WriteByte(v)
|
||||
continue
|
||||
default:
|
||||
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
||||
writer.WriteByte('`')
|
||||
underQuoted = true
|
||||
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
||||
continuousBacktick -= 1
|
||||
}
|
||||
}
|
||||
|
||||
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
|
||||
writer.WriteByte(v)
|
||||
}
|
||||
shiftDelimiter++
|
||||
}
|
||||
|
||||
if continuousBacktick > 0 && !selfQuoted {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `'`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "boolean"
|
||||
case schema.Int, schema.Uint:
|
||||
return dialector.getSchemaIntAndUnitType(field)
|
||||
case schema.Float:
|
||||
return dialector.getSchemaFloatType(field)
|
||||
case schema.String:
|
||||
return dialector.getSchemaStringType(field)
|
||||
case schema.Time:
|
||||
return dialector.getSchemaTimeType(field)
|
||||
case schema.Bytes:
|
||||
return dialector.getSchemaBytesType(field)
|
||||
}
|
||||
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaFloatType(field *schema.Field) string {
|
||||
if field.Precision > 0 {
|
||||
return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale)
|
||||
}
|
||||
|
||||
if field.Size <= 32 {
|
||||
return "float"
|
||||
}
|
||||
|
||||
return "double"
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
|
||||
size := field.Size
|
||||
if size == 0 {
|
||||
if dialector.DefaultStringSize > 0 {
|
||||
size = int(dialector.DefaultStringSize)
|
||||
} else {
|
||||
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
|
||||
// TEXT, GEOMETRY or JSON column can't have a default value
|
||||
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
|
||||
size = 191 // utf8mb4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if size >= 65536 && size <= int(math.Pow(2, 24)) {
|
||||
return "mediumtext"
|
||||
}
|
||||
|
||||
if size > int(math.Pow(2, 24)) || size <= 0 {
|
||||
return "longtext"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
|
||||
precision := ""
|
||||
if !dialector.DisableDatetimePrecision && field.Precision == 0 {
|
||||
field.Precision = *dialector.DefaultDatetimePrecision
|
||||
}
|
||||
|
||||
if field.Precision > 0 {
|
||||
precision = fmt.Sprintf("(%d)", field.Precision)
|
||||
}
|
||||
|
||||
if field.NotNull || field.PrimaryKey {
|
||||
return "datetime" + precision
|
||||
}
|
||||
return "datetime" + precision + " NULL"
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
|
||||
if field.Size > 0 && field.Size < 65536 {
|
||||
return fmt.Sprintf("varbinary(%d)", field.Size)
|
||||
}
|
||||
|
||||
if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
|
||||
return "mediumblob"
|
||||
}
|
||||
|
||||
return "longblob"
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
|
||||
sqlType := "bigint"
|
||||
switch {
|
||||
case field.Size <= 8:
|
||||
sqlType = "tinyint"
|
||||
case field.Size <= 16:
|
||||
sqlType = "smallint"
|
||||
case field.Size <= 24:
|
||||
sqlType = "mediumint"
|
||||
case field.Size <= 32:
|
||||
sqlType = "int"
|
||||
}
|
||||
|
||||
if field.DataType == schema.Uint {
|
||||
sqlType += " unsigned"
|
||||
}
|
||||
|
||||
if field.AutoIncrement {
|
||||
sqlType += " AUTO_INCREMENT"
|
||||
}
|
||||
|
||||
return sqlType
|
||||
}
|
||||
|
||||
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
return tx.Exec("SAVEPOINT " + name).Error
|
||||
}
|
||||
|
||||
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
TODO*
|
||||
documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
vendor
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
linters:
|
||||
enable:
|
||||
- cyclop
|
||||
- exportloopref
|
||||
- gocritic
|
||||
- gosec
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
- unconvert
|
||||
- unparam
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue